Files
memory-gateway/plugins/memory-gateway-agent/memory_gateway_plugin/tools.py

164 lines
5.1 KiB
Python

from __future__ import annotations
from typing import Any
from .client import MemoryGatewayClient
from .config import PluginConfig, load_config
from .safety import validate_memory_write
FEEDBACK_MAP = {
"confirm": "useful",
"correct": "useful",
"useful": "useful",
"delete": "incorrect",
"reject": "incorrect",
"incorrect": "incorrect",
"duplicate": "duplicate",
"outdated": "outdated",
"not_useful": "not_useful",
}
def _client(client: MemoryGatewayClient | None = None) -> MemoryGatewayClient:
return client or MemoryGatewayClient()
def _context_payload(user_id: str, agent_id: str = "", workspace_id: str = "", session_id: str = "") -> dict[str, Any]:
payload: dict[str, Any] = {"user_id": user_id}
if agent_id:
payload["agent_id"] = agent_id
if workspace_id:
payload["workspace_id"] = workspace_id
if session_id:
payload["session_id"] = session_id
return payload
def memory_search(
query: str,
user_id: str,
agent_id: str = "",
workspace_id: str = "",
session_id: str = "",
namespaces: list[str] | None = None,
memory_types: list[str] | None = None,
tags: list[str] | None = None,
limit: int = 5,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
if not query or not query.strip():
return {"ok": False, "error": "query_required"}
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update(
{
"query": query.strip(),
"namespaces": namespaces or [],
"memory_types": memory_types or [],
"tags": tags or [],
"limit": limit,
}
)
return _client(client).search_memory(payload)
def memory_append_episode(
user_id: str,
agent_id: str,
session_id: str,
content: str = "",
episode_summary: str = "",
workspace_id: str = "",
source: str = "conversation",
tags: list[str] | None = None,
importance: float | None = None,
confidence: float | None = None,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
candidate = (episode_summary or content or "").strip()
validation = validate_memory_write(candidate)
if not validation["allowed"]:
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"content": validation["sanitized_content"], "tags": tags or [], "source": source})
if importance is not None:
payload["events"] = [{"type": "importance_hint", "value": importance}]
if confidence is not None:
payload.setdefault("events", []).append({"type": "confidence_hint", "value": confidence})
return _client(client).append_episode(payload)
def memory_commit_session(
user_id: str,
agent_id: str,
session_id: str,
workspace_id: str = "",
promote: bool = True,
min_importance: float = 0.6,
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"promote": promote, "min_importance": min_importance})
return _client(client).commit_session(session_id, payload)
def memory_upsert(
user_id: str,
agent_id: str,
content: str,
workspace_id: str = "",
namespace: str = "",
memory_type: str = "fact",
summary: str = "",
tags: list[str] | None = None,
importance: float = 0.5,
confidence: float = 0.8,
visibility: str = "private",
source: str = "agent",
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
validation = validate_memory_write(content)
if not validation["allowed"]:
return {"ok": False, "error": "memory_write_rejected", "reason": validation["reason"]}
payload = _context_payload(user_id, agent_id, workspace_id)
payload.update(
{
"namespace": namespace or None,
"memory_type": memory_type,
"content": validation["sanitized_content"],
"summary": summary or None,
"tags": tags or [],
"importance": importance,
"confidence": confidence,
"visibility": visibility,
"source": source,
}
)
return _client(client).upsert_memory(payload)
def memory_feedback(
user_id: str,
agent_id: str,
memory_id: str,
feedback: str,
workspace_id: str = "",
session_id: str = "",
comment: str = "",
client: MemoryGatewayClient | None = None,
) -> dict[str, Any]:
mapped_feedback = FEEDBACK_MAP.get(feedback, feedback)
payload = _context_payload(user_id, agent_id, workspace_id, session_id)
payload.update({"feedback": mapped_feedback, "comment": comment or None})
return _client(client).send_feedback(memory_id, payload)
def default_context(config: PluginConfig | None = None) -> dict[str, str]:
cfg = config or load_config()
return {
"user_id": cfg.default_user_id,
"agent_id": cfg.default_agent_id,
"workspace_id": cfg.default_workspace_id,
}