from __future__ import annotations from typing import Any try: from . import schemas, tools from .memory_gateway_plugin.config import load_config from .memory_gateway_plugin import lifecycle from .memory_gateway_plugin.trace import trace_hook except ImportError: import sys from pathlib import Path _PLUGIN_ROOT = Path(__file__).resolve().parent if str(_PLUGIN_ROOT) not in sys.path: sys.path.insert(0, str(_PLUGIN_ROOT)) import schemas # type: ignore[no-redef] import tools # type: ignore[no-redef] from memory_gateway_plugin.config import load_config # type: ignore[no-redef] from memory_gateway_plugin import lifecycle # type: ignore[no-redef] from memory_gateway_plugin.trace import trace_hook # type: ignore[no-redef] TOOLSET = "memory_gateway" _LAST_USER_MESSAGES: dict[str, str] = {} def _context_from_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: cfg = load_config() return { "user_id": kwargs.get("user_id") or kwargs.get("user") or cfg.default_user_id, "agent_id": kwargs.get("agent_id") or kwargs.get("agent") or cfg.default_agent_id or "hermes_agent", "workspace_id": kwargs.get("workspace_id") or kwargs.get("workspace") or cfg.default_workspace_id, "session_id": kwargs.get("session_id") or kwargs.get("task_id") or "", "user_message": kwargs.get("user_message") or kwargs.get("prompt") or kwargs.get("message") or "", "assistant_response": kwargs.get("assistant_response") or kwargs.get("response") or "", "conversation_history": kwargs.get("conversation_history") or [], "model": kwargs.get("model") or "", "platform": kwargs.get("platform") or "", } def on_session_start(**kwargs: Any) -> dict[str, Any]: session_id = kwargs.get("session_id") or kwargs.get("task_id") or "" trace_hook("on_session_start", session_id=str(session_id), gateway_action="", gateway_called=False, ok=True) return { "status": "ok", "memory_gateway": "session_initialized", "session_id": session_id, } def pre_llm_call(**kwargs: Any) -> dict[str, str]: context = _context_from_kwargs(kwargs) if context.get("session_id") and context.get("user_message"): _LAST_USER_MESSAGES[context["session_id"]] = context["user_message"] result = lifecycle.on_conversation_start(context) trace_hook( "pre_llm_call", session_id=context.get("session_id", ""), gateway_action="memory_search", gateway_called=bool(result.get("raw")), ok=bool(result.get("ok")), reason=str(result.get("error") or result.get("reason") or ""), ) if not result.get("ok") or not result.get("memory_context"): return {} return {"context": "Relevant Memory Gateway context:\n" + result["memory_context"]} def post_llm_call(**kwargs: Any) -> dict[str, Any] | None: context = _context_from_kwargs(kwargs) if not context.get("user_message") and context.get("session_id"): context["user_message"] = _LAST_USER_MESSAGES.get(context["session_id"], "") result = lifecycle.after_user_message(context) trace_hook( "post_llm_call", session_id=context.get("session_id", ""), gateway_action="append_episode", gateway_called=bool(result.get("raw")), ok=bool(result.get("ok")), reason=str(result.get("error") or result.get("reason") or ""), ) if result.get("ok"): return None return {"memory_gateway_error": result.get("error") or result.get("reason") or "append_failed"} def on_session_end(**kwargs: Any) -> dict[str, Any] | None: context = _context_from_kwargs(kwargs) result = lifecycle.on_session_end(context) trace_hook( "on_session_end", session_id=context.get("session_id", ""), gateway_action="commit_session", gateway_called=bool(result.get("raw")), ok=bool(result.get("ok")), reason=str(result.get("error") or result.get("reason") or ""), ) if context.get("session_id"): _LAST_USER_MESSAGES.pop(context["session_id"], None) if result.get("ok"): return None return {"memory_gateway_error": result.get("error") or "commit_failed"} def register(ctx: Any) -> None: for name, schema in schemas.TOOL_SCHEMAS.items(): ctx.register_tool( name=name, toolset=TOOLSET, schema=schema, handler=tools.HANDLERS[name], ) if hasattr(ctx, "register_hook"): ctx.register_hook("on_session_start", on_session_start) ctx.register_hook("pre_llm_call", pre_llm_call) ctx.register_hook("post_llm_call", post_llm_call) ctx.register_hook("on_session_end", on_session_end)