修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Session state and persistence."""
|
||||
|
||||
from .manager import SessionManager
|
||||
from .models import MessageRecord, SessionRecord, SessionUsage
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
__all__ = [
|
||||
"MessageRecord",
|
||||
"SessionManager",
|
||||
"SessionRecord",
|
||||
"SessionSearchService",
|
||||
"SessionStore",
|
||||
"SessionUsage",
|
||||
]
|
||||
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .models import MessageRecord
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
|
||||
|
||||
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
|
||||
self.workspace = Path(workspace)
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
|
||||
self.store = SessionStore(self.db_path)
|
||||
self.search = SessionSearchService(self.store)
|
||||
|
||||
def close(self) -> None:
|
||||
self.store.close()
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
return self.store.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
record = self.store.get_session_record(session_id)
|
||||
return record.to_dict() if record is not None else None
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
session = self.get_session(session_id)
|
||||
if session is None:
|
||||
raise RuntimeError(f"Failed to create session {session_id!r}")
|
||||
return session
|
||||
|
||||
def append_message(self, session_id: str, **kwargs: Any) -> int:
|
||||
return self.store.append_message(session_id, **kwargs)
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
这里和 `get_messages_as_conversation()` 的区别很关键:
|
||||
- `get_event_records()` 面向 runtime / replay / audit,保留隐藏系统事件
|
||||
- `get_messages_as_conversation()` 面向 prompt builder,只暴露可进上下文的事件
|
||||
|
||||
第 6 阶段开始后,session 已不再只是“聊天消息存储”,而是在逐步收敛成
|
||||
“外部事件流 + 上层投影视图”。
|
||||
"""
|
||||
|
||||
return self.store.get_event_records(session_id)
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 direct run / future bus run 对应的事件片段。"""
|
||||
|
||||
return self.store.get_run_event_records(session_id, run_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按出现顺序列出当前 session 的所有 run_id。"""
|
||||
|
||||
return self.store.list_run_ids(session_id)
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
return self.store.get_messages_as_conversation(session_id)
|
||||
|
||||
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""返回适合注入 prompt 的可见历史切片。
|
||||
|
||||
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
|
||||
1. 只包含 `context_visible=True` 的事件
|
||||
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
|
||||
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
|
||||
"""
|
||||
|
||||
history = self.get_messages_as_conversation(session_id)
|
||||
sliced = history[-max_messages:]
|
||||
for index, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[index:]
|
||||
break
|
||||
return sliced
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""兼容旧名称,实际返回可见历史切片。"""
|
||||
|
||||
return self.get_visible_history(session_id, max_messages=max_messages)
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
self.store.update_system_prompt(session_id, system_prompt)
|
||||
|
||||
def update_usage(self, session_id: str, **kwargs: Any) -> None:
|
||||
self.store.update_usage(session_id, **kwargs)
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
self.store.end_session(session_id, end_reason)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
self.store.reopen_session(session_id)
|
||||
|
||||
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.list_sessions_rich(**kwargs)
|
||||
|
||||
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.search_messages(**kwargs)
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
return self.search.resolve_session_id(session_id_or_prefix)
|
||||
211
app-instance/backend/beaver/engine/session/models.py
Normal file
211
app-instance/backend/beaver/engine/session/models.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""Beaver session 子系统的数据模型。
|
||||
|
||||
这层只定义数据结构,不放数据库读写逻辑。目的是把:
|
||||
1. SQLite 行结构
|
||||
2. 运行时会话对象
|
||||
3. 对外暴露的 conversation message
|
||||
|
||||
三件事分开,避免后续所有地方都直接和裸字典耦合。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionUsage:
|
||||
"""会话维度的 usage/cost 统计。"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
actual_cost_usd: float | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"actual_cost_usd": self.actual_cost_usd,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MessageRecord:
|
||||
"""单条会话事件的结构化表示。
|
||||
|
||||
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
|
||||
1. 普通 user/assistant/tool 消息本身就是事件
|
||||
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
|
||||
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str | None = None
|
||||
timestamp: float | None = None
|
||||
message_id: int | None = None
|
||||
run_id: str | None = None
|
||||
event_type: str | None = None
|
||||
event_payload: dict[str, Any] | None = None
|
||||
context_visible: bool = True
|
||||
tool_name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
finish_reason: str | None = None
|
||||
reasoning: str | None = None
|
||||
reasoning_details: Any | None = None
|
||||
codex_reasoning_items: Any | None = None
|
||||
|
||||
def to_conversation_message(self) -> dict[str, Any]:
|
||||
"""转成 provider / context builder 可直接消费的消息格式。"""
|
||||
|
||||
if not self.context_visible:
|
||||
raise ValueError("Hidden session events cannot be converted into conversation messages")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.tool_name:
|
||||
payload["tool_name"] = self.tool_name
|
||||
if self.tool_calls:
|
||||
payload["tool_calls"] = self.tool_calls
|
||||
if self.tool_call_id:
|
||||
payload["tool_call_id"] = self.tool_call_id
|
||||
if self.finish_reason:
|
||||
payload["finish_reason"] = self.finish_reason
|
||||
if self.reasoning:
|
||||
payload["reasoning"] = self.reasoning
|
||||
if self.reasoning_details is not None:
|
||||
payload["reasoning_details"] = self.reasoning_details
|
||||
if self.codex_reasoning_items is not None:
|
||||
payload["codex_reasoning_items"] = self.codex_reasoning_items
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
|
||||
"""从 SQLite row/dict 恢复消息模型。"""
|
||||
|
||||
tool_calls = row.get("tool_calls")
|
||||
if isinstance(tool_calls, str):
|
||||
try:
|
||||
tool_calls = json.loads(tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
tool_calls = []
|
||||
|
||||
reasoning_details = row.get("reasoning_details")
|
||||
if isinstance(reasoning_details, str):
|
||||
try:
|
||||
reasoning_details = json.loads(reasoning_details)
|
||||
except json.JSONDecodeError:
|
||||
reasoning_details = None
|
||||
|
||||
codex_reasoning_items = row.get("codex_reasoning_items")
|
||||
if isinstance(codex_reasoning_items, str):
|
||||
try:
|
||||
codex_reasoning_items = json.loads(codex_reasoning_items)
|
||||
except json.JSONDecodeError:
|
||||
codex_reasoning_items = None
|
||||
|
||||
event_payload = row.get("event_payload")
|
||||
if isinstance(event_payload, str):
|
||||
try:
|
||||
event_payload = json.loads(event_payload)
|
||||
except json.JSONDecodeError:
|
||||
event_payload = None
|
||||
|
||||
return cls(
|
||||
message_id=row.get("id"),
|
||||
run_id=row.get("run_id"),
|
||||
role=row["role"],
|
||||
content=row.get("content"),
|
||||
event_type=row.get("event_type") or row.get("role"),
|
||||
event_payload=event_payload,
|
||||
context_visible=bool(row.get("context_visible", 1)),
|
||||
tool_name=row.get("tool_name"),
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=row.get("tool_call_id"),
|
||||
timestamp=row.get("timestamp"),
|
||||
finish_reason=row.get("finish_reason"),
|
||||
reasoning=row.get("reasoning"),
|
||||
reasoning_details=reasoning_details,
|
||||
codex_reasoning_items=codex_reasoning_items,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionRecord:
|
||||
"""单个 session 的结构化表示。"""
|
||||
|
||||
session_id: str
|
||||
source: str
|
||||
started_at: float
|
||||
last_active: float
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
model: str | None = None
|
||||
system_prompt: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
ended_at: float | None = None
|
||||
end_reason: str | None = None
|
||||
message_count: int = 0
|
||||
tool_call_count: int = 0
|
||||
preview: str | None = None
|
||||
usage: SessionUsage = field(default_factory=SessionUsage)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = {
|
||||
"id": self.session_id,
|
||||
"source": self.source,
|
||||
"user_id": self.user_id,
|
||||
"title": self.title,
|
||||
"model": self.model,
|
||||
"system_prompt": self.system_prompt,
|
||||
"parent_session_id": self.parent_session_id,
|
||||
"started_at": self.started_at,
|
||||
"last_active": self.last_active,
|
||||
"ended_at": self.ended_at,
|
||||
"end_reason": self.end_reason,
|
||||
"message_count": self.message_count,
|
||||
"tool_call_count": self.tool_call_count,
|
||||
"preview": self.preview,
|
||||
}
|
||||
payload.update(self.usage.to_dict())
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
|
||||
return cls(
|
||||
session_id=row["id"],
|
||||
source=row["source"],
|
||||
user_id=row.get("user_id"),
|
||||
title=row.get("title"),
|
||||
model=row.get("model"),
|
||||
system_prompt=row.get("system_prompt"),
|
||||
parent_session_id=row.get("parent_session_id"),
|
||||
started_at=row["started_at"],
|
||||
last_active=row["last_active"],
|
||||
ended_at=row.get("ended_at"),
|
||||
end_reason=row.get("end_reason"),
|
||||
message_count=row.get("message_count", 0),
|
||||
tool_call_count=row.get("tool_call_count", 0),
|
||||
preview=row.get("preview"),
|
||||
usage=SessionUsage(
|
||||
input_tokens=row.get("input_tokens", 0),
|
||||
output_tokens=row.get("output_tokens", 0),
|
||||
cache_read_tokens=row.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=row.get("cache_write_tokens", 0),
|
||||
reasoning_tokens=row.get("reasoning_tokens", 0),
|
||||
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
|
||||
actual_cost_usd=row.get("actual_cost_usd"),
|
||||
),
|
||||
)
|
||||
151
app-instance/backend/beaver/engine/session/search.py
Normal file
151
app-instance/backend/beaver/engine/session/search.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Beaver session 子系统的检索能力。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionSearchService:
|
||||
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
|
||||
|
||||
def __init__(self, store: SessionStore) -> None:
|
||||
self.store = store
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_fts5_query(query: str) -> str:
|
||||
quoted_parts: list[str] = []
|
||||
|
||||
def preserve(match: re.Match[str]) -> str:
|
||||
quoted_parts.append(match.group(0))
|
||||
return f"\x00Q{len(quoted_parts) - 1}\x00"
|
||||
|
||||
sanitized = re.sub(r'"[^"]*"', preserve, query)
|
||||
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
|
||||
sanitized = re.sub(r"\*+", "*", sanitized)
|
||||
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
|
||||
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
|
||||
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
|
||||
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
|
||||
|
||||
for index, quoted in enumerate(quoted_parts):
|
||||
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
|
||||
return sanitized.strip()
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
|
||||
|
||||
exact = self.store.get_session_record(session_id_or_prefix)
|
||||
if exact is not None:
|
||||
return exact.session_id
|
||||
|
||||
escaped = (
|
||||
session_id_or_prefix
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
rows = self.store._fetchall(
|
||||
"""
|
||||
SELECT id
|
||||
FROM sessions
|
||||
WHERE id LIKE ? ESCAPE '\\'
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 2
|
||||
""",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
if len(rows) == 1:
|
||||
return rows[0]["id"]
|
||||
return None
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
*,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
include_children: bool = False,
|
||||
source: str | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""列出最近活跃的 session 及其摘要元数据。"""
|
||||
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if not include_children:
|
||||
clauses.append("parent_session_id IS NULL")
|
||||
if source:
|
||||
clauses.append("source = ?")
|
||||
params.append(source)
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
|
||||
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
||||
params.extend([limit, offset])
|
||||
rows = self.store._fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM sessions
|
||||
{where}
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
return rows
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
role_filter: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""使用 FTS5 搜索 session transcript。"""
|
||||
|
||||
query = self._sanitize_fts5_query(query)
|
||||
if not query:
|
||||
return []
|
||||
|
||||
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
|
||||
params: list[Any] = [query]
|
||||
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"s.source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
if role_filter:
|
||||
placeholders = ",".join("?" for _ in role_filter)
|
||||
clauses.append(f"m.role IN ({placeholders})")
|
||||
params.extend(role_filter)
|
||||
|
||||
params.extend([limit, offset])
|
||||
sql = f"""
|
||||
SELECT
|
||||
m.id,
|
||||
m.session_id,
|
||||
m.role,
|
||||
s.source,
|
||||
s.model,
|
||||
s.started_at AS session_started,
|
||||
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
|
||||
FROM messages_fts
|
||||
JOIN messages m ON m.id = messages_fts.rowid
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(clauses)}
|
||||
ORDER BY rank
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.store._fetchall(sql, tuple(params))
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc
|
||||
467
app-instance/backend/beaver/engine/session/store.py
Normal file
467
app-instance/backend/beaver/engine/session/store.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Beaver session 子系统的 SQLite 存储实现。
|
||||
|
||||
设计来源主要参考 Hermes-agent:
|
||||
1. SQLite 作为统一 session/transcript backend
|
||||
2. WAL 模式支持多读单写
|
||||
3. FTS5 支持跨 session 文本检索
|
||||
4. `parent_session_id` 支持 lineage
|
||||
|
||||
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from .models import MessageRecord, SessionRecord
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
title TEXT,
|
||||
model TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
last_active REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
estimated_cost_usd REAL DEFAULT 0,
|
||||
actual_cost_usd REAL,
|
||||
preview TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||||
run_id TEXT,
|
||||
role TEXT NOT NULL,
|
||||
event_type TEXT,
|
||||
event_payload TEXT,
|
||||
context_visible INTEGER NOT NULL DEFAULT 1,
|
||||
content TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_call_id TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
finish_reason TEXT,
|
||||
reasoning TEXT,
|
||||
reasoning_details TEXT,
|
||||
codex_reasoning_items TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
|
||||
"""
|
||||
|
||||
FTS_TABLE_SQL = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||
content,
|
||||
content=messages,
|
||||
content_rowid=id
|
||||
);
|
||||
"""
|
||||
|
||||
FTS_TRIGGER_SQL = """
|
||||
DROP TRIGGER IF EXISTS messages_fts_insert;
|
||||
DROP TRIGGER IF EXISTS messages_fts_delete;
|
||||
DROP TRIGGER IF EXISTS messages_fts_update;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
"""
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""SQLite-backed session store."""
|
||||
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.executescript(SCHEMA_SQL)
|
||||
try:
|
||||
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
|
||||
except sqlite3.OperationalError:
|
||||
self._conn.executescript(FTS_TABLE_SQL)
|
||||
self._conn.executescript(FTS_TRIGGER_SQL)
|
||||
# 旧版本可能把 hidden 事件也写进了 FTS;初始化时顺手清掉这些噪声项。
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', id, content
|
||||
FROM messages
|
||||
WHERE context_visible = 0 AND content IS NOT NULL
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.close()
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
return result
|
||||
except BaseException:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
row = self._conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
rows = self._conn.execute(sql, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> str:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (
|
||||
id, source, user_id, title, model, parent_session_id, started_at, last_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
source = CASE
|
||||
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
|
||||
ELSE sessions.source
|
||||
END,
|
||||
user_id = COALESCE(sessions.user_id, excluded.user_id),
|
||||
title = COALESCE(sessions.title, excluded.title),
|
||||
model = COALESCE(sessions.model, excluded.model),
|
||||
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
|
||||
""",
|
||||
(session_id, source, user_id, title, model, parent_session_id, now, now),
|
||||
)
|
||||
return session_id
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_session_record(self, session_id: str) -> SessionRecord | None:
|
||||
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
return SessionRecord.from_row(row) if row else None
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""保存本 session 组装后的完整 system prompt snapshot。"""
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET system_prompt = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(system_prompt, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: float = 0.0,
|
||||
actual_cost_usd: float | None = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""更新会话 usage。默认按增量累加。"""
|
||||
|
||||
if absolute:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = ?,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
cache_write_tokens = cache_write_tokens + ?,
|
||||
reasoning_tokens = reasoning_tokens + ?,
|
||||
estimated_cost_usd = estimated_cost_usd + ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||||
END,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(sql, params)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
run_id: str | None = None,
|
||||
role: str,
|
||||
event_type: str | None = None,
|
||||
event_payload: dict[str, Any] | None = None,
|
||||
context_visible: bool = True,
|
||||
content: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
finish_reason: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
reasoning_details: Any | None = None,
|
||||
codex_reasoning_items: Any | None = None,
|
||||
source: str = "unknown",
|
||||
title: str | None = None,
|
||||
model: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> int:
|
||||
"""向指定 session 追加一条消息。"""
|
||||
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
now = time.time()
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
|
||||
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
|
||||
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
|
||||
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
|
||||
preview = (content or "")[:120] if role == "user" and content else None
|
||||
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> int:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
session_id, run_id, role, event_type, event_payload, context_visible, content,
|
||||
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
|
||||
reasoning_details, codex_reasoning_items
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
session_id,
|
||||
run_id,
|
||||
role,
|
||||
event_type or role,
|
||||
event_payload_json,
|
||||
1 if context_visible else 0,
|
||||
content,
|
||||
tool_name,
|
||||
tool_calls_json,
|
||||
tool_call_id,
|
||||
now,
|
||||
finish_reason,
|
||||
reasoning,
|
||||
reasoning_details_json,
|
||||
codex_items_json,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET last_active = ?,
|
||||
message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ?,
|
||||
model = COALESCE(model, ?),
|
||||
preview = CASE
|
||||
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
|
||||
ELSE preview
|
||||
END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now, tool_call_count, model, preview, preview, session_id),
|
||||
)
|
||||
return int(cursor.lastrowid)
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_message_records(self, session_id: str) -> list[MessageRecord]:
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
|
||||
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
|
||||
"""
|
||||
|
||||
return self.get_message_records(session_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按时间顺序列出当前 session 中出现过的 run_id。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT run_id
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id IS NOT NULL
|
||||
GROUP BY run_id
|
||||
ORDER BY MIN(timestamp), MIN(id)
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [str(row["run_id"]) for row in rows if row.get("run_id")]
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 run 对应的事件片段。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id, run_id),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
for record in self.get_event_records(session_id):
|
||||
if not record.context_visible:
|
||||
continue
|
||||
messages.append(record.to_conversation_message())
|
||||
return messages
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = ?, end_reason = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), end_reason, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = NULL, end_reason = NULL, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
Reference in New Issue
Block a user