121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import schemas, tools
|
|
from .memory_gateway_plugin.config import load_config
|
|
from .memory_gateway_plugin import lifecycle
|
|
from .memory_gateway_plugin.trace import trace_hook
|
|
except ImportError:
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
_PLUGIN_ROOT = Path(__file__).resolve().parent
|
|
if str(_PLUGIN_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(_PLUGIN_ROOT))
|
|
|
|
import schemas # type: ignore[no-redef]
|
|
import tools # type: ignore[no-redef]
|
|
from memory_gateway_plugin.config import load_config # type: ignore[no-redef]
|
|
from memory_gateway_plugin import lifecycle # type: ignore[no-redef]
|
|
from memory_gateway_plugin.trace import trace_hook # type: ignore[no-redef]
|
|
|
|
TOOLSET = "memory_gateway"
|
|
_LAST_USER_MESSAGES: dict[str, str] = {}
|
|
|
|
|
|
def _context_from_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
cfg = load_config()
|
|
return {
|
|
"user_id": kwargs.get("user_id") or kwargs.get("user") or cfg.default_user_id,
|
|
"agent_id": kwargs.get("agent_id") or kwargs.get("agent") or cfg.default_agent_id or "hermes_agent",
|
|
"workspace_id": kwargs.get("workspace_id") or kwargs.get("workspace") or cfg.default_workspace_id,
|
|
"session_id": kwargs.get("session_id") or kwargs.get("task_id") or "",
|
|
"user_message": kwargs.get("user_message") or kwargs.get("prompt") or kwargs.get("message") or "",
|
|
"assistant_response": kwargs.get("assistant_response") or kwargs.get("response") or "",
|
|
"conversation_history": kwargs.get("conversation_history") or [],
|
|
"model": kwargs.get("model") or "",
|
|
"platform": kwargs.get("platform") or "",
|
|
}
|
|
|
|
|
|
def on_session_start(**kwargs: Any) -> dict[str, Any]:
|
|
session_id = kwargs.get("session_id") or kwargs.get("task_id") or ""
|
|
trace_hook("on_session_start", session_id=str(session_id), gateway_action="", gateway_called=False, ok=True)
|
|
return {
|
|
"status": "ok",
|
|
"memory_gateway": "session_initialized",
|
|
"session_id": session_id,
|
|
}
|
|
|
|
|
|
def pre_llm_call(**kwargs: Any) -> dict[str, str]:
|
|
context = _context_from_kwargs(kwargs)
|
|
if context.get("session_id") and context.get("user_message"):
|
|
_LAST_USER_MESSAGES[context["session_id"]] = context["user_message"]
|
|
result = lifecycle.on_conversation_start(context)
|
|
trace_hook(
|
|
"pre_llm_call",
|
|
session_id=context.get("session_id", ""),
|
|
gateway_action="memory_search",
|
|
gateway_called=bool(result.get("raw")),
|
|
ok=bool(result.get("ok")),
|
|
reason=str(result.get("error") or result.get("reason") or ""),
|
|
)
|
|
if not result.get("ok") or not result.get("memory_context"):
|
|
return {}
|
|
return {"context": "Relevant Memory Gateway context:\n" + result["memory_context"]}
|
|
|
|
|
|
def post_llm_call(**kwargs: Any) -> dict[str, Any] | None:
|
|
context = _context_from_kwargs(kwargs)
|
|
if not context.get("user_message") and context.get("session_id"):
|
|
context["user_message"] = _LAST_USER_MESSAGES.get(context["session_id"], "")
|
|
result = lifecycle.after_user_message(context)
|
|
trace_hook(
|
|
"post_llm_call",
|
|
session_id=context.get("session_id", ""),
|
|
gateway_action="append_episode",
|
|
gateway_called=bool(result.get("raw")),
|
|
ok=bool(result.get("ok")),
|
|
reason=str(result.get("error") or result.get("reason") or ""),
|
|
)
|
|
if result.get("ok"):
|
|
return None
|
|
return {"memory_gateway_error": result.get("error") or result.get("reason") or "append_failed"}
|
|
|
|
|
|
def on_session_end(**kwargs: Any) -> dict[str, Any] | None:
|
|
context = _context_from_kwargs(kwargs)
|
|
result = lifecycle.on_session_end(context)
|
|
trace_hook(
|
|
"on_session_end",
|
|
session_id=context.get("session_id", ""),
|
|
gateway_action="commit_session",
|
|
gateway_called=bool(result.get("raw")),
|
|
ok=bool(result.get("ok")),
|
|
reason=str(result.get("error") or result.get("reason") or ""),
|
|
)
|
|
if context.get("session_id"):
|
|
_LAST_USER_MESSAGES.pop(context["session_id"], None)
|
|
if result.get("ok"):
|
|
return None
|
|
return {"memory_gateway_error": result.get("error") or "commit_failed"}
|
|
|
|
|
|
def register(ctx: Any) -> None:
|
|
for name, schema in schemas.TOOL_SCHEMAS.items():
|
|
ctx.register_tool(
|
|
name=name,
|
|
toolset=TOOLSET,
|
|
schema=schema,
|
|
handler=tools.HANDLERS[name],
|
|
)
|
|
|
|
if hasattr(ctx, "register_hook"):
|
|
ctx.register_hook("on_session_start", on_session_start)
|
|
ctx.register_hook("pre_llm_call", pre_llm_call)
|
|
ctx.register_hook("post_llm_call", post_llm_call)
|
|
ctx.register_hook("on_session_end", on_session_end)
|