Files
memory-gateway/core/repository.py

308 lines
9.7 KiB
Python

from __future__ import annotations
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from .db import connect
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _row_to_dict(row: Any | None) -> dict[str, Any] | None:
if row is None:
return None
return dict(row)
class MemoryRepository:
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
def create_user(self, user_id: str, user_key: str) -> dict[str, Any]:
existing = self.get_user(user_id)
if existing is not None:
return existing
now = utc_now()
with connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO users (id, user_key, created_at, updated_at)
VALUES (?, ?, ?, ?)
""",
(user_id, user_key, now, now),
)
conn.commit()
user = self.get_user(user_id)
if user is None:
raise RuntimeError("created user could not be read back")
return user
def get_user(self, user_id: str) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"SELECT * FROM users WHERE id = ?",
(user_id,),
).fetchone()
return _row_to_dict(row)
def create_resource(self, **values: Any) -> dict[str, Any]:
now = utc_now()
payload = {
"created_at": now,
"updated_at": now,
"deleted_at": None,
**values,
}
columns = ", ".join(payload)
placeholders = ", ".join(f":{key}" for key in payload)
with connect(self.db_path) as conn:
conn.execute(
f"INSERT INTO user_resources ({columns}) VALUES ({placeholders})",
payload,
)
conn.commit()
resource = self.get_resource(str(payload["id"]))
if resource is None:
raise RuntimeError("created resource could not be read back")
return resource
def update_resource_status(
self,
resource_id: str,
status: str,
error_message: str | None = None,
) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
conn.execute(
"""
UPDATE user_resources
SET status = ?, error_message = ?, updated_at = ?
WHERE id = ?
""",
(status, error_message, utc_now(), resource_id),
)
conn.commit()
return self.get_resource(resource_id)
def soft_delete_resource(
self,
resource_id: str,
user_id: str | None = None,
) -> dict[str, Any] | None:
now = utc_now()
where = "id = ? AND deleted_at IS NULL"
params: tuple[Any, ...] = (now, now, resource_id)
if user_id is not None:
where += " AND user_id = ?"
params = (now, now, resource_id, user_id)
with connect(self.db_path) as conn:
conn.execute(
f"""
UPDATE user_resources
SET status = 'deleted', deleted_at = ?, updated_at = ?
WHERE {where}
""",
params,
)
conn.commit()
return self.get_resource(resource_id)
def get_resource(self, resource_id: str) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"SELECT * FROM user_resources WHERE id = ?",
(resource_id,),
).fetchone()
return _row_to_dict(row)
def get_resource_for_user(
self,
resource_id: str,
user_id: str,
) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"""
SELECT * FROM user_resources
WHERE id = ? AND user_id = ? AND deleted_at IS NULL
""",
(resource_id, user_id),
).fetchone()
return _row_to_dict(row)
def get_resource_by_session(self, session_id: str) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"SELECT * FROM user_resources WHERE session_id = ?",
(session_id,),
).fetchone()
return _row_to_dict(row)
def get_resource_by_session_for_user(
self,
session_id: str,
user_id: str,
) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"""
SELECT * FROM user_resources
WHERE session_id = ? AND user_id = ? AND deleted_at IS NULL
""",
(session_id, user_id),
).fetchone()
return _row_to_dict(row)
def find_active_resource_by_sha256(
self,
*,
user_id: str,
app_id: str,
project_id: str,
sha256: str,
) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"""
SELECT * FROM user_resources
WHERE user_id = ?
AND app_id = ?
AND project_id = ?
AND sha256 = ?
AND deleted_at IS NULL
AND status IN ('ingesting', 'extracted')
ORDER BY created_at ASC
LIMIT 1
""",
(user_id, app_id, project_id, sha256),
).fetchone()
return _row_to_dict(row)
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(
"""
SELECT * FROM user_resources
WHERE user_id = ? AND deleted_at IS NULL
ORDER BY created_at DESC
""",
(user_id,),
).fetchall()
return [dict(row) for row in rows]
def list_extracted_resources(
self,
user_id: str,
app_id: str,
project_id: str,
) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(
"""
SELECT * FROM user_resources
WHERE user_id = ?
AND app_id = ?
AND project_id = ?
AND deleted_at IS NULL
AND status = 'extracted'
ORDER BY created_at DESC
""",
(user_id, app_id, project_id),
).fetchall()
return [dict(row) for row in rows]
def add_tombstone(
self,
user_id: str,
memory_id: str | None,
session_id: str | None,
reason: str | None,
) -> dict[str, Any]:
tombstone_id = f"t_{uuid.uuid4().hex}"
with connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO memory_tombstones
(id, user_id, memory_id, session_id, reason, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(tombstone_id, user_id, memory_id, session_id, reason, utc_now()),
)
conn.commit()
return {"id": tombstone_id}
def get_tombstones(self, user_id: str) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(
"SELECT * FROM memory_tombstones WHERE user_id = ?",
(user_id,),
).fetchall()
return [dict(row) for row in rows]
def upsert_override(
self,
user_id: str,
memory_id: str,
session_id: str | None,
override_text: str,
) -> dict[str, Any]:
now = utc_now()
with connect(self.db_path) as conn:
row = conn.execute(
"""
SELECT id FROM memory_overrides
WHERE user_id = ? AND memory_id = ? AND is_active = TRUE
ORDER BY created_at DESC
LIMIT 1
""",
(user_id, memory_id),
).fetchone()
if row:
override_id = row["id"]
conn.execute(
"""
UPDATE memory_overrides
SET session_id = ?, override_text = ?, updated_at = ?
WHERE id = ?
""",
(session_id, override_text, now, override_id),
)
else:
override_id = f"o_{uuid.uuid4().hex}"
conn.execute(
"""
INSERT INTO memory_overrides
(
id, user_id, memory_id, session_id, override_text,
is_active, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, TRUE, ?, ?)
""",
(
override_id,
user_id,
memory_id,
session_id,
override_text,
now,
now,
),
)
conn.commit()
return {"id": override_id}
def get_active_overrides(self, user_id: str) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(
"""
SELECT * FROM memory_overrides
WHERE user_id = ? AND is_active = TRUE
""",
(user_id,),
).fetchall()
return [dict(row) for row in rows]