修改了nanobot,往Hermes agent的风格走,进度1/3

This commit is contained in:
2026-04-20 18:11:14 +08:00
parent cdfc222c9f
commit 36882a7d7b
261 changed files with 12659 additions and 604 deletions

View File

@ -0,0 +1,31 @@
"""Unified Beaver agent engine.
这里不做顶层 eager import避免子模块导入时触发循环依赖。
对外仍然保留同样的导出名称,但改成按需加载。
"""
from __future__ import annotations
from typing import Any
__all__ = ["AgentLoop", "AgentProfile", "AgentRunResult", "EngineLoader", "EngineLoadResult"]
def __getattr__(name: str) -> Any:
if name == "EngineLoader":
from .loader import EngineLoader
return EngineLoader
if name == "EngineLoadResult":
from .loader import EngineLoadResult
return EngineLoadResult
if name in {"AgentLoop", "AgentProfile", "AgentRunResult"}:
from .loop import AgentLoop, AgentProfile, AgentRunResult
return {
"AgentLoop": AgentLoop,
"AgentProfile": AgentProfile,
"AgentRunResult": AgentRunResult,
}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,17 @@
"""Context assembly for agent runs."""
from .builder import (
ContextBuildInput,
ContextBuildResult,
ContextBuilder,
SessionContext,
SkillContext,
)
__all__ = [
"ContextBuildInput",
"ContextBuildResult",
"ContextBuilder",
"SessionContext",
"SkillContext",
]

View File

@ -0,0 +1,331 @@
"""Beaver 运行时上下文装配器。
这个模块是 `session` 和 `provider` 之间的中间层,职责非常明确:
1. 把运行前已经准备好的静态/半静态上下文拼成一份稳定的 system prompt
2. 把从 session 事件流里裁剪出的“可见历史”和当前用户输入整理成 provider 可直接消费的 messages
3. 在 tool loop 中,持续把 assistant/tool 消息按统一格式追加回消息数组
为什么这层必须单独存在:
1. `AgentLoop` 不应该自己拼 prompt否则很快又会长成一个大文件
2. `memory`、`skills`、`session` 的注入顺序需要固定,否则模型行为会漂移
3. tool loop 前后追加消息的格式必须统一,否则不同 provider 很容易出兼容问题
这一版 builder 的设计目标是“最小但稳定”:
1. 先服务单 agent 主链
2. 先支持 frozen curated memory而不是 live memory
3. skills 按 Hermes 风格支持“显式激活消息注入”,不在这里做磁盘扫描
4. 为后续 channel / gateway / team metadata 预留注入位,但不提前做复杂逻辑
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from beaver.memory.curated.snapshot import MemorySnapshot
@dataclass(slots=True)
class SkillContext:
"""单个已激活 skill 的最小表示。
这里故意不把 skill 设计成复杂对象,只保留 builder 真正关心的两部分:
- `name`:用于生成激活提示
- `content`skill 的完整正文
注意:按当前 Hermes 风格实现skill 正文不再塞进 system prompt而是转成显式消息注入。
"""
name: str
content: str
@dataclass(slots=True)
class SessionContext:
"""当前运行轮次的会话元数据。
这不是 session store 里的完整 record而是 prompt builder 关心的那一小部分:
- 哪个 session
- 来源是什么
- 当前使用什么 model
- 是否有 channel/chat/user 这类运行路由信息
把它单独抽出来的原因是:
1. builder 不应该知道 SQLite row 长什么样
2. 不同入口CLI/Web/Gateway都可以把自己的 metadata 收敛成同一种结构
"""
session_id: str | None = None
source: str | None = None
model: str | None = None
user_id: str | None = None
channel: str | None = None
chat_id: str | None = None
parent_session_id: str | None = None
@dataclass(slots=True)
class ContextBuildInput:
"""一次上下文构建所需的全部输入。
这个对象的作用不是“炫技式封装”,而是把主链里零散的数据显式收口。
这样一来,后面 `AgentLoop.process_direct()` 在组装参数时会更清晰,也更容易测试。
字段分组:
- 身份/基础段:`base_system_prompt`
- 会话可见历史:`history`
- 当前输入:`current_user_input`
- 冻结记忆:`memory_snapshot`
- 技能:`activated_skills`
- 运行元数据:`session_context` / `execution_context`
- 额外扩展:`extra_sections`
"""
base_system_prompt: str = ""
history: list[dict[str, Any]] = field(default_factory=list)
current_user_input: str | list[dict[str, Any]] | None = None
memory_snapshot: MemorySnapshot | None = None
activated_skills: list[SkillContext] = field(default_factory=list)
session_context: SessionContext | None = None
execution_context: str | None = None
extra_sections: list[str] = field(default_factory=list)
@dataclass(slots=True)
class ContextBuildResult:
"""一次上下文构建后的结果。
保留 `system_prompt` 的原因:
1. `SessionManager.update_system_prompt()` 需要把最终注入的 prompt snapshot 落盘
2. 调试时经常需要区分“system prompt 长什么样”和“messages 长什么样”
3. 后面如果做 prompt audit / replay也会直接复用这个结果
"""
system_prompt: str
messages: list[dict[str, Any]]
class ContextBuilder:
"""负责把运行时输入装配成稳定上下文。
这一层故意保持“无 IO、无数据库、无网络”
- 不直接读 session store
- 不直接读 memory store
- 不直接扫描 skills 目录
这样 builder 的行为只由输入决定,便于单测,也便于后面并到真正的 AgentLoop 主链里。
"""
def build_system_prompt(
self,
build_input: ContextBuildInput,
) -> str:
"""构建 system prompt。
顺序固定非常重要,当前约定是:
1. base system prompt
2. session metadata
3. execution context
4. frozen memory snapshot
5. extra sections
这样设计的原因:
- 身份与总规则要最靠前
- session/execution 是本轮运行语境,优先级高于长期记忆
- memory 必须是 frozen snapshot避免中途写 memory 后 prompt 失真
- activated skill 正文按 Hermes 风格放到显式消息里,避免 system prompt 持续膨胀
"""
sections: list[str] = []
base_system_prompt = (build_input.base_system_prompt or "").strip()
if base_system_prompt:
sections.append(base_system_prompt)
session_section = self._render_session_section(build_input.session_context)
if session_section:
sections.append(session_section)
execution_context = (build_input.execution_context or "").strip()
if execution_context:
sections.append(f"# Execution Context\n\n{execution_context}")
if build_input.memory_snapshot is not None:
# 这里明确只读 frozen snapshot而不是去读 live memory store。
# 否则一旦当前会话中途写 memorysystem prompt 语义就会和会话开头不一致。
snapshot_sections = build_input.memory_snapshot.as_prompt_sections()
if snapshot_sections:
sections.extend(snapshot_sections)
for extra in build_input.extra_sections:
cleaned = (extra or "").strip()
if cleaned:
sections.append(cleaned)
return "\n\n---\n\n".join(sections)
def build_messages(
self,
build_input: ContextBuildInput,
) -> ContextBuildResult:
"""构建一次模型调用的完整 messages。
这里做三件事:
1. 先生成最终 system prompt
2. 按 Hermes 风格,把已激活 skill 的完整正文作为显式消息注入
3. 把历史消息按原顺序接到后面
4. 如果存在当前用户输入,则把本轮输入追加为最后一条 user message
注意:
- `history` 默认被视为“已经由 session/context 上游从完整事件流中裁剪好的可见结构”
- builder 不负责裁剪历史窗口,这件事应由 session/loop 上层决定
- builder 只做最小格式统一
"""
system_prompt = self.build_system_prompt(build_input)
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
messages.extend(self.build_skill_activation_messages(build_input.activated_skills))
for message in build_input.history:
# 当前 builder 自己负责生成唯一的 system prompt。
# 如果上游 history 已经混入 system 消息,这里要主动跳过,避免双 system。
if message.get("role") == "system":
continue
messages.append(dict(message))
if build_input.current_user_input is not None:
messages.append(
{
"role": "user",
"content": build_input.current_user_input,
}
)
return ContextBuildResult(
system_prompt=system_prompt,
messages=messages,
)
def add_tool_result(
self,
messages: list[dict[str, Any]],
*,
tool_call_id: str,
tool_name: str,
result: str,
) -> list[dict[str, Any]]:
"""向消息数组追加一条 tool result。
为什么这个函数放在 builder而不是塞回 `AgentLoop`
- tool message 的结构必须和 provider 兼容
- 统一在这里追加,可以避免不同执行路径拼出不同字段名
- 后面如果要兼容更多 provider 差异,也只改这一层
这里返回原 list 本身,保持旧项目的“可链式追加”习惯。
"""
messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_name,
"content": result,
}
)
return messages
def add_assistant_message(
self,
messages: list[dict[str, Any]],
*,
content: str | None,
tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None,
) -> list[dict[str, Any]]:
"""向消息数组追加 assistant 消息。
这里有两个实现细节非常重要:
1. 无论 `content` 是否为空,都显式写入 `content` 键
原因是部分 provider 在 assistant 带 `tool_calls` 时仍要求消息里存在 `content`
2. `reasoning_content` 只有在非空时才附带
因为这属于思考模型扩展字段,不应污染普通 provider 路径
"""
message: dict[str, Any] = {
"role": "assistant",
"content": content,
}
if tool_calls:
message["tool_calls"] = tool_calls
if reasoning_content is not None:
message["reasoning_content"] = reasoning_content
messages.append(message)
return messages
def _render_session_section(self, session_context: SessionContext | None) -> str | None:
"""把运行时 session metadata 渲染成一个可读 section。
这一段的目标不是让模型“记住所有数据库字段”,而是给它足够的当前运行语境。
常见用途包括:
- 知道当前来自 CLI 还是 Web/Gateway
- 知道当前使用什么 model
- 知道当前 channel/chat_id便于后续多渠道行为约束
"""
if session_context is None:
return None
rows: list[str] = []
if session_context.session_id:
rows.append(f"Session ID: {session_context.session_id}")
if session_context.source:
rows.append(f"Source: {session_context.source}")
if session_context.model:
rows.append(f"Model: {session_context.model}")
if session_context.user_id:
rows.append(f"User ID: {session_context.user_id}")
if session_context.channel:
rows.append(f"Channel: {session_context.channel}")
if session_context.chat_id:
rows.append(f"Chat ID: {session_context.chat_id}")
if session_context.parent_session_id:
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
if not rows:
return None
return "# Current Session\n\n" + "\n".join(rows)
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
"""按 Hermes 风格把已激活 skill 转成显式消息。
关键区别:
- system prompt 只保留轻量 skills index
- 真正生效的 skill 正文通过额外消息块显式加载
这样模型不需要“从摘要里猜怎么读到正文”,而是直接拿到完整指导内容。
"""
messages: list[dict[str, str]] = []
for skill in activated_skills:
content = (skill.content or "").strip()
if not content:
continue
messages.append(
{
"role": "user",
"content": (
f'[SYSTEM: The "{skill.name}" skill is active for this run. '
"Follow its instructions as active guidance unless the user overrides them.]\n\n"
f"{content}"
),
}
)
return messages

View File

@ -0,0 +1,154 @@
"""Centralized runtime loading for Beaver agents."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable
from beaver.engine.context import ContextBuilder
from beaver.engine.session import SessionManager
from beaver.memory.curated.store import MemoryStore
from beaver.services.memory_service import MemoryService
from beaver.skills import SkillAssembler, SkillsLoader
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
@dataclass(slots=True)
class EngineLoadResult:
"""描述当前 agent runtime 已经装好的依赖。
这里同时保留两类字段:
1. `tools/skills/memory_stores/permissions`
- 便于做状态展示、调试、轻量测试
2. `session_manager/tool_registry/...`
- 供真正的运行时主链直接使用
"""
workspace: Path
tools: list[str] = field(default_factory=list)
skills: list[str] = field(default_factory=list)
memory_stores: list[str] = field(default_factory=list)
permissions: list[str] = field(default_factory=list)
session_manager: SessionManager | None = None
curated_memory_store: MemoryStore | None = None
memory_service: MemoryService | None = None
tool_registry: ToolRegistry | None = None
tool_executor: ToolExecutor | None = None
context_builder: ContextBuilder | None = None
skills_loader: SkillsLoader | None = None
skill_assembler: SkillAssembler | None = None
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
closed: bool = False
def register_closeable(self, name: str, close_fn: Callable[[], None]) -> None:
"""登记一个由 runtime 统一关闭的资源。"""
self.closeables.append((name, close_fn))
def close(self) -> None:
"""按后进先出顺序关闭 runtime 资源。
这一步先保持同步、最小、可组合:
1. 只管理已经明确需要关闭的资源
2. 暂不引入 async shutdown 协议
3. 为后续 Web/Gateway lifespan 留统一入口
"""
if self.closed:
return
errors: list[tuple[str, BaseException]] = []
for name, close_fn in reversed(self.closeables):
try:
close_fn()
except BaseException as exc: # pragma: no cover - defensive cleanup path
errors.append((name, exc))
self.closed = True
if errors:
parts = ", ".join(f"{name}: {exc}" for name, exc in errors)
raise RuntimeError(f"Runtime shutdown failed for {parts}")
class EngineLoader:
"""为任意 Beaver agent 装载共享 runtime 能力。
当前先做“最小可运行主链”需要的装配:
- session manager
- curated memory store
- context builder
- built-in tools
- tool executor
等主链跑稳后,再把 skills、权限、MCP、delegation 逐步加进来。
"""
def __init__(
self,
*,
workspace: str | Path | None = None,
session_manager: SessionManager | None = None,
curated_memory_store: MemoryStore | None = None,
memory_service: MemoryService | None = None,
tool_registry: ToolRegistry | None = None,
context_builder: ContextBuilder | None = None,
skills_loader: SkillsLoader | None = None,
skill_assembler: SkillAssembler | None = None,
) -> None:
self.workspace = Path(workspace or Path.cwd())
self._session_manager = session_manager
self._curated_memory_store = curated_memory_store
self._memory_service = memory_service
self._tool_registry = tool_registry
self._context_builder = context_builder
self._skills_loader = skills_loader
self._skill_assembler = skill_assembler
def load(self) -> EngineLoadResult:
"""装配当前主链需要的最小 runtime 对象。"""
workspace = self.workspace
session_manager = self._session_manager or SessionManager(workspace)
curated_root = workspace / "memory" / "curated"
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
memory_service.initialize()
tool_registry = self._tool_registry or ToolRegistry()
skills_loader = self._skills_loader or SkillsLoader(workspace)
if self._tool_registry is None:
# 这里先注册最小工具集,满足主链的 tool loop。
tool_registry.register_many(
[
ObjectBackedTool(EchoTool()),
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
ObjectBackedTool(SessionSearchTool(db=session_manager)),
]
)
context_builder = self._context_builder or ContextBuilder()
tool_executor = ToolExecutor(tool_registry)
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
result = EngineLoadResult(
workspace=workspace,
tools=[spec.name for spec in tool_registry.list_specs()],
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
memory_stores=["curated"],
permissions=[],
session_manager=session_manager,
curated_memory_store=memory_service.get_store(),
memory_service=memory_service,
tool_registry=tool_registry,
tool_executor=tool_executor,
context_builder=context_builder,
skills_loader=skills_loader,
skill_assembler=skill_assembler,
)
if self._session_manager is None:
result.register_closeable("session_manager", session_manager.close)
return result

