Files
memory-gateway/memory_gateway/services.py
2026-05-05 16:18:31 +08:00

370 lines
16 KiB
Python

"""Application services for the generic Memory Gateway v1 API."""
from __future__ import annotations
from datetime import datetime, timezone
from fastapi import HTTPException, status
from .config import get_config
from .evermemos_client import EverMemOSError, EverMemOSClient
from .namespace import can_access_memory, default_namespace_for_context, user_long_term_namespace, visible_namespaces
from .openviking_client import get_openviking_client
from .repositories import MetadataRepository, repository
from .schemas import (
AccessContext,
AuditLog,
CommitSessionRequest,
CreateUserRequest,
EpisodeAppendRequest,
EpisodeRecord,
MemoryFeedbackRequest,
MemoryPatchRequest,
MemoryRecord,
MemorySearchRequest,
MemoryType,
MemoryUpsertRequest,
NamespaceInfo,
ProfileRecord,
SourceType,
UserRecord,
Visibility,
)
from .workers.evermemos_worker import EverMemOSWorker
class MemoryGatewayService:
def __init__(self, repo: MetadataRepository = repository, evermemos_client: EverMemOSClient | None = None) -> None:
self.repo = repo
self.evermemos_client = evermemos_client
def create_user(self, request: CreateUserRequest) -> UserRecord:
user = UserRecord(
id=request.user_id or UserRecord(display_name=request.display_name).id,
display_name=request.display_name,
preferences=request.preferences,
)
user.profile_namespace = f"user/{user.id}/profile"
self.repo.create_user(user)
self._audit("create_user", "user", user.id, namespace=user.profile_namespace, actor_user_id=user.id)
return user
def get_user(self, user_id: str) -> UserRecord:
user = self.repo.get_user(user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
def search_memory(self, request: MemorySearchRequest) -> dict:
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
query = request.query.lower().strip()
results = []
for memory in self.repo.list_memories():
if not can_access_memory(ctx, memory):
continue
if request.namespaces and memory.namespace not in request.namespaces:
continue
if request.memory_types and memory.memory_type not in request.memory_types:
continue
if request.tags and not set(request.tags).intersection(memory.tags):
continue
haystack = " ".join([memory.content, memory.summary or "", " ".join(memory.tags)]).lower()
if query and query not in haystack:
continue
score = self._score(memory, query)
results.append({"memory": memory, "score": score})
results.sort(key=lambda item: item["score"], reverse=True)
return {"results": results[: request.limit], "total": len(results)}
async def search_memory_with_openviking(self, request: MemorySearchRequest) -> dict:
"""Search local metadata first, then fan out to OpenViking for visible namespaces."""
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
local = self.search_memory(request)
visible = {namespace.namespace for namespace in visible_namespaces(ctx)}
requested = set(request.namespaces) if request.namespaces else visible
allowed_namespaces = sorted(requested.intersection(visible))
openviking_results = []
if allowed_namespaces and request.query.strip():
try:
ov_client = await get_openviking_client()
per_namespace_limit = max(1, min(request.limit, 10))
for namespace in allowed_namespaces:
result = await ov_client.search(
query=request.query,
namespace=namespace,
limit=per_namespace_limit,
)
for item in result.results:
item = dict(item)
item["namespace"] = namespace
item["source"] = "openviking"
openviking_results.append(item)
except Exception as exc: # noqa: BLE001
self._audit(
"openviking_search_failed",
"search",
None,
actor_user_id=request.user_id,
actor_agent_id=request.agent_id,
metadata={"error": str(exc)},
)
self._audit(
"memory_search",
"memory",
None,
actor_user_id=request.user_id,
actor_agent_id=request.agent_id,
metadata={"query": request.query, "namespaces": allowed_namespaces, "openviking_results": len(openviking_results)},
)
return {
"results": local["results"] + [{"openviking": item, "score": item.get("score", 0)} for item in openviking_results],
"total": local["total"] + len(openviking_results),
"local_total": local["total"],
"openviking_total": len(openviking_results),
"searched_namespaces": allowed_namespaces,
}
def upsert_memory(self, request: MemoryUpsertRequest) -> MemoryRecord:
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
namespace = request.namespace or default_namespace_for_context(ctx, request.visibility)
memory = MemoryRecord(
user_id=request.user_id,
agent_id=request.agent_id,
workspace_id=request.workspace_id,
session_id=request.session_id,
namespace=namespace,
memory_type=request.memory_type,
content=request.content,
summary=request.summary,
tags=request.tags,
importance=request.importance,
confidence=request.confidence,
visibility=request.visibility,
source=request.source,
expires_at=request.expires_at,
)
self.repo.upsert_memory(memory)
self._audit("upsert_memory", "memory", memory.id, namespace=memory.namespace, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
return memory
def get_memory(self, memory_id: str, ctx: AccessContext) -> MemoryRecord:
memory = self.repo.get_memory(memory_id)
if not memory:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Memory not found")
if not can_access_memory(ctx, memory):
self._audit("get_memory", "memory", memory_id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id, decision="deny")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Memory access denied")
return memory
def patch_memory(self, memory_id: str, ctx: AccessContext, patch: MemoryPatchRequest) -> MemoryRecord:
memory = self.get_memory(memory_id, ctx)
updates = patch.model_dump(exclude_unset=True)
for key, value in updates.items():
setattr(memory, key, value)
memory.updated_at = datetime.now(timezone.utc)
memory.version += 1
self.repo.upsert_memory(memory)
self._audit("patch_memory", "memory", memory.id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id)
return memory
def delete_memory(self, memory_id: str, ctx: AccessContext) -> dict:
memory = self.get_memory(memory_id, ctx)
deleted = self.repo.delete_memory(memory_id)
self._audit("delete_memory", "memory", memory_id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id)
return {"deleted": deleted, "id": memory_id}
def append_episode(self, request: EpisodeAppendRequest) -> EpisodeRecord:
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
episode = EpisodeRecord(
user_id=request.user_id,
agent_id=request.agent_id,
workspace_id=request.workspace_id,
session_id=request.session_id or "default",
namespace=request.namespace or default_namespace_for_context(ctx, Visibility.PRIVATE),
content=request.content,
events=request.events,
tags=request.tags,
source=request.source,
expires_at=request.expires_at,
)
self.repo.append_episode(episode)
self._audit("append_episode", "episode", episode.id, namespace=episode.namespace, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
return episode
def commit_session(self, session_id: str, request: CommitSessionRequest) -> dict:
episodes = self.repo.list_session_episodes(session_id)
backend = "disabled"
error: str | None = None
if request.promote:
ctx = AccessContext(
user_id=request.user_id,
agent_id=request.agent_id,
workspace_id=request.workspace_id,
session_id=session_id,
)
target_namespace = request.target_namespace or user_long_term_namespace(request.user_id)
config = get_config().evermemos
if config.enabled:
try:
external_result = (self.evermemos_client or EverMemOSClient()).consolidate_session(
session_id=session_id,
ctx=ctx,
episodes=episodes,
existing_memories=list(self.repo.list_memories()),
min_importance=request.min_importance,
target_namespace=target_namespace,
)
result = self._persist_external_consolidation(external_result, ctx, session_id)
backend = "external"
except EverMemOSError as exc:
error = str(exc)
if not config.fallback_to_local:
self._audit(
"evermemos_commit_failed",
"session",
session_id,
actor_user_id=request.user_id,
actor_agent_id=request.agent_id,
decision="deny",
metadata={"error": error},
)
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"EverMemOS failed: {error}") from exc
result = self._commit_session_locally(session_id, ctx, request)
backend = "local-fallback"
else:
result = self._commit_session_locally(session_id, ctx, request)
backend = "local-disabled"
else:
result = None
self._audit("commit_session", "session", session_id, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
if not result:
return {"session_id": session_id, "episodes": len(episodes), "promoted": [], "evermemos_backend": backend}
return {
"evermemos_backend": backend,
"evermemos_error": error,
"session_id": session_id,
"episodes": result.episodes,
"candidates": result.candidates,
"promoted": result.promoted,
"duplicates": result.duplicates,
"conflicts": result.conflicts,
"review_drafts": result.review_drafts,
}
def evermemos_health(self) -> dict:
config = get_config().evermemos
if not config.enabled:
return {"status": "disabled", "url": config.url}
return (self.evermemos_client or EverMemOSClient()).health()
def _commit_session_locally(self, session_id: str, ctx: AccessContext, request: CommitSessionRequest):
worker = EverMemOSWorker(self.repo)
return worker.consolidate_session(
session_id=session_id,
ctx=ctx,
min_importance=request.min_importance,
target_namespace=request.target_namespace or user_long_term_namespace(request.user_id),
)
def _persist_external_consolidation(self, external_result: dict, ctx: AccessContext, session_id: str):
from .workers.evermemos_worker import ConsolidationResult
result = ConsolidationResult(
session_id=session_id,
episodes=external_result.get("episodes") or len(self.repo.list_session_episodes(session_id)),
duplicates=external_result.get("duplicates", []),
conflicts=external_result.get("conflicts", []),
review_drafts=external_result.get("review_drafts", []),
)
for item in external_result.get("candidates", []):
memory = self._memory_from_external(item, ctx, session_id)
if memory:
result.candidates.append(memory)
for item in external_result.get("promoted", []):
memory = self._memory_from_external(item, ctx, session_id)
if memory:
self.repo.upsert_memory(memory)
result.promoted.append(memory)
if all(candidate.id != memory.id for candidate in result.candidates):
result.candidates.append(memory)
return result
def _memory_from_external(self, item: dict, ctx: AccessContext, session_id: str) -> MemoryRecord | None:
if not isinstance(item, dict):
return None
data = dict(item)
data.setdefault("user_id", ctx.user_id)
data.setdefault("agent_id", ctx.agent_id)
data.setdefault("workspace_id", ctx.workspace_id)
data.setdefault("session_id", session_id)
data.setdefault("namespace", default_namespace_for_context(ctx, Visibility.PRIVATE))
data.setdefault("memory_type", MemoryType.SUMMARY.value)
data.setdefault("content", data.get("text") or data.get("summary") or "")
data.setdefault("summary", data.get("content", "")[:180])
data.setdefault("tags", ["evermemos-external"])
data.setdefault("importance", 0.7)
data.setdefault("confidence", 0.65)
data.setdefault("visibility", Visibility.PRIVATE.value)
data.setdefault("source", SourceType.EVERMEMOS.value)
if not data["content"]:
return None
return MemoryRecord.model_validate(data)
def get_profile(self, user_id: str) -> ProfileRecord:
profile = self.repo.get_profile(user_id)
if not profile:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Profile not found")
return profile
def add_feedback(self, memory_id: str, request: MemoryFeedbackRequest) -> dict:
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
memory = self.get_memory(memory_id, ctx)
self._audit(
f"feedback:{request.feedback}",
"memory",
memory.id,
namespace=memory.namespace,
actor_user_id=request.user_id,
actor_agent_id=request.agent_id,
metadata={"comment": request.comment},
)
return {"status": "ok", "memory_id": memory_id, "feedback": request.feedback}
def list_namespaces(self, ctx: AccessContext) -> list[NamespaceInfo]:
return visible_namespaces(ctx)
def list_audit(self, limit: int = 100) -> list[AuditLog]:
return self.repo.list_audit(limit)
def _score(self, memory: MemoryRecord, query: str) -> float:
lexical = 1.0 if query and query in memory.content.lower() else 0.2
return lexical + memory.importance + memory.confidence
def _audit(
self,
action: str,
target_type: str,
target_id: str | None,
namespace: str | None = None,
actor_user_id: str | None = None,
actor_agent_id: str | None = None,
decision: str = "allow",
metadata: dict | None = None,
) -> None:
self.repo.add_audit(
AuditLog(
actor_user_id=actor_user_id,
actor_agent_id=actor_agent_id,
action=action,
target_type=target_type,
target_id=target_id,
namespace=namespace,
decision=decision, # type: ignore[arg-type]
metadata=metadata or {},
)
)
service = MemoryGatewayService()