Refine memory system user-key flow and search output
This commit is contained in:
@ -2,21 +2,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import hmac
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
ADMIN_ACCOUNT_ID = "admin"
|
||||
ADMIN_USER_ID = "admin"
|
||||
|
||||
|
||||
class OpenVikingUserKeyStore:
|
||||
def __init__(self, sqlite_path: str) -> None:
|
||||
self.sqlite_path = sqlite_path
|
||||
self._ensure_table()
|
||||
|
||||
def get_account_key(self, account_id: str) -> str | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT account_key FROM memory_system_openviking_accounts WHERE account_id = ?",
|
||||
(account_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT user_key FROM memory_system_openviking_users
|
||||
WHERE account_id = ?
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(account_id,),
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
def account_key_matches(self, account_id: str, account_key: str) -> bool:
|
||||
expected = self.get_account_key(account_id)
|
||||
return bool(expected and hmac.compare_digest(expected, account_key))
|
||||
|
||||
def save_account_key(self, account_id: str, admin_user_id: str, account_key: str) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_system_openviking_accounts (account_id, admin_user_id, account_key, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(account_id) DO UPDATE SET
|
||||
admin_user_id = excluded.admin_user_id,
|
||||
account_key = excluded.account_key,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(account_id, admin_user_id, account_key, now, now),
|
||||
)
|
||||
|
||||
def get_user_key(self, user_id: str) -> str | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT user_key FROM memory_system_openviking_users WHERE user_id = ?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
row = conn.execute(
|
||||
"SELECT user_key FROM memory_system_openviking_users WHERE user_id = ?",
|
||||
(self._legacy_store_key(ADMIN_ACCOUNT_ID, user_id),),
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
def save_user_key(self, user_id: str, user_key: str) -> None:
|
||||
@ -30,13 +76,109 @@ class OpenVikingUserKeyStore:
|
||||
user_key = excluded.user_key,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(user_id, user_id, user_key, now, now),
|
||||
(user_id, ADMIN_ACCOUNT_ID, user_key, now, now),
|
||||
)
|
||||
|
||||
def user_key_matches(self, user_id: str, user_key: str) -> bool:
|
||||
expected = self.get_user_key(user_id)
|
||||
return bool(expected and hmac.compare_digest(expected, user_key))
|
||||
|
||||
def save_session(self, user_id: str, session_id: str) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_system_openviking_sessions
|
||||
(user_id, session_id, latest_task_id, latest_archive_uri, created_at, updated_at)
|
||||
VALUES (?, ?, NULL, NULL, ?, ?)
|
||||
ON CONFLICT(user_id, session_id) DO UPDATE SET
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(user_id, session_id, now, now),
|
||||
)
|
||||
|
||||
def get_session(self, user_id: str, session_id: str) -> dict[str, str | None] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT user_id, session_id, latest_task_id, latest_archive_uri
|
||||
FROM memory_system_openviking_sessions
|
||||
WHERE user_id = ? AND session_id = ?
|
||||
""",
|
||||
(user_id, session_id),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"user_id": str(row[0]),
|
||||
"session_id": str(row[1]),
|
||||
"latest_task_id": str(row[2]) if row[2] is not None else None,
|
||||
"latest_archive_uri": str(row[3]) if row[3] is not None else None,
|
||||
}
|
||||
|
||||
def save_task(self, user_id: str, session_id: str, task_id: str, archive_uri: str | None) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_system_openviking_tasks
|
||||
(task_id, user_id, session_id, archive_uri, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(task_id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
session_id = excluded.session_id,
|
||||
archive_uri = excluded.archive_uri,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(task_id, user_id, session_id, archive_uri, now, now),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_system_openviking_sessions
|
||||
(user_id, session_id, latest_task_id, latest_archive_uri, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(user_id, session_id) DO UPDATE SET
|
||||
latest_task_id = excluded.latest_task_id,
|
||||
latest_archive_uri = excluded.latest_archive_uri,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(user_id, session_id, task_id, archive_uri, now, now),
|
||||
)
|
||||
|
||||
def get_task(self, task_id: str) -> dict[str, str | None] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT task_id, user_id, session_id, archive_uri
|
||||
FROM memory_system_openviking_tasks
|
||||
WHERE task_id = ?
|
||||
""",
|
||||
(task_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"task_id": str(row[0]),
|
||||
"user_id": str(row[1]),
|
||||
"session_id": str(row[2]),
|
||||
"archive_uri": str(row[3]) if row[3] is not None else None,
|
||||
}
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
path = Path(self.sqlite_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_system_openviking_accounts (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
admin_user_id TEXT NOT NULL,
|
||||
account_key TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_system_openviking_users (
|
||||
@ -48,6 +190,34 @@ class OpenVikingUserKeyStore:
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_system_openviking_sessions (
|
||||
user_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
latest_task_id TEXT,
|
||||
latest_archive_uri TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, session_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_system_openviking_tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
archive_uri TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return sqlite3.connect(self.sqlite_path)
|
||||
|
||||
def _legacy_store_key(self, account_id: str, user_id: str) -> str:
|
||||
return f"{account_id}:{user_id}"
|
||||
|
||||
Reference in New Issue
Block a user