View File

@ -0,0 +1,689 @@
"""Unified agent loop used by all Beaver agents."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4
from beaver.engine.context import ContextBuildInput, SessionContext
from beaver.engine.providers import ProviderBundle, make_provider_bundle
from beaver.tools import ToolContext
from .loader import EngineLoader, EngineLoadResult
@dataclass(slots=True)
class AgentProfile:
"""Runtime profile for a Beaver agent instance."""
name: str = "default"
system_prompt: str = ""
default_model: str = "gpt-4.1-mini"
max_tokens: int = 4096
temperature: float = 0.2
max_tool_iterations: int = 8
@dataclass(slots=True)
class AgentRunResult:
"""一次 direct run 的最小结果结构。"""
session_id: str
run_id: str
output_text: str
finish_reason: str
tool_iterations: int
provider_name: str | None = None
model: str | None = None
usage: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class _DirectRunRequest:
"""运行循环中的单个 direct task。"""
task: str
kwargs: dict[str, Any]
future: asyncio.Future[AgentRunResult]
class AgentLoop:
"""Single execution kernel shared by root agents and delegated agents."""
def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None:
self.profile = profile or AgentProfile()
self.loader = loader or EngineLoader()
self.loaded: EngineLoadResult | None = None
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
self._running = False
self._stop_requested = False
def boot(self) -> EngineLoadResult:
"""Load shared runtime capabilities once for this agent instance."""
if self.loaded is None:
self.loaded = self.loader.load()
return self.loaded
@property
def is_running(self) -> bool:
return self._running
async def run(self) -> None:
"""启动最小运行循环,顺序消费提交进来的 direct tasks。
第一版故意保持克制:
1. 只做单消费者串行消费
2. 真正执行仍复用 `process_direct()`
3. 不引入 bus / worker / priority / retry
"""
if self._running:
raise RuntimeError("AgentLoop.run() is already active")
self.boot()
self._run_queue = asyncio.Queue()
self._running = True
self._stop_requested = False
try:
while True:
item = await self._run_queue.get()
if item is None:
if self._stop_requested:
break
continue
if item.future.cancelled():
continue
try:
result = await self._process_direct_impl(item.task, **item.kwargs)
except asyncio.CancelledError:
if not item.future.done():
item.future.cancel()
raise
except Exception as exc: # pragma: no cover - defensive queue path
if not item.future.done():
item.future.set_exception(exc)
else:
if not item.future.done():
item.future.set_result(result)
finally:
if self._run_queue is not None:
while True:
try:
pending = self._run_queue.get_nowait()
except asyncio.QueueEmpty:
break
if isinstance(pending, _DirectRunRequest) and not pending.future.done():
pending.future.set_exception(
RuntimeError("AgentLoop.run() stopped before processing the queued task")
)
self._running = False
self._stop_requested = False
self._run_queue = None
async def stop(self) -> None:
"""停止运行循环。
第一版语义:
- 不再接收新任务
- 当前已经取出的任务允许收尾
- 不自动 close runtime
"""
if not self._running or self._run_queue is None:
return
self._stop_requested = True
await self._run_queue.put(None)
async def submit_direct(
self,
task: str,
**kwargs: Any,
) -> AgentRunResult:
"""向运行中的 loop 提交一个 direct task并等待结果。"""
if not self._running or self._run_queue is None:
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
if self._stop_requested:
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
return await future
def close(self) -> None:
"""关闭当前 loop 持有的 runtime。
第 6 阶段先把生命周期最小骨架立住:
- `boot()` 负责建立 runtime
- `close()` 负责释放由 runtime 持有的资源
- 之后再在此基础上扩 `run()/stop()/shutdown hooks`
"""
if self._running:
raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first")
if self.loaded is None:
return
try:
self.loaded.close()
finally:
self.loaded = None
async def process_direct(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""跑通最小 direct run 主链。
当前主链刻意保持克制,只解决这些事情:
1. 确保 session 存在
2. 用 frozen memory + history 组 prompt
3. 调 provider
4. 若有 tool calls则进入最小 tool loop
5. 把 user/assistant/tool 消息和 usage 写回 session
"""
if self._running:
raise RuntimeError(
"AgentLoop.process_direct() is disabled while run() is active; "
"submit tasks via submit_direct() instead."
)
return await self._process_direct_impl(
task,
session_id=session_id,
source=source,
user_id=user_id,
title=title,
execution_context=execution_context,
model=model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model,
max_tokens=max_tokens,
temperature=temperature,
max_tool_iterations=max_tool_iterations,
provider_bundle=provider_bundle,
)
async def _process_direct_impl(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""真正执行一轮 direct run 的内部实现。
规则:
- 外部直接调用时走 `process_direct()`
- 运行循环内部消费时走 `_process_direct_impl()`
- 这样才能保证 run 模式下外部不能绕过队列直接执行
"""
loaded = self.boot()
session_manager = self._require_loaded("session_manager")
memory_service = self._require_loaded("memory_service")
context_builder = self._require_loaded("context_builder")
tool_registry = self._require_loaded("tool_registry")
tool_executor = self._require_loaded("tool_executor")
skill_assembler = self._require_loaded("skill_assembler")
resolved_session_id = session_id or uuid4().hex
resolved_run_id = uuid4().hex
resolved_model = model or self.profile.default_model
resolved_max_tokens = max_tokens or self.profile.max_tokens
resolved_temperature = self.profile.temperature if temperature is None else temperature
resolved_max_tool_iterations = (
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
)
# 每次新运行开始前都通过 MemoryService 刷新 live state。
# 这样 memory policy 会收口在 service而不是散在 loop 里。
memory_service.reload_for_new_run()
session_manager.ensure_session(
resolved_session_id,
source=source,
model=resolved_model,
title=title,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_started",
event_payload={
"source": source,
"model": resolved_model,
"agent_name": self.profile.name,
},
content=task,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = False
iterations = 0
final_usage: dict[str, Any] = {}
final_provider_name: str | None = provider_name
final_model: str | None = resolved_model
try:
bundle = provider_bundle or make_provider_bundle(
model=resolved_model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model or "text-embedding-v4",
)
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
skill_selector_model = (
bundle.auxiliary_runtime.model
if bundle.auxiliary_runtime is not None
else bundle.main_runtime.model
)
assembled_skills = await skill_assembler.assemble(
task_description=task,
provider=skill_selector_provider,
model=skill_selector_model,
embedding_runtime=bundle.embedding_runtime,
)
skill_activation_messages = context_builder.build_skill_activation_messages(
assembled_skills.activated_skills
)
if skill_activation_messages:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="skill_activation_snapshotted",
event_payload={
"activation_messages": skill_activation_messages,
},
content="\n\n".join(message["content"] for message in skill_activation_messages) or None,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
build_input = ContextBuildInput(
base_system_prompt=self.profile.system_prompt,
history=session_manager.get_history(resolved_session_id),
current_user_input=task,
memory_snapshot=memory_service.get_snapshot(),
activated_skills=assembled_skills.activated_skills,
session_context=SessionContext(
session_id=resolved_session_id,
source=source,
model=resolved_model,
user_id=user_id,
),
execution_context=execution_context,
)
context_result = context_builder.build_messages(build_input)
session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="system_prompt_snapshotted",
event_payload={
"source": source,
"model": resolved_model,
"system_prompt_length": len(context_result.system_prompt),
},
content=context_result.system_prompt,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = True
provider = bundle.main_provider
messages = list(context_result.messages)
tool_schemas = tool_registry.export_provider_schemas()
tool_context = ToolContext(
workspace=str(loaded.workspace),
session_id=resolved_session_id,
user_id=user_id,
services={
"session_manager": session_manager,
"memory_service": memory_service,
"memory_store": memory_service.get_store(),
"tool_registry": tool_registry,
},
metadata={
"source": source,
"agent_name": self.profile.name,
},
)
final_text = ""
final_finish_reason = "stop"
final_provider_name = bundle.main_runtime.provider_name
final_model = bundle.main_runtime.model
while True:
response = await provider.chat(
messages=messages,
tools=tool_schemas,
model=final_model,
max_tokens=resolved_max_tokens,
temperature=resolved_temperature,
)
final_provider_name = response.provider_name or final_provider_name
final_model = response.model or final_model
final_usage = self._merge_usage(final_usage, response.usage or {})
self._record_usage(session_manager, resolved_session_id, response.usage or {})
assistant_tool_calls = self._serialize_tool_calls(response.tool_calls)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=response.content,
tool_calls=assistant_tool_calls or None,
finish_reason=response.finish_reason,
reasoning=response.reasoning_content,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=response.content,
tool_calls=assistant_tool_calls or None,
reasoning_content=response.reasoning_content,
)
if not response.has_tool_calls:
final_text = response.content or ""
final_finish_reason = response.finish_reason or "stop"
break
if iterations >= resolved_max_tool_iterations:
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
final_finish_reason = "max_tool_iterations"
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=final_text,
finish_reason=final_finish_reason,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=final_text,
)
break
iterations += 1
for tool_call in response.tool_calls:
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="tool",
event_type="tool_result_recorded",
event_payload={
"success": result.success,
"error": result.error,
},
content=result.content,
tool_name=result.tool_name,
tool_call_id=tool_call.id,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_tool_result(
messages,
tool_call_id=tool_call.id,
tool_name=result.tool_name,
result=result.content,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_completed",
event_payload={
"finish_reason": final_finish_reason,
"tool_iterations": iterations,
},
content=final_text,
finish_reason=final_finish_reason,
context_visible=False,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
return AgentRunResult(
session_id=resolved_session_id,
run_id=resolved_run_id,
output_text=final_text,
finish_reason=final_finish_reason,
tool_iterations=iterations,
provider_name=final_provider_name,
model=final_model,
usage=final_usage,
)
except Exception as exc:
if not user_message_recorded:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
return self._build_error_result(
session_manager=session_manager,
session_id=resolved_session_id,
run_id=resolved_run_id,
source=source,
title=title,
user_id=user_id,
model=final_model or resolved_model,
message=f"Run failed before completion: {exc}",
tool_iterations=iterations,
provider_name=final_provider_name,
usage=final_usage,
)
def _require_loaded(self, field_name: str) -> Any:
loaded = self.boot()
value = getattr(loaded, field_name)
if value is None:
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
return value
@staticmethod
def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
payload: list[dict[str, Any]] = []
for tool_call in tool_calls:
payload.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
)
return payload
@staticmethod
def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None:
"""把 provider usage 映射到 session usage 字段。
这里先做最常见字段的最小映射:
- prompt_tokens -> input_tokens
- completion_tokens -> output_tokens
后面如果 provider 层补了更细的 cache/reasoning/cost再往这里扩。
"""
if not usage:
return
session_manager.update_usage(
session_id,
input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0),
output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0),
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
)
@staticmethod
def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]:
"""把多轮 provider usage 合并成一次 run 的累计 usage。"""
merged = dict(total)
for key, value in delta.items():
if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)):
merged[key] = merged.get(key, 0) + value
else:
merged[key] = value
return merged
@staticmethod
def _build_error_result(
*,
session_manager: Any,
session_id: str,
run_id: str,
source: str,
title: str | None,
user_id: str | None,
model: str | None,
message: str,
tool_iterations: int,
provider_name: str | None,
usage: dict[str, Any],
) -> AgentRunResult:
"""把主链中的未处理异常收口成可追踪的 assistant error turn。"""
session_manager.append_message(
session_id,
run_id=run_id,
role="assistant",
event_type="assistant_message_added",
content=message,
finish_reason="error",
source=source,
title=title,
model=model,
user_id=user_id,
)
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="run_failed",
event_payload={
"tool_iterations": tool_iterations,
"provider_name": provider_name,
},
content=message,
finish_reason="error",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
return AgentRunResult(
session_id=session_id,
run_id=run_id,
output_text=message,
finish_reason="error",
tool_iterations=tool_iterations,
provider_name=provider_name,
model=model,
usage=usage,
)

View File

@ -0,0 +1,33 @@
"""LLM provider adapters."""
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .chain import FallbackProviderChain
from .factory import (
ProviderBundle,
ProviderRoutingConfig,
ProviderRuntime,
ProviderTarget,
build_provider_runtime,
make_aux_provider,
make_fallback_provider,
make_main_provider,
make_provider_bundle,
make_provider_from_runtime,
)
__all__ = [
"FallbackProviderChain",
"LLMProvider",
"LLMResponse",
"ProviderBundle",
"ProviderRoutingConfig",
"ProviderRuntime",
"ProviderTarget",
"ToolCallRequest",
"build_provider_runtime",
"make_aux_provider",
"make_fallback_provider",
"make_main_provider",
"make_provider_bundle",
"make_provider_from_runtime",
]

View File

@ -0,0 +1,173 @@
"""Native Anthropic Messages API provider."""
from __future__ import annotations
import json
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import anthropic
except ModuleNotFoundError: # pragma: no cover
anthropic = None # type: ignore[assignment]
class AnthropicProvider(LLMProvider):
"""使用 Anthropic 原生 Messages API而不是强行走 OpenAI-compatible path。"""
def __init__(
self,
api_key: str | None = None,
default_model: str = "claude-sonnet-4-5",
api_base: str | None = None,
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self._client = None
def _client_or_raise(self):
if anthropic is None:
raise RuntimeError("anthropic package is not installed")
if self._client is None:
self._client = anthropic.AsyncAnthropic(
api_key=self.api_key,
base_url=self.api_base,
timeout=self.request_timeout_seconds,
)
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
try:
client = self._client_or_raise()
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
system_prompt, anthropic_messages = _convert_messages(messages)
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"system": system_prompt or "",
"messages": anthropic_messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if tools:
kwargs["tools"] = _convert_tools(tools)
try:
response = await client.messages.create(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
for block in response.content:
if block.type == "text":
content_parts.append(block.text)
elif block.type == "tool_use":
tool_calls.append(
ToolCallRequest(
id=block.id,
name=block.name,
arguments=block.input,
)
)
usage_payload = {}
if getattr(response, "usage", None):
usage_payload = {
"input_tokens": getattr(response.usage, "input_tokens", 0),
"output_tokens": getattr(response.usage, "output_tokens", 0),
}
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
usage=usage_payload,
provider_name="anthropic",
model=model or self.default_model,
)
def get_default_model(self) -> str:
return self.default_model
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
converted: list[dict[str, Any]] = []
for message in messages:
role = message.get("role")
if role == "system":
content = message.get("content")
system_prompt = content if isinstance(content, str) else ""
continue
if role == "tool":
converted.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message.get("tool_call_id"),
"content": message.get("content") or "",
}
],
}
)
continue
if role == "assistant" and message.get("tool_calls"):
content_blocks: list[dict[str, Any]] = []
if message.get("content"):
content_blocks.append({"type": "text", "text": message["content"]})
for tool_call in message.get("tool_calls", []):
function = tool_call.get("function", tool_call)
arguments = function.get("arguments")
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
content_blocks.append(
{
"type": "tool_use",
"id": tool_call.get("id"),
"name": function.get("name"),
"input": arguments or {},
}
)
converted.append({"role": "assistant", "content": content_blocks})
continue
content = message.get("content")
if isinstance(content, list):
blocks = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
blocks.append({"type": "text", "text": item.get("text", "")})
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
else:
converted.append({"role": role, "content": content or ""})
return system_prompt, converted
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
if not fn.get("name"):
continue
converted.append(
{
"name": fn["name"],
"description": fn.get("description") or "",
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
}
)
return converted

View File

@ -0,0 +1,98 @@
"""Beaver provider 子系统的统一契约。"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class ToolCallRequest:
"""模型返回的一次工具调用请求。"""
id: str
name: str
arguments: dict[str, Any]
@dataclass(slots=True)
class LLMResponse:
"""统一的模型响应结构。"""
content: str | None
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, Any] = field(default_factory=dict)
reasoning_content: str | None = None
provider_name: str | None = None
model: str | None = None
@property
def has_tool_calls(self) -> bool:
return bool(self.tool_calls)
class LLMProvider(ABC):
"""所有 provider 实现必须遵守的统一接口。"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
) -> None:
self.api_key = api_key
self.api_base = api_base
self.request_timeout_seconds = (
max(1.0, float(request_timeout_seconds))
if request_timeout_seconds is not None
else None
)
@staticmethod
def sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""清理 provider 普遍不接受的空 content。"""
result: list[dict[str, Any]] = []
for message in messages:
content = message.get("content")
if isinstance(content, str) and content == "":
clean = dict(message)
clean["content"] = None if (message.get("role") == "assistant" and message.get("tool_calls")) else "(empty)"
result.append(clean)
continue
if isinstance(content, list):
filtered = [
item
for item in content
if not (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
)
]
if len(filtered) != len(content):
clean = dict(message)
clean["content"] = filtered or "(empty)"
if message.get("role") == "assistant" and message.get("tool_calls") and not filtered:
clean["content"] = None
result.append(clean)
continue
result.append(message)
return result
@abstractmethod
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""统一聊天接口。"""
@abstractmethod
def get_default_model(self) -> str:
"""返回 provider 的默认模型名。"""

View File

@ -0,0 +1,145 @@
"""Provider chain helpers.
这里先实现最小可用的 fallback chain
- 每次调用都先尝试主 provider
- 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback
- fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路
这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。
"""
from __future__ import annotations
from .base import LLMProvider, LLMResponse
from .runtime import ProviderRuntime
class FallbackProviderChain(LLMProvider):
"""把 primary/fallback provider 封装成一个统一的 LLMProvider。"""
def __init__(
self,
primary_runtime: ProviderRuntime,
primary_provider: LLMProvider,
fallback_runtime: ProviderRuntime | None = None,
fallback_provider: LLMProvider | None = None,
) -> None:
super().__init__(
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
)
self.primary_runtime = primary_runtime
self.primary_provider = primary_provider
self.fallback_runtime = fallback_runtime
self.fallback_provider = fallback_provider
# 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。
# 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。
self._last_runtime = primary_runtime
self._last_provider = primary_provider
self._last_call_used_fallback = False
@property
def fallback_activated(self) -> bool:
"""最近一次 chat 是否实际用到了 fallback。"""
return self._last_call_used_fallback
@property
def active_runtime(self) -> ProviderRuntime:
"""最近一次 chat 实际使用的 runtime。"""
return self._last_runtime
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
self._last_provider = self.primary_provider
self._last_runtime = self.primary_runtime
self._last_call_used_fallback = False
response = await self._safe_chat(
self.primary_provider,
self.primary_runtime,
messages=messages,
tools=tools,
model=model or self.primary_runtime.model,
max_tokens=max_tokens,
temperature=temperature,
)
response = self._decorate_response(response, self.primary_runtime)
if not self._should_activate_fallback(response):
return response
assert self.fallback_provider is not None
assert self.fallback_runtime is not None
self._last_provider = self.fallback_provider
self._last_runtime = self.fallback_runtime
self._last_call_used_fallback = True
response = await self._safe_chat(
self.fallback_provider,
self.fallback_runtime,
messages=messages,
tools=tools,
model=self.fallback_runtime.model,
max_tokens=max_tokens,
temperature=temperature,
)
return self._decorate_response(response, self.fallback_runtime)
def get_default_model(self) -> str:
return self.primary_runtime.model
def _should_activate_fallback(self, response: LLMResponse) -> bool:
return (
self.fallback_provider is not None
and self.fallback_runtime is not None
and response.finish_reason == "error"
)
@staticmethod
async def _safe_chat(
provider: LLMProvider,
runtime: ProviderRuntime,
*,
messages: list[dict],
tools: list[dict] | None,
model: str,
max_tokens: int,
temperature: float,
) -> LLMResponse:
"""把 provider 抛出的异常也收敛成统一 error response。
这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。
"""
try:
return await provider.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
)
except Exception as exc:
return LLMResponse(
content=f"Error: {exc}",
finish_reason="error",
provider_name=runtime.provider_name,
model=runtime.model,
)
@staticmethod
def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse:
if response.provider_name is None:
response.provider_name = runtime.provider_name
if response.model is None:
response.model = runtime.model
return response

View File

@ -0,0 +1,274 @@
"""OpenAI Codex Responses provider."""
from __future__ import annotations
import asyncio
import hashlib
import json
from typing import Any, AsyncGenerator
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import httpx
except ModuleNotFoundError: # pragma: no cover
httpx = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
from oauth_cli_kit import get_token as get_codex_token
except ModuleNotFoundError: # pragma: no cover
get_codex_token = None # type: ignore[assignment]
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "beaver"
class OpenAICodexProvider(LLMProvider):
"""使用 Codex OAuth 调用 Responses API。"""
def __init__(
self,
default_model: str = "openai-codex/gpt-5.1-codex",
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
if httpx is None or get_codex_token is None:
return LLMResponse(content="Error: codex dependencies are not installed", finish_reason="error", provider_name="openai_codex")
resolved_model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access)
body: dict[str, Any] = {
"model": _strip_model_prefix(resolved_model),
"store": False,
"stream": True,
"instructions": system_prompt,
"input": input_items,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
"tool_choice": "auto",
"parallel_tool_calls": True,
}
if tools:
body["tools"] = _convert_tools(tools)
try:
content, tool_calls, finish_reason = await _request_codex(
DEFAULT_CODEX_URL,
headers,
body,
verify=True,
timeout_seconds=self.request_timeout_seconds or 600.0,
)
except Exception as exc:
return LLMResponse(content=f"Error calling Codex: {exc}", finish_reason="error", provider_name="openai_codex")
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason=finish_reason,
provider_name="openai_codex",
model=resolved_model,
)
def get_default_model(self) -> str:
return self.default_model
def _strip_model_prefix(model: str) -> str:
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
return model.split("/", 1)[1]
return model
def _build_headers(account_id: str, token: str) -> dict[str, str]:
return {
"Authorization": f"Bearer {token}",
"chatgpt-account-id": account_id,
"OpenAI-Beta": "responses=experimental",
"originator": DEFAULT_ORIGINATOR,
"User-Agent": "beaver (python)",
"accept": "text/event-stream",
"content-type": "application/json",
}
async def _request_codex(
url: str,
headers: dict[str, str],
body: dict[str, Any],
verify: bool,
timeout_seconds: float,
) -> tuple[str, list[ToolCallRequest], str]:
async with httpx.AsyncClient(timeout=timeout_seconds, verify=verify) as client:
async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append(
{
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
}
)
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for index, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed",
"id": f"msg_{index}",
}
)
for tool_call in message.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append(
{
"type": "function_call",
"id": item_id or f"fc_{index}",
"call_id": call_id or f"call_{index}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
}
)
continue
if role == "tool":
call_id, _ = _split_tool_call_id(message.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append(
{
"type": "function_call_output",
"call_id": call_id,
"output": output_text,
}
)
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: Any) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [item[5:].strip() for item in buffer if item.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(response: Any) -> tuple[str, list[ToolCallRequest], str]:
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_text.delta":
delta = event.get("delta") or ""
content_parts.append(delta)
elif event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
raw_arguments = item.get("arguments") or "{}"
try:
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
except json.JSONDecodeError:
arguments = {}
tool_calls.append(
ToolCallRequest(
id=f"{item.get('call_id', 'call')}|{item.get('id', '')}",
name=item.get("name", ""),
arguments=arguments,
)
)
elif event_type == "response.completed":
finish_reason = event.get("response", {}).get("status", "completed")
return "".join(content_parts) or None, tool_calls, finish_reason
def _friendly_error(status_code: int, body: str) -> str:
return f"Codex API error ({status_code}): {body[:400]}"

View File

@ -0,0 +1,106 @@
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
from __future__ import annotations
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import json_repair
except ModuleNotFoundError: # pragma: no cover
json_repair = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
from openai import AsyncOpenAI
except ModuleNotFoundError: # pragma: no cover
AsyncOpenAI = None # type: ignore[assignment]
class CustomProvider(LLMProvider):
"""直接连接任意 OpenAI-compatible endpoint。"""
def __init__(
self,
api_key: str = "no-key",
api_base: str = "http://localhost:8000/v1",
default_model: str = "default",
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self._client = None
def _client_or_raise(self):
if AsyncOpenAI is None:
raise RuntimeError("openai package is not installed")
if self._client is None:
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
timeout=self.request_timeout_seconds,
)
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
client = self._client_or_raise()
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self.sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if tools:
kwargs.update(tools=tools, tool_choice="auto")
try:
response = await client.chat.completions.create(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom")
choice = response.choices[0]
message = choice.message
parsed_tool_calls: list[ToolCallRequest] = []
for tool_call in message.tool_calls or []:
raw_arguments = tool_call.function.arguments
if isinstance(raw_arguments, str):
if json_repair is not None:
arguments = json_repair.loads(raw_arguments)
else:
import json
arguments = json.loads(raw_arguments)
else:
arguments = raw_arguments
parsed_tool_calls.append(
ToolCallRequest(
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
)
)
usage = getattr(response, "usage", None)
usage_payload = {}
if usage is not None:
usage_payload = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return LLMResponse(
content=message.content,
tool_calls=parsed_tool_calls,
finish_reason=choice.finish_reason or "stop",
usage=usage_payload,
reasoning_content=getattr(message, "reasoning_content", None),
provider_name="custom",
model=model or self.default_model,
)
def get_default_model(self) -> str:
return self.default_model

View File

@ -0,0 +1,235 @@
"""Provider runtime 的统一工厂入口。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .anthropic import AnthropicProvider
from .base import LLMProvider
from .chain import FallbackProviderChain
from .codex import OpenAICodexProvider
from .custom import CustomProvider
from .litellm import LiteLLMProvider
from .runtime import (
ProviderRoutingConfig,
ProviderRuntime,
ProviderTarget,
normalize_provider_target,
resolve_auxiliary_runtime,
resolve_embedding_runtime,
resolve_fallback_runtime,
resolve_provider_runtime,
)
@dataclass(slots=True)
class ProviderBundle:
"""一次运行所需的 provider 组合。
这里把三条常见链路收口到一起:
- `main`:主对话
- `fallback`:主链失败后的备用 provider
- `auxiliary`搜索摘要、压缩、memory flush 等辅助任务
"""
main_runtime: ProviderRuntime
main_provider: LLMProvider
fallback_runtime: ProviderRuntime | None = None
fallback_provider: LLMProvider | None = None
auxiliary_runtime: ProviderRuntime | None = None
auxiliary_provider: LLMProvider | None = None
embedding_runtime: ProviderRuntime | None = None
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
"""构建统一 provider runtime。"""
return resolve_provider_runtime(**kwargs)
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
"""根据 runtime 创建具体 provider 实例。"""
if runtime.spec.provider_impl == "custom":
return CustomProvider(
api_key=runtime.api_key or "no-key",
api_base=runtime.api_base or "http://localhost:8000/v1",
default_model=runtime.default_model or runtime.model,
request_timeout_seconds=runtime.request_timeout_seconds,
)
if runtime.spec.provider_impl == "codex":
return OpenAICodexProvider(
default_model=runtime.default_model or runtime.model,
request_timeout_seconds=runtime.request_timeout_seconds,
)
if runtime.spec.provider_impl == "anthropic":
return AnthropicProvider(
api_key=runtime.api_key,
default_model=runtime.default_model or runtime.model,
api_base=runtime.api_base,
request_timeout_seconds=runtime.request_timeout_seconds,
)
return LiteLLMProvider(
api_key=runtime.api_key,
api_base=runtime.api_base,
default_model=runtime.default_model or runtime.model,
provider_name=runtime.provider_name,
extra_headers=runtime.extra_headers,
request_timeout_seconds=runtime.request_timeout_seconds,
routing=runtime.routing,
)
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
"""构建主对话 provider。"""
fallback_target = kwargs.pop("fallback_target", None)
if fallback_target is None and "fallback_model" in kwargs:
fallback_target = kwargs.pop("fallback_model")
runtime = build_provider_runtime(
auxiliary=False,
fallback_target=fallback_target,
role="main",
source="main_config",
**kwargs,
)
provider = make_provider_from_runtime(runtime)
fallback_pair = make_fallback_provider(runtime, fallback_target)
if fallback_pair is None:
return runtime, provider
fallback_runtime, fallback_provider = fallback_pair
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
def make_fallback_provider(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
) -> tuple[ProviderRuntime, LLMProvider] | None:
"""构建 fallback provider。"""
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
if runtime is None:
return None
return runtime, make_provider_from_runtime(runtime)
def make_aux_provider(
main_runtime: ProviderRuntime | None = None,
*,
target: ProviderTarget | dict[str, Any] | None = None,
task_name: str = "auxiliary",
**kwargs: Any,
) -> tuple[ProviderRuntime, LLMProvider]:
"""构建辅助任务 provider。"""
if target is None and kwargs:
target = kwargs
if main_runtime is not None:
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
else:
normalized = normalize_provider_target(target)
if normalized is None or not normalized.model:
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
runtime = build_provider_runtime(
model=normalized.model,
provider_name=normalized.provider_name,
api_key=normalized.api_key,
api_base=normalized.api_base,
request_timeout_seconds=normalized.request_timeout_seconds,
extra_headers=normalized.extra_headers,
routing=normalized.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
return runtime, make_provider_from_runtime(runtime)
def make_embedding_runtime(
main_runtime: ProviderRuntime,
*,
target: ProviderTarget | dict[str, Any] | None = None,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""构建 embedding 专用 runtime。"""
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
def make_provider_bundle(
*,
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary_task_name: str = "auxiliary",
embedding_target: ProviderTarget | dict[str, Any] | None = None,
embedding_model: str = "text-embedding-v4",
**kwargs: Any,
) -> ProviderBundle:
"""一次性构建 main/fallback/aux 三条 provider 链。"""
runtime_kwargs = dict(kwargs)
fallback_target = runtime_kwargs.pop("fallback_target", None)
if fallback_target is None and "fallback_model" in kwargs:
fallback_target = runtime_kwargs.pop("fallback_model")
main_runtime = build_provider_runtime(
auxiliary=False,
fallback_target=fallback_target,
role="main",
source="main_config",
**runtime_kwargs,
)
primary_provider = make_provider_from_runtime(main_runtime)
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
if fallback_pair is None:
main_provider: LLMProvider = primary_provider
fallback_runtime = None
fallback_provider = None
else:
fallback_runtime, fallback_provider = fallback_pair
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
auxiliary_runtime = None
auxiliary_provider = None
if auxiliary_target is not None:
auxiliary_runtime, auxiliary_provider = make_aux_provider(
main_runtime,
target=auxiliary_target,
task_name=auxiliary_task_name,
)
embedding_runtime = make_embedding_runtime(
main_runtime,
target=embedding_target,
default_model=embedding_model,
)
return ProviderBundle(
main_runtime=main_runtime,
main_provider=main_provider,
fallback_runtime=fallback_runtime,
fallback_provider=fallback_provider,
auxiliary_runtime=auxiliary_runtime,
auxiliary_provider=auxiliary_provider,
embedding_runtime=embedding_runtime,
)
__all__ = [
"ProviderBundle",
"ProviderRoutingConfig",
"ProviderRuntime",
"ProviderTarget",
"build_provider_runtime",
"make_aux_provider",
"make_embedding_runtime",
"make_fallback_provider",
"make_main_provider",
"make_provider_bundle",
"make_provider_from_runtime",
]

View File

@ -0,0 +1,230 @@
"""LiteLLM provider implementation for multi-provider support."""
from __future__ import annotations
from contextlib import contextmanager
import json
import os
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .registry import find_by_model, find_gateway
from .runtime import ProviderRoutingConfig
try: # pragma: no cover - optional dependency
import json_repair
except ModuleNotFoundError: # pragma: no cover
json_repair = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
import litellm
from litellm import acompletion
except ModuleNotFoundError: # pragma: no cover
litellm = None # type: ignore[assignment]
acompletion = None # type: ignore[assignment]
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class LiteLLMProvider(LLMProvider):
"""通过 LiteLLM 统一访问大多数 provider。"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
provider_name: str | None = None,
request_timeout_seconds: float | None = None,
routing: ProviderRoutingConfig | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self.extra_headers = extra_headers or {}
self.routing = routing
self.provider_name = provider_name
self._gateway = find_gateway(provider_name, api_key, api_base)
if litellm is not None:
litellm.suppress_debug_info = True
litellm.drop_params = True
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
"""
if not api_key:
return {}
spec = self._gateway or find_by_model(model)
if spec is None or not spec.env_key:
return {}
overrides: dict[str, str] = {spec.env_key: api_key}
effective_base = api_base or spec.default_api_base
for env_name, env_value in spec.env_extras:
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
overrides[env_name] = resolved
return overrides
@contextmanager
def _temporary_env(self, overrides: dict[str, str]):
"""只在当前请求期间注入 provider 需要的环境变量。"""
if not overrides:
yield
return
sentinel = object()
previous: dict[str, object] = {}
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
try:
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = str(old_value)
def _resolve_model(self, model: str) -> str:
if self._gateway:
prefix = self._gateway.litellm_prefix
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
spec = find_by_model(model)
if spec and spec.litellm_prefix:
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
model = f"{spec.litellm_prefix}/{model}"
return model
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
sanitized = []
for message in messages:
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
spec = find_by_model(original_model)
if spec is None:
return
model_lower = original_model.lower()
for pattern, overrides in spec.model_overrides:
if pattern in model_lower:
kwargs.update(overrides)
return
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
if self.provider_name != "openrouter" or self.routing is None:
return
provider_payload: dict[str, Any] = {}
if self.routing.sort:
provider_payload["sort"] = self.routing.sort
if self.routing.only:
provider_payload["only"] = self.routing.only
if self.routing.ignore:
provider_payload["ignore"] = self.routing.ignore
if self.routing.order:
provider_payload["order"] = self.routing.order
if self.routing.require_parameters:
provider_payload["require_parameters"] = True
if self.routing.data_collection:
provider_payload["data_collection"] = self.routing.data_collection
if provider_payload:
kwargs["provider"] = provider_payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
if acompletion is None:
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
original_model = model or self.default_model
resolved_model = self._resolve_model(original_model)
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
kwargs: dict[str, Any] = {
"model": resolved_model,
"messages": sanitized_messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
self._apply_model_overrides(original_model, kwargs)
self._apply_openrouter_routing(kwargs)
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
try:
with self._temporary_env(env_overrides):
response = await acompletion(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
choice = response.choices[0]
message = choice.message
tool_calls: list[ToolCallRequest] = []
for tool_call in message.tool_calls or []:
raw_arguments = tool_call.function.arguments
if isinstance(raw_arguments, str):
try:
if json_repair is not None:
arguments = json_repair.loads(raw_arguments)
else:
arguments = json.loads(raw_arguments)
except Exception as exc:
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
arguments = {
"__beaver_tool_argument_parse_error__": str(exc),
"__raw_arguments__": raw_arguments,
}
else:
arguments = raw_arguments
tool_calls.append(
ToolCallRequest(
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
)
)
usage = getattr(response, "usage", None)
usage_payload = {}
if usage is not None:
usage_payload = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return LLMResponse(
content=getattr(message, "content", None),
tool_calls=tool_calls,
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
usage=usage_payload,
reasoning_content=getattr(message, "reasoning_content", None),
provider_name=self.provider_name or "litellm",
model=resolved_model,
)
def get_default_model(self) -> str:
return self.default_model

View File

@ -0,0 +1,249 @@
"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class ProviderSpec:
"""单个 provider 的元数据定义。"""
name: str
keywords: tuple[str, ...]
env_key: str
display_name: str = ""
litellm_prefix: str = ""
skip_prefixes: tuple[str, ...] = ()
env_extras: tuple[tuple[str, str], ...] = ()
is_gateway: bool = False
is_local: bool = False
detect_by_key_prefix: str = ""
detect_by_base_keyword: str = ""
default_api_base: str = ""
strip_model_prefix: bool = False
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
is_oauth: bool = False
is_direct: bool = False
supports_prompt_caching: bool = False
api_mode: str = "chat_completions"
provider_impl: str = "litellm"
@property
def label(self) -> str:
return self.display_name or self.name.title()
PROVIDERS: tuple[ProviderSpec, ...] = (
ProviderSpec(
name="custom",
keywords=(),
env_key="",
display_name="Custom",
is_direct=True,
provider_impl="custom",
api_mode="chat_completions",
),
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter",
is_gateway=True,
detect_by_key_prefix="sk-or-",
detect_by_base_keyword="openrouter",
default_api_base="https://openrouter.ai/api/v1",
supports_prompt_caching=True,
),
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY",
display_name="AiHubMix",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True,
),
ProviderSpec(
name="siliconflow",
keywords=("siliconflow",),
env_key="OPENAI_API_KEY",
display_name="SiliconFlow",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="siliconflow",
default_api_base="https://api.siliconflow.cn/v1",
),
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
env_key="OPENAI_API_KEY",
display_name="VolcEngine",
litellm_prefix="volcengine",
is_gateway=True,
detect_by_base_keyword="volces",
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
),
ProviderSpec(
name="anthropic",
keywords=("anthropic", "claude"),
env_key="ANTHROPIC_API_KEY",
display_name="Anthropic",
supports_prompt_caching=True,
api_mode="anthropic_messages",
provider_impl="anthropic",
),
ProviderSpec(
name="openai",
keywords=("openai", "gpt"),
env_key="OPENAI_API_KEY",
display_name="OpenAI",
),
ProviderSpec(
name="openai_codex",
keywords=("openai-codex", "codex"),
env_key="",
display_name="OpenAI Codex",
is_oauth=True,
detect_by_base_keyword="codex",
default_api_base="https://chatgpt.com/backend-api",
api_mode="codex_responses",
provider_impl="codex",
),
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="",
display_name="Github Copilot",
litellm_prefix="github_copilot",
skip_prefixes=("github_copilot/",),
is_oauth=True,
),
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
litellm_prefix="deepseek",
skip_prefixes=("deepseek/",),
),
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini",
skip_prefixes=("gemini/",),
),
ProviderSpec(
name="zhipu",
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai",
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
),
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope",
skip_prefixes=("dashscope/", "openrouter/"),
),
ProviderSpec(
name="moonshot",
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot",
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
default_api_base="https://api.moonshot.ai/v1",
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
ProviderSpec(
name="minimax",
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax",
skip_prefixes=("minimax/", "openrouter/"),
default_api_base="https://api.minimax.io/v1",
),
ProviderSpec(
name="vllm",
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
litellm_prefix="hosted_vllm",
is_local=True,
),
ProviderSpec(
name="groq",
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq",
skip_prefixes=("groq/",),
),
)
def find_by_name(name: str) -> ProviderSpec | None:
for spec in PROVIDERS:
if spec.name == name:
return spec
return None
def find_by_model(model: str) -> ProviderSpec | None:
"""按模型名关键词匹配标准 provider。"""
model_lower = model.lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
normalized_prefix = model_prefix.replace("-", "_")
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
# 显式前缀优先级最高。
# 这里不能只看 standard provider
# - `openrouter/...` 应该直接命中 openrouter
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
# - `github_copilot/...codex` 也不应被误判成 openai_codex
for spec in PROVIDERS:
aliases = {spec.name}
if spec.litellm_prefix:
aliases.add(spec.litellm_prefix.replace("-", "_"))
if model_prefix and normalized_prefix in aliases:
return spec
for spec in standard_specs:
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
return spec
return None
def find_gateway(
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
) -> ProviderSpec | None:
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
if provider_name:
spec = find_by_name(provider_name)
if spec and (spec.is_gateway or spec.is_local):
return spec
for spec in PROVIDERS:
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
return spec
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
return spec
return None

View File

@ -0,0 +1,408 @@
"""Hermes 风格的 provider runtime resolution。"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Any
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
@dataclass(slots=True)
class ProviderRoutingConfig:
"""OpenRouter provider routing 配置。"""
sort: str | None = None
only: list[str] = field(default_factory=list)
ignore: list[str] = field(default_factory=list)
order: list[str] = field(default_factory=list)
require_parameters: bool = False
data_collection: str | None = None
@dataclass(slots=True)
class ProviderTarget:
"""一次 provider 选路请求的标准化配置。
这层不是具体 runtime而是“调用方想要什么”
- 用哪个 provider
- 跑哪个 model
- 是否指定自定义 base_url
- 是否带额外 headers / routing
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
"""
provider_name: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
routing: ProviderRoutingConfig | None = None
@dataclass(slots=True)
class ProviderRuntime:
"""运行时真正使用的 provider 解析结果。"""
spec: ProviderSpec
model: str
provider_name: str
api_mode: str
api_key: str | None = None
api_base: str | None = None
default_model: str | None = None
request_timeout_seconds: float | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
routing: ProviderRoutingConfig | None = None
fallback_target: ProviderTarget | None = None
auxiliary: bool = False
role: str = "main"
source: str = "runtime"
def resolve_provider_runtime(
*,
model: str,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
extra_headers: dict[str, str] | None = None,
routing: ProviderRoutingConfig | None = None,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary: bool = False,
role: str = "main",
source: str = "runtime",
) -> ProviderRuntime:
"""把调用侧传入的配置解析成统一 runtime。"""
gateway = find_gateway(provider_name, api_key, api_base)
if gateway is not None:
spec = gateway
elif provider_name:
spec = find_by_name(provider_name)
else:
spec = find_by_model(model)
if spec is None:
if api_base:
spec = find_by_name("custom")
else:
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
resolved_api_base = api_base or spec.default_api_base or None
return ProviderRuntime(
spec=spec,
model=resolved_model,
provider_name=spec.name,
api_mode=spec.api_mode,
api_key=api_key,
api_base=resolved_api_base,
default_model=resolved_model,
request_timeout_seconds=request_timeout_seconds,
extra_headers=extra_headers or {},
routing=routing,
fallback_target=normalize_provider_target(fallback_target),
auxiliary=auxiliary,
role=role,
source=source,
)
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
"""把 dict/对象形式的 provider 配置收敛成统一结构。
这里兼容几种常见写法,便于后续接 CLI / config / gateway
- `provider` 或 `provider_name`
- `base_url` 或 `api_base`
- `headers` 或 `extra_headers`
- `timeout` 或 `request_timeout_seconds`
"""
if target is None:
return None
if isinstance(target, ProviderTarget):
return target
provider_name = target.get("provider_name")
if provider_name is None:
provider_name = target.get("provider")
api_base = target.get("api_base")
if api_base is None:
api_base = target.get("base_url")
extra_headers = target.get("extra_headers")
if extra_headers is None:
extra_headers = target.get("headers")
request_timeout_seconds = target.get("request_timeout_seconds")
if request_timeout_seconds is None:
request_timeout_seconds = target.get("timeout")
routing = target.get("routing")
if isinstance(routing, dict):
routing = ProviderRoutingConfig(**routing)
return ProviderTarget(
provider_name=provider_name,
model=target.get("model"),
api_key=target.get("api_key"),
api_base=api_base,
extra_headers=dict(extra_headers or {}),
request_timeout_seconds=request_timeout_seconds,
routing=routing,
)
def resolve_fallback_runtime(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None,
) -> ProviderRuntime | None:
"""把 fallback 配置解析成独立 runtime。
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
"""
target = normalize_provider_target(fallback_target)
if target is None or not target.model:
return None
inferred_provider = target.provider_name
if inferred_provider in {None, "", "main"}:
inferred_provider = primary_runtime.provider_name
api_key = target.api_key
api_base = target.api_base
extra_headers = dict(target.extra_headers)
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
if inferred_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=target.model,
provider_name=inferred_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=target.routing,
auxiliary=False,
role="fallback",
source="fallback_config",
)
def resolve_auxiliary_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
task_name: str = "auxiliary",
) -> ProviderRuntime:
"""解析辅助任务专用 runtime。
支持三类输入:
- `None` / `provider=main`:直接复用主链 provider
- 显式 `provider + model`:走独立 provider
- 仅给 `model`:按模型名自动匹配 provider
"""
normalized = normalize_provider_target(target)
if normalized is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
)
provider_name = normalized.provider_name
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
routing=normalized.routing or primary_runtime.routing,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
)
if provider_name == "main":
return resolve_provider_runtime(
model=normalized.model or primary_runtime.model,
provider_name=primary_runtime.provider_name,
api_key=normalized.api_key or primary_runtime.api_key,
api_base=normalized.api_base or primary_runtime.api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="main_runtime",
)
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="auto->main",
)
resolved_model = normalized.model or primary_runtime.model
resolved_provider = normalized.provider_name
if resolved_provider in {"auto", "", None} and not normalized.api_base:
# `auto` 的第一阶段实现保持保守:
# - 有显式 model 时按 model 匹配 provider
# - 匹配不到则回退主链 provider
spec = find_by_model(resolved_model)
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if resolved_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
def resolve_embedding_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""解析 embedding 专用 runtime。
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
"""
normalized = normalize_provider_target(target)
if normalized is None:
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
if not _supports_openai_embeddings(primary_runtime):
return None
return resolve_provider_runtime(
model=default_model,
provider_name="openai",
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
extra_headers=dict(primary_runtime.extra_headers),
routing=primary_runtime.routing,
auxiliary=False,
role="embedding",
source="embedding_inherited",
)
resolved_model = normalized.model or default_model
resolved_provider = normalized.provider_name
if resolved_provider in {None, "", "main", "auto"}:
resolved_provider = "custom" if normalized.api_base else "openai"
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if not api_base and _supports_openai_embeddings(primary_runtime):
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
runtime = resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing,
auxiliary=False,
role="embedding",
source="embedding_config",
)
if not _supports_openai_embeddings(runtime):
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
return runtime
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
def _clone_runtime(
runtime: ProviderRuntime,
**changes: Any,
) -> ProviderRuntime:
"""基于现有 runtime 复制一个轻量变体。
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
"""
payload = {
"extra_headers": dict(runtime.extra_headers),
"routing": runtime.routing,
"fallback_target": runtime.fallback_target,
}
payload.update(changes)
return replace(runtime, **payload)
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
"""根据 registry 规则应用必要前缀。"""
resolved = model
if gateway_mode:
prefix = spec.litellm_prefix
if spec.strip_model_prefix:
resolved = resolved.split("/")[-1]
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if spec.litellm_prefix:
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
return resolved
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"

View File

@ -0,0 +1,2 @@
"""Runtime helper objects and execution context."""

View File

@ -0,0 +1,15 @@
"""Session state and persistence."""
from .manager import SessionManager
from .models import MessageRecord, SessionRecord, SessionUsage
from .search import SessionSearchService
from .store import SessionStore
__all__ = [
"MessageRecord",
"SessionManager",
"SessionRecord",
"SessionSearchService",
"SessionStore",
"SessionUsage",
]

View File

@ -0,0 +1,143 @@
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from .models import MessageRecord
from .search import SessionSearchService
from .store import SessionStore
class SessionManager:
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
self.workspace = Path(workspace)
self.sessions_dir = self.workspace / "sessions"
self.sessions_dir.mkdir(parents=True, exist_ok=True)
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
self.store = SessionStore(self.db_path)
self.search = SessionSearchService(self.store)
def close(self) -> None:
self.store.close()
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
return self.store.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
def get_session(self, session_id: str) -> dict[str, Any] | None:
record = self.store.get_session_record(session_id)
return record.to_dict() if record is not None else None
def get_or_create(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> dict[str, Any]:
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
session = self.get_session(session_id)
if session is None:
raise RuntimeError(f"Failed to create session {session_id!r}")
return session
def append_message(self, session_id: str, **kwargs: Any) -> int:
return self.store.append_message(session_id, **kwargs)
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
这里和 `get_messages_as_conversation()` 的区别很关键:
- `get_event_records()` 面向 runtime / replay / audit保留隐藏系统事件
- `get_messages_as_conversation()` 面向 prompt builder只暴露可进上下文的事件
第 6 阶段开始后session 已不再只是“聊天消息存储”,而是在逐步收敛成
“外部事件流 + 上层投影视图”。
"""
return self.store.get_event_records(session_id)
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 direct run / future bus run 对应的事件片段。"""
return self.store.get_run_event_records(session_id, run_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按出现顺序列出当前 session 的所有 run_id。"""
return self.store.list_run_ids(session_id)
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
return self.store.get_messages_as_conversation(session_id)
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""返回适合注入 prompt 的可见历史切片。
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
1. 只包含 `context_visible=True` 的事件
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
"""
history = self.get_messages_as_conversation(session_id)
sliced = history[-max_messages:]
for index, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[index:]
break
return sliced
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""兼容旧名称,实际返回可见历史切片。"""
return self.get_visible_history(session_id, max_messages=max_messages)
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
self.store.update_system_prompt(session_id, system_prompt)
def update_usage(self, session_id: str, **kwargs: Any) -> None:
self.store.update_usage(session_id, **kwargs)
def end_session(self, session_id: str, end_reason: str) -> None:
self.store.end_session(session_id, end_reason)
def reopen_session(self, session_id: str) -> None:
self.store.reopen_session(session_id)
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.list_sessions_rich(**kwargs)
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.search_messages(**kwargs)
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
return self.search.resolve_session_id(session_id_or_prefix)

