Add Memory Gateway agent plugin

This commit is contained in:
2026-05-06 16:10:04 +08:00
parent e65731a273
commit c44af407d4
48 changed files with 3111 additions and 0 deletions

View File

@ -0,0 +1,74 @@
from __future__ import annotations
import logging
from typing import Any, Callable
from . import lifecycle
from .tools import memory_append_episode, memory_commit_session, memory_feedback, memory_search, memory_upsert
_logger = logging.getLogger(__name__)
__all__ = [
"register",
"memory_search",
"memory_append_episode",
"memory_commit_session",
"memory_upsert",
"memory_feedback",
]
TOOLS: dict[str, Callable[..., dict[str, Any]]] = {
"memory_search": memory_search,
"memory_append_episode": memory_append_episode,
"memory_commit_session": memory_commit_session,
"memory_upsert": memory_upsert,
"memory_feedback": memory_feedback,
}
def _try_call(target: Any, method_names: list[str], *args: Any, **kwargs: Any) -> bool:
for name in method_names:
method = getattr(target, name, None)
if callable(method):
try:
method(*args, **kwargs)
return True
except TypeError:
try:
method(args[0], args[1])
return True
except Exception as exc:
_logger.debug("[_try_call] %s(%s, %s) failed: %s", name, args, kwargs, exc)
return False
except Exception as exc:
_logger.debug("[_try_call] %s(%s, %s) failed: %s", name, args, kwargs, exc)
return False
return False
def register(ctx: Any) -> dict[str, Any]:
registered_tools: list[str] = []
registered_hooks: list[str] = []
for name, func in TOOLS.items():
if _try_call(ctx, ["register_tool", "add_tool", "tool"], name, func):
registered_tools.append(name)
hook_map = {
"pre_llm_call": lifecycle.on_conversation_start,
"post_llm_call": lifecycle.after_user_message,
"session_end": lifecycle.on_session_end,
"after_task_complete": lifecycle.after_task_complete,
}
for name, func in hook_map.items():
if _try_call(ctx, ["register_hook", "add_hook", "hook"], name, func):
registered_hooks.append(name)
return {
"ok": True,
"mode": "tools-and-hooks" if registered_hooks else "tools-only" if registered_tools else "manual",
"registered_tools": registered_tools,
"registered_hooks": registered_hooks,
}

View File

