370 lines
16 KiB
Python
370 lines
16 KiB
Python
"""Application services for the generic Memory Gateway v1 API."""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from .config import get_config
|
|
from .evermemos_client import EverMemOSError, EverMemOSClient
|
|
from .namespace import can_access_memory, default_namespace_for_context, user_long_term_namespace, visible_namespaces
|
|
from .openviking_client import get_openviking_client
|
|
from .repositories import MetadataRepository, repository
|
|
from .schemas import (
|
|
AccessContext,
|
|
AuditLog,
|
|
CommitSessionRequest,
|
|
CreateUserRequest,
|
|
EpisodeAppendRequest,
|
|
EpisodeRecord,
|
|
MemoryFeedbackRequest,
|
|
MemoryPatchRequest,
|
|
MemoryRecord,
|
|
MemorySearchRequest,
|
|
MemoryType,
|
|
MemoryUpsertRequest,
|
|
NamespaceInfo,
|
|
ProfileRecord,
|
|
SourceType,
|
|
UserRecord,
|
|
Visibility,
|
|
)
|
|
from .workers.evermemos_worker import EverMemOSWorker
|
|
|
|
|
|
class MemoryGatewayService:
|
|
def __init__(self, repo: MetadataRepository = repository, evermemos_client: EverMemOSClient | None = None) -> None:
|
|
self.repo = repo
|
|
self.evermemos_client = evermemos_client
|
|
|
|
def create_user(self, request: CreateUserRequest) -> UserRecord:
|
|
user = UserRecord(
|
|
id=request.user_id or UserRecord(display_name=request.display_name).id,
|
|
display_name=request.display_name,
|
|
preferences=request.preferences,
|
|
)
|
|
user.profile_namespace = f"user/{user.id}/profile"
|
|
self.repo.create_user(user)
|
|
self._audit("create_user", "user", user.id, namespace=user.profile_namespace, actor_user_id=user.id)
|
|
return user
|
|
|
|
def get_user(self, user_id: str) -> UserRecord:
|
|
user = self.repo.get_user(user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
|
return user
|
|
|
|
def search_memory(self, request: MemorySearchRequest) -> dict:
|
|
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
|
|
query = request.query.lower().strip()
|
|
results = []
|
|
for memory in self.repo.list_memories():
|
|
if not can_access_memory(ctx, memory):
|
|
continue
|
|
if request.namespaces and memory.namespace not in request.namespaces:
|
|
continue
|
|
if request.memory_types and memory.memory_type not in request.memory_types:
|
|
continue
|
|
if request.tags and not set(request.tags).intersection(memory.tags):
|
|
continue
|
|
haystack = " ".join([memory.content, memory.summary or "", " ".join(memory.tags)]).lower()
|
|
if query and query not in haystack:
|
|
continue
|
|
score = self._score(memory, query)
|
|
results.append({"memory": memory, "score": score})
|
|
results.sort(key=lambda item: item["score"], reverse=True)
|
|
return {"results": results[: request.limit], "total": len(results)}
|
|
|
|
async def search_memory_with_openviking(self, request: MemorySearchRequest) -> dict:
|
|
"""Search local metadata first, then fan out to OpenViking for visible namespaces."""
|
|
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
|
|
local = self.search_memory(request)
|
|
visible = {namespace.namespace for namespace in visible_namespaces(ctx)}
|
|
requested = set(request.namespaces) if request.namespaces else visible
|
|
allowed_namespaces = sorted(requested.intersection(visible))
|
|
|
|
openviking_results = []
|
|
if allowed_namespaces and request.query.strip():
|
|
try:
|
|
ov_client = await get_openviking_client()
|
|
per_namespace_limit = max(1, min(request.limit, 10))
|
|
for namespace in allowed_namespaces:
|
|
result = await ov_client.search(
|
|
query=request.query,
|
|
namespace=namespace,
|
|
limit=per_namespace_limit,
|
|
)
|
|
for item in result.results:
|
|
item = dict(item)
|
|
item["namespace"] = namespace
|
|
item["source"] = "openviking"
|
|
openviking_results.append(item)
|
|
except Exception as exc: # noqa: BLE001
|
|
self._audit(
|
|
"openviking_search_failed",
|
|
"search",
|
|
None,
|
|
actor_user_id=request.user_id,
|
|
actor_agent_id=request.agent_id,
|
|
metadata={"error": str(exc)},
|
|
)
|
|
|
|
self._audit(
|
|
"memory_search",
|
|
"memory",
|
|
None,
|
|
actor_user_id=request.user_id,
|
|
actor_agent_id=request.agent_id,
|
|
metadata={"query": request.query, "namespaces": allowed_namespaces, "openviking_results": len(openviking_results)},
|
|
)
|
|
return {
|
|
"results": local["results"] + [{"openviking": item, "score": item.get("score", 0)} for item in openviking_results],
|
|
"total": local["total"] + len(openviking_results),
|
|
"local_total": local["total"],
|
|
"openviking_total": len(openviking_results),
|
|
"searched_namespaces": allowed_namespaces,
|
|
}
|
|
|
|
def upsert_memory(self, request: MemoryUpsertRequest) -> MemoryRecord:
|
|
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
|
|
namespace = request.namespace or default_namespace_for_context(ctx, request.visibility)
|
|
memory = MemoryRecord(
|
|
user_id=request.user_id,
|
|
agent_id=request.agent_id,
|
|
workspace_id=request.workspace_id,
|
|
session_id=request.session_id,
|
|
namespace=namespace,
|
|
memory_type=request.memory_type,
|
|
content=request.content,
|
|
summary=request.summary,
|
|
tags=request.tags,
|
|
importance=request.importance,
|
|
confidence=request.confidence,
|
|
visibility=request.visibility,
|
|
source=request.source,
|
|
expires_at=request.expires_at,
|
|
)
|
|
self.repo.upsert_memory(memory)
|
|
self._audit("upsert_memory", "memory", memory.id, namespace=memory.namespace, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
|
|
return memory
|
|
|
|
def get_memory(self, memory_id: str, ctx: AccessContext) -> MemoryRecord:
|
|
memory = self.repo.get_memory(memory_id)
|
|
if not memory:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Memory not found")
|
|
if not can_access_memory(ctx, memory):
|
|
self._audit("get_memory", "memory", memory_id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id, decision="deny")
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Memory access denied")
|
|
return memory
|
|
|
|
def patch_memory(self, memory_id: str, ctx: AccessContext, patch: MemoryPatchRequest) -> MemoryRecord:
|
|
memory = self.get_memory(memory_id, ctx)
|
|
updates = patch.model_dump(exclude_unset=True)
|
|
for key, value in updates.items():
|
|
setattr(memory, key, value)
|
|
memory.updated_at = datetime.now(timezone.utc)
|
|
memory.version += 1
|
|
self.repo.upsert_memory(memory)
|
|
self._audit("patch_memory", "memory", memory.id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id)
|
|
return memory
|
|
|
|
def delete_memory(self, memory_id: str, ctx: AccessContext) -> dict:
|
|
memory = self.get_memory(memory_id, ctx)
|
|
deleted = self.repo.delete_memory(memory_id)
|
|
self._audit("delete_memory", "memory", memory_id, namespace=memory.namespace, actor_user_id=ctx.user_id, actor_agent_id=ctx.agent_id)
|
|
return {"deleted": deleted, "id": memory_id}
|
|
|
|
def append_episode(self, request: EpisodeAppendRequest) -> EpisodeRecord:
|
|
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
|
|
episode = EpisodeRecord(
|
|
user_id=request.user_id,
|
|
agent_id=request.agent_id,
|
|
workspace_id=request.workspace_id,
|
|
session_id=request.session_id or "default",
|
|
namespace=request.namespace or default_namespace_for_context(ctx, Visibility.PRIVATE),
|
|
content=request.content,
|
|
events=request.events,
|
|
tags=request.tags,
|
|
source=request.source,
|
|
expires_at=request.expires_at,
|
|
)
|
|
self.repo.append_episode(episode)
|
|
self._audit("append_episode", "episode", episode.id, namespace=episode.namespace, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
|
|
return episode
|
|
|
|
def commit_session(self, session_id: str, request: CommitSessionRequest) -> dict:
|
|
episodes = self.repo.list_session_episodes(session_id)
|
|
backend = "disabled"
|
|
error: str | None = None
|
|
if request.promote:
|
|
ctx = AccessContext(
|
|
user_id=request.user_id,
|
|
agent_id=request.agent_id,
|
|
workspace_id=request.workspace_id,
|
|
session_id=session_id,
|
|
)
|
|
target_namespace = request.target_namespace or user_long_term_namespace(request.user_id)
|
|
config = get_config().evermemos
|
|
if config.enabled:
|
|
try:
|
|
external_result = (self.evermemos_client or EverMemOSClient()).consolidate_session(
|
|
session_id=session_id,
|
|
ctx=ctx,
|
|
episodes=episodes,
|
|
existing_memories=list(self.repo.list_memories()),
|
|
min_importance=request.min_importance,
|
|
target_namespace=target_namespace,
|
|
)
|
|
result = self._persist_external_consolidation(external_result, ctx, session_id)
|
|
backend = "external"
|
|
except EverMemOSError as exc:
|
|
error = str(exc)
|
|
if not config.fallback_to_local:
|
|
self._audit(
|
|
"evermemos_commit_failed",
|
|
"session",
|
|
session_id,
|
|
actor_user_id=request.user_id,
|
|
actor_agent_id=request.agent_id,
|
|
decision="deny",
|
|
metadata={"error": error},
|
|
)
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"EverMemOS failed: {error}") from exc
|
|
result = self._commit_session_locally(session_id, ctx, request)
|
|
backend = "local-fallback"
|
|
else:
|
|
result = self._commit_session_locally(session_id, ctx, request)
|
|
backend = "local-disabled"
|
|
else:
|
|
result = None
|
|
self._audit("commit_session", "session", session_id, actor_user_id=request.user_id, actor_agent_id=request.agent_id)
|
|
if not result:
|
|
return {"session_id": session_id, "episodes": len(episodes), "promoted": [], "evermemos_backend": backend}
|
|
return {
|
|
"evermemos_backend": backend,
|
|
"evermemos_error": error,
|
|
"session_id": session_id,
|
|
"episodes": result.episodes,
|
|
"candidates": result.candidates,
|
|
"promoted": result.promoted,
|
|
"duplicates": result.duplicates,
|
|
"conflicts": result.conflicts,
|
|
"review_drafts": result.review_drafts,
|
|
}
|
|
|
|
def evermemos_health(self) -> dict:
|
|
config = get_config().evermemos
|
|
if not config.enabled:
|
|
return {"status": "disabled", "url": config.url}
|
|
return (self.evermemos_client or EverMemOSClient()).health()
|
|
|
|
def _commit_session_locally(self, session_id: str, ctx: AccessContext, request: CommitSessionRequest):
|
|
worker = EverMemOSWorker(self.repo)
|
|
return worker.consolidate_session(
|
|
session_id=session_id,
|
|
ctx=ctx,
|
|
min_importance=request.min_importance,
|
|
target_namespace=request.target_namespace or user_long_term_namespace(request.user_id),
|
|
)
|
|
|
|
def _persist_external_consolidation(self, external_result: dict, ctx: AccessContext, session_id: str):
|
|
from .workers.evermemos_worker import ConsolidationResult
|
|
|
|
result = ConsolidationResult(
|
|
session_id=session_id,
|
|
episodes=external_result.get("episodes") or len(self.repo.list_session_episodes(session_id)),
|
|
duplicates=external_result.get("duplicates", []),
|
|
conflicts=external_result.get("conflicts", []),
|
|
review_drafts=external_result.get("review_drafts", []),
|
|
)
|
|
for item in external_result.get("candidates", []):
|
|
memory = self._memory_from_external(item, ctx, session_id)
|
|
if memory:
|
|
result.candidates.append(memory)
|
|
for item in external_result.get("promoted", []):
|
|
memory = self._memory_from_external(item, ctx, session_id)
|
|
if memory:
|
|
self.repo.upsert_memory(memory)
|
|
result.promoted.append(memory)
|
|
if all(candidate.id != memory.id for candidate in result.candidates):
|
|
result.candidates.append(memory)
|
|
return result
|
|
|
|
def _memory_from_external(self, item: dict, ctx: AccessContext, session_id: str) -> MemoryRecord | None:
|
|
if not isinstance(item, dict):
|
|
return None
|
|
data = dict(item)
|
|
data.setdefault("user_id", ctx.user_id)
|
|
data.setdefault("agent_id", ctx.agent_id)
|
|
data.setdefault("workspace_id", ctx.workspace_id)
|
|
data.setdefault("session_id", session_id)
|
|
data.setdefault("namespace", default_namespace_for_context(ctx, Visibility.PRIVATE))
|
|
data.setdefault("memory_type", MemoryType.SUMMARY.value)
|
|
data.setdefault("content", data.get("text") or data.get("summary") or "")
|
|
data.setdefault("summary", data.get("content", "")[:180])
|
|
data.setdefault("tags", ["evermemos-external"])
|
|
data.setdefault("importance", 0.7)
|
|
data.setdefault("confidence", 0.65)
|
|
data.setdefault("visibility", Visibility.PRIVATE.value)
|
|
data.setdefault("source", SourceType.EVERMEMOS.value)
|
|
if not data["content"]:
|
|
return None
|
|
return MemoryRecord.model_validate(data)
|
|
|
|
def get_profile(self, user_id: str) -> ProfileRecord:
|
|
profile = self.repo.get_profile(user_id)
|
|
if not profile:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Profile not found")
|
|
return profile
|
|
|
|
def add_feedback(self, memory_id: str, request: MemoryFeedbackRequest) -> dict:
|
|
ctx = AccessContext(**request.model_dump(include={"user_id", "agent_id", "workspace_id", "session_id"}))
|
|
memory = self.get_memory(memory_id, ctx)
|
|
self._audit(
|
|
f"feedback:{request.feedback}",
|
|
"memory",
|
|
memory.id,
|
|
namespace=memory.namespace,
|
|
actor_user_id=request.user_id,
|
|
actor_agent_id=request.agent_id,
|
|
metadata={"comment": request.comment},
|
|
)
|
|
return {"status": "ok", "memory_id": memory_id, "feedback": request.feedback}
|
|
|
|
def list_namespaces(self, ctx: AccessContext) -> list[NamespaceInfo]:
|
|
return visible_namespaces(ctx)
|
|
|
|
def list_audit(self, limit: int = 100) -> list[AuditLog]:
|
|
return self.repo.list_audit(limit)
|
|
|
|
def _score(self, memory: MemoryRecord, query: str) -> float:
|
|
lexical = 1.0 if query and query in memory.content.lower() else 0.2
|
|
return lexical + memory.importance + memory.confidence
|
|
|
|
def _audit(
|
|
self,
|
|
action: str,
|
|
target_type: str,
|
|
target_id: str | None,
|
|
namespace: str | None = None,
|
|
actor_user_id: str | None = None,
|
|
actor_agent_id: str | None = None,
|
|
decision: str = "allow",
|
|
metadata: dict | None = None,
|
|
) -> None:
|
|
self.repo.add_audit(
|
|
AuditLog(
|
|
actor_user_id=actor_user_id,
|
|
actor_agent_id=actor_agent_id,
|
|
action=action,
|
|
target_type=target_type,
|
|
target_id=target_id,
|
|
namespace=namespace,
|
|
decision=decision, # type: ignore[arg-type]
|
|
metadata=metadata or {},
|
|
)
|
|
)
|
|
|
|
|
|
service = MemoryGatewayService()
|