376 lines
12 KiB
Python
376 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from .db import connect
|
|
|
|
|
|
def utc_now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _row_to_dict(row: Any | None) -> dict[str, Any] | None:
|
|
if row is None:
|
|
return None
|
|
return dict(row)
|
|
|
|
|
|
class MemoryRepository:
|
|
def __init__(self, db_path: Path) -> None:
|
|
self.db_path = db_path
|
|
|
|
def create_user(self, user_id: str, user_key: str) -> dict[str, Any]:
|
|
existing = self.get_user(user_id)
|
|
if existing is not None:
|
|
return existing
|
|
now = utc_now()
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO users (id, user_key, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(user_id, user_key, now, now),
|
|
)
|
|
conn.commit()
|
|
user = self.get_user(user_id)
|
|
if user is None:
|
|
raise RuntimeError("created user could not be read back")
|
|
return user
|
|
|
|
def get_user(self, user_id: str) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"SELECT * FROM users WHERE id = ?",
|
|
(user_id,),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def create_resource(self, **values: Any) -> dict[str, Any]:
|
|
now = utc_now()
|
|
payload = {
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"deleted_at": None,
|
|
**values,
|
|
}
|
|
columns = ", ".join(payload)
|
|
placeholders = ", ".join(f":{key}" for key in payload)
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
f"INSERT INTO user_resources ({columns}) VALUES ({placeholders})",
|
|
payload,
|
|
)
|
|
conn.commit()
|
|
resource = self.get_resource(str(payload["id"]))
|
|
if resource is None:
|
|
raise RuntimeError("created resource could not be read back")
|
|
return resource
|
|
|
|
def update_resource_status(
|
|
self,
|
|
resource_id: str,
|
|
status: str,
|
|
error_message: str | None = None,
|
|
) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE user_resources
|
|
SET status = ?, error_message = ?, updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(status, error_message, utc_now(), resource_id),
|
|
)
|
|
conn.commit()
|
|
return self.get_resource(resource_id)
|
|
|
|
def soft_delete_resource(
|
|
self,
|
|
resource_id: str,
|
|
user_id: str | None = None,
|
|
) -> dict[str, Any] | None:
|
|
now = utc_now()
|
|
where = "id = ? AND deleted_at IS NULL"
|
|
params: tuple[Any, ...] = (now, now, resource_id)
|
|
attachment_where = "resource_id = ? AND deleted_at IS NULL"
|
|
attachment_params: tuple[Any, ...] = (now, resource_id)
|
|
if user_id is not None:
|
|
where += " AND user_id = ?"
|
|
params = (now, now, resource_id, user_id)
|
|
attachment_where += " AND user_id = ?"
|
|
attachment_params = (now, resource_id, user_id)
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
f"""
|
|
UPDATE user_resources
|
|
SET status = 'deleted', deleted_at = ?, updated_at = ?
|
|
WHERE {where}
|
|
""",
|
|
params,
|
|
)
|
|
conn.execute(
|
|
f"""
|
|
UPDATE memory_attachments
|
|
SET deleted_at = ?
|
|
WHERE {attachment_where}
|
|
""",
|
|
attachment_params,
|
|
)
|
|
conn.commit()
|
|
return self.get_resource(resource_id)
|
|
|
|
def get_resource(self, resource_id: str) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"SELECT * FROM user_resources WHERE id = ?",
|
|
(resource_id,),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def get_resource_for_user(
|
|
self,
|
|
resource_id: str,
|
|
user_id: str,
|
|
) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM user_resources
|
|
WHERE id = ? AND user_id = ? AND deleted_at IS NULL
|
|
""",
|
|
(resource_id, user_id),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def get_resource_by_session(self, session_id: str) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"SELECT * FROM user_resources WHERE session_id = ?",
|
|
(session_id,),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def get_resource_by_session_for_user(
|
|
self,
|
|
session_id: str,
|
|
user_id: str,
|
|
) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM user_resources
|
|
WHERE session_id = ? AND user_id = ? AND deleted_at IS NULL
|
|
""",
|
|
(session_id, user_id),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def find_active_resource_by_sha256(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
sha256: str,
|
|
) -> dict[str, Any] | None:
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM user_resources
|
|
WHERE user_id = ?
|
|
AND app_id = ?
|
|
AND project_id = ?
|
|
AND sha256 = ?
|
|
AND deleted_at IS NULL
|
|
AND status IN ('ingesting', 'extracted')
|
|
ORDER BY created_at ASC
|
|
LIMIT 1
|
|
""",
|
|
(user_id, app_id, project_id, sha256),
|
|
).fetchone()
|
|
return _row_to_dict(row)
|
|
|
|
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
|
|
with connect(self.db_path) as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM user_resources
|
|
WHERE user_id = ? AND deleted_at IS NULL
|
|
ORDER BY created_at DESC
|
|
""",
|
|
(user_id,),
|
|
).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
def list_extracted_resources(
|
|
self,
|
|
user_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
) -> list[dict[str, Any]]:
|
|
with connect(self.db_path) as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM user_resources
|
|
WHERE user_id = ?
|
|
AND app_id = ?
|
|
AND project_id = ?
|
|
AND deleted_at IS NULL
|
|
AND status = 'extracted'
|
|
ORDER BY created_at DESC
|
|
""",
|
|
(user_id, app_id, project_id),
|
|
).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
def create_attachment(self, **values: Any) -> dict[str, Any]:
|
|
attachment_id = str(values.get("id") or f"a_{uuid.uuid4().hex}")
|
|
payload = {
|
|
"id": attachment_id,
|
|
"created_at": utc_now(),
|
|
"deleted_at": None,
|
|
**values,
|
|
}
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO memory_attachments (
|
|
id, user_id, app_id, project_id, session_id, resource_id,
|
|
content_type, name, internal_uri, source, sha256,
|
|
created_at, deleted_at
|
|
) VALUES (
|
|
:id, :user_id, :app_id, :project_id, :session_id, :resource_id,
|
|
:content_type, :name, :internal_uri, :source, :sha256,
|
|
:created_at, :deleted_at
|
|
)
|
|
""",
|
|
payload,
|
|
)
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM memory_attachments
|
|
WHERE user_id = ? AND session_id = ? AND internal_uri = ?
|
|
""",
|
|
(
|
|
payload["user_id"],
|
|
payload["session_id"],
|
|
payload["internal_uri"],
|
|
),
|
|
).fetchone()
|
|
conn.commit()
|
|
attachment = _row_to_dict(row)
|
|
if attachment is None:
|
|
raise RuntimeError("created attachment could not be read back")
|
|
return attachment
|
|
|
|
def list_attachments_for_session(
|
|
self,
|
|
user_id: str,
|
|
session_id: str,
|
|
) -> list[dict[str, Any]]:
|
|
with connect(self.db_path) as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM memory_attachments
|
|
WHERE user_id = ? AND session_id = ? AND deleted_at IS NULL
|
|
ORDER BY created_at ASC, id ASC
|
|
""",
|
|
(user_id, session_id),
|
|
).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
def add_tombstone(
|
|
self,
|
|
user_id: str,
|
|
memory_id: str | None,
|
|
session_id: str | None,
|
|
reason: str | None,
|
|
) -> dict[str, Any]:
|
|
tombstone_id = f"t_{uuid.uuid4().hex}"
|
|
with connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO memory_tombstones
|
|
(id, user_id, memory_id, session_id, reason, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(tombstone_id, user_id, memory_id, session_id, reason, utc_now()),
|
|
)
|
|
conn.commit()
|
|
return {"id": tombstone_id}
|
|
|
|
def get_tombstones(self, user_id: str) -> list[dict[str, Any]]:
|
|
with connect(self.db_path) as conn:
|
|
rows = conn.execute(
|
|
"SELECT * FROM memory_tombstones WHERE user_id = ?",
|
|
(user_id,),
|
|
).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
def upsert_override(
|
|
self,
|
|
user_id: str,
|
|
memory_id: str,
|
|
session_id: str | None,
|
|
override_text: str,
|
|
) -> dict[str, Any]:
|
|
now = utc_now()
|
|
with connect(self.db_path) as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT id FROM memory_overrides
|
|
WHERE user_id = ? AND memory_id = ? AND is_active = TRUE
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id, memory_id),
|
|
).fetchone()
|
|
if row:
|
|
override_id = row["id"]
|
|
conn.execute(
|
|
"""
|
|
UPDATE memory_overrides
|
|
SET session_id = ?, override_text = ?, updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(session_id, override_text, now, override_id),
|
|
)
|
|
else:
|
|
override_id = f"o_{uuid.uuid4().hex}"
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO memory_overrides
|
|
(
|
|
id, user_id, memory_id, session_id, override_text,
|
|
is_active, created_at, updated_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, TRUE, ?, ?)
|
|
""",
|
|
(
|
|
override_id,
|
|
user_id,
|
|
memory_id,
|
|
session_id,
|
|
override_text,
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
conn.commit()
|
|
return {"id": override_id}
|
|
|
|
def get_active_overrides(self, user_id: str) -> list[dict[str, Any]]:
|
|
with connect(self.db_path) as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM memory_overrides
|
|
WHERE user_id = ? AND is_active = TRUE
|
|
""",
|
|
(user_id,),
|
|
).fetchall()
|
|
return [dict(row) for row in rows]
|