@ -0,0 +1,101 @@
from __future__ import annotations
import json
import logging
import time
import urllib.error
import urllib.request
from typing import Any
from .config import PluginConfig, load_config
_logger = logging.getLogger(__name__)
def _short_error(value: Any, max_chars: int = 500) -> str:
text = str(value).replace("\n", " ").strip()
return text[:max_chars]
class MemoryGatewayClient:
def __init__(self, config: PluginConfig | None = None) -> None:
self.config = config or load_config()
def _headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.config.api_key:
headers["X-API-Key"] = self.config.api_key
return headers
def _post(self, endpoint: str, payload: dict[str, Any], retries: int = 3, backoff: float = 1.0) -> dict[str, Any]:
url = self.config.gateway_url.rstrip("/") + endpoint
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
last_error: Exception | None = None
for attempt in range(retries):
request = urllib.request.Request(url, data=body, headers=self._headers(), method="POST")
try:
with urllib.request.urlopen(request, timeout=self.config.timeout) as response:
raw = response.read().decode("utf-8")
data = json.loads(raw) if raw else {}
return {
"ok": True,
"status_code": getattr(response, "status", 200),
"endpoint": endpoint,
"data": data,
}
except urllib.error.HTTPError as exc:
# Typically, client errors (4xx) shouldn't be retried unless specifically handled.
# Since HTTPError is a subclass of URLError, we catch it first.
if exc.code < 500 and exc.code != 429:
try:
body_text = exc.read().decode("utf-8")
except Exception:
body_text = exc.reason
_logger.error(f"HTTPError in _post to {endpoint}: {exc.code} {body_text}")
return {
"ok": False,
"status_code": exc.code,
"endpoint": endpoint,
"error": _short_error(body_text),
}
last_error = exc
except (urllib.error.URLError, TimeoutError, OSError) as exc:
last_error = exc
except Exception as exc:
_logger.error("Unexpected error in _post to %s: %s", endpoint, exc, exc_info=True)
return {
"ok": False,
"status_code": None,
"endpoint": endpoint,
"error": _short_error(exc),
}
if attempt < retries - 1:
time.sleep(backoff * (2 ** attempt))
# Exhausted retries
error_msg = str(last_error) if last_error else "Max retries exceeded"
_logger.error("Failed _post to %s after %d attempts. Last error: %s", endpoint, retries, last_error)
return {
"ok": False,
"status_code": None,
"endpoint": endpoint,
"error": error_msg,
}
def search_memory(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post("/v1/memory/search", payload)
def append_episode(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post("/v1/episodes", payload)
def commit_session(self, session_id: str, payload: dict[str, Any]) -> dict[str, Any]:
return self._post(f"/v1/sessions/{session_id}/commit", payload)
def upsert_memory(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post("/v1/memory", payload)
def send_feedback(self, memory_id: str, payload: dict[str, Any]) -> dict[str, Any]:
return self._post(f"/v1/memory/{memory_id}/feedback", payload)

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import os
from dataclasses import dataclass
def _env_bool(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
@dataclass(frozen=True)
class PluginConfig:
gateway_url: str = "http://127.0.0.1:1934"
api_key: str = ""
default_user_id: str = ""
default_agent_id: str = ""
default_workspace_id: str = ""
auto_search: bool = True
auto_append_episode: bool = True
auto_commit_session: bool = False
review_mode: bool = True
timeout: int = 30
@classmethod
def from_env(cls) -> "PluginConfig":
try:
timeout_val = int(os.environ.get("MEMORY_GATEWAY_TIMEOUT", "30"))
except ValueError:
timeout_val = 30
return cls(
gateway_url=os.environ.get("MEMORY_GATEWAY_URL", cls.gateway_url).rstrip("/"),
api_key=os.environ.get("MEMORY_GATEWAY_API_KEY", ""),
default_user_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_USER_ID", ""),
default_agent_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_AGENT_ID", ""),
default_workspace_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_WORKSPACE_ID", ""),
auto_search=_env_bool("MEMORY_GATEWAY_AUTO_SEARCH", True),
auto_append_episode=_env_bool("MEMORY_GATEWAY_AUTO_APPEND_EPISODE", True),
auto_commit_session=_env_bool("MEMORY_GATEWAY_AUTO_COMMIT_SESSION", False),
review_mode=_env_bool("MEMORY_GATEWAY_REVIEW_MODE", True),
timeout=timeout_val,
)
def load_config() -> PluginConfig:
return PluginConfig.from_env()

View File

@ -0,0 +1,109 @@
from __future__ import annotations
from typing import Any
from .client import MemoryGatewayClient
from .config import PluginConfig, load_config
from .policy import build_episode_summary, should_append_episode, should_commit_session, should_search_memory
from .tools import memory_append_episode, memory_commit_session, memory_search
def _get(context: dict[str, Any], key: str, default: str = "") -> str:
value = context.get(key, default)
return "" if value is None else str(value)
def compact_memory_context(search_result: dict[str, Any], limit: int = 5) -> str:
if not search_result.get("ok"):
return ""
data = search_result.get("data", {})
rows = []
for item in data.get("results", [])[:limit]:
memory = item.get("memory") or item.get("openviking") or {}
summary = memory.get("summary") or memory.get("abstract") or memory.get("content") or ""
namespace = memory.get("namespace", "")
memory_id = memory.get("id") or memory.get("uri") or ""
if summary:
rows.append(f"- {memory_id} [{namespace}]: {summary[:240]}")
return "\n".join(rows)
def on_conversation_start(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
cfg = config or load_config()
user_message = _get(context, "user_message") or _get(context, "query")
if not should_search_memory(user_message, context, cfg):
return {"ok": True, "memory_context": ""}
user_id = _get(context, "user_id", cfg.default_user_id)
if not user_id:
return {"ok": False, "error": "user_id_required"}
try:
limit_val = int(context.get("limit", 5))
except (ValueError, TypeError):
limit_val = 5
result = memory_search(
query=user_message,
user_id=user_id,
agent_id=_get(context, "agent_id", cfg.default_agent_id),
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
session_id=_get(context, "session_id"),
limit=limit_val,
client=client,
)
return {"ok": result.get("ok", False), "memory_context": compact_memory_context(result), "raw": result}
def after_user_message(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
cfg = config or load_config()
user_message = _get(context, "user_message")
assistant_response = _get(context, "assistant_response")
if not should_append_episode(user_message, assistant_response, context, cfg):
return {"ok": True, "appended": False, "reason": "policy_skip"}
user_id = _get(context, "user_id", cfg.default_user_id)
session_id = _get(context, "session_id")
if not user_id or not session_id:
return {"ok": False, "error": "user_id_and_session_id_required"}
summary = build_episode_summary(user_message, assistant_response, context)
result = memory_append_episode(
user_id=user_id,
agent_id=_get(context, "agent_id", cfg.default_agent_id),
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
session_id=session_id,
episode_summary=summary,
tags=["plugin-candidate"],
client=client,
)
return {"ok": result.get("ok", False), "appended": result.get("ok", False), "raw": result}
def after_task_complete(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
return _maybe_commit(context, client, config)
def on_session_end(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
return _maybe_commit(context, client, config)
def _maybe_commit(context: dict[str, Any], client: MemoryGatewayClient | None, config: PluginConfig | None) -> dict[str, Any]:
cfg = config or load_config()
if not should_commit_session(context, cfg):
return {"ok": True, "committed": False, "reason": "auto_commit_disabled"}
user_id = _get(context, "user_id", cfg.default_user_id)
session_id = _get(context, "session_id")
if not user_id or not session_id:
return {"ok": False, "error": "user_id_and_session_id_required"}
try:
min_importance_val = float(context.get("min_importance", 0.6))
except (ValueError, TypeError):
min_importance_val = 0.6
result = memory_commit_session(
user_id=user_id,
agent_id=_get(context, "agent_id", cfg.default_agent_id),
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
session_id=session_id,
min_importance=min_importance_val,
client=client,
)
return {"ok": result.get("ok", False), "committed": result.get("ok", False), "raw": result}

View File

@ -0,0 +1,86 @@
from __future__ import annotations
import json
import os
from typing import Any
SENSITIVE_KEYS = ("api_key", "apikey", "authorization", "token", "cookie", "secret", "password", "x-api-key")
def debug_raw_enabled() -> bool:
return os.environ.get("MEMORY_GATEWAY_PLUGIN_DEBUG_RAW", "").strip().lower() in {"1", "true", "yes", "on"}
def short_id(value: Any, prefix: int = 8, suffix: int = 4) -> str:
text = "" if value is None else str(value)
if len(text) <= prefix + suffix + 3:
return text
return f"{text[:prefix]}...{text[-suffix:]}"
import re
def redact(value: Any) -> Any:
if isinstance(value, dict):
return {
key: ("<redacted>" if key.lower() in SENSITIVE_KEYS else redact(item))
for key, item in value.items()
}
if isinstance(value, list):
return [redact(item) for item in value]
if isinstance(value, str):
lowered = value.lower()
sensitive_markers = ("api_key=", "password=", "token=", "bearer ", "cookie:", "private key")
if any(marker in lowered for marker in sensitive_markers):
return "<redacted>"
return value
def summarize_data(data: Any) -> Any:
if debug_raw_enabled():
return redact(data)
if isinstance(data, list):
return {"count": len(data)}
if not isinstance(data, dict):
return data
if "results" in data:
return {
"count": len(data.get("results") or []),
"total": data.get("total"),
"local_total": data.get("local_total"),
"openviking_total": data.get("openviking_total"),
"searched_namespaces": data.get("searched_namespaces", []),
}
if "id" in data:
return {
"id": short_id(data.get("id")),
"namespace": data.get("namespace"),
"memory_type": data.get("memory_type"),
"source": data.get("source"),
}
if "memory_id" in data:
return {"status": data.get("status"), "memory_id": short_id(data.get("memory_id")), "feedback": data.get("feedback")}
if "promoted" in data or "consolidation" in data:
return {
"status": data.get("status"),
"promoted_count": len(data.get("promoted") or []),
"archived_count": len(data.get("archived_episode_ids") or []),
"consolidation_status": (data.get("consolidation") or {}).get("status") if isinstance(data.get("consolidation"), dict) else None,
}
allowed = {"ok", "status", "gateway", "service", "version", "healthy", "endpoint", "status_code", "error", "count"}
return {key: redact(value) for key, value in data.items() if key in allowed}
def summarize_result(result: dict[str, Any]) -> dict[str, Any]:
return {
"ok": bool(result.get("ok")),
"endpoint": result.get("endpoint"),
"status_code": result.get("status_code"),
"error": redact(result.get("error", "")),
"data": summarize_data(result.get("data")),
}
def dumps_safe(payload: Any, *, indent: int = 2) -> str:
return json.dumps(redact(payload), ensure_ascii=False, indent=indent, default=str)

View File

@ -0,0 +1,59 @@
from __future__ import annotations
import re
from typing import Any
from .config import PluginConfig, load_config
from .safety import validate_memory_write
REMEMBER_RE = re.compile(r"记住|请保存|remember this|save this|keep in memory", re.I)
STABLE_SIGNAL_RE = re.compile(
r"偏好|长期|约束|架构决策|决策|结论|workflow|工作流|preference|constraint|decision|always|以后都|project fact",
re.I,
)
SMALL_TALK_RE = re.compile(r"^\s*(你好|hi|hello|谢谢|thanks|ok|好的|收到|再见)[。.!\s\w]*$", re.I)
def should_search_memory(user_message: str, context: dict[str, Any] | None = None, config: PluginConfig | None = None) -> bool:
cfg = config or load_config()
if not cfg.auto_search:
return False
return bool(user_message and user_message.strip())
def should_append_episode(
user_message: str,
assistant_response: str = "",
context: dict[str, Any] | None = None,
config: PluginConfig | None = None,
) -> bool:
cfg = config or load_config()
if not cfg.auto_append_episode:
return False
combined = "\n".join(part for part in [user_message, assistant_response] if part)
if not combined.strip() or SMALL_TALK_RE.match(combined.strip()):
return False
if not validate_memory_write(combined)["allowed"]:
return False
return bool(REMEMBER_RE.search(combined) or STABLE_SIGNAL_RE.search(combined))
def build_episode_summary(user_message: str, assistant_response: str = "", context: dict[str, Any] | None = None) -> str:
parts = []
if REMEMBER_RE.search(user_message or ""):
parts.append(f"用户明确要求记住:{user_message.strip()}")
elif user_message:
parts.append(f"用户输入中的可复用信息:{user_message.strip()}")
if assistant_response and STABLE_SIGNAL_RE.search(assistant_response):
parts.append(f"助手结论:{assistant_response.strip()}")
summary = " ".join(parts).strip()
return summary[:1000]
def should_commit_session(context: dict[str, Any] | None = None, config: PluginConfig | None = None) -> bool:
cfg = config or load_config()
if cfg.auto_commit_session:
return True
return bool((context or {}).get("force_commit"))

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import re
from typing import Any
SECRET_PATTERNS = [
r"\bpassword\s*[:=]",
r"\bapi[_-]?key\s*[:=]",
r"\btoken\s*[:=]",
r"\bsecret\s*[:=]",
r"\bbearer\s+[a-z0-9._\-]{12,}",
r"\bcookie\s*[:=]",
r"\bsession[_ -]?id\s*[:=]",
r"-----BEGIN [A-Z ]*PRIVATE KEY-----",
r"\bssh-rsa\s+[a-z0-9+/=]{40,}",
r"\bone[- ]?time (?:password|code)\b",
r"\botp\s*[:=]?\s*\d{4,8}\b",
r"\b验证码\s*[:]?\s*\d{4,8}\b",
]
CHAT_LINE_RE = re.compile(r"^\s*(user|assistant|system|用户|助手|模型|human|ai)\s*[:]", re.I)
LOG_LINE_RE = re.compile(r"\b(ERROR|WARN|INFO|DEBUG|TRACE)\b|^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}")
CHAIN_OF_THOUGHT_RE = re.compile(r"chain[- ]of[- ]thought|逐步推理|隐藏推理|internal reasoning", re.I)
def detect_secret(content: str) -> tuple[bool, str]:
for pattern in SECRET_PATTERNS:
if re.search(pattern, content, re.I):
return True, "secret_like_content"
return False, ""
def detect_raw_transcript(content: str) -> tuple[bool, str]:
lines = [line for line in content.splitlines() if line.strip()]
chat_lines = sum(1 for line in lines if CHAT_LINE_RE.search(line))
if chat_lines >= 4:
return True, "raw_chat_transcript"
if "完整原始对话" in content or "full transcript" in content.lower():
return True, "raw_chat_transcript"
return False, ""
def detect_large_log(content: str) -> tuple[bool, str]:
lines = [line for line in content.splitlines() if line.strip()]
log_lines = sum(1 for line in lines if LOG_LINE_RE.search(line))
if len(content) > 4000 or len(lines) > 40 or log_lines >= 8:
return True, "large_or_raw_log"
return False, ""
def detect_low_value_memory(content: str) -> tuple[bool, str]:
normalized = re.sub(r"\s+", " ", content).strip().lower()
stable_signal = re.search(r"记住|偏好|长期|决策|结论|约束|preference|remember|decision|constraint", normalized, re.I)
if stable_signal:
return False, ""
if len(normalized) < 12:
return True, "too_short"
small_talk = {
"hi",
"hello",
"thanks",
"thank you",
"ok",
"好的",
"谢谢",
"你好",
"收到",
"再见",
}
if normalized in small_talk:
return True, "small_talk"
return False, ""
def sanitize_memory_content(content: str) -> str:
sanitized = content.strip()
sanitized = re.sub(r"\b(password|api[_-]?key|token|secret)\s*[:=]\s*\S+", r"\1=<redacted>", sanitized, flags=re.I)
sanitized = re.sub(r"\bbearer\s+[a-z0-9._\-]{12,}", "Bearer <redacted>", sanitized, flags=re.I)
sanitized = re.sub(r"-----BEGIN [A-Z ]*PRIVATE KEY-----.*?-----END [A-Z ]*PRIVATE KEY-----", "<redacted-private-key>", sanitized, flags=re.I | re.S)
return sanitized
def validate_memory_write(content: str, *, allow_low_value: bool = False) -> dict[str, Any]:
if not content or not content.strip():
return {"allowed": False, "reason": "empty_content", "sanitized_content": ""}
checks = [detect_secret, detect_raw_transcript, detect_large_log]
for check in checks:
blocked, reason = check(content)
if blocked:
return {"allowed": False, "reason": reason, "sanitized_content": ""}
if CHAIN_OF_THOUGHT_RE.search(content):
return {"allowed": False, "reason": "chain_of_thought", "sanitized_content": ""}
low_value, reason = detect_low_value_memory(content)
if low_value and not allow_low_value:
return {"allowed": False, "reason": reason, "sanitized_content": ""}
return {"allowed": True, "reason": "ok", "sanitized_content": sanitize_memory_content(content)}

View File

@ -0,0 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
@dataclass
class AgentContext:
user_id: str
agent_id: str = ""
workspace_id: str = ""
session_id: str = ""
metadata: dict[str, Any] = field(default_factory=dict)

View File

@ -0,0 +1,163 @@
from __future__ import annotations
from typing import Any
from .client import MemoryGatewayClient
from .config import PluginConfig, load_config
from .safety import validate_memory_write
FEEDBACK_MAP = {
"confirm": "useful",
"correct": "useful",
"useful": "useful",
"delete": "incorrect",
"reject": "incorrect",
"incorrect": "incorrect",
"duplicate": "duplicate",
"outdated": "outdated",
"not_useful": "not_useful",
}
def _client(client: MemoryGatewayClient | None = None) -> MemoryGatewayClient:
return client or MemoryGatewayClient()
def _context_payload(user_id: str, agent_id: str = "", workspace_id: str = "", session_id: str = "") -> dict[str, Any]:
payload: dict[str, Any] = {"user_id": user_id}
if agent_id:
payload["agent_id"] = agent_id
if workspace_id:
payload["workspace_id"] = workspace_id
if session_id:
payload["session_id"] = session_id
return payload
def memory_search(
query: str,
user_id: str,
agent_id: str = "",
workspace_id: str = "",
session_id: str = "",
namespaces: list[str] | None = None,
memory_types: list[str] | None = None,
tags: list[str] | None = None,
limit: int = 5,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
if not query or not query.strip():
return {"ok": False, "error": "query_required"}
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update(
{
"query": query.strip(),
"namespaces": namespaces or [],
"memory_types": memory_types or [],
"tags": tags or [],
"limit": limit,
}
)
return _client(client).search_memory(payload)
def memory_append_episode(
user_id: str,
agent_id: str,
session_id: str,
content: str = "",
episode_summary: str = "",
workspace_id: str = "",
source: str = "conversation",
tags: list[str] | None = None,
importance: float | None = None,
confidence: float | None = None,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
candidate = (episode_summary or content or "").strip()
validation = validate_memory_write(candidate)
if not validation["allowed"]:
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"content": validation["sanitized_content"], "tags": tags or [], "source": source})
if importance is not None:
payload["events"] = [{"type": "importance_hint", "value": importance}]
if confidence is not None:
payload.setdefault("events", []).append({"type": "confidence_hint", "value": confidence})
return _client(client).append_episode(payload)
def memory_commit_session(
user_id: str,
agent_id: str,
session_id: str,
workspace_id: str = "",
promote: bool = True,
min_importance: float = 0.6,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"promote": promote, "min_importance": min_importance})
return _client(client).commit_session(session_id, payload)
def memory_upsert(
user_id: str,
agent_id: str,
content: str,
workspace_id: str = "",
namespace: str = "",
memory_type: str = "fact",
summary: str = "",
tags: list[str] | None = None,
importance: float = 0.5,
confidence: float = 0.8,
visibility: str = "private",
source: str = "agent",
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
validation = validate_memory_write(content)
if not validation["allowed"]:
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
payload = _context_payload(user_id, agent_id, workspace_id)
payload.update(
{
"namespace": namespace or None,
"memory_type": memory_type,
"content": validation["sanitized_content"],
"summary": summary or None,
"tags": tags or [],
"importance": importance,
"confidence": confidence,
"visibility": visibility,
"source": source,
}
)
return _client(client).upsert_memory(payload)
def memory_feedback(
user_id: str,
agent_id: str,
memory_id: str,
feedback: str,
workspace_id: str = "",
session_id: str = "",
comment: str = "",
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
mapped_feedback = FEEDBACK_MAP.get(feedback, feedback)
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"feedback": mapped_feedback, "comment": comment or None})
return _client(client).send_feedback(memory_id, payload)
def default_context(config: PluginConfig | None = None) -> dict[str, str]:
cfg = config or load_config()
return {
"user_id": cfg.default_user_id,
"agent_id": cfg.default_agent_id,
"workspace_id": cfg.default_workspace_id,
}

View File

@ -0,0 +1,45 @@
from __future__ import annotations
import json
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from .output import redact, short_id
def trace_enabled() -> bool:
return os.environ.get("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "").strip().lower() in {"1", "true", "yes", "on"}
def trace_path() -> Path:
return Path(__file__).resolve().parents[1] / ".tmp" / "hook_trace.log"
def trace_hook(
hook_name: str,
*,
session_id: str = "",
gateway_action: str = "",
gateway_called: bool = False,
ok: bool | None = None,
audit_delta: int | None = None,
reason: str = "",
) -> None:
if not trace_enabled():
return
path = trace_path()
path.parent.mkdir(parents=True, exist_ok=True)
payload: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"hook": hook_name,
"session_id": short_id(session_id),
"gateway_action": gateway_action,
"gateway_called": gateway_called,
"ok": ok,
"audit_delta": audit_delta,
"reason": reason[:160] if reason else "",
}
with path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(redact(payload), ensure_ascii=False, default=str) + "\n")