Files
beaver_project/app-instance/backend/beaver/engine/session/store.py
steven_li 3b0af173cc refactor(beaver): 移除Hermes相关引用和迁移代码,完善Beaver后端主线实现
移除了所有Hermes相关的命名引用,包括:
- 从.gitignore中清理相关构建缓存文件
- 将README中的beaver-home路径配置更新
- 完善backend/README.md文档说明Beaver后端主线实现
- 移除Hermes风格的相关注释和兼容性代码
- 清理nanobot环境变量兼容性处理
- 删除技能迁移和服务迁移相关功能代码
- 更新测试用例中相关命名和函数名

BREAKING CHANGE: 移除了Hermes迁移相关API和CLI命令,不再支持nanobot环境变量兼容性
2026-05-14 17:20:32 +08:00

560 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Beaver session 子系统的 SQLite 存储实现。
设计目标:
1. SQLite 作为统一 session/transcript backend
2. WAL 模式支持多读单写
3. FTS5 支持跨 session 文本检索
4. `parent_session_id` 支持 lineage
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable, TypeVar
from .models import MessageRecord, SessionRecord
T = TypeVar("T")
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
user_id TEXT,
title TEXT,
model TEXT,
system_prompt TEXT,
parent_session_id TEXT,
started_at REAL NOT NULL,
last_active REAL NOT NULL,
ended_at REAL,
end_reason TEXT,
message_count INTEGER DEFAULT 0,
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
estimated_cost_usd REAL DEFAULT 0,
actual_cost_usd REAL,
preview TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id),
run_id TEXT,
role TEXT NOT NULL,
event_type TEXT,
event_payload TEXT,
context_visible INTEGER NOT NULL DEFAULT 1,
content TEXT,
tool_name TEXT,
tool_calls TEXT,
tool_call_id TEXT,
timestamp REAL NOT NULL,
finish_reason TEXT,
reasoning TEXT,
reasoning_details TEXT,
codex_reasoning_items TEXT
);
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
"""
FTS_TABLE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content,
content=messages,
content_rowid=id
);
"""
FTS_TRIGGER_SQL = """
DROP TRIGGER IF EXISTS messages_fts_insert;
DROP TRIGGER IF EXISTS messages_fts_delete;
DROP TRIGGER IF EXISTS messages_fts_update;
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
"""
class SessionStore:
"""SQLite-backed session store."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._init_schema()
def _init_schema(self) -> None:
with self._lock:
self._conn.executescript(SCHEMA_SQL)
try:
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
self._conn.executescript(FTS_TRIGGER_SQL)
except sqlite3.Error:
self._rebuild_fts_index()
return
# 旧版本可能把 hidden 事件也写进了 FTS初始化时顺手清掉这些噪声项。
try:
self._conn.execute(
"""
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', id, content
FROM messages
WHERE context_visible = 0 AND content IS NOT NULL
"""
)
self._conn.commit()
except sqlite3.Error:
self._rebuild_fts_index()
def _rebuild_fts_index(self) -> None:
"""Recreate the derived FTS index without touching canonical session rows."""
self._conn.executescript(
"""
DROP TRIGGER IF EXISTS messages_fts_insert;
DROP TRIGGER IF EXISTS messages_fts_delete;
DROP TRIGGER IF EXISTS messages_fts_update;
DROP TABLE IF EXISTS messages_fts;
"""
)
self._conn.executescript(FTS_TABLE_SQL)
self._conn.executescript(FTS_TRIGGER_SQL)
self._conn.execute(
"""
INSERT INTO messages_fts(rowid, content)
SELECT id, content
FROM messages
WHERE context_visible = 1 AND content IS NOT NULL
"""
)
self._conn.commit()
def close(self) -> None:
with self._lock:
self._conn.close()
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
result = fn(self._conn)
self._conn.commit()
return result
except BaseException:
self._conn.rollback()
raise
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
with self._lock:
row = self._conn.execute(sql, params).fetchone()
return dict(row) if row else None
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
with self._lock:
rows = self._conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
now = time.time()
def _do(conn: sqlite3.Connection) -> str:
conn.execute(
"""
INSERT INTO sessions (
id, source, user_id, title, model, parent_session_id, started_at, last_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
source = CASE
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
ELSE sessions.source
END,
user_id = COALESCE(sessions.user_id, excluded.user_id),
title = COALESCE(sessions.title, excluded.title),
model = COALESCE(sessions.model, excluded.model),
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
""",
(session_id, source, user_id, title, model, parent_session_id, now, now),
)
return session_id
return self._execute_write(_do)
def get_session_record(self, session_id: str) -> SessionRecord | None:
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
return SessionRecord.from_row(row) if row else None
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""保存本 session 组装后的完整 system prompt snapshot。"""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET system_prompt = ?, last_active = ?
WHERE id = ?
""",
(system_prompt, time.time(), session_id),
)
self._execute_write(_do)
def update_usage(
self,
session_id: str,
*,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: float = 0.0,
actual_cost_usd: float | None = None,
absolute: bool = False,
) -> None:
"""更新会话 usage。默认按增量累加。"""
if absolute:
sql = """
UPDATE sessions
SET input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = ?,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
else:
sql = """
UPDATE sessions
SET input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = estimated_cost_usd + ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
def _do(conn: sqlite3.Connection) -> None:
conn.execute(sql, params)
self._execute_write(_do)
def append_message(
self,
session_id: str,
*,
run_id: str | None = None,
role: str,
event_type: str | None = None,
event_payload: dict[str, Any] | None = None,
context_visible: bool = True,
content: str | None = None,
tool_name: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
finish_reason: str | None = None,
reasoning: str | None = None,
reasoning_details: Any | None = None,
codex_reasoning_items: Any | None = None,
source: str = "unknown",
title: str | None = None,
model: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> int:
"""向指定 session 追加一条消息。"""
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
now = time.time()
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
preview = (content or "")[:120] if role == "user" and content else None
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
def _do(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"""
INSERT INTO messages (
session_id, run_id, role, event_type, event_payload, context_visible, content,
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
reasoning_details, codex_reasoning_items
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
run_id,
role,
event_type or role,
event_payload_json,
1 if context_visible else 0,
content,
tool_name,
tool_calls_json,
tool_call_id,
now,
finish_reason,
reasoning,
reasoning_details_json,
codex_items_json,
),
)
conn.execute(
"""
UPDATE sessions
SET last_active = ?,
message_count = message_count + 1,
tool_call_count = tool_call_count + ?,
model = COALESCE(model, ?),
preview = CASE
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
ELSE preview
END
WHERE id = ?
""",
(now, tool_call_count, model, preview, preview, session_id),
)
return int(cursor.lastrowid)
return self._execute_write(_do)
def get_message_records(self, session_id: str) -> list[MessageRecord]:
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ?
ORDER BY timestamp, id
""",
(session_id,),
)
return [MessageRecord.from_row(row) for row in rows]
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
"""
return self.get_message_records(session_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按时间顺序列出当前 session 中出现过的 run_id。"""
rows = self._fetchall(
"""
SELECT run_id
FROM messages
WHERE session_id = ? AND run_id IS NOT NULL
GROUP BY run_id
ORDER BY MIN(timestamp), MIN(id)
""",
(session_id,),
)
return [str(row["run_id"]) for row in rows if row.get("run_id")]
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 run 对应的事件片段。"""
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ? AND run_id = ?
ORDER BY timestamp, id
""",
(session_id, run_id),
)
return [MessageRecord.from_row(row) for row in rows]
def update_latest_assistant_event_payload(
self,
session_id: str,
run_id: str,
updates: dict[str, Any],
) -> None:
"""Merge payload fields into the latest visible assistant message for a run."""
if not updates:
return
def _do(conn: sqlite3.Connection) -> None:
row = conn.execute(
"""
SELECT id, event_payload
FROM messages
WHERE session_id = ?
AND run_id = ?
AND role = 'assistant'
AND event_type = 'assistant_message_added'
AND context_visible = 1
ORDER BY timestamp DESC, id DESC
LIMIT 1
""",
(session_id, run_id),
).fetchone()
if row is None:
return
payload: dict[str, Any] = {}
if row["event_payload"]:
try:
parsed = json.loads(row["event_payload"])
if isinstance(parsed, dict):
payload = parsed
except json.JSONDecodeError:
payload = {}
payload.update(updates)
conn.execute(
"""
UPDATE messages
SET event_payload = ?
WHERE id = ?
""",
(json.dumps(payload, ensure_ascii=False, sort_keys=True), row["id"]),
)
self._execute_write(_do)
def set_run_context_visible(self, session_id: str, run_id: str, visible: bool) -> None:
"""Set context visibility for all currently visible events in one run."""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE messages
SET context_visible = ?
WHERE session_id = ?
AND run_id = ?
AND context_visible != ?
""",
(1 if visible else 0, session_id, run_id, 1 if visible else 0),
)
self._execute_write(_do)
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for record in self.get_event_records(session_id):
if not record.context_visible:
continue
messages.append(record.to_conversation_message())
return messages
def end_session(self, session_id: str, end_reason: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = ?, end_reason = ?, last_active = ?
WHERE id = ?
""",
(time.time(), end_reason, time.time(), session_id),
)
self._execute_write(_do)
def reopen_session(self, session_id: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = NULL, end_reason = NULL, last_active = ?
WHERE id = ?
""",
(time.time(), session_id),
)
self._execute_write(_do)