"""Metadata repositories for Memory Gateway. SQLite is the default POC store. The in-memory implementation is retained for small isolated tests and for cases where persistence is explicitly disabled. """ from __future__ import annotations import json import sqlite3 from datetime import datetime, timezone from pathlib import Path from typing import Iterable, Optional, Protocol from .config import get_config from .schemas import AuditLog, EpisodeRecord, MemoryRecord, ProfileRecord, UserRecord class MetadataRepository(Protocol): def create_user(self, user: UserRecord) -> UserRecord: ... def get_user(self, user_id: str) -> Optional[UserRecord]: ... def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: ... def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: ... def delete_memory(self, memory_id: str) -> bool: ... def list_memories(self) -> Iterable[MemoryRecord]: ... def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: ... def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: ... def get_profile(self, user_id: str) -> Optional[ProfileRecord]: ... def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: ... def add_audit(self, audit: AuditLog) -> AuditLog: ... def list_audit(self, limit: int = 100) -> list[AuditLog]: ... def _json_dump_model(model) -> str: return json.dumps(model.model_dump(mode="json"), ensure_ascii=False) def _json_load_model(model_cls, payload: str): return model_cls.model_validate(json.loads(payload)) class InMemoryRepository: def __init__(self) -> None: self.users: dict[str, UserRecord] = {} self.memories: dict[str, MemoryRecord] = {} self.episodes: dict[str, EpisodeRecord] = {} self.profiles: dict[str, ProfileRecord] = {} self.audit_logs: list[AuditLog] = [] def create_user(self, user: UserRecord) -> UserRecord: now = datetime.now(timezone.utc) user.created_at = now user.updated_at = now self.users[user.id] = user self.profiles.setdefault( user.id, ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"), ) return user def get_user(self, user_id: str) -> Optional[UserRecord]: return self.users.get(user_id) def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: now = datetime.now(timezone.utc) existing = self.memories.get(memory.id) if existing: memory.version = existing.version + 1 memory.created_at = existing.created_at memory.updated_at = now self.memories[memory.id] = memory return memory def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: return self.memories.get(memory_id) def delete_memory(self, memory_id: str) -> bool: return self.memories.pop(memory_id, None) is not None def list_memories(self) -> Iterable[MemoryRecord]: return list(self.memories.values()) def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: self.episodes[episode.id] = episode return episode def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: return [episode for episode in self.episodes.values() if episode.session_id == session_id] def get_profile(self, user_id: str) -> Optional[ProfileRecord]: return self.profiles.get(user_id) def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: profile.updated_at = datetime.now(timezone.utc) profile.version += 1 self.profiles[profile.user_id] = profile return profile def add_audit(self, audit: AuditLog) -> AuditLog: self.audit_logs.append(audit) return audit def list_audit(self, limit: int = 100) -> list[AuditLog]: return self.audit_logs[-limit:] class SQLiteRepository: def __init__(self, db_path: str | Path) -> None: self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_schema() def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def _init_schema(self) -> None: with self._connect() as conn: conn.executescript( """ CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, payload TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS profiles ( user_id TEXT PRIMARY KEY, payload TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS memories ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, agent_id TEXT, workspace_id TEXT, session_id TEXT, namespace TEXT NOT NULL, memory_type TEXT NOT NULL, visibility TEXT NOT NULL, importance REAL NOT NULL, confidence REAL NOT NULL, expires_at TEXT, archived_at TEXT, payload TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id); CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace); CREATE INDEX IF NOT EXISTS idx_memories_workspace ON memories(workspace_id); CREATE TABLE IF NOT EXISTS episodes ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, agent_id TEXT, workspace_id TEXT, session_id TEXT NOT NULL, namespace TEXT NOT NULL, payload TEXT NOT NULL, created_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_episodes_session ON episodes(session_id); CREATE TABLE IF NOT EXISTS audit_logs ( id TEXT PRIMARY KEY, actor_user_id TEXT, actor_agent_id TEXT, action TEXT NOT NULL, target_type TEXT NOT NULL, target_id TEXT, namespace TEXT, decision TEXT NOT NULL, payload TEXT NOT NULL, created_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_logs(created_at); """ ) def create_user(self, user: UserRecord) -> UserRecord: now = datetime.now(timezone.utc) user.created_at = user.created_at or now user.updated_at = now with self._connect() as conn: conn.execute( "INSERT OR REPLACE INTO users(id, payload, updated_at) VALUES (?, ?, ?)", (user.id, _json_dump_model(user), user.updated_at.isoformat()), ) self.upsert_profile(ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile")) return user def get_user(self, user_id: str) -> Optional[UserRecord]: with self._connect() as conn: row = conn.execute("SELECT payload FROM users WHERE id = ?", (user_id,)).fetchone() return _json_load_model(UserRecord, row["payload"]) if row else None def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: existing = self.get_memory(memory.id) now = datetime.now(timezone.utc) if existing: memory.version = existing.version + 1 memory.created_at = existing.created_at memory.updated_at = now with self._connect() as conn: conn.execute( """ INSERT OR REPLACE INTO memories( id, user_id, agent_id, workspace_id, session_id, namespace, memory_type, visibility, importance, confidence, expires_at, archived_at, payload, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( memory.id, memory.user_id, memory.agent_id, memory.workspace_id, memory.session_id, memory.namespace, memory.memory_type.value, memory.visibility.value, memory.importance, memory.confidence, memory.expires_at.isoformat() if memory.expires_at else None, memory.archived_at.isoformat() if memory.archived_at else None, _json_dump_model(memory), memory.updated_at.isoformat(), ), ) return memory def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: with self._connect() as conn: row = conn.execute("SELECT payload FROM memories WHERE id = ?", (memory_id,)).fetchone() return _json_load_model(MemoryRecord, row["payload"]) if row else None def delete_memory(self, memory_id: str) -> bool: with self._connect() as conn: cursor = conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) return cursor.rowcount > 0 def list_memories(self) -> Iterable[MemoryRecord]: with self._connect() as conn: rows = conn.execute("SELECT payload FROM memories").fetchall() return [_json_load_model(MemoryRecord, row["payload"]) for row in rows] def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: with self._connect() as conn: conn.execute( """ INSERT OR REPLACE INTO episodes( id, user_id, agent_id, workspace_id, session_id, namespace, payload, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( episode.id, episode.user_id, episode.agent_id, episode.workspace_id, episode.session_id, episode.namespace, _json_dump_model(episode), episode.created_at.isoformat(), ), ) return episode def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: with self._connect() as conn: rows = conn.execute( "SELECT payload FROM episodes WHERE session_id = ? ORDER BY created_at ASC", (session_id,), ).fetchall() return [_json_load_model(EpisodeRecord, row["payload"]) for row in rows] def get_profile(self, user_id: str) -> Optional[ProfileRecord]: with self._connect() as conn: row = conn.execute("SELECT payload FROM profiles WHERE user_id = ?", (user_id,)).fetchone() return _json_load_model(ProfileRecord, row["payload"]) if row else None def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: profile.updated_at = datetime.now(timezone.utc) with self._connect() as conn: conn.execute( "INSERT OR REPLACE INTO profiles(user_id, payload, updated_at) VALUES (?, ?, ?)", (profile.user_id, _json_dump_model(profile), profile.updated_at.isoformat()), ) return profile def add_audit(self, audit: AuditLog) -> AuditLog: with self._connect() as conn: conn.execute( """ INSERT OR REPLACE INTO audit_logs( id, actor_user_id, actor_agent_id, action, target_type, target_id, namespace, decision, payload, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( audit.id, audit.actor_user_id, audit.actor_agent_id, audit.action, audit.target_type, audit.target_id, audit.namespace, audit.decision, _json_dump_model(audit), audit.created_at.isoformat(), ), ) return audit def list_audit(self, limit: int = 100) -> list[AuditLog]: with self._connect() as conn: rows = conn.execute( "SELECT payload FROM audit_logs ORDER BY created_at DESC LIMIT ?", (limit,), ).fetchall() return [_json_load_model(AuditLog, row["payload"]) for row in rows] def build_repository() -> MetadataRepository: config = get_config() if config.storage.backend == "memory": return InMemoryRepository() return SQLiteRepository(config.storage.sqlite_path) repository = build_repository()