Files
memory-gateway/memory_gateway/repositories.py
2026-05-05 16:18:31 +08:00

329 lines
13 KiB
Python

"""Metadata repositories for Memory Gateway.
SQLite is the default POC store. The in-memory implementation is retained for
small isolated tests and for cases where persistence is explicitly disabled.
"""
from __future__ import annotations
import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable, Optional, Protocol
from .config import get_config
from .schemas import AuditLog, EpisodeRecord, MemoryRecord, ProfileRecord, UserRecord
class MetadataRepository(Protocol):
def create_user(self, user: UserRecord) -> UserRecord: ...
def get_user(self, user_id: str) -> Optional[UserRecord]: ...
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: ...
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: ...
def delete_memory(self, memory_id: str) -> bool: ...
def list_memories(self) -> Iterable[MemoryRecord]: ...
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: ...
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: ...
def get_profile(self, user_id: str) -> Optional[ProfileRecord]: ...
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: ...
def add_audit(self, audit: AuditLog) -> AuditLog: ...
def list_audit(self, limit: int = 100) -> list[AuditLog]: ...
def _json_dump_model(model) -> str:
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False)
def _json_load_model(model_cls, payload: str):
return model_cls.model_validate(json.loads(payload))
class InMemoryRepository:
def __init__(self) -> None:
self.users: dict[str, UserRecord] = {}
self.memories: dict[str, MemoryRecord] = {}
self.episodes: dict[str, EpisodeRecord] = {}
self.profiles: dict[str, ProfileRecord] = {}
self.audit_logs: list[AuditLog] = []
def create_user(self, user: UserRecord) -> UserRecord:
now = datetime.now(timezone.utc)
user.created_at = now
user.updated_at = now
self.users[user.id] = user
self.profiles.setdefault(
user.id,
ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"),
)
return user
def get_user(self, user_id: str) -> Optional[UserRecord]:
return self.users.get(user_id)
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
now = datetime.now(timezone.utc)
existing = self.memories.get(memory.id)
if existing:
memory.version = existing.version + 1
memory.created_at = existing.created_at
memory.updated_at = now
self.memories[memory.id] = memory
return memory
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
return self.memories.get(memory_id)
def delete_memory(self, memory_id: str) -> bool:
return self.memories.pop(memory_id, None) is not None
def list_memories(self) -> Iterable[MemoryRecord]:
return list(self.memories.values())
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
self.episodes[episode.id] = episode
return episode
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
return [episode for episode in self.episodes.values() if episode.session_id == session_id]
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
return self.profiles.get(user_id)
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
profile.updated_at = datetime.now(timezone.utc)
profile.version += 1
self.profiles[profile.user_id] = profile
return profile
def add_audit(self, audit: AuditLog) -> AuditLog:
self.audit_logs.append(audit)
return audit
def list_audit(self, limit: int = 100) -> list[AuditLog]:
return self.audit_logs[-limit:]
class SQLiteRepository:
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_schema()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _init_schema(self) -> None:
with self._connect() as conn:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
payload TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS profiles (
user_id TEXT PRIMARY KEY,
payload TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
agent_id TEXT,
workspace_id TEXT,
session_id TEXT,
namespace TEXT NOT NULL,
memory_type TEXT NOT NULL,
visibility TEXT NOT NULL,
importance REAL NOT NULL,
confidence REAL NOT NULL,
expires_at TEXT,
archived_at TEXT,
payload TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id);
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);
CREATE INDEX IF NOT EXISTS idx_memories_workspace ON memories(workspace_id);
CREATE TABLE IF NOT EXISTS episodes (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
agent_id TEXT,
workspace_id TEXT,
session_id TEXT NOT NULL,
namespace TEXT NOT NULL,
payload TEXT NOT NULL,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_episodes_session ON episodes(session_id);
CREATE TABLE IF NOT EXISTS audit_logs (
id TEXT PRIMARY KEY,
actor_user_id TEXT,
actor_agent_id TEXT,
action TEXT NOT NULL,
target_type TEXT NOT NULL,
target_id TEXT,
namespace TEXT,
decision TEXT NOT NULL,
payload TEXT NOT NULL,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_logs(created_at);
"""
)
def create_user(self, user: UserRecord) -> UserRecord:
now = datetime.now(timezone.utc)
user.created_at = user.created_at or now
user.updated_at = now
with self._connect() as conn:
conn.execute(
"INSERT OR REPLACE INTO users(id, payload, updated_at) VALUES (?, ?, ?)",
(user.id, _json_dump_model(user), user.updated_at.isoformat()),
)
self.upsert_profile(ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"))
return user
def get_user(self, user_id: str) -> Optional[UserRecord]:
with self._connect() as conn:
row = conn.execute("SELECT payload FROM users WHERE id = ?", (user_id,)).fetchone()
return _json_load_model(UserRecord, row["payload"]) if row else None
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
existing = self.get_memory(memory.id)
now = datetime.now(timezone.utc)
if existing:
memory.version = existing.version + 1
memory.created_at = existing.created_at
memory.updated_at = now
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO memories(
id, user_id, agent_id, workspace_id, session_id, namespace,
memory_type, visibility, importance, confidence, expires_at,
archived_at, payload, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
memory.id,
memory.user_id,
memory.agent_id,
memory.workspace_id,
memory.session_id,
memory.namespace,
memory.memory_type.value,
memory.visibility.value,
memory.importance,
memory.confidence,
memory.expires_at.isoformat() if memory.expires_at else None,
memory.archived_at.isoformat() if memory.archived_at else None,
_json_dump_model(memory),
memory.updated_at.isoformat(),
),
)
return memory
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
with self._connect() as conn:
row = conn.execute("SELECT payload FROM memories WHERE id = ?", (memory_id,)).fetchone()
return _json_load_model(MemoryRecord, row["payload"]) if row else None
def delete_memory(self, memory_id: str) -> bool:
with self._connect() as conn:
cursor = conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
return cursor.rowcount > 0
def list_memories(self) -> Iterable[MemoryRecord]:
with self._connect() as conn:
rows = conn.execute("SELECT payload FROM memories").fetchall()
return [_json_load_model(MemoryRecord, row["payload"]) for row in rows]
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO episodes(
id, user_id, agent_id, workspace_id, session_id, namespace, payload, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
episode.id,
episode.user_id,
episode.agent_id,
episode.workspace_id,
episode.session_id,
episode.namespace,
_json_dump_model(episode),
episode.created_at.isoformat(),
),
)
return episode
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
with self._connect() as conn:
rows = conn.execute(
"SELECT payload FROM episodes WHERE session_id = ? ORDER BY created_at ASC",
(session_id,),
).fetchall()
return [_json_load_model(EpisodeRecord, row["payload"]) for row in rows]
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
with self._connect() as conn:
row = conn.execute("SELECT payload FROM profiles WHERE user_id = ?", (user_id,)).fetchone()
return _json_load_model(ProfileRecord, row["payload"]) if row else None
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
profile.updated_at = datetime.now(timezone.utc)
with self._connect() as conn:
conn.execute(
"INSERT OR REPLACE INTO profiles(user_id, payload, updated_at) VALUES (?, ?, ?)",
(profile.user_id, _json_dump_model(profile), profile.updated_at.isoformat()),
)
return profile
def add_audit(self, audit: AuditLog) -> AuditLog:
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO audit_logs(
id, actor_user_id, actor_agent_id, action, target_type, target_id,
namespace, decision, payload, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
audit.id,
audit.actor_user_id,
audit.actor_agent_id,
audit.action,
audit.target_type,
audit.target_id,
audit.namespace,
audit.decision,
_json_dump_model(audit),
audit.created_at.isoformat(),
),
)
return audit
def list_audit(self, limit: int = 100) -> list[AuditLog]:
with self._connect() as conn:
rows = conn.execute(
"SELECT payload FROM audit_logs ORDER BY created_at DESC LIMIT ?",
(limit,),
).fetchall()
return [_json_load_model(AuditLog, row["payload"]) for row in rows]
def build_repository() -> MetadataRepository:
config = get_config()
if config.storage.backend == "memory":
return InMemoryRepository()
return SQLiteRepository(config.storage.sqlite_path)
repository = build_repository()