修改了nanobot,往Hermes agent的风格走,进度1/3

This commit is contained in:
2026-04-20 18:11:14 +08:00
parent cdfc222c9f
commit 36882a7d7b
261 changed files with 12659 additions and 604 deletions

View File

@ -0,0 +1,15 @@
"""Session state and persistence."""
from .manager import SessionManager
from .models import MessageRecord, SessionRecord, SessionUsage
from .search import SessionSearchService
from .store import SessionStore
__all__ = [
"MessageRecord",
"SessionManager",
"SessionRecord",
"SessionSearchService",
"SessionStore",
"SessionUsage",
]

View File

@ -0,0 +1,143 @@
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from .models import MessageRecord
from .search import SessionSearchService
from .store import SessionStore
class SessionManager:
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
self.workspace = Path(workspace)
self.sessions_dir = self.workspace / "sessions"
self.sessions_dir.mkdir(parents=True, exist_ok=True)
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
self.store = SessionStore(self.db_path)
self.search = SessionSearchService(self.store)
def close(self) -> None:
self.store.close()
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
return self.store.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
def get_session(self, session_id: str) -> dict[str, Any] | None:
record = self.store.get_session_record(session_id)
return record.to_dict() if record is not None else None
def get_or_create(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> dict[str, Any]:
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
session = self.get_session(session_id)
if session is None:
raise RuntimeError(f"Failed to create session {session_id!r}")
return session
def append_message(self, session_id: str, **kwargs: Any) -> int:
return self.store.append_message(session_id, **kwargs)
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
这里和 `get_messages_as_conversation()` 的区别很关键:
- `get_event_records()` 面向 runtime / replay / audit保留隐藏系统事件
- `get_messages_as_conversation()` 面向 prompt builder只暴露可进上下文的事件
第 6 阶段开始后session 已不再只是“聊天消息存储”,而是在逐步收敛成
“外部事件流 + 上层投影视图”。
"""
return self.store.get_event_records(session_id)
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 direct run / future bus run 对应的事件片段。"""
return self.store.get_run_event_records(session_id, run_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按出现顺序列出当前 session 的所有 run_id。"""
return self.store.list_run_ids(session_id)
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
return self.store.get_messages_as_conversation(session_id)
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""返回适合注入 prompt 的可见历史切片。
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
1. 只包含 `context_visible=True` 的事件
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
"""
history = self.get_messages_as_conversation(session_id)
sliced = history[-max_messages:]
for index, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[index:]
break
return sliced
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""兼容旧名称,实际返回可见历史切片。"""
return self.get_visible_history(session_id, max_messages=max_messages)
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
self.store.update_system_prompt(session_id, system_prompt)
def update_usage(self, session_id: str, **kwargs: Any) -> None:
self.store.update_usage(session_id, **kwargs)
def end_session(self, session_id: str, end_reason: str) -> None:
self.store.end_session(session_id, end_reason)
def reopen_session(self, session_id: str) -> None:
self.store.reopen_session(session_id)
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.list_sessions_rich(**kwargs)
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.search_messages(**kwargs)
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
return self.search.resolve_session_id(session_id_or_prefix)

View File

@ -0,0 +1,211 @@
"""Beaver session 子系统的数据模型。
这层只定义数据结构,不放数据库读写逻辑。目的是把:
1. SQLite 行结构
2. 运行时会话对象
3. 对外暴露的 conversation message
三件事分开,避免后续所有地方都直接和裸字典耦合。
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class SessionUsage:
"""会话维度的 usage/cost 统计。"""
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
estimated_cost_usd: float = 0.0
actual_cost_usd: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"actual_cost_usd": self.actual_cost_usd,
}
@dataclass(slots=True)
class MessageRecord:
"""单条会话事件的结构化表示。
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
1. 普通 user/assistant/tool 消息本身就是事件
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
"""
role: str
content: str | None = None
timestamp: float | None = None
message_id: int | None = None
run_id: str | None = None
event_type: str | None = None
event_payload: dict[str, Any] | None = None
context_visible: bool = True
tool_name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
finish_reason: str | None = None
reasoning: str | None = None
reasoning_details: Any | None = None
codex_reasoning_items: Any | None = None
def to_conversation_message(self) -> dict[str, Any]:
"""转成 provider / context builder 可直接消费的消息格式。"""
if not self.context_visible:
raise ValueError("Hidden session events cannot be converted into conversation messages")
payload: dict[str, Any] = {
"role": self.role,
"content": self.content,
}
if self.tool_name:
payload["tool_name"] = self.tool_name
if self.tool_calls:
payload["tool_calls"] = self.tool_calls
if self.tool_call_id:
payload["tool_call_id"] = self.tool_call_id
if self.finish_reason:
payload["finish_reason"] = self.finish_reason
if self.reasoning:
payload["reasoning"] = self.reasoning
if self.reasoning_details is not None:
payload["reasoning_details"] = self.reasoning_details
if self.codex_reasoning_items is not None:
payload["codex_reasoning_items"] = self.codex_reasoning_items
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
"""从 SQLite row/dict 恢复消息模型。"""
tool_calls = row.get("tool_calls")
if isinstance(tool_calls, str):
try:
tool_calls = json.loads(tool_calls)
except json.JSONDecodeError:
tool_calls = []
reasoning_details = row.get("reasoning_details")
if isinstance(reasoning_details, str):
try:
reasoning_details = json.loads(reasoning_details)
except json.JSONDecodeError:
reasoning_details = None
codex_reasoning_items = row.get("codex_reasoning_items")
if isinstance(codex_reasoning_items, str):
try:
codex_reasoning_items = json.loads(codex_reasoning_items)
except json.JSONDecodeError:
codex_reasoning_items = None
event_payload = row.get("event_payload")
if isinstance(event_payload, str):
try:
event_payload = json.loads(event_payload)
except json.JSONDecodeError:
event_payload = None
return cls(
message_id=row.get("id"),
run_id=row.get("run_id"),
role=row["role"],
content=row.get("content"),
event_type=row.get("event_type") or row.get("role"),
event_payload=event_payload,
context_visible=bool(row.get("context_visible", 1)),
tool_name=row.get("tool_name"),
tool_calls=tool_calls,
tool_call_id=row.get("tool_call_id"),
timestamp=row.get("timestamp"),
finish_reason=row.get("finish_reason"),
reasoning=row.get("reasoning"),
reasoning_details=reasoning_details,
codex_reasoning_items=codex_reasoning_items,
)
@dataclass(slots=True)
class SessionRecord:
"""单个 session 的结构化表示。"""
session_id: str
source: str
started_at: float
last_active: float
user_id: str | None = None
title: str | None = None
model: str | None = None
system_prompt: str | None = None
parent_session_id: str | None = None
ended_at: float | None = None
end_reason: str | None = None
message_count: int = 0
tool_call_count: int = 0
preview: str | None = None
usage: SessionUsage = field(default_factory=SessionUsage)
def to_dict(self) -> dict[str, Any]:
payload = {
"id": self.session_id,
"source": self.source,
"user_id": self.user_id,
"title": self.title,
"model": self.model,
"system_prompt": self.system_prompt,
"parent_session_id": self.parent_session_id,
"started_at": self.started_at,
"last_active": self.last_active,
"ended_at": self.ended_at,
"end_reason": self.end_reason,
"message_count": self.message_count,
"tool_call_count": self.tool_call_count,
"preview": self.preview,
}
payload.update(self.usage.to_dict())
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
return cls(
session_id=row["id"],
source=row["source"],
user_id=row.get("user_id"),
title=row.get("title"),
model=row.get("model"),
system_prompt=row.get("system_prompt"),
parent_session_id=row.get("parent_session_id"),
started_at=row["started_at"],
last_active=row["last_active"],
ended_at=row.get("ended_at"),
end_reason=row.get("end_reason"),
message_count=row.get("message_count", 0),
tool_call_count=row.get("tool_call_count", 0),
preview=row.get("preview"),
usage=SessionUsage(
input_tokens=row.get("input_tokens", 0),
output_tokens=row.get("output_tokens", 0),
cache_read_tokens=row.get("cache_read_tokens", 0),
cache_write_tokens=row.get("cache_write_tokens", 0),
reasoning_tokens=row.get("reasoning_tokens", 0),
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
actual_cost_usd=row.get("actual_cost_usd"),
),
)

View File

@ -0,0 +1,151 @@
"""Beaver session 子系统的检索能力。"""
from __future__ import annotations
import re
import sqlite3
from typing import Any
from .store import SessionStore
class SessionSearchService:
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
def __init__(self, store: SessionStore) -> None:
self.store = store
@staticmethod
def _sanitize_fts5_query(query: str) -> str:
quoted_parts: list[str] = []
def preserve(match: re.Match[str]) -> str:
quoted_parts.append(match.group(0))
return f"\x00Q{len(quoted_parts) - 1}\x00"
sanitized = re.sub(r'"[^"]*"', preserve, query)
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
sanitized = re.sub(r"\*+", "*", sanitized)
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
for index, quoted in enumerate(quoted_parts):
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
return sanitized.strip()
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
exact = self.store.get_session_record(session_id_or_prefix)
if exact is not None:
return exact.session_id
escaped = (
session_id_or_prefix
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
rows = self.store._fetchall(
"""
SELECT id
FROM sessions
WHERE id LIKE ? ESCAPE '\\'
ORDER BY started_at DESC
LIMIT 2
""",
(f"{escaped}%",),
)
if len(rows) == 1:
return rows[0]["id"]
return None
def list_sessions_rich(
self,
*,
limit: int = 20,
offset: int = 0,
include_children: bool = False,
source: str | None = None,
exclude_sources: list[str] | None = None,
) -> list[dict[str, Any]]:
"""列出最近活跃的 session 及其摘要元数据。"""
clauses: list[str] = []
params: list[Any] = []
if not include_children:
clauses.append("parent_session_id IS NULL")
if source:
clauses.append("source = ?")
params.append(source)
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"source NOT IN ({placeholders})")
params.extend(exclude_sources)
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
params.extend([limit, offset])
rows = self.store._fetchall(
f"""
SELECT *
FROM sessions
{where}
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
tuple(params),
)
return rows
def search_messages(
self,
*,
query: str,
role_filter: list[str] | None = None,
exclude_sources: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> list[dict[str, Any]]:
"""使用 FTS5 搜索 session transcript。"""
query = self._sanitize_fts5_query(query)
if not query:
return []
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
params: list[Any] = [query]
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"s.source NOT IN ({placeholders})")
params.extend(exclude_sources)
if role_filter:
placeholders = ",".join("?" for _ in role_filter)
clauses.append(f"m.role IN ({placeholders})")
params.extend(role_filter)
params.extend([limit, offset])
sql = f"""
SELECT
m.id,
m.session_id,
m.role,
s.source,
s.model,
s.started_at AS session_started,
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
FROM messages_fts
JOIN messages m ON m.id = messages_fts.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(clauses)}
ORDER BY rank
LIMIT ? OFFSET ?
"""
try:
return self.store._fetchall(sql, tuple(params))
except sqlite3.Error as exc:
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc

View File

@ -0,0 +1,467 @@
"""Beaver session 子系统的 SQLite 存储实现。
设计来源主要参考 Hermes-agent
1. SQLite 作为统一 session/transcript backend
2. WAL 模式支持多读单写
3. FTS5 支持跨 session 文本检索
4. `parent_session_id` 支持 lineage
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable, TypeVar
from .models import MessageRecord, SessionRecord
T = TypeVar("T")
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
user_id TEXT,
title TEXT,
model TEXT,
system_prompt TEXT,
parent_session_id TEXT,
started_at REAL NOT NULL,
last_active REAL NOT NULL,
ended_at REAL,
end_reason TEXT,
message_count INTEGER DEFAULT 0,
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
estimated_cost_usd REAL DEFAULT 0,
actual_cost_usd REAL,
preview TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id),
run_id TEXT,
role TEXT NOT NULL,
event_type TEXT,
event_payload TEXT,
context_visible INTEGER NOT NULL DEFAULT 1,
content TEXT,
tool_name TEXT,
tool_calls TEXT,
tool_call_id TEXT,
timestamp REAL NOT NULL,
finish_reason TEXT,
reasoning TEXT,
reasoning_details TEXT,
codex_reasoning_items TEXT
);
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
"""
FTS_TABLE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content,
content=messages,
content_rowid=id
);
"""
FTS_TRIGGER_SQL = """
DROP TRIGGER IF EXISTS messages_fts_insert;
DROP TRIGGER IF EXISTS messages_fts_delete;
DROP TRIGGER IF EXISTS messages_fts_update;
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
"""
class SessionStore:
"""SQLite-backed session store."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._init_schema()
def _init_schema(self) -> None:
with self._lock:
self._conn.executescript(SCHEMA_SQL)
try:
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
except sqlite3.OperationalError:
self._conn.executescript(FTS_TABLE_SQL)
self._conn.executescript(FTS_TRIGGER_SQL)
# 旧版本可能把 hidden 事件也写进了 FTS初始化时顺手清掉这些噪声项。
self._conn.execute(
"""
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', id, content
FROM messages
WHERE context_visible = 0 AND content IS NOT NULL
"""
)
self._conn.commit()
def close(self) -> None:
with self._lock:
self._conn.close()
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
result = fn(self._conn)
self._conn.commit()
return result
except BaseException:
self._conn.rollback()
raise
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
with self._lock:
row = self._conn.execute(sql, params).fetchone()
return dict(row) if row else None
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
with self._lock:
rows = self._conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
now = time.time()
def _do(conn: sqlite3.Connection) -> str:
conn.execute(
"""
INSERT INTO sessions (
id, source, user_id, title, model, parent_session_id, started_at, last_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
source = CASE
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
ELSE sessions.source
END,
user_id = COALESCE(sessions.user_id, excluded.user_id),
title = COALESCE(sessions.title, excluded.title),
model = COALESCE(sessions.model, excluded.model),
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
""",
(session_id, source, user_id, title, model, parent_session_id, now, now),
)
return session_id
return self._execute_write(_do)
def get_session_record(self, session_id: str) -> SessionRecord | None:
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
return SessionRecord.from_row(row) if row else None
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""保存本 session 组装后的完整 system prompt snapshot。"""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET system_prompt = ?, last_active = ?
WHERE id = ?
""",
(system_prompt, time.time(), session_id),
)
self._execute_write(_do)
def update_usage(
self,
session_id: str,
*,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: float = 0.0,
actual_cost_usd: float | None = None,
absolute: bool = False,
) -> None:
"""更新会话 usage。默认按增量累加。"""
if absolute:
sql = """
UPDATE sessions
SET input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = ?,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
else:
sql = """
UPDATE sessions
SET input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = estimated_cost_usd + ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
def _do(conn: sqlite3.Connection) -> None:
conn.execute(sql, params)
self._execute_write(_do)
def append_message(
self,
session_id: str,
*,
run_id: str | None = None,
role: str,
event_type: str | None = None,
event_payload: dict[str, Any] | None = None,
context_visible: bool = True,
content: str | None = None,
tool_name: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
finish_reason: str | None = None,
reasoning: str | None = None,
reasoning_details: Any | None = None,
codex_reasoning_items: Any | None = None,
source: str = "unknown",
title: str | None = None,
model: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> int:
"""向指定 session 追加一条消息。"""
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
now = time.time()
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
preview = (content or "")[:120] if role == "user" and content else None
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
def _do(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"""
INSERT INTO messages (
session_id, run_id, role, event_type, event_payload, context_visible, content,
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
reasoning_details, codex_reasoning_items
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
run_id,
role,
event_type or role,
event_payload_json,
1 if context_visible else 0,
content,
tool_name,
tool_calls_json,
tool_call_id,
now,
finish_reason,
reasoning,
reasoning_details_json,
codex_items_json,
),
)
conn.execute(
"""
UPDATE sessions
SET last_active = ?,
message_count = message_count + 1,
tool_call_count = tool_call_count + ?,
model = COALESCE(model, ?),
preview = CASE
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
ELSE preview
END
WHERE id = ?
""",
(now, tool_call_count, model, preview, preview, session_id),
)
return int(cursor.lastrowid)
return self._execute_write(_do)
def get_message_records(self, session_id: str) -> list[MessageRecord]:
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ?
ORDER BY timestamp, id
""",
(session_id,),
)
return [MessageRecord.from_row(row) for row in rows]
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
"""
return self.get_message_records(session_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按时间顺序列出当前 session 中出现过的 run_id。"""
rows = self._fetchall(
"""
SELECT run_id
FROM messages
WHERE session_id = ? AND run_id IS NOT NULL
GROUP BY run_id
ORDER BY MIN(timestamp), MIN(id)
""",
(session_id,),
)
return [str(row["run_id"]) for row in rows if row.get("run_id")]
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 run 对应的事件片段。"""
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ? AND run_id = ?
ORDER BY timestamp, id
""",
(session_id, run_id),
)
return [MessageRecord.from_row(row) for row in rows]
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for record in self.get_event_records(session_id):
if not record.context_visible:
continue
messages.append(record.to_conversation_message())
return messages
def end_session(self, session_id: str, end_reason: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = ?, end_reason = ?, last_active = ?
WHERE id = ?
""",
(time.time(), end_reason, time.time(), session_id),
)
self._execute_write(_do)
def reopen_session(self, session_id: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = NULL, end_reason = NULL, last_active = ?
WHERE id = ?
""",
(time.time(), session_id),
)
self._execute_write(_do)