"""Beaver session 子系统的 SQLite 存储实现。 设计目标: 1. SQLite 作为统一 session/transcript backend 2. WAL 模式支持多读单写 3. FTS5 支持跨 session 文本检索 4. `parent_session_id` 支持 lineage 这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。 """ from __future__ import annotations import json import os import sqlite3 import threading import time from pathlib import Path from typing import Any, Callable, TypeVar from .models import MessageRecord, SessionRecord T = TypeVar("T") SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, source TEXT NOT NULL, user_id TEXT, title TEXT, model TEXT, system_prompt TEXT, parent_session_id TEXT, started_at REAL NOT NULL, last_active REAL NOT NULL, ended_at REAL, end_reason TEXT, message_count INTEGER DEFAULT 0, tool_call_count INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, cache_read_tokens INTEGER DEFAULT 0, cache_write_tokens INTEGER DEFAULT 0, reasoning_tokens INTEGER DEFAULT 0, estimated_cost_usd REAL DEFAULT 0, actual_cost_usd REAL, preview TEXT, FOREIGN KEY (parent_session_id) REFERENCES sessions(id) ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL REFERENCES sessions(id), run_id TEXT, role TEXT NOT NULL, event_type TEXT, event_payload TEXT, context_visible INTEGER NOT NULL DEFAULT 1, content TEXT, tool_name TEXT, tool_calls TEXT, tool_call_id TEXT, timestamp REAL NOT NULL, finish_reason TEXT, reasoning TEXT, reasoning_details TEXT, codex_reasoning_items TEXT ); CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC); CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC); CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id); CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id); CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id); """ FTS_TABLE_SQL = """ CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( content, content=messages, content_rowid=id ); """ FTS_TRIGGER_SQL = """ DROP TRIGGER IF EXISTS messages_fts_insert; DROP TRIGGER IF EXISTS messages_fts_delete; DROP TRIGGER IF EXISTS messages_fts_update; CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN INSERT INTO messages_fts(rowid, content) SELECT new.id, new.content WHERE new.context_visible = 1 AND new.content IS NOT NULL; END; CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN INSERT INTO messages_fts(messages_fts, rowid, content) SELECT 'delete', old.id, old.content WHERE old.context_visible = 1 AND old.content IS NOT NULL; END; CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN INSERT INTO messages_fts(messages_fts, rowid, content) SELECT 'delete', old.id, old.content WHERE old.context_visible = 1 AND old.content IS NOT NULL; INSERT INTO messages_fts(rowid, content) SELECT new.id, new.content WHERE new.context_visible = 1 AND new.content IS NOT NULL; END; """ def _sqlite_journal_mode() -> str: requested = os.getenv("BEAVER_SQLITE_JOURNAL_MODE", "DELETE").strip().upper() allowed = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "OFF", "WAL"} return requested if requested in allowed else "DELETE" class SessionStore: """SQLite-backed session store.""" def __init__(self, db_path: str | Path) -> None: self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._lock = threading.Lock() self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None) self._conn.row_factory = sqlite3.Row self._conn.execute("PRAGMA mmap_size=0") self._conn.execute("PRAGMA busy_timeout=5000") self._conn.execute(f"PRAGMA journal_mode={_sqlite_journal_mode()}") self._conn.execute("PRAGMA foreign_keys=ON") self._init_schema() def _init_schema(self) -> None: with self._lock: self._conn.executescript(SCHEMA_SQL) try: self._conn.execute("SELECT * FROM messages_fts LIMIT 0") self._conn.executescript(FTS_TRIGGER_SQL) except sqlite3.Error: self._rebuild_fts_index() return # 旧版本可能把 hidden 事件也写进了 FTS;初始化时顺手清掉这些噪声项。 try: self._conn.execute( """ INSERT INTO messages_fts(messages_fts, rowid, content) SELECT 'delete', id, content FROM messages WHERE context_visible = 0 AND content IS NOT NULL """ ) self._conn.commit() except sqlite3.Error: self._rebuild_fts_index() def _rebuild_fts_index(self) -> None: """Recreate the derived FTS index without touching canonical session rows.""" self._conn.executescript( """ DROP TRIGGER IF EXISTS messages_fts_insert; DROP TRIGGER IF EXISTS messages_fts_delete; DROP TRIGGER IF EXISTS messages_fts_update; DROP TABLE IF EXISTS messages_fts; """ ) self._conn.executescript(FTS_TABLE_SQL) self._conn.executescript(FTS_TRIGGER_SQL) self._conn.execute( """ INSERT INTO messages_fts(rowid, content) SELECT id, content FROM messages WHERE context_visible = 1 AND content IS NOT NULL """ ) self._conn.commit() def close(self) -> None: with self._lock: self._conn.close() def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T: with self._lock: self._conn.execute("BEGIN IMMEDIATE") try: result = fn(self._conn) self._conn.commit() return result except BaseException: self._conn.rollback() raise def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None: with self._lock: row = self._conn.execute(sql, params).fetchone() return dict(row) if row else None def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]: with self._lock: rows = self._conn.execute(sql, params).fetchall() return [dict(row) for row in rows] def ensure_session( self, session_id: str, *, source: str = "unknown", model: str | None = None, title: str | None = None, user_id: str | None = None, parent_session_id: str | None = None, ) -> str: """确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。""" now = time.time() def _do(conn: sqlite3.Connection) -> str: conn.execute( """ INSERT INTO sessions ( id, source, user_id, title, model, parent_session_id, started_at, last_active ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET source = CASE WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source ELSE sessions.source END, user_id = COALESCE(sessions.user_id, excluded.user_id), title = COALESCE(sessions.title, excluded.title), model = COALESCE(sessions.model, excluded.model), parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id) """, (session_id, source, user_id, title, model, parent_session_id, now, now), ) return session_id return self._execute_write(_do) def get_session_record(self, session_id: str) -> SessionRecord | None: row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,)) return SessionRecord.from_row(row) if row else None def update_system_prompt(self, session_id: str, system_prompt: str) -> None: """保存本 session 组装后的完整 system prompt snapshot。""" def _do(conn: sqlite3.Connection) -> None: conn.execute( """ UPDATE sessions SET system_prompt = ?, last_active = ? WHERE id = ? """, (system_prompt, time.time(), session_id), ) self._execute_write(_do) def update_usage( self, session_id: str, *, input_tokens: int = 0, output_tokens: int = 0, cache_read_tokens: int = 0, cache_write_tokens: int = 0, reasoning_tokens: int = 0, estimated_cost_usd: float = 0.0, actual_cost_usd: float | None = None, absolute: bool = False, ) -> None: """更新会话 usage。默认按增量累加。""" if absolute: sql = """ UPDATE sessions SET input_tokens = ?, output_tokens = ?, cache_read_tokens = ?, cache_write_tokens = ?, reasoning_tokens = ?, estimated_cost_usd = ?, actual_cost_usd = ?, last_active = ? WHERE id = ? """ params = ( input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, reasoning_tokens, estimated_cost_usd, actual_cost_usd, time.time(), session_id, ) else: sql = """ UPDATE sessions SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, cache_read_tokens = cache_read_tokens + ?, cache_write_tokens = cache_write_tokens + ?, reasoning_tokens = reasoning_tokens + ?, estimated_cost_usd = estimated_cost_usd + ?, actual_cost_usd = CASE WHEN ? IS NULL THEN actual_cost_usd ELSE COALESCE(actual_cost_usd, 0) + ? END, last_active = ? WHERE id = ? """ params = ( input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, reasoning_tokens, estimated_cost_usd, actual_cost_usd, actual_cost_usd, time.time(), session_id, ) def _do(conn: sqlite3.Connection) -> None: conn.execute(sql, params) self._execute_write(_do) def append_message( self, session_id: str, *, run_id: str | None = None, role: str, event_type: str | None = None, event_payload: dict[str, Any] | None = None, context_visible: bool = True, content: str | None = None, tool_name: str | None = None, tool_calls: list[dict[str, Any]] | None = None, tool_call_id: str | None = None, finish_reason: str | None = None, reasoning: str | None = None, reasoning_details: Any | None = None, codex_reasoning_items: Any | None = None, source: str = "unknown", title: str | None = None, model: str | None = None, user_id: str | None = None, parent_session_id: str | None = None, ) -> int: """向指定 session 追加一条消息。""" self.ensure_session( session_id, source=source, model=model, title=title, user_id=user_id, parent_session_id=parent_session_id, ) now = time.time() tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None event_payload_json = json.dumps(event_payload) if event_payload is not None else None reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None preview = (content or "")[:120] if role == "user" and content else None tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0) def _do(conn: sqlite3.Connection) -> int: cursor = conn.execute( """ INSERT INTO messages ( session_id, run_id, role, event_type, event_payload, context_visible, content, tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning, reasoning_details, codex_reasoning_items ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( session_id, run_id, role, event_type or role, event_payload_json, 1 if context_visible else 0, content, tool_name, tool_calls_json, tool_call_id, now, finish_reason, reasoning, reasoning_details_json, codex_items_json, ), ) conn.execute( """ UPDATE sessions SET last_active = ?, message_count = message_count + 1, tool_call_count = tool_call_count + ?, model = COALESCE(model, ?), preview = CASE WHEN preview IS NULL AND ? IS NOT NULL THEN ? ELSE preview END WHERE id = ? """, (now, tool_call_count, model, preview, preview, session_id), ) return int(cursor.lastrowid) return self._execute_write(_do) def get_message_records(self, session_id: str) -> list[MessageRecord]: rows = self._fetchall( """ SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id """, (session_id,), ) return [MessageRecord.from_row(row) for row in rows] def get_event_records(self, session_id: str) -> list[MessageRecord]: """返回当前 session 的完整事件流。 当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。 后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。 """ return self.get_message_records(session_id) def list_run_ids(self, session_id: str) -> list[str]: """按时间顺序列出当前 session 中出现过的 run_id。""" rows = self._fetchall( """ SELECT run_id FROM messages WHERE session_id = ? AND run_id IS NOT NULL GROUP BY run_id ORDER BY MIN(timestamp), MIN(id) """, (session_id,), ) return [str(row["run_id"]) for row in rows if row.get("run_id")] def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]: """返回某一次 run 对应的事件片段。""" rows = self._fetchall( """ SELECT * FROM messages WHERE session_id = ? AND run_id = ? ORDER BY timestamp, id """, (session_id, run_id), ) return [MessageRecord.from_row(row) for row in rows] def update_latest_assistant_event_payload( self, session_id: str, run_id: str, updates: dict[str, Any], ) -> None: """Merge payload fields into the latest visible assistant message for a run.""" if not updates: return def _do(conn: sqlite3.Connection) -> None: row = conn.execute( """ SELECT id, event_payload FROM messages WHERE session_id = ? AND run_id = ? AND role = 'assistant' AND event_type = 'assistant_message_added' AND context_visible = 1 ORDER BY timestamp DESC, id DESC LIMIT 1 """, (session_id, run_id), ).fetchone() if row is None: return payload: dict[str, Any] = {} if row["event_payload"]: try: parsed = json.loads(row["event_payload"]) if isinstance(parsed, dict): payload = parsed except json.JSONDecodeError: payload = {} payload.update(updates) conn.execute( """ UPDATE messages SET event_payload = ? WHERE id = ? """, (json.dumps(payload, ensure_ascii=False, sort_keys=True), row["id"]), ) self._execute_write(_do) def set_run_context_visible(self, session_id: str, run_id: str, visible: bool) -> None: """Set context visibility for all currently visible events in one run.""" def _do(conn: sqlite3.Connection) -> None: conn.execute( """ UPDATE messages SET context_visible = ? WHERE session_id = ? AND run_id = ? AND context_visible != ? """, (1 if visible else 0, session_id, run_id, 1 if visible else 0), ) self._execute_write(_do) def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] for record in self.get_event_records(session_id): if not record.context_visible: continue messages.append(record.to_conversation_message()) return messages def end_session(self, session_id: str, end_reason: str) -> None: def _do(conn: sqlite3.Connection) -> None: conn.execute( """ UPDATE sessions SET ended_at = ?, end_reason = ?, last_active = ? WHERE id = ? """, (time.time(), end_reason, time.time(), session_id), ) self._execute_write(_do) def reopen_session(self, session_id: str) -> None: def _do(conn: sqlite3.Connection) -> None: conn.execute( """ UPDATE sessions SET ended_at = NULL, end_reason = NULL, last_active = ? WHERE id = ? """, (time.time(), session_id), ) self._execute_write(_do)