Add Memory Gateway agent plugin
This commit is contained in:
120
plugins/memory-gateway-agent/__init__.py
Normal file
120
plugins/memory-gateway-agent/__init__.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import schemas, tools
|
||||
from .memory_gateway_plugin.config import load_config
|
||||
from .memory_gateway_plugin import lifecycle
|
||||
from .memory_gateway_plugin.trace import trace_hook
|
||||
except ImportError:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_PLUGIN_ROOT = Path(__file__).resolve().parent
|
||||
if str(_PLUGIN_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PLUGIN_ROOT))
|
||||
|
||||
import schemas # type: ignore[no-redef]
|
||||
import tools # type: ignore[no-redef]
|
||||
from memory_gateway_plugin.config import load_config # type: ignore[no-redef]
|
||||
from memory_gateway_plugin import lifecycle # type: ignore[no-redef]
|
||||
from memory_gateway_plugin.trace import trace_hook # type: ignore[no-redef]
|
||||
|
||||
TOOLSET = "memory_gateway"
|
||||
_LAST_USER_MESSAGES: dict[str, str] = {}
|
||||
|
||||
|
||||
def _context_from_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
cfg = load_config()
|
||||
return {
|
||||
"user_id": kwargs.get("user_id") or kwargs.get("user") or cfg.default_user_id,
|
||||
"agent_id": kwargs.get("agent_id") or kwargs.get("agent") or cfg.default_agent_id or "hermes_agent",
|
||||
"workspace_id": kwargs.get("workspace_id") or kwargs.get("workspace") or cfg.default_workspace_id,
|
||||
"session_id": kwargs.get("session_id") or kwargs.get("task_id") or "",
|
||||
"user_message": kwargs.get("user_message") or kwargs.get("prompt") or kwargs.get("message") or "",
|
||||
"assistant_response": kwargs.get("assistant_response") or kwargs.get("response") or "",
|
||||
"conversation_history": kwargs.get("conversation_history") or [],
|
||||
"model": kwargs.get("model") or "",
|
||||
"platform": kwargs.get("platform") or "",
|
||||
}
|
||||
|
||||
|
||||
def on_session_start(**kwargs: Any) -> dict[str, Any]:
|
||||
session_id = kwargs.get("session_id") or kwargs.get("task_id") or ""
|
||||
trace_hook("on_session_start", session_id=str(session_id), gateway_action="", gateway_called=False, ok=True)
|
||||
return {
|
||||
"status": "ok",
|
||||
"memory_gateway": "session_initialized",
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
|
||||
def pre_llm_call(**kwargs: Any) -> dict[str, str]:
|
||||
context = _context_from_kwargs(kwargs)
|
||||
if context.get("session_id") and context.get("user_message"):
|
||||
_LAST_USER_MESSAGES[context["session_id"]] = context["user_message"]
|
||||
result = lifecycle.on_conversation_start(context)
|
||||
trace_hook(
|
||||
"pre_llm_call",
|
||||
session_id=context.get("session_id", ""),
|
||||
gateway_action="memory_search",
|
||||
gateway_called=bool(result.get("raw")),
|
||||
ok=bool(result.get("ok")),
|
||||
reason=str(result.get("error") or result.get("reason") or ""),
|
||||
)
|
||||
if not result.get("ok") or not result.get("memory_context"):
|
||||
return {}
|
||||
return {"context": "Relevant Memory Gateway context:\n" + result["memory_context"]}
|
||||
|
||||
|
||||
def post_llm_call(**kwargs: Any) -> dict[str, Any] | None:
|
||||
context = _context_from_kwargs(kwargs)
|
||||
if not context.get("user_message") and context.get("session_id"):
|
||||
context["user_message"] = _LAST_USER_MESSAGES.get(context["session_id"], "")
|
||||
result = lifecycle.after_user_message(context)
|
||||
trace_hook(
|
||||
"post_llm_call",
|
||||
session_id=context.get("session_id", ""),
|
||||
gateway_action="append_episode",
|
||||
gateway_called=bool(result.get("raw")),
|
||||
ok=bool(result.get("ok")),
|
||||
reason=str(result.get("error") or result.get("reason") or ""),
|
||||
)
|
||||
if result.get("ok"):
|
||||
return None
|
||||
return {"memory_gateway_error": result.get("error") or result.get("reason") or "append_failed"}
|
||||
|
||||
|
||||
def on_session_end(**kwargs: Any) -> dict[str, Any] | None:
|
||||
context = _context_from_kwargs(kwargs)
|
||||
result = lifecycle.on_session_end(context)
|
||||
trace_hook(
|
||||
"on_session_end",
|
||||
session_id=context.get("session_id", ""),
|
||||
gateway_action="commit_session",
|
||||
gateway_called=bool(result.get("raw")),
|
||||
ok=bool(result.get("ok")),
|
||||
reason=str(result.get("error") or result.get("reason") or ""),
|
||||
)
|
||||
if context.get("session_id"):
|
||||
_LAST_USER_MESSAGES.pop(context["session_id"], None)
|
||||
if result.get("ok"):
|
||||
return None
|
||||
return {"memory_gateway_error": result.get("error") or "commit_failed"}
|
||||
|
||||
|
||||
def register(ctx: Any) -> None:
|
||||
for name, schema in schemas.TOOL_SCHEMAS.items():
|
||||
ctx.register_tool(
|
||||
name=name,
|
||||
toolset=TOOLSET,
|
||||
schema=schema,
|
||||
handler=tools.HANDLERS[name],
|
||||
)
|
||||
|
||||
if hasattr(ctx, "register_hook"):
|
||||
ctx.register_hook("on_session_start", on_session_start)
|
||||
ctx.register_hook("pre_llm_call", pre_llm_call)
|
||||
ctx.register_hook("post_llm_call", post_llm_call)
|
||||
ctx.register_hook("on_session_end", on_session_end)
|
||||
Reference in New Issue
Block a user