329 lines
13 KiB
Python
329 lines
13 KiB
Python
"""Metadata repositories for Memory Gateway.
|
|
|
|
SQLite is the default POC store. The in-memory implementation is retained for
|
|
small isolated tests and for cases where persistence is explicitly disabled.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Iterable, Optional, Protocol
|
|
|
|
from .config import get_config
|
|
from .schemas import AuditLog, EpisodeRecord, MemoryRecord, ProfileRecord, UserRecord
|
|
|
|
|
|
class MetadataRepository(Protocol):
|
|
def create_user(self, user: UserRecord) -> UserRecord: ...
|
|
def get_user(self, user_id: str) -> Optional[UserRecord]: ...
|
|
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: ...
|
|
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: ...
|
|
def delete_memory(self, memory_id: str) -> bool: ...
|
|
def list_memories(self) -> Iterable[MemoryRecord]: ...
|
|
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: ...
|
|
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: ...
|
|
def get_profile(self, user_id: str) -> Optional[ProfileRecord]: ...
|
|
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: ...
|
|
def add_audit(self, audit: AuditLog) -> AuditLog: ...
|
|
def list_audit(self, limit: int = 100) -> list[AuditLog]: ...
|
|
|
|
|
|
def _json_dump_model(model) -> str:
|
|
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False)
|
|
|
|
|
|
def _json_load_model(model_cls, payload: str):
|
|
return model_cls.model_validate(json.loads(payload))
|
|
|
|
|
|
class InMemoryRepository:
|
|
def __init__(self) -> None:
|
|
self.users: dict[str, UserRecord] = {}
|
|
self.memories: dict[str, MemoryRecord] = {}
|
|
self.episodes: dict[str, EpisodeRecord] = {}
|
|
self.profiles: dict[str, ProfileRecord] = {}
|
|
self.audit_logs: list[AuditLog] = []
|
|
|
|
def create_user(self, user: UserRecord) -> UserRecord:
|
|
now = datetime.now(timezone.utc)
|
|
user.created_at = now
|
|
user.updated_at = now
|
|
self.users[user.id] = user
|
|
self.profiles.setdefault(
|
|
user.id,
|
|
ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"),
|
|
)
|
|
return user
|
|
|
|
def get_user(self, user_id: str) -> Optional[UserRecord]:
|
|
return self.users.get(user_id)
|
|
|
|
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
|
|
now = datetime.now(timezone.utc)
|
|
existing = self.memories.get(memory.id)
|
|
if existing:
|
|
memory.version = existing.version + 1
|
|
memory.created_at = existing.created_at
|
|
memory.updated_at = now
|
|
self.memories[memory.id] = memory
|
|
return memory
|
|
|
|
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
|
|
return self.memories.get(memory_id)
|
|
|
|
def delete_memory(self, memory_id: str) -> bool:
|
|
return self.memories.pop(memory_id, None) is not None
|
|
|
|
def list_memories(self) -> Iterable[MemoryRecord]:
|
|
return list(self.memories.values())
|
|
|
|
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
|
|
self.episodes[episode.id] = episode
|
|
return episode
|
|
|
|
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
|
|
return [episode for episode in self.episodes.values() if episode.session_id == session_id]
|
|
|
|
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
|
|
return self.profiles.get(user_id)
|
|
|
|
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
|
|
profile.updated_at = datetime.now(timezone.utc)
|
|
profile.version += 1
|
|
self.profiles[profile.user_id] = profile
|
|
return profile
|
|
|
|
def add_audit(self, audit: AuditLog) -> AuditLog:
|
|
self.audit_logs.append(audit)
|
|
return audit
|
|
|
|
def list_audit(self, limit: int = 100) -> list[AuditLog]:
|
|
return self.audit_logs[-limit:]
|
|
|
|
|
|
class SQLiteRepository:
|
|
def __init__(self, db_path: str | Path) -> None:
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_schema()
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _init_schema(self) -> None:
|
|
with self._connect() as conn:
|
|
conn.executescript(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
payload TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS profiles (
|
|
user_id TEXT PRIMARY KEY,
|
|
payload TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS memories (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
agent_id TEXT,
|
|
workspace_id TEXT,
|
|
session_id TEXT,
|
|
namespace TEXT NOT NULL,
|
|
memory_type TEXT NOT NULL,
|
|
visibility TEXT NOT NULL,
|
|
importance REAL NOT NULL,
|
|
confidence REAL NOT NULL,
|
|
expires_at TEXT,
|
|
archived_at TEXT,
|
|
payload TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id);
|
|
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);
|
|
CREATE INDEX IF NOT EXISTS idx_memories_workspace ON memories(workspace_id);
|
|
CREATE TABLE IF NOT EXISTS episodes (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
agent_id TEXT,
|
|
workspace_id TEXT,
|
|
session_id TEXT NOT NULL,
|
|
namespace TEXT NOT NULL,
|
|
payload TEXT NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_episodes_session ON episodes(session_id);
|
|
CREATE TABLE IF NOT EXISTS audit_logs (
|
|
id TEXT PRIMARY KEY,
|
|
actor_user_id TEXT,
|
|
actor_agent_id TEXT,
|
|
action TEXT NOT NULL,
|
|
target_type TEXT NOT NULL,
|
|
target_id TEXT,
|
|
namespace TEXT,
|
|
decision TEXT NOT NULL,
|
|
payload TEXT NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_logs(created_at);
|
|
"""
|
|
)
|
|
|
|
def create_user(self, user: UserRecord) -> UserRecord:
|
|
now = datetime.now(timezone.utc)
|
|
user.created_at = user.created_at or now
|
|
user.updated_at = now
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO users(id, payload, updated_at) VALUES (?, ?, ?)",
|
|
(user.id, _json_dump_model(user), user.updated_at.isoformat()),
|
|
)
|
|
self.upsert_profile(ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"))
|
|
return user
|
|
|
|
def get_user(self, user_id: str) -> Optional[UserRecord]:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT payload FROM users WHERE id = ?", (user_id,)).fetchone()
|
|
return _json_load_model(UserRecord, row["payload"]) if row else None
|
|
|
|
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
|
|
existing = self.get_memory(memory.id)
|
|
now = datetime.now(timezone.utc)
|
|
if existing:
|
|
memory.version = existing.version + 1
|
|
memory.created_at = existing.created_at
|
|
memory.updated_at = now
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO memories(
|
|
id, user_id, agent_id, workspace_id, session_id, namespace,
|
|
memory_type, visibility, importance, confidence, expires_at,
|
|
archived_at, payload, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
memory.id,
|
|
memory.user_id,
|
|
memory.agent_id,
|
|
memory.workspace_id,
|
|
memory.session_id,
|
|
memory.namespace,
|
|
memory.memory_type.value,
|
|
memory.visibility.value,
|
|
memory.importance,
|
|
memory.confidence,
|
|
memory.expires_at.isoformat() if memory.expires_at else None,
|
|
memory.archived_at.isoformat() if memory.archived_at else None,
|
|
_json_dump_model(memory),
|
|
memory.updated_at.isoformat(),
|
|
),
|
|
)
|
|
return memory
|
|
|
|
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT payload FROM memories WHERE id = ?", (memory_id,)).fetchone()
|
|
return _json_load_model(MemoryRecord, row["payload"]) if row else None
|
|
|
|
def delete_memory(self, memory_id: str) -> bool:
|
|
with self._connect() as conn:
|
|
cursor = conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
|
return cursor.rowcount > 0
|
|
|
|
def list_memories(self) -> Iterable[MemoryRecord]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute("SELECT payload FROM memories").fetchall()
|
|
return [_json_load_model(MemoryRecord, row["payload"]) for row in rows]
|
|
|
|
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO episodes(
|
|
id, user_id, agent_id, workspace_id, session_id, namespace, payload, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
episode.id,
|
|
episode.user_id,
|
|
episode.agent_id,
|
|
episode.workspace_id,
|
|
episode.session_id,
|
|
episode.namespace,
|
|
_json_dump_model(episode),
|
|
episode.created_at.isoformat(),
|
|
),
|
|
)
|
|
return episode
|
|
|
|
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
"SELECT payload FROM episodes WHERE session_id = ? ORDER BY created_at ASC",
|
|
(session_id,),
|
|
).fetchall()
|
|
return [_json_load_model(EpisodeRecord, row["payload"]) for row in rows]
|
|
|
|
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT payload FROM profiles WHERE user_id = ?", (user_id,)).fetchone()
|
|
return _json_load_model(ProfileRecord, row["payload"]) if row else None
|
|
|
|
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
|
|
profile.updated_at = datetime.now(timezone.utc)
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO profiles(user_id, payload, updated_at) VALUES (?, ?, ?)",
|
|
(profile.user_id, _json_dump_model(profile), profile.updated_at.isoformat()),
|
|
)
|
|
return profile
|
|
|
|
def add_audit(self, audit: AuditLog) -> AuditLog:
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO audit_logs(
|
|
id, actor_user_id, actor_agent_id, action, target_type, target_id,
|
|
namespace, decision, payload, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
audit.id,
|
|
audit.actor_user_id,
|
|
audit.actor_agent_id,
|
|
audit.action,
|
|
audit.target_type,
|
|
audit.target_id,
|
|
audit.namespace,
|
|
audit.decision,
|
|
_json_dump_model(audit),
|
|
audit.created_at.isoformat(),
|
|
),
|
|
)
|
|
return audit
|
|
|
|
def list_audit(self, limit: int = 100) -> list[AuditLog]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
"SELECT payload FROM audit_logs ORDER BY created_at DESC LIMIT ?",
|
|
(limit,),
|
|
).fetchall()
|
|
return [_json_load_model(AuditLog, row["payload"]) for row in rows]
|
|
|
|
|
|
def build_repository() -> MetadataRepository:
|
|
config = get_config()
|
|
if config.storage.backend == "memory":
|
|
return InMemoryRepository()
|
|
return SQLiteRepository(config.storage.sqlite_path)
|
|
|
|
|
|
repository = build_repository()
|
|
|