移除了所有Hermes相关的命名引用,包括: - 从.gitignore中清理相关构建缓存文件 - 将README中的beaver-home路径配置更新 - 完善backend/README.md文档说明Beaver后端主线实现 - 移除Hermes风格的相关注释和兼容性代码 - 清理nanobot环境变量兼容性处理 - 删除技能迁移和服务迁移相关功能代码 - 更新测试用例中相关命名和函数名 BREAKING CHANGE: 移除了Hermes迁移相关API和CLI命令,不再支持nanobot环境变量兼容性
560 lines
19 KiB
Python
560 lines
19 KiB
Python
"""Beaver session 子系统的 SQLite 存储实现。
|
||
|
||
设计目标:
|
||
1. SQLite 作为统一 session/transcript backend
|
||
2. WAL 模式支持多读单写
|
||
3. FTS5 支持跨 session 文本检索
|
||
4. `parent_session_id` 支持 lineage
|
||
|
||
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import sqlite3
|
||
import threading
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Callable, TypeVar
|
||
|
||
from .models import MessageRecord, SessionRecord
|
||
|
||
T = TypeVar("T")
|
||
|
||
SCHEMA_SQL = """
|
||
CREATE TABLE IF NOT EXISTS sessions (
|
||
id TEXT PRIMARY KEY,
|
||
source TEXT NOT NULL,
|
||
user_id TEXT,
|
||
title TEXT,
|
||
model TEXT,
|
||
system_prompt TEXT,
|
||
parent_session_id TEXT,
|
||
started_at REAL NOT NULL,
|
||
last_active REAL NOT NULL,
|
||
ended_at REAL,
|
||
end_reason TEXT,
|
||
message_count INTEGER DEFAULT 0,
|
||
tool_call_count INTEGER DEFAULT 0,
|
||
input_tokens INTEGER DEFAULT 0,
|
||
output_tokens INTEGER DEFAULT 0,
|
||
cache_read_tokens INTEGER DEFAULT 0,
|
||
cache_write_tokens INTEGER DEFAULT 0,
|
||
reasoning_tokens INTEGER DEFAULT 0,
|
||
estimated_cost_usd REAL DEFAULT 0,
|
||
actual_cost_usd REAL,
|
||
preview TEXT,
|
||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS messages (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||
run_id TEXT,
|
||
role TEXT NOT NULL,
|
||
event_type TEXT,
|
||
event_payload TEXT,
|
||
context_visible INTEGER NOT NULL DEFAULT 1,
|
||
content TEXT,
|
||
tool_name TEXT,
|
||
tool_calls TEXT,
|
||
tool_call_id TEXT,
|
||
timestamp REAL NOT NULL,
|
||
finish_reason TEXT,
|
||
reasoning TEXT,
|
||
reasoning_details TEXT,
|
||
codex_reasoning_items TEXT
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
|
||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
|
||
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
|
||
"""
|
||
|
||
FTS_TABLE_SQL = """
|
||
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||
content,
|
||
content=messages,
|
||
content_rowid=id
|
||
);
|
||
"""
|
||
|
||
FTS_TRIGGER_SQL = """
|
||
DROP TRIGGER IF EXISTS messages_fts_insert;
|
||
DROP TRIGGER IF EXISTS messages_fts_delete;
|
||
DROP TRIGGER IF EXISTS messages_fts_update;
|
||
|
||
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
|
||
INSERT INTO messages_fts(rowid, content)
|
||
SELECT new.id, new.content
|
||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||
END;
|
||
|
||
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
|
||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||
SELECT 'delete', old.id, old.content
|
||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||
END;
|
||
|
||
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
|
||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||
SELECT 'delete', old.id, old.content
|
||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||
INSERT INTO messages_fts(rowid, content)
|
||
SELECT new.id, new.content
|
||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||
END;
|
||
"""
|
||
|
||
|
||
class SessionStore:
|
||
"""SQLite-backed session store."""
|
||
|
||
def __init__(self, db_path: str | Path) -> None:
|
||
self.db_path = Path(db_path)
|
||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||
self._lock = threading.Lock()
|
||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
|
||
self._conn.row_factory = sqlite3.Row
|
||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||
self._init_schema()
|
||
|
||
def _init_schema(self) -> None:
|
||
with self._lock:
|
||
self._conn.executescript(SCHEMA_SQL)
|
||
try:
|
||
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
|
||
self._conn.executescript(FTS_TRIGGER_SQL)
|
||
except sqlite3.Error:
|
||
self._rebuild_fts_index()
|
||
return
|
||
# 旧版本可能把 hidden 事件也写进了 FTS;初始化时顺手清掉这些噪声项。
|
||
try:
|
||
self._conn.execute(
|
||
"""
|
||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||
SELECT 'delete', id, content
|
||
FROM messages
|
||
WHERE context_visible = 0 AND content IS NOT NULL
|
||
"""
|
||
)
|
||
self._conn.commit()
|
||
except sqlite3.Error:
|
||
self._rebuild_fts_index()
|
||
|
||
def _rebuild_fts_index(self) -> None:
|
||
"""Recreate the derived FTS index without touching canonical session rows."""
|
||
|
||
self._conn.executescript(
|
||
"""
|
||
DROP TRIGGER IF EXISTS messages_fts_insert;
|
||
DROP TRIGGER IF EXISTS messages_fts_delete;
|
||
DROP TRIGGER IF EXISTS messages_fts_update;
|
||
DROP TABLE IF EXISTS messages_fts;
|
||
"""
|
||
)
|
||
self._conn.executescript(FTS_TABLE_SQL)
|
||
self._conn.executescript(FTS_TRIGGER_SQL)
|
||
self._conn.execute(
|
||
"""
|
||
INSERT INTO messages_fts(rowid, content)
|
||
SELECT id, content
|
||
FROM messages
|
||
WHERE context_visible = 1 AND content IS NOT NULL
|
||
"""
|
||
)
|
||
self._conn.commit()
|
||
|
||
def close(self) -> None:
|
||
with self._lock:
|
||
self._conn.close()
|
||
|
||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||
with self._lock:
|
||
self._conn.execute("BEGIN IMMEDIATE")
|
||
try:
|
||
result = fn(self._conn)
|
||
self._conn.commit()
|
||
return result
|
||
except BaseException:
|
||
self._conn.rollback()
|
||
raise
|
||
|
||
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
|
||
with self._lock:
|
||
row = self._conn.execute(sql, params).fetchone()
|
||
return dict(row) if row else None
|
||
|
||
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
||
with self._lock:
|
||
rows = self._conn.execute(sql, params).fetchall()
|
||
return [dict(row) for row in rows]
|
||
|
||
def ensure_session(
|
||
self,
|
||
session_id: str,
|
||
*,
|
||
source: str = "unknown",
|
||
model: str | None = None,
|
||
title: str | None = None,
|
||
user_id: str | None = None,
|
||
parent_session_id: str | None = None,
|
||
) -> str:
|
||
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
|
||
|
||
now = time.time()
|
||
|
||
def _do(conn: sqlite3.Connection) -> str:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO sessions (
|
||
id, source, user_id, title, model, parent_session_id, started_at, last_active
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
ON CONFLICT(id) DO UPDATE SET
|
||
source = CASE
|
||
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
|
||
ELSE sessions.source
|
||
END,
|
||
user_id = COALESCE(sessions.user_id, excluded.user_id),
|
||
title = COALESCE(sessions.title, excluded.title),
|
||
model = COALESCE(sessions.model, excluded.model),
|
||
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
|
||
""",
|
||
(session_id, source, user_id, title, model, parent_session_id, now, now),
|
||
)
|
||
return session_id
|
||
|
||
return self._execute_write(_do)
|
||
|
||
def get_session_record(self, session_id: str) -> SessionRecord | None:
|
||
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||
return SessionRecord.from_row(row) if row else None
|
||
|
||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||
"""保存本 session 组装后的完整 system prompt snapshot。"""
|
||
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
UPDATE sessions
|
||
SET system_prompt = ?, last_active = ?
|
||
WHERE id = ?
|
||
""",
|
||
(system_prompt, time.time(), session_id),
|
||
)
|
||
|
||
self._execute_write(_do)
|
||
|
||
def update_usage(
|
||
self,
|
||
session_id: str,
|
||
*,
|
||
input_tokens: int = 0,
|
||
output_tokens: int = 0,
|
||
cache_read_tokens: int = 0,
|
||
cache_write_tokens: int = 0,
|
||
reasoning_tokens: int = 0,
|
||
estimated_cost_usd: float = 0.0,
|
||
actual_cost_usd: float | None = None,
|
||
absolute: bool = False,
|
||
) -> None:
|
||
"""更新会话 usage。默认按增量累加。"""
|
||
|
||
if absolute:
|
||
sql = """
|
||
UPDATE sessions
|
||
SET input_tokens = ?,
|
||
output_tokens = ?,
|
||
cache_read_tokens = ?,
|
||
cache_write_tokens = ?,
|
||
reasoning_tokens = ?,
|
||
estimated_cost_usd = ?,
|
||
actual_cost_usd = ?,
|
||
last_active = ?
|
||
WHERE id = ?
|
||
"""
|
||
params = (
|
||
input_tokens,
|
||
output_tokens,
|
||
cache_read_tokens,
|
||
cache_write_tokens,
|
||
reasoning_tokens,
|
||
estimated_cost_usd,
|
||
actual_cost_usd,
|
||
time.time(),
|
||
session_id,
|
||
)
|
||
else:
|
||
sql = """
|
||
UPDATE sessions
|
||
SET input_tokens = input_tokens + ?,
|
||
output_tokens = output_tokens + ?,
|
||
cache_read_tokens = cache_read_tokens + ?,
|
||
cache_write_tokens = cache_write_tokens + ?,
|
||
reasoning_tokens = reasoning_tokens + ?,
|
||
estimated_cost_usd = estimated_cost_usd + ?,
|
||
actual_cost_usd = CASE
|
||
WHEN ? IS NULL THEN actual_cost_usd
|
||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||
END,
|
||
last_active = ?
|
||
WHERE id = ?
|
||
"""
|
||
params = (
|
||
input_tokens,
|
||
output_tokens,
|
||
cache_read_tokens,
|
||
cache_write_tokens,
|
||
reasoning_tokens,
|
||
estimated_cost_usd,
|
||
actual_cost_usd,
|
||
actual_cost_usd,
|
||
time.time(),
|
||
session_id,
|
||
)
|
||
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
conn.execute(sql, params)
|
||
|
||
self._execute_write(_do)
|
||
|
||
def append_message(
|
||
self,
|
||
session_id: str,
|
||
*,
|
||
run_id: str | None = None,
|
||
role: str,
|
||
event_type: str | None = None,
|
||
event_payload: dict[str, Any] | None = None,
|
||
context_visible: bool = True,
|
||
content: str | None = None,
|
||
tool_name: str | None = None,
|
||
tool_calls: list[dict[str, Any]] | None = None,
|
||
tool_call_id: str | None = None,
|
||
finish_reason: str | None = None,
|
||
reasoning: str | None = None,
|
||
reasoning_details: Any | None = None,
|
||
codex_reasoning_items: Any | None = None,
|
||
source: str = "unknown",
|
||
title: str | None = None,
|
||
model: str | None = None,
|
||
user_id: str | None = None,
|
||
parent_session_id: str | None = None,
|
||
) -> int:
|
||
"""向指定 session 追加一条消息。"""
|
||
|
||
self.ensure_session(
|
||
session_id,
|
||
source=source,
|
||
model=model,
|
||
title=title,
|
||
user_id=user_id,
|
||
parent_session_id=parent_session_id,
|
||
)
|
||
now = time.time()
|
||
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
|
||
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
|
||
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
|
||
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
|
||
preview = (content or "")[:120] if role == "user" and content else None
|
||
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
|
||
|
||
def _do(conn: sqlite3.Connection) -> int:
|
||
cursor = conn.execute(
|
||
"""
|
||
INSERT INTO messages (
|
||
session_id, run_id, role, event_type, event_payload, context_visible, content,
|
||
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
|
||
reasoning_details, codex_reasoning_items
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
session_id,
|
||
run_id,
|
||
role,
|
||
event_type or role,
|
||
event_payload_json,
|
||
1 if context_visible else 0,
|
||
content,
|
||
tool_name,
|
||
tool_calls_json,
|
||
tool_call_id,
|
||
now,
|
||
finish_reason,
|
||
reasoning,
|
||
reasoning_details_json,
|
||
codex_items_json,
|
||
),
|
||
)
|
||
conn.execute(
|
||
"""
|
||
UPDATE sessions
|
||
SET last_active = ?,
|
||
message_count = message_count + 1,
|
||
tool_call_count = tool_call_count + ?,
|
||
model = COALESCE(model, ?),
|
||
preview = CASE
|
||
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
|
||
ELSE preview
|
||
END
|
||
WHERE id = ?
|
||
""",
|
||
(now, tool_call_count, model, preview, preview, session_id),
|
||
)
|
||
return int(cursor.lastrowid)
|
||
|
||
return self._execute_write(_do)
|
||
|
||
def get_message_records(self, session_id: str) -> list[MessageRecord]:
|
||
rows = self._fetchall(
|
||
"""
|
||
SELECT *
|
||
FROM messages
|
||
WHERE session_id = ?
|
||
ORDER BY timestamp, id
|
||
""",
|
||
(session_id,),
|
||
)
|
||
return [MessageRecord.from_row(row) for row in rows]
|
||
|
||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||
"""返回当前 session 的完整事件流。
|
||
|
||
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
|
||
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
|
||
"""
|
||
|
||
return self.get_message_records(session_id)
|
||
|
||
def list_run_ids(self, session_id: str) -> list[str]:
|
||
"""按时间顺序列出当前 session 中出现过的 run_id。"""
|
||
|
||
rows = self._fetchall(
|
||
"""
|
||
SELECT run_id
|
||
FROM messages
|
||
WHERE session_id = ? AND run_id IS NOT NULL
|
||
GROUP BY run_id
|
||
ORDER BY MIN(timestamp), MIN(id)
|
||
""",
|
||
(session_id,),
|
||
)
|
||
return [str(row["run_id"]) for row in rows if row.get("run_id")]
|
||
|
||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||
"""返回某一次 run 对应的事件片段。"""
|
||
|
||
rows = self._fetchall(
|
||
"""
|
||
SELECT *
|
||
FROM messages
|
||
WHERE session_id = ? AND run_id = ?
|
||
ORDER BY timestamp, id
|
||
""",
|
||
(session_id, run_id),
|
||
)
|
||
return [MessageRecord.from_row(row) for row in rows]
|
||
|
||
def update_latest_assistant_event_payload(
|
||
self,
|
||
session_id: str,
|
||
run_id: str,
|
||
updates: dict[str, Any],
|
||
) -> None:
|
||
"""Merge payload fields into the latest visible assistant message for a run."""
|
||
|
||
if not updates:
|
||
return
|
||
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
row = conn.execute(
|
||
"""
|
||
SELECT id, event_payload
|
||
FROM messages
|
||
WHERE session_id = ?
|
||
AND run_id = ?
|
||
AND role = 'assistant'
|
||
AND event_type = 'assistant_message_added'
|
||
AND context_visible = 1
|
||
ORDER BY timestamp DESC, id DESC
|
||
LIMIT 1
|
||
""",
|
||
(session_id, run_id),
|
||
).fetchone()
|
||
if row is None:
|
||
return
|
||
payload: dict[str, Any] = {}
|
||
if row["event_payload"]:
|
||
try:
|
||
parsed = json.loads(row["event_payload"])
|
||
if isinstance(parsed, dict):
|
||
payload = parsed
|
||
except json.JSONDecodeError:
|
||
payload = {}
|
||
payload.update(updates)
|
||
conn.execute(
|
||
"""
|
||
UPDATE messages
|
||
SET event_payload = ?
|
||
WHERE id = ?
|
||
""",
|
||
(json.dumps(payload, ensure_ascii=False, sort_keys=True), row["id"]),
|
||
)
|
||
|
||
self._execute_write(_do)
|
||
|
||
def set_run_context_visible(self, session_id: str, run_id: str, visible: bool) -> None:
|
||
"""Set context visibility for all currently visible events in one run."""
|
||
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
UPDATE messages
|
||
SET context_visible = ?
|
||
WHERE session_id = ?
|
||
AND run_id = ?
|
||
AND context_visible != ?
|
||
""",
|
||
(1 if visible else 0, session_id, run_id, 1 if visible else 0),
|
||
)
|
||
|
||
self._execute_write(_do)
|
||
|
||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||
messages: list[dict[str, Any]] = []
|
||
for record in self.get_event_records(session_id):
|
||
if not record.context_visible:
|
||
continue
|
||
messages.append(record.to_conversation_message())
|
||
return messages
|
||
|
||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
UPDATE sessions
|
||
SET ended_at = ?, end_reason = ?, last_active = ?
|
||
WHERE id = ?
|
||
""",
|
||
(time.time(), end_reason, time.time(), session_id),
|
||
)
|
||
|
||
self._execute_write(_do)
|
||
|
||
def reopen_session(self, session_id: str) -> None:
|
||
def _do(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
UPDATE sessions
|
||
SET ended_at = NULL, end_reason = NULL, last_active = ?
|
||
WHERE id = ?
|
||
""",
|
||
(time.time(), session_id),
|
||
)
|
||
|
||
self._execute_write(_do)
|