replace main with lightweight memory gateway
This commit is contained in:
282
core/repository.py
Normal file
282
core/repository.py
Normal file
@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .db import connect
|
||||
|
||||
|
||||
def utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _row_to_dict(row: Any | None) -> dict[str, Any] | None:
|
||||
if row is None:
|
||||
return None
|
||||
return dict(row)
|
||||
|
||||
|
||||
class MemoryRepository:
|
||||
def __init__(self, db_path: Path) -> None:
|
||||
self.db_path = db_path
|
||||
|
||||
def create_user(self, user_id: str, user_key: str) -> dict[str, Any]:
|
||||
existing = self.get_user(user_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
now = utc_now()
|
||||
with connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, user_key, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, user_key, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
user = self.get_user(user_id)
|
||||
if user is None:
|
||||
raise RuntimeError("created user could not be read back")
|
||||
return user
|
||||
|
||||
def get_user(self, user_id: str) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM users WHERE id = ?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def create_resource(self, **values: Any) -> dict[str, Any]:
|
||||
now = utc_now()
|
||||
payload = {
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
**values,
|
||||
}
|
||||
columns = ", ".join(payload)
|
||||
placeholders = ", ".join(f":{key}" for key in payload)
|
||||
with connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
f"INSERT INTO user_resources ({columns}) VALUES ({placeholders})",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
resource = self.get_resource(str(payload["id"]))
|
||||
if resource is None:
|
||||
raise RuntimeError("created resource could not be read back")
|
||||
return resource
|
||||
|
||||
def update_resource_status(
|
||||
self,
|
||||
resource_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE user_resources
|
||||
SET status = ?, error_message = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(status, error_message, utc_now(), resource_id),
|
||||
)
|
||||
conn.commit()
|
||||
return self.get_resource(resource_id)
|
||||
|
||||
def soft_delete_resource(
|
||||
self,
|
||||
resource_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
now = utc_now()
|
||||
where = "id = ? AND deleted_at IS NULL"
|
||||
params: tuple[Any, ...] = (now, now, resource_id)
|
||||
if user_id is not None:
|
||||
where += " AND user_id = ?"
|
||||
params = (now, now, resource_id, user_id)
|
||||
with connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
f"""
|
||||
UPDATE user_resources
|
||||
SET status = 'deleted', deleted_at = ?, updated_at = ?
|
||||
WHERE {where}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
conn.commit()
|
||||
return self.get_resource(resource_id)
|
||||
|
||||
def get_resource(self, resource_id: str) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM user_resources WHERE id = ?",
|
||||
(resource_id,),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def get_resource_for_user(
|
||||
self,
|
||||
resource_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM user_resources
|
||||
WHERE id = ? AND user_id = ? AND deleted_at IS NULL
|
||||
""",
|
||||
(resource_id, user_id),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def get_resource_by_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM user_resources WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def get_resource_by_session_for_user(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM user_resources
|
||||
WHERE session_id = ? AND user_id = ? AND deleted_at IS NULL
|
||||
""",
|
||||
(session_id, user_id),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
|
||||
with connect(self.db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM user_resources
|
||||
WHERE user_id = ? AND deleted_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def list_extracted_resources(
|
||||
self,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
with connect(self.db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM user_resources
|
||||
WHERE user_id = ?
|
||||
AND app_id = ?
|
||||
AND project_id = ?
|
||||
AND deleted_at IS NULL
|
||||
AND status = 'extracted'
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(user_id, app_id, project_id),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def add_tombstone(
|
||||
self,
|
||||
user_id: str,
|
||||
memory_id: str | None,
|
||||
session_id: str | None,
|
||||
reason: str | None,
|
||||
) -> dict[str, Any]:
|
||||
tombstone_id = f"t_{uuid.uuid4().hex}"
|
||||
with connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_tombstones
|
||||
(id, user_id, memory_id, session_id, reason, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(tombstone_id, user_id, memory_id, session_id, reason, utc_now()),
|
||||
)
|
||||
conn.commit()
|
||||
return {"id": tombstone_id}
|
||||
|
||||
def get_tombstones(self, user_id: str) -> list[dict[str, Any]]:
|
||||
with connect(self.db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM memory_tombstones WHERE user_id = ?",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def upsert_override(
|
||||
self,
|
||||
user_id: str,
|
||||
memory_id: str,
|
||||
session_id: str | None,
|
||||
override_text: str,
|
||||
) -> dict[str, Any]:
|
||||
now = utc_now()
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM memory_overrides
|
||||
WHERE user_id = ? AND memory_id = ? AND is_active = TRUE
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id, memory_id),
|
||||
).fetchone()
|
||||
if row:
|
||||
override_id = row["id"]
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE memory_overrides
|
||||
SET session_id = ?, override_text = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(session_id, override_text, now, override_id),
|
||||
)
|
||||
else:
|
||||
override_id = f"o_{uuid.uuid4().hex}"
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_overrides
|
||||
(
|
||||
id, user_id, memory_id, session_id, override_text,
|
||||
is_active, created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, TRUE, ?, ?)
|
||||
""",
|
||||
(
|
||||
override_id,
|
||||
user_id,
|
||||
memory_id,
|
||||
session_id,
|
||||
override_text,
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return {"id": override_id}
|
||||
|
||||
def get_active_overrides(self, user_id: str) -> list[dict[str, Any]]:
|
||||
with connect(self.db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM memory_overrides
|
||||
WHERE user_id = ? AND is_active = TRUE
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
Reference in New Issue
Block a user