View File

@ -0,0 +1,211 @@
"""Beaver session 子系统的数据模型。
这层只定义数据结构,不放数据库读写逻辑。目的是把:
1. SQLite 行结构
2. 运行时会话对象
3. 对外暴露的 conversation message
三件事分开,避免后续所有地方都直接和裸字典耦合。
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class SessionUsage:
"""会话维度的 usage/cost 统计。"""
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
estimated_cost_usd: float = 0.0
actual_cost_usd: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"actual_cost_usd": self.actual_cost_usd,
}
@dataclass(slots=True)
class MessageRecord:
"""单条会话事件的结构化表示。
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
1. 普通 user/assistant/tool 消息本身就是事件
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
"""
role: str
content: str | None = None
timestamp: float | None = None
message_id: int | None = None
run_id: str | None = None
event_type: str | None = None
event_payload: dict[str, Any] | None = None
context_visible: bool = True
tool_name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
finish_reason: str | None = None
reasoning: str | None = None
reasoning_details: Any | None = None
codex_reasoning_items: Any | None = None
def to_conversation_message(self) -> dict[str, Any]:
"""转成 provider / context builder 可直接消费的消息格式。"""
if not self.context_visible:
raise ValueError("Hidden session events cannot be converted into conversation messages")
payload: dict[str, Any] = {
"role": self.role,
"content": self.content,
}
if self.tool_name:
payload["tool_name"] = self.tool_name
if self.tool_calls:
payload["tool_calls"] = self.tool_calls
if self.tool_call_id:
payload["tool_call_id"] = self.tool_call_id
if self.finish_reason:
payload["finish_reason"] = self.finish_reason
if self.reasoning:
payload["reasoning"] = self.reasoning
if self.reasoning_details is not None:
payload["reasoning_details"] = self.reasoning_details
if self.codex_reasoning_items is not None:
payload["codex_reasoning_items"] = self.codex_reasoning_items
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
"""从 SQLite row/dict 恢复消息模型。"""
tool_calls = row.get("tool_calls")
if isinstance(tool_calls, str):
try:
tool_calls = json.loads(tool_calls)
except json.JSONDecodeError:
tool_calls = []
reasoning_details = row.get("reasoning_details")
if isinstance(reasoning_details, str):
try:
reasoning_details = json.loads(reasoning_details)
except json.JSONDecodeError:
reasoning_details = None
codex_reasoning_items = row.get("codex_reasoning_items")
if isinstance(codex_reasoning_items, str):
try:
codex_reasoning_items = json.loads(codex_reasoning_items)
except json.JSONDecodeError:
codex_reasoning_items = None
event_payload = row.get("event_payload")
if isinstance(event_payload, str):
try:
event_payload = json.loads(event_payload)
except json.JSONDecodeError:
event_payload = None
return cls(
message_id=row.get("id"),
run_id=row.get("run_id"),
role=row["role"],
content=row.get("content"),
event_type=row.get("event_type") or row.get("role"),
event_payload=event_payload,
context_visible=bool(row.get("context_visible", 1)),
tool_name=row.get("tool_name"),
tool_calls=tool_calls,
tool_call_id=row.get("tool_call_id"),
timestamp=row.get("timestamp"),
finish_reason=row.get("finish_reason"),
reasoning=row.get("reasoning"),
reasoning_details=reasoning_details,
codex_reasoning_items=codex_reasoning_items,
)
@dataclass(slots=True)
class SessionRecord:
"""单个 session 的结构化表示。"""
session_id: str
source: str
started_at: float
last_active: float
user_id: str | None = None
title: str | None = None
model: str | None = None
system_prompt: str | None = None
parent_session_id: str | None = None
ended_at: float | None = None
end_reason: str | None = None
message_count: int = 0
tool_call_count: int = 0
preview: str | None = None
usage: SessionUsage = field(default_factory=SessionUsage)
def to_dict(self) -> dict[str, Any]:
payload = {
"id": self.session_id,
"source": self.source,
"user_id": self.user_id,
"title": self.title,
"model": self.model,
"system_prompt": self.system_prompt,
"parent_session_id": self.parent_session_id,
"started_at": self.started_at,
"last_active": self.last_active,
"ended_at": self.ended_at,
"end_reason": self.end_reason,
"message_count": self.message_count,
"tool_call_count": self.tool_call_count,
"preview": self.preview,
}
payload.update(self.usage.to_dict())
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
return cls(
session_id=row["id"],
source=row["source"],
user_id=row.get("user_id"),
title=row.get("title"),
model=row.get("model"),
system_prompt=row.get("system_prompt"),
parent_session_id=row.get("parent_session_id"),
started_at=row["started_at"],
last_active=row["last_active"],
ended_at=row.get("ended_at"),
end_reason=row.get("end_reason"),
message_count=row.get("message_count", 0),
tool_call_count=row.get("tool_call_count", 0),
preview=row.get("preview"),
usage=SessionUsage(
input_tokens=row.get("input_tokens", 0),
output_tokens=row.get("output_tokens", 0),
cache_read_tokens=row.get("cache_read_tokens", 0),
cache_write_tokens=row.get("cache_write_tokens", 0),
reasoning_tokens=row.get("reasoning_tokens", 0),
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
actual_cost_usd=row.get("actual_cost_usd"),
),
)

View File

@ -0,0 +1,151 @@
"""Beaver session 子系统的检索能力。"""
from __future__ import annotations
import re
import sqlite3
from typing import Any
from .store import SessionStore
class SessionSearchService:
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
def __init__(self, store: SessionStore) -> None:
self.store = store
@staticmethod
def _sanitize_fts5_query(query: str) -> str:
quoted_parts: list[str] = []
def preserve(match: re.Match[str]) -> str:
quoted_parts.append(match.group(0))
return f"\x00Q{len(quoted_parts) - 1}\x00"
sanitized = re.sub(r'"[^"]*"', preserve, query)
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
sanitized = re.sub(r"\*+", "*", sanitized)
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
for index, quoted in enumerate(quoted_parts):
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
return sanitized.strip()
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
exact = self.store.get_session_record(session_id_or_prefix)
if exact is not None:
return exact.session_id
escaped = (
session_id_or_prefix
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
rows = self.store._fetchall(
"""
SELECT id
FROM sessions
WHERE id LIKE ? ESCAPE '\\'
ORDER BY started_at DESC
LIMIT 2
""",
(f"{escaped}%",),
)
if len(rows) == 1:
return rows[0]["id"]
return None
def list_sessions_rich(
self,
*,
limit: int = 20,
offset: int = 0,
include_children: bool = False,
source: str | None = None,
exclude_sources: list[str] | None = None,
) -> list[dict[str, Any]]:
"""列出最近活跃的 session 及其摘要元数据。"""
clauses: list[str] = []
params: list[Any] = []
if not include_children:
clauses.append("parent_session_id IS NULL")
if source:
clauses.append("source = ?")
params.append(source)
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"source NOT IN ({placeholders})")
params.extend(exclude_sources)
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
params.extend([limit, offset])
rows = self.store._fetchall(
f"""
SELECT *
FROM sessions
{where}
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
tuple(params),
)
return rows
def search_messages(
self,
*,
query: str,
role_filter: list[str] | None = None,
exclude_sources: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> list[dict[str, Any]]:
"""使用 FTS5 搜索 session transcript。"""
query = self._sanitize_fts5_query(query)
if not query:
return []
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
params: list[Any] = [query]
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"s.source NOT IN ({placeholders})")
params.extend(exclude_sources)
if role_filter:
placeholders = ",".join("?" for _ in role_filter)
clauses.append(f"m.role IN ({placeholders})")
params.extend(role_filter)
params.extend([limit, offset])
sql = f"""
SELECT
m.id,
m.session_id,
m.role,
s.source,
s.model,
s.started_at AS session_started,
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
FROM messages_fts
JOIN messages m ON m.id = messages_fts.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(clauses)}
ORDER BY rank
LIMIT ? OFFSET ?
"""
try:
return self.store._fetchall(sql, tuple(params))
except sqlite3.Error as exc:
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc

