Add Memory Gateway agent plugin
This commit is contained in:
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
from . import lifecycle
|
||||
from .tools import memory_append_episode, memory_commit_session, memory_feedback, memory_search, memory_upsert
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"register",
|
||||
"memory_search",
|
||||
"memory_append_episode",
|
||||
"memory_commit_session",
|
||||
"memory_upsert",
|
||||
"memory_feedback",
|
||||
]
|
||||
|
||||
|
||||
TOOLS: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"memory_search": memory_search,
|
||||
"memory_append_episode": memory_append_episode,
|
||||
"memory_commit_session": memory_commit_session,
|
||||
"memory_upsert": memory_upsert,
|
||||
"memory_feedback": memory_feedback,
|
||||
}
|
||||
|
||||
|
||||
def _try_call(target: Any, method_names: list[str], *args: Any, **kwargs: Any) -> bool:
|
||||
for name in method_names:
|
||||
method = getattr(target, name, None)
|
||||
if callable(method):
|
||||
try:
|
||||
method(*args, **kwargs)
|
||||
return True
|
||||
except TypeError:
|
||||
try:
|
||||
method(args[0], args[1])
|
||||
return True
|
||||
except Exception as exc:
|
||||
_logger.debug("[_try_call] %s(%s, %s) failed: %s", name, args, kwargs, exc)
|
||||
return False
|
||||
except Exception as exc:
|
||||
_logger.debug("[_try_call] %s(%s, %s) failed: %s", name, args, kwargs, exc)
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def register(ctx: Any) -> dict[str, Any]:
|
||||
registered_tools: list[str] = []
|
||||
registered_hooks: list[str] = []
|
||||
|
||||
for name, func in TOOLS.items():
|
||||
if _try_call(ctx, ["register_tool", "add_tool", "tool"], name, func):
|
||||
registered_tools.append(name)
|
||||
|
||||
hook_map = {
|
||||
"pre_llm_call": lifecycle.on_conversation_start,
|
||||
"post_llm_call": lifecycle.after_user_message,
|
||||
"session_end": lifecycle.on_session_end,
|
||||
"after_task_complete": lifecycle.after_task_complete,
|
||||
}
|
||||
for name, func in hook_map.items():
|
||||
if _try_call(ctx, ["register_hook", "add_hook", "hook"], name, func):
|
||||
registered_hooks.append(name)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"mode": "tools-and-hooks" if registered_hooks else "tools-only" if registered_tools else "manual",
|
||||
"registered_tools": registered_tools,
|
||||
"registered_hooks": registered_hooks,
|
||||
}
|
||||
|
||||
101
plugins/memory-gateway-agent/memory_gateway_plugin/client.py
Normal file
101
plugins/memory-gateway-agent/memory_gateway_plugin/client.py
Normal file
@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
from .config import PluginConfig, load_config
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _short_error(value: Any, max_chars: int = 500) -> str:
|
||||
text = str(value).replace("\n", " ").strip()
|
||||
return text[:max_chars]
|
||||
|
||||
|
||||
class MemoryGatewayClient:
|
||||
def __init__(self, config: PluginConfig | None = None) -> None:
|
||||
self.config = config or load_config()
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.config.api_key:
|
||||
headers["X-API-Key"] = self.config.api_key
|
||||
return headers
|
||||
|
||||
def _post(self, endpoint: str, payload: dict[str, Any], retries: int = 3, backoff: float = 1.0) -> dict[str, Any]:
|
||||
url = self.config.gateway_url.rstrip("/") + endpoint
|
||||
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(retries):
|
||||
request = urllib.request.Request(url, data=body, headers=self._headers(), method="POST")
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=self.config.timeout) as response:
|
||||
raw = response.read().decode("utf-8")
|
||||
data = json.loads(raw) if raw else {}
|
||||
return {
|
||||
"ok": True,
|
||||
"status_code": getattr(response, "status", 200),
|
||||
"endpoint": endpoint,
|
||||
"data": data,
|
||||
}
|
||||
except urllib.error.HTTPError as exc:
|
||||
# Typically, client errors (4xx) shouldn't be retried unless specifically handled.
|
||||
# Since HTTPError is a subclass of URLError, we catch it first.
|
||||
if exc.code < 500 and exc.code != 429:
|
||||
try:
|
||||
body_text = exc.read().decode("utf-8")
|
||||
except Exception:
|
||||
body_text = exc.reason
|
||||
_logger.error(f"HTTPError in _post to {endpoint}: {exc.code} {body_text}")
|
||||
return {
|
||||
"ok": False,
|
||||
"status_code": exc.code,
|
||||
"endpoint": endpoint,
|
||||
"error": _short_error(body_text),
|
||||
}
|
||||
last_error = exc
|
||||
except (urllib.error.URLError, TimeoutError, OSError) as exc:
|
||||
last_error = exc
|
||||
except Exception as exc:
|
||||
_logger.error("Unexpected error in _post to %s: %s", endpoint, exc, exc_info=True)
|
||||
return {
|
||||
"ok": False,
|
||||
"status_code": None,
|
||||
"endpoint": endpoint,
|
||||
"error": _short_error(exc),
|
||||
}
|
||||
|
||||
if attempt < retries - 1:
|
||||
time.sleep(backoff * (2 ** attempt))
|
||||
|
||||
# Exhausted retries
|
||||
error_msg = str(last_error) if last_error else "Max retries exceeded"
|
||||
_logger.error("Failed _post to %s after %d attempts. Last error: %s", endpoint, retries, last_error)
|
||||
return {
|
||||
"ok": False,
|
||||
"status_code": None,
|
||||
"endpoint": endpoint,
|
||||
"error": error_msg,
|
||||
}
|
||||
|
||||
def search_memory(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self._post("/v1/memory/search", payload)
|
||||
|
||||
def append_episode(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self._post("/v1/episodes", payload)
|
||||
|
||||
def commit_session(self, session_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self._post(f"/v1/sessions/{session_id}/commit", payload)
|
||||
|
||||
def upsert_memory(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self._post("/v1/memory", payload)
|
||||
|
||||
def send_feedback(self, memory_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self._post(f"/v1/memory/{memory_id}/feedback", payload)
|
||||
|
||||
50
plugins/memory-gateway-agent/memory_gateway_plugin/config.py
Normal file
50
plugins/memory-gateway-agent/memory_gateway_plugin/config.py
Normal file
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PluginConfig:
|
||||
gateway_url: str = "http://127.0.0.1:1934"
|
||||
api_key: str = ""
|
||||
default_user_id: str = ""
|
||||
default_agent_id: str = ""
|
||||
default_workspace_id: str = ""
|
||||
auto_search: bool = True
|
||||
auto_append_episode: bool = True
|
||||
auto_commit_session: bool = False
|
||||
review_mode: bool = True
|
||||
timeout: int = 30
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "PluginConfig":
|
||||
try:
|
||||
timeout_val = int(os.environ.get("MEMORY_GATEWAY_TIMEOUT", "30"))
|
||||
except ValueError:
|
||||
timeout_val = 30
|
||||
|
||||
return cls(
|
||||
gateway_url=os.environ.get("MEMORY_GATEWAY_URL", cls.gateway_url).rstrip("/"),
|
||||
api_key=os.environ.get("MEMORY_GATEWAY_API_KEY", ""),
|
||||
default_user_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_USER_ID", ""),
|
||||
default_agent_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_AGENT_ID", ""),
|
||||
default_workspace_id=os.environ.get("MEMORY_GATEWAY_DEFAULT_WORKSPACE_ID", ""),
|
||||
auto_search=_env_bool("MEMORY_GATEWAY_AUTO_SEARCH", True),
|
||||
auto_append_episode=_env_bool("MEMORY_GATEWAY_AUTO_APPEND_EPISODE", True),
|
||||
auto_commit_session=_env_bool("MEMORY_GATEWAY_AUTO_COMMIT_SESSION", False),
|
||||
review_mode=_env_bool("MEMORY_GATEWAY_REVIEW_MODE", True),
|
||||
timeout=timeout_val,
|
||||
)
|
||||
|
||||
|
||||
def load_config() -> PluginConfig:
|
||||
return PluginConfig.from_env()
|
||||
|
||||
109
plugins/memory-gateway-agent/memory_gateway_plugin/lifecycle.py
Normal file
109
plugins/memory-gateway-agent/memory_gateway_plugin/lifecycle.py
Normal file
@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .client import MemoryGatewayClient
|
||||
from .config import PluginConfig, load_config
|
||||
from .policy import build_episode_summary, should_append_episode, should_commit_session, should_search_memory
|
||||
from .tools import memory_append_episode, memory_commit_session, memory_search
|
||||
|
||||
|
||||
def _get(context: dict[str, Any], key: str, default: str = "") -> str:
|
||||
value = context.get(key, default)
|
||||
return "" if value is None else str(value)
|
||||
|
||||
|
||||
def compact_memory_context(search_result: dict[str, Any], limit: int = 5) -> str:
|
||||
if not search_result.get("ok"):
|
||||
return ""
|
||||
data = search_result.get("data", {})
|
||||
rows = []
|
||||
for item in data.get("results", [])[:limit]:
|
||||
memory = item.get("memory") or item.get("openviking") or {}
|
||||
summary = memory.get("summary") or memory.get("abstract") or memory.get("content") or ""
|
||||
namespace = memory.get("namespace", "")
|
||||
memory_id = memory.get("id") or memory.get("uri") or ""
|
||||
if summary:
|
||||
rows.append(f"- {memory_id} [{namespace}]: {summary[:240]}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def on_conversation_start(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
|
||||
cfg = config or load_config()
|
||||
user_message = _get(context, "user_message") or _get(context, "query")
|
||||
if not should_search_memory(user_message, context, cfg):
|
||||
return {"ok": True, "memory_context": ""}
|
||||
user_id = _get(context, "user_id", cfg.default_user_id)
|
||||
if not user_id:
|
||||
return {"ok": False, "error": "user_id_required"}
|
||||
try:
|
||||
limit_val = int(context.get("limit", 5))
|
||||
except (ValueError, TypeError):
|
||||
limit_val = 5
|
||||
|
||||
result = memory_search(
|
||||
query=user_message,
|
||||
user_id=user_id,
|
||||
agent_id=_get(context, "agent_id", cfg.default_agent_id),
|
||||
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
|
||||
session_id=_get(context, "session_id"),
|
||||
limit=limit_val,
|
||||
client=client,
|
||||
)
|
||||
return {"ok": result.get("ok", False), "memory_context": compact_memory_context(result), "raw": result}
|
||||
|
||||
|
||||
def after_user_message(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
|
||||
cfg = config or load_config()
|
||||
user_message = _get(context, "user_message")
|
||||
assistant_response = _get(context, "assistant_response")
|
||||
if not should_append_episode(user_message, assistant_response, context, cfg):
|
||||
return {"ok": True, "appended": False, "reason": "policy_skip"}
|
||||
user_id = _get(context, "user_id", cfg.default_user_id)
|
||||
session_id = _get(context, "session_id")
|
||||
if not user_id or not session_id:
|
||||
return {"ok": False, "error": "user_id_and_session_id_required"}
|
||||
summary = build_episode_summary(user_message, assistant_response, context)
|
||||
result = memory_append_episode(
|
||||
user_id=user_id,
|
||||
agent_id=_get(context, "agent_id", cfg.default_agent_id),
|
||||
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
|
||||
session_id=session_id,
|
||||
episode_summary=summary,
|
||||
tags=["plugin-candidate"],
|
||||
client=client,
|
||||
)
|
||||
return {"ok": result.get("ok", False), "appended": result.get("ok", False), "raw": result}
|
||||
|
||||
|
||||
def after_task_complete(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
|
||||
return _maybe_commit(context, client, config)
|
||||
|
||||
|
||||
def on_session_end(context: dict[str, Any], client: MemoryGatewayClient | None = None, config: PluginConfig | None = None) -> dict[str, Any]:
|
||||
return _maybe_commit(context, client, config)
|
||||
|
||||
|
||||
def _maybe_commit(context: dict[str, Any], client: MemoryGatewayClient | None, config: PluginConfig | None) -> dict[str, Any]:
|
||||
cfg = config or load_config()
|
||||
if not should_commit_session(context, cfg):
|
||||
return {"ok": True, "committed": False, "reason": "auto_commit_disabled"}
|
||||
user_id = _get(context, "user_id", cfg.default_user_id)
|
||||
session_id = _get(context, "session_id")
|
||||
if not user_id or not session_id:
|
||||
return {"ok": False, "error": "user_id_and_session_id_required"}
|
||||
try:
|
||||
min_importance_val = float(context.get("min_importance", 0.6))
|
||||
except (ValueError, TypeError):
|
||||
min_importance_val = 0.6
|
||||
|
||||
result = memory_commit_session(
|
||||
user_id=user_id,
|
||||
agent_id=_get(context, "agent_id", cfg.default_agent_id),
|
||||
workspace_id=_get(context, "workspace_id", cfg.default_workspace_id),
|
||||
session_id=session_id,
|
||||
min_importance=min_importance_val,
|
||||
client=client,
|
||||
)
|
||||
return {"ok": result.get("ok", False), "committed": result.get("ok", False), "raw": result}
|
||||
|
||||
86
plugins/memory-gateway-agent/memory_gateway_plugin/output.py
Normal file
86
plugins/memory-gateway-agent/memory_gateway_plugin/output.py
Normal file
@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
SENSITIVE_KEYS = ("api_key", "apikey", "authorization", "token", "cookie", "secret", "password", "x-api-key")
|
||||
|
||||
|
||||
def debug_raw_enabled() -> bool:
|
||||
return os.environ.get("MEMORY_GATEWAY_PLUGIN_DEBUG_RAW", "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def short_id(value: Any, prefix: int = 8, suffix: int = 4) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
if len(text) <= prefix + suffix + 3:
|
||||
return text
|
||||
return f"{text[:prefix]}...{text[-suffix:]}"
|
||||
|
||||
|
||||
import re
|
||||
|
||||
def redact(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: ("<redacted>" if key.lower() in SENSITIVE_KEYS else redact(item))
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [redact(item) for item in value]
|
||||
if isinstance(value, str):
|
||||
lowered = value.lower()
|
||||
sensitive_markers = ("api_key=", "password=", "token=", "bearer ", "cookie:", "private key")
|
||||
if any(marker in lowered for marker in sensitive_markers):
|
||||
return "<redacted>"
|
||||
return value
|
||||
|
||||
|
||||
def summarize_data(data: Any) -> Any:
|
||||
if debug_raw_enabled():
|
||||
return redact(data)
|
||||
if isinstance(data, list):
|
||||
return {"count": len(data)}
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if "results" in data:
|
||||
return {
|
||||
"count": len(data.get("results") or []),
|
||||
"total": data.get("total"),
|
||||
"local_total": data.get("local_total"),
|
||||
"openviking_total": data.get("openviking_total"),
|
||||
"searched_namespaces": data.get("searched_namespaces", []),
|
||||
}
|
||||
if "id" in data:
|
||||
return {
|
||||
"id": short_id(data.get("id")),
|
||||
"namespace": data.get("namespace"),
|
||||
"memory_type": data.get("memory_type"),
|
||||
"source": data.get("source"),
|
||||
}
|
||||
if "memory_id" in data:
|
||||
return {"status": data.get("status"), "memory_id": short_id(data.get("memory_id")), "feedback": data.get("feedback")}
|
||||
if "promoted" in data or "consolidation" in data:
|
||||
return {
|
||||
"status": data.get("status"),
|
||||
"promoted_count": len(data.get("promoted") or []),
|
||||
"archived_count": len(data.get("archived_episode_ids") or []),
|
||||
"consolidation_status": (data.get("consolidation") or {}).get("status") if isinstance(data.get("consolidation"), dict) else None,
|
||||
}
|
||||
allowed = {"ok", "status", "gateway", "service", "version", "healthy", "endpoint", "status_code", "error", "count"}
|
||||
return {key: redact(value) for key, value in data.items() if key in allowed}
|
||||
|
||||
|
||||
def summarize_result(result: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"ok": bool(result.get("ok")),
|
||||
"endpoint": result.get("endpoint"),
|
||||
"status_code": result.get("status_code"),
|
||||
"error": redact(result.get("error", "")),
|
||||
"data": summarize_data(result.get("data")),
|
||||
}
|
||||
|
||||
|
||||
def dumps_safe(payload: Any, *, indent: int = 2) -> str:
|
||||
return json.dumps(redact(payload), ensure_ascii=False, indent=indent, default=str)
|
||||
59
plugins/memory-gateway-agent/memory_gateway_plugin/policy.py
Normal file
59
plugins/memory-gateway-agent/memory_gateway_plugin/policy.py
Normal file
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .config import PluginConfig, load_config
|
||||
from .safety import validate_memory_write
|
||||
|
||||
|
||||
REMEMBER_RE = re.compile(r"记住|请保存|remember this|save this|keep in memory", re.I)
|
||||
STABLE_SIGNAL_RE = re.compile(
|
||||
r"偏好|长期|约束|架构决策|决策|结论|workflow|工作流|preference|constraint|decision|always|以后都|project fact",
|
||||
re.I,
|
||||
)
|
||||
SMALL_TALK_RE = re.compile(r"^\s*(你好|hi|hello|谢谢|thanks|ok|好的|收到|再见)[。.!!\s\w]*$", re.I)
|
||||
|
||||
|
||||
def should_search_memory(user_message: str, context: dict[str, Any] | None = None, config: PluginConfig | None = None) -> bool:
|
||||
cfg = config or load_config()
|
||||
if not cfg.auto_search:
|
||||
return False
|
||||
return bool(user_message and user_message.strip())
|
||||
|
||||
|
||||
def should_append_episode(
|
||||
user_message: str,
|
||||
assistant_response: str = "",
|
||||
context: dict[str, Any] | None = None,
|
||||
config: PluginConfig | None = None,
|
||||
) -> bool:
|
||||
cfg = config or load_config()
|
||||
if not cfg.auto_append_episode:
|
||||
return False
|
||||
combined = "\n".join(part for part in [user_message, assistant_response] if part)
|
||||
if not combined.strip() or SMALL_TALK_RE.match(combined.strip()):
|
||||
return False
|
||||
if not validate_memory_write(combined)["allowed"]:
|
||||
return False
|
||||
return bool(REMEMBER_RE.search(combined) or STABLE_SIGNAL_RE.search(combined))
|
||||
|
||||
|
||||
def build_episode_summary(user_message: str, assistant_response: str = "", context: dict[str, Any] | None = None) -> str:
|
||||
parts = []
|
||||
if REMEMBER_RE.search(user_message or ""):
|
||||
parts.append(f"用户明确要求记住:{user_message.strip()}")
|
||||
elif user_message:
|
||||
parts.append(f"用户输入中的可复用信息:{user_message.strip()}")
|
||||
if assistant_response and STABLE_SIGNAL_RE.search(assistant_response):
|
||||
parts.append(f"助手结论:{assistant_response.strip()}")
|
||||
summary = " ".join(parts).strip()
|
||||
return summary[:1000]
|
||||
|
||||
|
||||
def should_commit_session(context: dict[str, Any] | None = None, config: PluginConfig | None = None) -> bool:
|
||||
cfg = config or load_config()
|
||||
if cfg.auto_commit_session:
|
||||
return True
|
||||
return bool((context or {}).get("force_commit"))
|
||||
|
||||
97
plugins/memory-gateway-agent/memory_gateway_plugin/safety.py
Normal file
97
plugins/memory-gateway-agent/memory_gateway_plugin/safety.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
SECRET_PATTERNS = [
|
||||
r"\bpassword\s*[:=]",
|
||||
r"\bapi[_-]?key\s*[:=]",
|
||||
r"\btoken\s*[:=]",
|
||||
r"\bsecret\s*[:=]",
|
||||
r"\bbearer\s+[a-z0-9._\-]{12,}",
|
||||
r"\bcookie\s*[:=]",
|
||||
r"\bsession[_ -]?id\s*[:=]",
|
||||
r"-----BEGIN [A-Z ]*PRIVATE KEY-----",
|
||||
r"\bssh-rsa\s+[a-z0-9+/=]{40,}",
|
||||
r"\bone[- ]?time (?:password|code)\b",
|
||||
r"\botp\s*[:=]?\s*\d{4,8}\b",
|
||||
r"\b验证码\s*[::]?\s*\d{4,8}\b",
|
||||
]
|
||||
|
||||
CHAT_LINE_RE = re.compile(r"^\s*(user|assistant|system|用户|助手|模型|human|ai)\s*[::]", re.I)
|
||||
LOG_LINE_RE = re.compile(r"\b(ERROR|WARN|INFO|DEBUG|TRACE)\b|^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}")
|
||||
CHAIN_OF_THOUGHT_RE = re.compile(r"chain[- ]of[- ]thought|逐步推理|隐藏推理|internal reasoning", re.I)
|
||||
|
||||
|
||||
def detect_secret(content: str) -> tuple[bool, str]:
|
||||
for pattern in SECRET_PATTERNS:
|
||||
if re.search(pattern, content, re.I):
|
||||
return True, "secret_like_content"
|
||||
return False, ""
|
||||
|
||||
|
||||
def detect_raw_transcript(content: str) -> tuple[bool, str]:
|
||||
lines = [line for line in content.splitlines() if line.strip()]
|
||||
chat_lines = sum(1 for line in lines if CHAT_LINE_RE.search(line))
|
||||
if chat_lines >= 4:
|
||||
return True, "raw_chat_transcript"
|
||||
if "完整原始对话" in content or "full transcript" in content.lower():
|
||||
return True, "raw_chat_transcript"
|
||||
return False, ""
|
||||
|
||||
|
||||
def detect_large_log(content: str) -> tuple[bool, str]:
|
||||
lines = [line for line in content.splitlines() if line.strip()]
|
||||
log_lines = sum(1 for line in lines if LOG_LINE_RE.search(line))
|
||||
if len(content) > 4000 or len(lines) > 40 or log_lines >= 8:
|
||||
return True, "large_or_raw_log"
|
||||
return False, ""
|
||||
|
||||
|
||||
def detect_low_value_memory(content: str) -> tuple[bool, str]:
|
||||
normalized = re.sub(r"\s+", " ", content).strip().lower()
|
||||
stable_signal = re.search(r"记住|偏好|长期|决策|结论|约束|preference|remember|decision|constraint", normalized, re.I)
|
||||
if stable_signal:
|
||||
return False, ""
|
||||
if len(normalized) < 12:
|
||||
return True, "too_short"
|
||||
small_talk = {
|
||||
"hi",
|
||||
"hello",
|
||||
"thanks",
|
||||
"thank you",
|
||||
"ok",
|
||||
"好的",
|
||||
"谢谢",
|
||||
"你好",
|
||||
"收到",
|
||||
"再见",
|
||||
}
|
||||
if normalized in small_talk:
|
||||
return True, "small_talk"
|
||||
return False, ""
|
||||
|
||||
|
||||
def sanitize_memory_content(content: str) -> str:
|
||||
sanitized = content.strip()
|
||||
sanitized = re.sub(r"\b(password|api[_-]?key|token|secret)\s*[:=]\s*\S+", r"\1=<redacted>", sanitized, flags=re.I)
|
||||
sanitized = re.sub(r"\bbearer\s+[a-z0-9._\-]{12,}", "Bearer <redacted>", sanitized, flags=re.I)
|
||||
sanitized = re.sub(r"-----BEGIN [A-Z ]*PRIVATE KEY-----.*?-----END [A-Z ]*PRIVATE KEY-----", "<redacted-private-key>", sanitized, flags=re.I | re.S)
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_memory_write(content: str, *, allow_low_value: bool = False) -> dict[str, Any]:
|
||||
if not content or not content.strip():
|
||||
return {"allowed": False, "reason": "empty_content", "sanitized_content": ""}
|
||||
checks = [detect_secret, detect_raw_transcript, detect_large_log]
|
||||
for check in checks:
|
||||
blocked, reason = check(content)
|
||||
if blocked:
|
||||
return {"allowed": False, "reason": reason, "sanitized_content": ""}
|
||||
if CHAIN_OF_THOUGHT_RE.search(content):
|
||||
return {"allowed": False, "reason": "chain_of_thought", "sanitized_content": ""}
|
||||
low_value, reason = detect_low_value_memory(content)
|
||||
if low_value and not allow_low_value:
|
||||
return {"allowed": False, "reason": reason, "sanitized_content": ""}
|
||||
return {"allowed": True, "reason": "ok", "sanitized_content": sanitize_memory_content(content)}
|
||||
@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentContext:
|
||||
user_id: str
|
||||
agent_id: str = ""
|
||||
workspace_id: str = ""
|
||||
session_id: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
163
plugins/memory-gateway-agent/memory_gateway_plugin/tools.py
Normal file
163
plugins/memory-gateway-agent/memory_gateway_plugin/tools.py
Normal file
@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .client import MemoryGatewayClient
|
||||
from .config import PluginConfig, load_config
|
||||
from .safety import validate_memory_write
|
||||
|
||||
|
||||
FEEDBACK_MAP = {
|
||||
"confirm": "useful",
|
||||
"correct": "useful",
|
||||
"useful": "useful",
|
||||
"delete": "incorrect",
|
||||
"reject": "incorrect",
|
||||
"incorrect": "incorrect",
|
||||
"duplicate": "duplicate",
|
||||
"outdated": "outdated",
|
||||
"not_useful": "not_useful",
|
||||
}
|
||||
|
||||
|
||||
def _client(client: MemoryGatewayClient | None = None) -> MemoryGatewayClient:
|
||||
return client or MemoryGatewayClient()
|
||||
|
||||
|
||||
def _context_payload(user_id: str, agent_id: str = "", workspace_id: str = "", session_id: str = "") -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"user_id": user_id}
|
||||
if agent_id:
|
||||
payload["agent_id"] = agent_id
|
||||
if workspace_id:
|
||||
payload["workspace_id"] = workspace_id
|
||||
if session_id:
|
||||
payload["session_id"] = session_id
|
||||
return payload
|
||||
|
||||
|
||||
def memory_search(
|
||||
query: str,
|
||||
user_id: str,
|
||||
agent_id: str = "",
|
||||
workspace_id: str = "",
|
||||
session_id: str = "",
|
||||
namespaces: list[str] | None = None,
|
||||
memory_types: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
limit: int = 5,
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if not query or not query.strip():
|
||||
return {"ok": False, "error": "query_required"}
|
||||
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
|
||||
payload.update(
|
||||
{
|
||||
"query": query.strip(),
|
||||
"namespaces": namespaces or [],
|
||||
"memory_types": memory_types or [],
|
||||
"tags": tags or [],
|
||||
"limit": limit,
|
||||
}
|
||||
)
|
||||
return _client(client).search_memory(payload)
|
||||
|
||||
|
||||
def memory_append_episode(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
content: str = "",
|
||||
episode_summary: str = "",
|
||||
workspace_id: str = "",
|
||||
source: str = "conversation",
|
||||
tags: list[str] | None = None,
|
||||
importance: float | None = None,
|
||||
confidence: float | None = None,
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
candidate = (episode_summary or content or "").strip()
|
||||
validation = validate_memory_write(candidate)
|
||||
if not validation["allowed"]:
|
||||
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
|
||||
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
|
||||
payload.update({"content": validation["sanitized_content"], "tags": tags or [], "source": source})
|
||||
if importance is not None:
|
||||
payload["events"] = [{"type": "importance_hint", "value": importance}]
|
||||
if confidence is not None:
|
||||
payload.setdefault("events", []).append({"type": "confidence_hint", "value": confidence})
|
||||
return _client(client).append_episode(payload)
|
||||
|
||||
|
||||
def memory_commit_session(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
workspace_id: str = "",
|
||||
promote: bool = True,
|
||||
min_importance: float = 0.6,
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
|
||||
payload.update({"promote": promote, "min_importance": min_importance})
|
||||
return _client(client).commit_session(session_id, payload)
|
||||
|
||||
|
||||
def memory_upsert(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
workspace_id: str = "",
|
||||
namespace: str = "",
|
||||
memory_type: str = "fact",
|
||||
summary: str = "",
|
||||
tags: list[str] | None = None,
|
||||
importance: float = 0.5,
|
||||
confidence: float = 0.8,
|
||||
visibility: str = "private",
|
||||
source: str = "agent",
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
validation = validate_memory_write(content)
|
||||
if not validation["allowed"]:
|
||||
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
|
||||
payload = _context_payload(user_id, agent_id, workspace_id)
|
||||
payload.update(
|
||||
{
|
||||
"namespace": namespace or None,
|
||||
"memory_type": memory_type,
|
||||
"content": validation["sanitized_content"],
|
||||
"summary": summary or None,
|
||||
"tags": tags or [],
|
||||
"importance": importance,
|
||||
"confidence": confidence,
|
||||
"visibility": visibility,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
return _client(client).upsert_memory(payload)
|
||||
|
||||
|
||||
def memory_feedback(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
feedback: str,
|
||||
workspace_id: str = "",
|
||||
session_id: str = "",
|
||||
comment: str = "",
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
mapped_feedback = FEEDBACK_MAP.get(feedback, feedback)
|
||||
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
|
||||
payload.update({"feedback": mapped_feedback, "comment": comment or None})
|
||||
return _client(client).send_feedback(memory_id, payload)
|
||||
|
||||
|
||||
def default_context(config: PluginConfig | None = None) -> dict[str, str]:
|
||||
cfg = config or load_config()
|
||||
return {
|
||||
"user_id": cfg.default_user_id,
|
||||
"agent_id": cfg.default_agent_id,
|
||||
"workspace_id": cfg.default_workspace_id,
|
||||
}
|
||||
|
||||
45
plugins/memory-gateway-agent/memory_gateway_plugin/trace.py
Normal file
45
plugins/memory-gateway-agent/memory_gateway_plugin/trace.py
Normal file
@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .output import redact, short_id
|
||||
|
||||
|
||||
def trace_enabled() -> bool:
|
||||
return os.environ.get("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def trace_path() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / ".tmp" / "hook_trace.log"
|
||||
|
||||
|
||||
def trace_hook(
|
||||
hook_name: str,
|
||||
*,
|
||||
session_id: str = "",
|
||||
gateway_action: str = "",
|
||||
gateway_called: bool = False,
|
||||
ok: bool | None = None,
|
||||
audit_delta: int | None = None,
|
||||
reason: str = "",
|
||||
) -> None:
|
||||
if not trace_enabled():
|
||||
return
|
||||
path = trace_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"hook": hook_name,
|
||||
"session_id": short_id(session_id),
|
||||
"gateway_action": gateway_action,
|
||||
"gateway_called": gateway_called,
|
||||
"ok": ok,
|
||||
"audit_delta": audit_delta,
|
||||
"reason": reason[:160] if reason else "",
|
||||
}
|
||||
with path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(redact(payload), ensure_ascii=False, default=str) + "\n")
|
||||
Reference in New Issue
Block a user