from __future__ import annotations import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from .db import connect def utc_now() -> str: return datetime.now(timezone.utc).isoformat() def _row_to_dict(row: Any | None) -> dict[str, Any] | None: if row is None: return None return dict(row) class MemoryRepository: def __init__(self, db_path: Path) -> None: self.db_path = db_path def create_user(self, user_id: str, user_key: str) -> dict[str, Any]: existing = self.get_user(user_id) if existing is not None: return existing now = utc_now() with connect(self.db_path) as conn: conn.execute( """ INSERT INTO users (id, user_key, created_at, updated_at) VALUES (?, ?, ?, ?) """, (user_id, user_key, now, now), ) conn.commit() user = self.get_user(user_id) if user is None: raise RuntimeError("created user could not be read back") return user def get_user(self, user_id: str) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( "SELECT * FROM users WHERE id = ?", (user_id,), ).fetchone() return _row_to_dict(row) def create_resource(self, **values: Any) -> dict[str, Any]: now = utc_now() payload = { "created_at": now, "updated_at": now, "deleted_at": None, **values, } columns = ", ".join(payload) placeholders = ", ".join(f":{key}" for key in payload) with connect(self.db_path) as conn: conn.execute( f"INSERT INTO user_resources ({columns}) VALUES ({placeholders})", payload, ) conn.commit() resource = self.get_resource(str(payload["id"])) if resource is None: raise RuntimeError("created resource could not be read back") return resource def update_resource_status( self, resource_id: str, status: str, error_message: str | None = None, ) -> dict[str, Any] | None: with connect(self.db_path) as conn: conn.execute( """ UPDATE user_resources SET status = ?, error_message = ?, updated_at = ? WHERE id = ? """, (status, error_message, utc_now(), resource_id), ) conn.commit() return self.get_resource(resource_id) def soft_delete_resource( self, resource_id: str, user_id: str | None = None, ) -> dict[str, Any] | None: now = utc_now() where = "id = ? AND deleted_at IS NULL" params: tuple[Any, ...] = (now, now, resource_id) attachment_where = "resource_id = ? AND deleted_at IS NULL" attachment_params: tuple[Any, ...] = (now, resource_id) if user_id is not None: where += " AND user_id = ?" params = (now, now, resource_id, user_id) attachment_where += " AND user_id = ?" attachment_params = (now, resource_id, user_id) with connect(self.db_path) as conn: conn.execute( f""" UPDATE user_resources SET status = 'deleted', deleted_at = ?, updated_at = ? WHERE {where} """, params, ) conn.execute( f""" UPDATE memory_attachments SET deleted_at = ? WHERE {attachment_where} """, attachment_params, ) conn.commit() return self.get_resource(resource_id) def get_resource(self, resource_id: str) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( "SELECT * FROM user_resources WHERE id = ?", (resource_id,), ).fetchone() return _row_to_dict(row) def get_resource_for_user( self, resource_id: str, user_id: str, ) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( """ SELECT * FROM user_resources WHERE id = ? AND user_id = ? AND deleted_at IS NULL """, (resource_id, user_id), ).fetchone() return _row_to_dict(row) def get_resource_by_session(self, session_id: str) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( "SELECT * FROM user_resources WHERE session_id = ?", (session_id,), ).fetchone() return _row_to_dict(row) def get_resource_by_session_for_user( self, session_id: str, user_id: str, ) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( """ SELECT * FROM user_resources WHERE session_id = ? AND user_id = ? AND deleted_at IS NULL """, (session_id, user_id), ).fetchone() return _row_to_dict(row) def find_active_resource_by_sha256( self, *, user_id: str, app_id: str, project_id: str, sha256: str, ) -> dict[str, Any] | None: with connect(self.db_path) as conn: row = conn.execute( """ SELECT * FROM user_resources WHERE user_id = ? AND app_id = ? AND project_id = ? AND sha256 = ? AND deleted_at IS NULL AND status IN ('ingesting', 'extracted') ORDER BY created_at ASC LIMIT 1 """, (user_id, app_id, project_id, sha256), ).fetchone() return _row_to_dict(row) def list_resources(self, user_id: str) -> list[dict[str, Any]]: with connect(self.db_path) as conn: rows = conn.execute( """ SELECT * FROM user_resources WHERE user_id = ? AND deleted_at IS NULL ORDER BY created_at DESC """, (user_id,), ).fetchall() return [dict(row) for row in rows] def list_extracted_resources( self, user_id: str, app_id: str, project_id: str, ) -> list[dict[str, Any]]: with connect(self.db_path) as conn: rows = conn.execute( """ SELECT * FROM user_resources WHERE user_id = ? AND app_id = ? AND project_id = ? AND deleted_at IS NULL AND status = 'extracted' ORDER BY created_at DESC """, (user_id, app_id, project_id), ).fetchall() return [dict(row) for row in rows] def create_attachment(self, **values: Any) -> dict[str, Any]: attachment_id = str(values.get("id") or f"a_{uuid.uuid4().hex}") payload = { "id": attachment_id, "created_at": utc_now(), "deleted_at": None, **values, } with connect(self.db_path) as conn: conn.execute( """ INSERT OR IGNORE INTO memory_attachments ( id, user_id, app_id, project_id, session_id, resource_id, content_type, name, internal_uri, source, sha256, created_at, deleted_at ) VALUES ( :id, :user_id, :app_id, :project_id, :session_id, :resource_id, :content_type, :name, :internal_uri, :source, :sha256, :created_at, :deleted_at ) """, payload, ) row = conn.execute( """ SELECT * FROM memory_attachments WHERE user_id = ? AND session_id = ? AND internal_uri = ? """, ( payload["user_id"], payload["session_id"], payload["internal_uri"], ), ).fetchone() conn.commit() attachment = _row_to_dict(row) if attachment is None: raise RuntimeError("created attachment could not be read back") return attachment def list_attachments_for_session( self, user_id: str, session_id: str, ) -> list[dict[str, Any]]: with connect(self.db_path) as conn: rows = conn.execute( """ SELECT * FROM memory_attachments WHERE user_id = ? AND session_id = ? AND deleted_at IS NULL ORDER BY created_at ASC, id ASC """, (user_id, session_id), ).fetchall() return [dict(row) for row in rows] def add_tombstone( self, user_id: str, memory_id: str | None, session_id: str | None, reason: str | None, ) -> dict[str, Any]: tombstone_id = f"t_{uuid.uuid4().hex}" with connect(self.db_path) as conn: conn.execute( """ INSERT INTO memory_tombstones (id, user_id, memory_id, session_id, reason, created_at) VALUES (?, ?, ?, ?, ?, ?) """, (tombstone_id, user_id, memory_id, session_id, reason, utc_now()), ) conn.commit() return {"id": tombstone_id} def get_tombstones(self, user_id: str) -> list[dict[str, Any]]: with connect(self.db_path) as conn: rows = conn.execute( "SELECT * FROM memory_tombstones WHERE user_id = ?", (user_id,), ).fetchall() return [dict(row) for row in rows] def upsert_override( self, user_id: str, memory_id: str, session_id: str | None, override_text: str, ) -> dict[str, Any]: now = utc_now() with connect(self.db_path) as conn: row = conn.execute( """ SELECT id FROM memory_overrides WHERE user_id = ? AND memory_id = ? AND is_active = TRUE ORDER BY created_at DESC LIMIT 1 """, (user_id, memory_id), ).fetchone() if row: override_id = row["id"] conn.execute( """ UPDATE memory_overrides SET session_id = ?, override_text = ?, updated_at = ? WHERE id = ? """, (session_id, override_text, now, override_id), ) else: override_id = f"o_{uuid.uuid4().hex}" conn.execute( """ INSERT INTO memory_overrides ( id, user_id, memory_id, session_id, override_text, is_active, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, TRUE, ?, ?) """, ( override_id, user_id, memory_id, session_id, override_text, now, now, ), ) conn.commit() return {"id": override_id} def get_active_overrides(self, user_id: str) -> list[dict[str, Any]]: with connect(self.db_path) as conn: rows = conn.execute( """ SELECT * FROM memory_overrides WHERE user_id = ? AND is_active = TRUE """, (user_id,), ).fetchall() return [dict(row) for row in rows]