Add generic memory gateway v1
This commit is contained in:
328
memory_gateway/repositories.py
Normal file
328
memory_gateway/repositories.py
Normal file
@ -0,0 +1,328 @@
|
||||
"""Metadata repositories for Memory Gateway.
|
||||
|
||||
SQLite is the default POC store. The in-memory implementation is retained for
|
||||
small isolated tests and for cases where persistence is explicitly disabled.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Protocol
|
||||
|
||||
from .config import get_config
|
||||
from .schemas import AuditLog, EpisodeRecord, MemoryRecord, ProfileRecord, UserRecord
|
||||
|
||||
|
||||
class MetadataRepository(Protocol):
|
||||
def create_user(self, user: UserRecord) -> UserRecord: ...
|
||||
def get_user(self, user_id: str) -> Optional[UserRecord]: ...
|
||||
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord: ...
|
||||
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]: ...
|
||||
def delete_memory(self, memory_id: str) -> bool: ...
|
||||
def list_memories(self) -> Iterable[MemoryRecord]: ...
|
||||
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord: ...
|
||||
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]: ...
|
||||
def get_profile(self, user_id: str) -> Optional[ProfileRecord]: ...
|
||||
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord: ...
|
||||
def add_audit(self, audit: AuditLog) -> AuditLog: ...
|
||||
def list_audit(self, limit: int = 100) -> list[AuditLog]: ...
|
||||
|
||||
|
||||
def _json_dump_model(model) -> str:
|
||||
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False)
|
||||
|
||||
|
||||
def _json_load_model(model_cls, payload: str):
|
||||
return model_cls.model_validate(json.loads(payload))
|
||||
|
||||
|
||||
class InMemoryRepository:
|
||||
def __init__(self) -> None:
|
||||
self.users: dict[str, UserRecord] = {}
|
||||
self.memories: dict[str, MemoryRecord] = {}
|
||||
self.episodes: dict[str, EpisodeRecord] = {}
|
||||
self.profiles: dict[str, ProfileRecord] = {}
|
||||
self.audit_logs: list[AuditLog] = []
|
||||
|
||||
def create_user(self, user: UserRecord) -> UserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
user.created_at = now
|
||||
user.updated_at = now
|
||||
self.users[user.id] = user
|
||||
self.profiles.setdefault(
|
||||
user.id,
|
||||
ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"),
|
||||
)
|
||||
return user
|
||||
|
||||
def get_user(self, user_id: str) -> Optional[UserRecord]:
|
||||
return self.users.get(user_id)
|
||||
|
||||
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
existing = self.memories.get(memory.id)
|
||||
if existing:
|
||||
memory.version = existing.version + 1
|
||||
memory.created_at = existing.created_at
|
||||
memory.updated_at = now
|
||||
self.memories[memory.id] = memory
|
||||
return memory
|
||||
|
||||
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
|
||||
return self.memories.get(memory_id)
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
return self.memories.pop(memory_id, None) is not None
|
||||
|
||||
def list_memories(self) -> Iterable[MemoryRecord]:
|
||||
return list(self.memories.values())
|
||||
|
||||
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
|
||||
self.episodes[episode.id] = episode
|
||||
return episode
|
||||
|
||||
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
|
||||
return [episode for episode in self.episodes.values() if episode.session_id == session_id]
|
||||
|
||||
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
|
||||
return self.profiles.get(user_id)
|
||||
|
||||
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
|
||||
profile.updated_at = datetime.now(timezone.utc)
|
||||
profile.version += 1
|
||||
self.profiles[profile.user_id] = profile
|
||||
return profile
|
||||
|
||||
def add_audit(self, audit: AuditLog) -> AuditLog:
|
||||
self.audit_logs.append(audit)
|
||||
return audit
|
||||
|
||||
def list_audit(self, limit: int = 100) -> list[AuditLog]:
|
||||
return self.audit_logs[-limit:]
|
||||
|
||||
|
||||
class SQLiteRepository:
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_schema()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS profiles (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
workspace_id TEXT,
|
||||
session_id TEXT,
|
||||
namespace TEXT NOT NULL,
|
||||
memory_type TEXT NOT NULL,
|
||||
visibility TEXT NOT NULL,
|
||||
importance REAL NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
expires_at TEXT,
|
||||
archived_at TEXT,
|
||||
payload TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_workspace ON memories(workspace_id);
|
||||
CREATE TABLE IF NOT EXISTS episodes (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
workspace_id TEXT,
|
||||
session_id TEXT NOT NULL,
|
||||
namespace TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_episodes_session ON episodes(session_id);
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
actor_user_id TEXT,
|
||||
actor_agent_id TEXT,
|
||||
action TEXT NOT NULL,
|
||||
target_type TEXT NOT NULL,
|
||||
target_id TEXT,
|
||||
namespace TEXT,
|
||||
decision TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_logs(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
def create_user(self, user: UserRecord) -> UserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
user.created_at = user.created_at or now
|
||||
user.updated_at = now
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO users(id, payload, updated_at) VALUES (?, ?, ?)",
|
||||
(user.id, _json_dump_model(user), user.updated_at.isoformat()),
|
||||
)
|
||||
self.upsert_profile(ProfileRecord(user_id=user.id, namespace=user.profile_namespace or f"user/{user.id}/profile"))
|
||||
return user
|
||||
|
||||
def get_user(self, user_id: str) -> Optional[UserRecord]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT payload FROM users WHERE id = ?", (user_id,)).fetchone()
|
||||
return _json_load_model(UserRecord, row["payload"]) if row else None
|
||||
|
||||
def upsert_memory(self, memory: MemoryRecord) -> MemoryRecord:
|
||||
existing = self.get_memory(memory.id)
|
||||
now = datetime.now(timezone.utc)
|
||||
if existing:
|
||||
memory.version = existing.version + 1
|
||||
memory.created_at = existing.created_at
|
||||
memory.updated_at = now
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO memories(
|
||||
id, user_id, agent_id, workspace_id, session_id, namespace,
|
||||
memory_type, visibility, importance, confidence, expires_at,
|
||||
archived_at, payload, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
memory.id,
|
||||
memory.user_id,
|
||||
memory.agent_id,
|
||||
memory.workspace_id,
|
||||
memory.session_id,
|
||||
memory.namespace,
|
||||
memory.memory_type.value,
|
||||
memory.visibility.value,
|
||||
memory.importance,
|
||||
memory.confidence,
|
||||
memory.expires_at.isoformat() if memory.expires_at else None,
|
||||
memory.archived_at.isoformat() if memory.archived_at else None,
|
||||
_json_dump_model(memory),
|
||||
memory.updated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
return memory
|
||||
|
||||
def get_memory(self, memory_id: str) -> Optional[MemoryRecord]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT payload FROM memories WHERE id = ?", (memory_id,)).fetchone()
|
||||
return _json_load_model(MemoryRecord, row["payload"]) if row else None
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
with self._connect() as conn:
|
||||
cursor = conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def list_memories(self) -> Iterable[MemoryRecord]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute("SELECT payload FROM memories").fetchall()
|
||||
return [_json_load_model(MemoryRecord, row["payload"]) for row in rows]
|
||||
|
||||
def append_episode(self, episode: EpisodeRecord) -> EpisodeRecord:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO episodes(
|
||||
id, user_id, agent_id, workspace_id, session_id, namespace, payload, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
episode.id,
|
||||
episode.user_id,
|
||||
episode.agent_id,
|
||||
episode.workspace_id,
|
||||
episode.session_id,
|
||||
episode.namespace,
|
||||
_json_dump_model(episode),
|
||||
episode.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
return episode
|
||||
|
||||
def list_session_episodes(self, session_id: str) -> list[EpisodeRecord]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT payload FROM episodes WHERE session_id = ? ORDER BY created_at ASC",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
return [_json_load_model(EpisodeRecord, row["payload"]) for row in rows]
|
||||
|
||||
def get_profile(self, user_id: str) -> Optional[ProfileRecord]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT payload FROM profiles WHERE user_id = ?", (user_id,)).fetchone()
|
||||
return _json_load_model(ProfileRecord, row["payload"]) if row else None
|
||||
|
||||
def upsert_profile(self, profile: ProfileRecord) -> ProfileRecord:
|
||||
profile.updated_at = datetime.now(timezone.utc)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO profiles(user_id, payload, updated_at) VALUES (?, ?, ?)",
|
||||
(profile.user_id, _json_dump_model(profile), profile.updated_at.isoformat()),
|
||||
)
|
||||
return profile
|
||||
|
||||
def add_audit(self, audit: AuditLog) -> AuditLog:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO audit_logs(
|
||||
id, actor_user_id, actor_agent_id, action, target_type, target_id,
|
||||
namespace, decision, payload, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
audit.id,
|
||||
audit.actor_user_id,
|
||||
audit.actor_agent_id,
|
||||
audit.action,
|
||||
audit.target_type,
|
||||
audit.target_id,
|
||||
audit.namespace,
|
||||
audit.decision,
|
||||
_json_dump_model(audit),
|
||||
audit.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
return audit
|
||||
|
||||
def list_audit(self, limit: int = 100) -> list[AuditLog]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT payload FROM audit_logs ORDER BY created_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [_json_load_model(AuditLog, row["payload"]) for row in rows]
|
||||
|
||||
|
||||
def build_repository() -> MetadataRepository:
|
||||
config = get_config()
|
||||
if config.storage.backend == "memory":
|
||||
return InMemoryRepository()
|
||||
return SQLiteRepository(config.storage.sqlite_path)
|
||||
|
||||
|
||||
repository = build_repository()
|
||||
|
||||
Reference in New Issue
Block a user