View File

@ -0,0 +1,467 @@
"""Beaver session 子系统的 SQLite 存储实现。
设计来源主要参考 Hermes-agent
1. SQLite 作为统一 session/transcript backend
2. WAL 模式支持多读单写
3. FTS5 支持跨 session 文本检索
4. `parent_session_id` 支持 lineage
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable, TypeVar
from .models import MessageRecord, SessionRecord
T = TypeVar("T")
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
user_id TEXT,
title TEXT,
model TEXT,
system_prompt TEXT,
parent_session_id TEXT,
started_at REAL NOT NULL,
last_active REAL NOT NULL,
ended_at REAL,
end_reason TEXT,
message_count INTEGER DEFAULT 0,
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
estimated_cost_usd REAL DEFAULT 0,
actual_cost_usd REAL,
preview TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id),
run_id TEXT,
role TEXT NOT NULL,
event_type TEXT,
event_payload TEXT,
context_visible INTEGER NOT NULL DEFAULT 1,
content TEXT,
tool_name TEXT,
tool_calls TEXT,
tool_call_id TEXT,
timestamp REAL NOT NULL,
finish_reason TEXT,
reasoning TEXT,
reasoning_details TEXT,
codex_reasoning_items TEXT
);
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
"""
FTS_TABLE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content,
content=messages,
content_rowid=id
);
"""
FTS_TRIGGER_SQL = """
DROP TRIGGER IF EXISTS messages_fts_insert;
DROP TRIGGER IF EXISTS messages_fts_delete;
DROP TRIGGER IF EXISTS messages_fts_update;
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
"""
class SessionStore:
"""SQLite-backed session store."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._init_schema()
def _init_schema(self) -> None:
with self._lock:
self._conn.executescript(SCHEMA_SQL)
try:
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
except sqlite3.OperationalError:
self._conn.executescript(FTS_TABLE_SQL)
self._conn.executescript(FTS_TRIGGER_SQL)
# 旧版本可能把 hidden 事件也写进了 FTS初始化时顺手清掉这些噪声项。
self._conn.execute(
"""
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', id, content
FROM messages
WHERE context_visible = 0 AND content IS NOT NULL
"""
)
self._conn.commit()
def close(self) -> None:
with self._lock:
self._conn.close()
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
result = fn(self._conn)
self._conn.commit()
return result
except BaseException:
self._conn.rollback()
raise
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
with self._lock:
row = self._conn.execute(sql, params).fetchone()
return dict(row) if row else None
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
with self._lock:
rows = self._conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
now = time.time()
def _do(conn: sqlite3.Connection) -> str:
conn.execute(
"""
INSERT INTO sessions (
id, source, user_id, title, model, parent_session_id, started_at, last_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
source = CASE
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
ELSE sessions.source
END,
user_id = COALESCE(sessions.user_id, excluded.user_id),
title = COALESCE(sessions.title, excluded.title),
model = COALESCE(sessions.model, excluded.model),
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
""",
(session_id, source, user_id, title, model, parent_session_id, now, now),
)
return session_id
return self._execute_write(_do)
def get_session_record(self, session_id: str) -> SessionRecord | None:
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
return SessionRecord.from_row(row) if row else None
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""保存本 session 组装后的完整 system prompt snapshot。"""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET system_prompt = ?, last_active = ?
WHERE id = ?
""",
(system_prompt, time.time(), session_id),
)
self._execute_write(_do)
def update_usage(
self,
session_id: str,
*,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: float = 0.0,
actual_cost_usd: float | None = None,
absolute: bool = False,
) -> None:
"""更新会话 usage。默认按增量累加。"""
if absolute:
sql = """
UPDATE sessions
SET input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = ?,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
else:
sql = """
UPDATE sessions
SET input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = estimated_cost_usd + ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
def _do(conn: sqlite3.Connection) -> None:
conn.execute(sql, params)
self._execute_write(_do)
def append_message(
self,
session_id: str,
*,
run_id: str | None = None,
role: str,
event_type: str | None = None,
event_payload: dict[str, Any] | None = None,
context_visible: bool = True,
content: str | None = None,
tool_name: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
finish_reason: str | None = None,
reasoning: str | None = None,
reasoning_details: Any | None = None,
codex_reasoning_items: Any | None = None,
source: str = "unknown",
title: str | None = None,
model: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> int:
"""向指定 session 追加一条消息。"""
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
now = time.time()
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
preview = (content or "")[:120] if role == "user" and content else None
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
def _do(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"""
INSERT INTO messages (
session_id, run_id, role, event_type, event_payload, context_visible, content,
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
reasoning_details, codex_reasoning_items
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
run_id,
role,
event_type or role,
event_payload_json,
1 if context_visible else 0,
content,
tool_name,
tool_calls_json,
tool_call_id,
now,
finish_reason,
reasoning,
reasoning_details_json,
codex_items_json,
),
)
conn.execute(
"""
UPDATE sessions
SET last_active = ?,
message_count = message_count + 1,
tool_call_count = tool_call_count + ?,
model = COALESCE(model, ?),
preview = CASE
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
ELSE preview
END
WHERE id = ?
""",
(now, tool_call_count, model, preview, preview, session_id),
)
return int(cursor.lastrowid)
return self._execute_write(_do)
def get_message_records(self, session_id: str) -> list[MessageRecord]:
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ?
ORDER BY timestamp, id
""",
(session_id,),
)
return [MessageRecord.from_row(row) for row in rows]
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
"""
return self.get_message_records(session_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按时间顺序列出当前 session 中出现过的 run_id。"""
rows = self._fetchall(
"""
SELECT run_id
FROM messages
WHERE session_id = ? AND run_id IS NOT NULL
GROUP BY run_id
ORDER BY MIN(timestamp), MIN(id)
""",
(session_id,),
)
return [str(row["run_id"]) for row in rows if row.get("run_id")]
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 run 对应的事件片段。"""
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ? AND run_id = ?
ORDER BY timestamp, id
""",
(session_id, run_id),
)
return [MessageRecord.from_row(row) for row in rows]
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for record in self.get_event_records(session_id):
if not record.context_visible:
continue
messages.append(record.to_conversation_message())
return messages
def end_session(self, session_id: str, end_reason: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = ?, end_reason = ?, last_active = ?
WHERE id = ?
""",
(time.time(), end_reason, time.time(), session_id),
)
self._execute_write(_do)
def reopen_session(self, session_id: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = NULL, end_reason = NULL, last_active = ?
WHERE id = ?
""",
(time.time(), session_id),
)
self._execute_write(_do)