Files
memory-gateway/plugins/memory-gateway-agent/__init__.py

121 lines
4.7 KiB
Python

from __future__ import annotations
from typing import Any
try:
from . import schemas, tools
from .memory_gateway_plugin.config import load_config
from .memory_gateway_plugin import lifecycle
from .memory_gateway_plugin.trace import trace_hook
except ImportError:
import sys
from pathlib import Path
_PLUGIN_ROOT = Path(__file__).resolve().parent
if str(_PLUGIN_ROOT) not in sys.path:
sys.path.insert(0, str(_PLUGIN_ROOT))
import schemas # type: ignore[no-redef]
import tools # type: ignore[no-redef]
from memory_gateway_plugin.config import load_config # type: ignore[no-redef]
from memory_gateway_plugin import lifecycle # type: ignore[no-redef]
from memory_gateway_plugin.trace import trace_hook # type: ignore[no-redef]
TOOLSET = "memory_gateway"
_LAST_USER_MESSAGES: dict[str, str] = {}
def _context_from_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
cfg = load_config()
return {
"user_id": kwargs.get("user_id") or kwargs.get("user") or cfg.default_user_id,
"agent_id": kwargs.get("agent_id") or kwargs.get("agent") or cfg.default_agent_id or "hermes_agent",
"workspace_id": kwargs.get("workspace_id") or kwargs.get("workspace") or cfg.default_workspace_id,
"session_id": kwargs.get("session_id") or kwargs.get("task_id") or "",
"user_message": kwargs.get("user_message") or kwargs.get("prompt") or kwargs.get("message") or "",
"assistant_response": kwargs.get("assistant_response") or kwargs.get("response") or "",
"conversation_history": kwargs.get("conversation_history") or [],
"model": kwargs.get("model") or "",
"platform": kwargs.get("platform") or "",
}
def on_session_start(**kwargs: Any) -> dict[str, Any]:
session_id = kwargs.get("session_id") or kwargs.get("task_id") or ""
trace_hook("on_session_start", session_id=str(session_id), gateway_action="", gateway_called=False, ok=True)
return {
"status": "ok",
"memory_gateway": "session_initialized",
"session_id": session_id,
}
def pre_llm_call(**kwargs: Any) -> dict[str, str]:
context = _context_from_kwargs(kwargs)
if context.get("session_id") and context.get("user_message"):
_LAST_USER_MESSAGES[context["session_id"]] = context["user_message"]
result = lifecycle.on_conversation_start(context)
trace_hook(
"pre_llm_call",
session_id=context.get("session_id", ""),
gateway_action="memory_search",
gateway_called=bool(result.get("raw")),
ok=bool(result.get("ok")),
reason=str(result.get("error") or result.get("reason") or ""),
)
if not result.get("ok") or not result.get("memory_context"):
return {}
return {"context": "Relevant Memory Gateway context:\n" + result["memory_context"]}
def post_llm_call(**kwargs: Any) -> dict[str, Any] | None:
context = _context_from_kwargs(kwargs)
if not context.get("user_message") and context.get("session_id"):
context["user_message"] = _LAST_USER_MESSAGES.get(context["session_id"], "")
result = lifecycle.after_user_message(context)
trace_hook(
"post_llm_call",
session_id=context.get("session_id", ""),
gateway_action="append_episode",
gateway_called=bool(result.get("raw")),
ok=bool(result.get("ok")),
reason=str(result.get("error") or result.get("reason") or ""),
)
if result.get("ok"):
return None
return {"memory_gateway_error": result.get("error") or result.get("reason") or "append_failed"}
def on_session_end(**kwargs: Any) -> dict[str, Any] | None:
context = _context_from_kwargs(kwargs)
result = lifecycle.on_session_end(context)
trace_hook(
"on_session_end",
session_id=context.get("session_id", ""),
gateway_action="commit_session",
gateway_called=bool(result.get("raw")),
ok=bool(result.get("ok")),
reason=str(result.get("error") or result.get("reason") or ""),
)
if context.get("session_id"):
_LAST_USER_MESSAGES.pop(context["session_id"], None)
if result.get("ok"):
return None
return {"memory_gateway_error": result.get("error") or "commit_failed"}
def register(ctx: Any) -> None:
for name, schema in schemas.TOOL_SCHEMAS.items():
ctx.register_tool(
name=name,
toolset=TOOLSET,
schema=schema,
handler=tools.HANDLERS[name],
)
if hasattr(ctx, "register_hook"):
ctx.register_hook("on_session_start", on_session_start)
ctx.register_hook("pre_llm_call", pre_llm_call)
ctx.register_hook("post_llm_call", post_llm_call)
ctx.register_hook("on_session_end", on_session_end)