修改了nanobot,往Hermes agent的风格走,进度1/3

This commit is contained in:
2026-04-20 18:11:14 +08:00
parent cdfc222c9f
commit 36882a7d7b
261 changed files with 12659 additions and 604 deletions

View File

@ -0,0 +1,6 @@
"""Beaver backend package."""
__all__ = ["__version__"]
__version__ = "0.1.0"

View File

@ -0,0 +1,2 @@
"""Multi-agent coordination layer."""

View File

@ -0,0 +1,2 @@
"""Pluggable multi-agent backends."""

View File

@ -0,0 +1,20 @@
"""Backend interfaces for multi-agent execution."""
from dataclasses import dataclass
from typing import Protocol
@dataclass(slots=True)
class BackendResult:
"""Normalized result returned by a coordination backend."""
success: bool
summary: str
class CoordinationBackend(Protocol):
"""Protocol implemented by pluggable coordination backends."""
def run(self, task: str) -> BackendResult:
"""Execute a team task and return a normalized result."""

View File

@ -0,0 +1,6 @@
"""Swarms backend wrapper for Beaver.
This package is intentionally local to Beaver's coordinator layer.
There is no `third_party/` directory in the new backend layout.
"""

View File

@ -0,0 +1,2 @@
"""Delegation orchestration."""

View File

@ -0,0 +1,2 @@
"""Execution control, retry, and aggregation."""

View File

@ -0,0 +1,2 @@
"""Team planning and execution-plan generation."""

View File

@ -0,0 +1,2 @@
"""Agent registry and descriptors."""

View File

@ -0,0 +1,2 @@
"""Team models and orchestration objects."""

View File

@ -0,0 +1,31 @@
"""Unified Beaver agent engine.
这里不做顶层 eager import避免子模块导入时触发循环依赖。
对外仍然保留同样的导出名称,但改成按需加载。
"""
from __future__ import annotations
from typing import Any
__all__ = ["AgentLoop", "AgentProfile", "AgentRunResult", "EngineLoader", "EngineLoadResult"]
def __getattr__(name: str) -> Any:
if name == "EngineLoader":
from .loader import EngineLoader
return EngineLoader
if name == "EngineLoadResult":
from .loader import EngineLoadResult
return EngineLoadResult
if name in {"AgentLoop", "AgentProfile", "AgentRunResult"}:
from .loop import AgentLoop, AgentProfile, AgentRunResult
return {
"AgentLoop": AgentLoop,
"AgentProfile": AgentProfile,
"AgentRunResult": AgentRunResult,
}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,17 @@
"""Context assembly for agent runs."""
from .builder import (
ContextBuildInput,
ContextBuildResult,
ContextBuilder,
SessionContext,
SkillContext,
)
__all__ = [
"ContextBuildInput",
"ContextBuildResult",
"ContextBuilder",
"SessionContext",
"SkillContext",
]

View File

@ -0,0 +1,331 @@
"""Beaver 运行时上下文装配器。
这个模块是 `session` 和 `provider` 之间的中间层,职责非常明确:
1. 把运行前已经准备好的静态/半静态上下文拼成一份稳定的 system prompt
2. 把从 session 事件流里裁剪出的“可见历史”和当前用户输入整理成 provider 可直接消费的 messages
3. 在 tool loop 中,持续把 assistant/tool 消息按统一格式追加回消息数组
为什么这层必须单独存在:
1. `AgentLoop` 不应该自己拼 prompt否则很快又会长成一个大文件
2. `memory`、`skills`、`session` 的注入顺序需要固定,否则模型行为会漂移
3. tool loop 前后追加消息的格式必须统一,否则不同 provider 很容易出兼容问题
这一版 builder 的设计目标是“最小但稳定”:
1. 先服务单 agent 主链
2. 先支持 frozen curated memory而不是 live memory
3. skills 按 Hermes 风格支持“显式激活消息注入”,不在这里做磁盘扫描
4. 为后续 channel / gateway / team metadata 预留注入位,但不提前做复杂逻辑
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from beaver.memory.curated.snapshot import MemorySnapshot
@dataclass(slots=True)
class SkillContext:
"""单个已激活 skill 的最小表示。
这里故意不把 skill 设计成复杂对象,只保留 builder 真正关心的两部分:
- `name`:用于生成激活提示
- `content`skill 的完整正文
注意:按当前 Hermes 风格实现skill 正文不再塞进 system prompt而是转成显式消息注入。
"""
name: str
content: str
@dataclass(slots=True)
class SessionContext:
"""当前运行轮次的会话元数据。
这不是 session store 里的完整 record而是 prompt builder 关心的那一小部分:
- 哪个 session
- 来源是什么
- 当前使用什么 model
- 是否有 channel/chat/user 这类运行路由信息
把它单独抽出来的原因是:
1. builder 不应该知道 SQLite row 长什么样
2. 不同入口CLI/Web/Gateway都可以把自己的 metadata 收敛成同一种结构
"""
session_id: str | None = None
source: str | None = None
model: str | None = None
user_id: str | None = None
channel: str | None = None
chat_id: str | None = None
parent_session_id: str | None = None
@dataclass(slots=True)
class ContextBuildInput:
"""一次上下文构建所需的全部输入。
这个对象的作用不是“炫技式封装”,而是把主链里零散的数据显式收口。
这样一来,后面 `AgentLoop.process_direct()` 在组装参数时会更清晰,也更容易测试。
字段分组:
- 身份/基础段:`base_system_prompt`
- 会话可见历史:`history`
- 当前输入:`current_user_input`
- 冻结记忆:`memory_snapshot`
- 技能:`activated_skills`
- 运行元数据:`session_context` / `execution_context`
- 额外扩展:`extra_sections`
"""
base_system_prompt: str = ""
history: list[dict[str, Any]] = field(default_factory=list)
current_user_input: str | list[dict[str, Any]] | None = None
memory_snapshot: MemorySnapshot | None = None
activated_skills: list[SkillContext] = field(default_factory=list)
session_context: SessionContext | None = None
execution_context: str | None = None
extra_sections: list[str] = field(default_factory=list)
@dataclass(slots=True)
class ContextBuildResult:
"""一次上下文构建后的结果。
保留 `system_prompt` 的原因:
1. `SessionManager.update_system_prompt()` 需要把最终注入的 prompt snapshot 落盘
2. 调试时经常需要区分“system prompt 长什么样”和“messages 长什么样”
3. 后面如果做 prompt audit / replay也会直接复用这个结果
"""
system_prompt: str
messages: list[dict[str, Any]]
class ContextBuilder:
"""负责把运行时输入装配成稳定上下文。
这一层故意保持“无 IO、无数据库、无网络”
- 不直接读 session store
- 不直接读 memory store
- 不直接扫描 skills 目录
这样 builder 的行为只由输入决定,便于单测,也便于后面并到真正的 AgentLoop 主链里。
"""
def build_system_prompt(
self,
build_input: ContextBuildInput,
) -> str:
"""构建 system prompt。
顺序固定非常重要,当前约定是:
1. base system prompt
2. session metadata
3. execution context
4. frozen memory snapshot
5. extra sections
这样设计的原因:
- 身份与总规则要最靠前
- session/execution 是本轮运行语境,优先级高于长期记忆
- memory 必须是 frozen snapshot避免中途写 memory 后 prompt 失真
- activated skill 正文按 Hermes 风格放到显式消息里,避免 system prompt 持续膨胀
"""
sections: list[str] = []
base_system_prompt = (build_input.base_system_prompt or "").strip()
if base_system_prompt:
sections.append(base_system_prompt)
session_section = self._render_session_section(build_input.session_context)
if session_section:
sections.append(session_section)
execution_context = (build_input.execution_context or "").strip()
if execution_context:
sections.append(f"# Execution Context\n\n{execution_context}")
if build_input.memory_snapshot is not None:
# 这里明确只读 frozen snapshot而不是去读 live memory store。
# 否则一旦当前会话中途写 memorysystem prompt 语义就会和会话开头不一致。
snapshot_sections = build_input.memory_snapshot.as_prompt_sections()
if snapshot_sections:
sections.extend(snapshot_sections)
for extra in build_input.extra_sections:
cleaned = (extra or "").strip()
if cleaned:
sections.append(cleaned)
return "\n\n---\n\n".join(sections)
def build_messages(
self,
build_input: ContextBuildInput,
) -> ContextBuildResult:
"""构建一次模型调用的完整 messages。
这里做三件事:
1. 先生成最终 system prompt
2. 按 Hermes 风格,把已激活 skill 的完整正文作为显式消息注入
3. 把历史消息按原顺序接到后面
4. 如果存在当前用户输入,则把本轮输入追加为最后一条 user message
注意:
- `history` 默认被视为“已经由 session/context 上游从完整事件流中裁剪好的可见结构”
- builder 不负责裁剪历史窗口,这件事应由 session/loop 上层决定
- builder 只做最小格式统一
"""
system_prompt = self.build_system_prompt(build_input)
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
messages.extend(self.build_skill_activation_messages(build_input.activated_skills))
for message in build_input.history:
# 当前 builder 自己负责生成唯一的 system prompt。
# 如果上游 history 已经混入 system 消息,这里要主动跳过,避免双 system。
if message.get("role") == "system":
continue
messages.append(dict(message))
if build_input.current_user_input is not None:
messages.append(
{
"role": "user",
"content": build_input.current_user_input,
}
)
return ContextBuildResult(
system_prompt=system_prompt,
messages=messages,
)
def add_tool_result(
self,
messages: list[dict[str, Any]],
*,
tool_call_id: str,
tool_name: str,
result: str,
) -> list[dict[str, Any]]:
"""向消息数组追加一条 tool result。
为什么这个函数放在 builder而不是塞回 `AgentLoop`
- tool message 的结构必须和 provider 兼容
- 统一在这里追加,可以避免不同执行路径拼出不同字段名
- 后面如果要兼容更多 provider 差异,也只改这一层
这里返回原 list 本身,保持旧项目的“可链式追加”习惯。
"""
messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_name,
"content": result,
}
)
return messages
def add_assistant_message(
self,
messages: list[dict[str, Any]],
*,
content: str | None,
tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None,
) -> list[dict[str, Any]]:
"""向消息数组追加 assistant 消息。
这里有两个实现细节非常重要:
1. 无论 `content` 是否为空,都显式写入 `content` 键
原因是部分 provider 在 assistant 带 `tool_calls` 时仍要求消息里存在 `content`
2. `reasoning_content` 只有在非空时才附带
因为这属于思考模型扩展字段,不应污染普通 provider 路径
"""
message: dict[str, Any] = {
"role": "assistant",
"content": content,
}
if tool_calls:
message["tool_calls"] = tool_calls
if reasoning_content is not None:
message["reasoning_content"] = reasoning_content
messages.append(message)
return messages
def _render_session_section(self, session_context: SessionContext | None) -> str | None:
"""把运行时 session metadata 渲染成一个可读 section。
这一段的目标不是让模型“记住所有数据库字段”,而是给它足够的当前运行语境。
常见用途包括:
- 知道当前来自 CLI 还是 Web/Gateway
- 知道当前使用什么 model
- 知道当前 channel/chat_id便于后续多渠道行为约束
"""
if session_context is None:
return None
rows: list[str] = []
if session_context.session_id:
rows.append(f"Session ID: {session_context.session_id}")
if session_context.source:
rows.append(f"Source: {session_context.source}")
if session_context.model:
rows.append(f"Model: {session_context.model}")
if session_context.user_id:
rows.append(f"User ID: {session_context.user_id}")
if session_context.channel:
rows.append(f"Channel: {session_context.channel}")
if session_context.chat_id:
rows.append(f"Chat ID: {session_context.chat_id}")
if session_context.parent_session_id:
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
if not rows:
return None
return "# Current Session\n\n" + "\n".join(rows)
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
"""按 Hermes 风格把已激活 skill 转成显式消息。
关键区别:
- system prompt 只保留轻量 skills index
- 真正生效的 skill 正文通过额外消息块显式加载
这样模型不需要“从摘要里猜怎么读到正文”,而是直接拿到完整指导内容。
"""
messages: list[dict[str, str]] = []
for skill in activated_skills:
content = (skill.content or "").strip()
if not content:
continue
messages.append(
{
"role": "user",
"content": (
f'[SYSTEM: The "{skill.name}" skill is active for this run. '
"Follow its instructions as active guidance unless the user overrides them.]\n\n"
f"{content}"
),
}
)
return messages

View File

@ -0,0 +1,154 @@
"""Centralized runtime loading for Beaver agents."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable
from beaver.engine.context import ContextBuilder
from beaver.engine.session import SessionManager
from beaver.memory.curated.store import MemoryStore
from beaver.services.memory_service import MemoryService
from beaver.skills import SkillAssembler, SkillsLoader
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
@dataclass(slots=True)
class EngineLoadResult:
"""描述当前 agent runtime 已经装好的依赖。
这里同时保留两类字段:
1. `tools/skills/memory_stores/permissions`
- 便于做状态展示、调试、轻量测试
2. `session_manager/tool_registry/...`
- 供真正的运行时主链直接使用
"""
workspace: Path
tools: list[str] = field(default_factory=list)
skills: list[str] = field(default_factory=list)
memory_stores: list[str] = field(default_factory=list)
permissions: list[str] = field(default_factory=list)
session_manager: SessionManager | None = None
curated_memory_store: MemoryStore | None = None
memory_service: MemoryService | None = None
tool_registry: ToolRegistry | None = None
tool_executor: ToolExecutor | None = None
context_builder: ContextBuilder | None = None
skills_loader: SkillsLoader | None = None
skill_assembler: SkillAssembler | None = None
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
closed: bool = False
def register_closeable(self, name: str, close_fn: Callable[[], None]) -> None:
"""登记一个由 runtime 统一关闭的资源。"""
self.closeables.append((name, close_fn))
def close(self) -> None:
"""按后进先出顺序关闭 runtime 资源。
这一步先保持同步、最小、可组合:
1. 只管理已经明确需要关闭的资源
2. 暂不引入 async shutdown 协议
3. 为后续 Web/Gateway lifespan 留统一入口
"""
if self.closed:
return
errors: list[tuple[str, BaseException]] = []
for name, close_fn in reversed(self.closeables):
try:
close_fn()
except BaseException as exc: # pragma: no cover - defensive cleanup path
errors.append((name, exc))
self.closed = True
if errors:
parts = ", ".join(f"{name}: {exc}" for name, exc in errors)
raise RuntimeError(f"Runtime shutdown failed for {parts}")
class EngineLoader:
"""为任意 Beaver agent 装载共享 runtime 能力。
当前先做“最小可运行主链”需要的装配:
- session manager
- curated memory store
- context builder
- built-in tools
- tool executor
等主链跑稳后,再把 skills、权限、MCP、delegation 逐步加进来。
"""
def __init__(
self,
*,
workspace: str | Path | None = None,
session_manager: SessionManager | None = None,
curated_memory_store: MemoryStore | None = None,
memory_service: MemoryService | None = None,
tool_registry: ToolRegistry | None = None,
context_builder: ContextBuilder | None = None,
skills_loader: SkillsLoader | None = None,
skill_assembler: SkillAssembler | None = None,
) -> None:
self.workspace = Path(workspace or Path.cwd())
self._session_manager = session_manager
self._curated_memory_store = curated_memory_store
self._memory_service = memory_service
self._tool_registry = tool_registry
self._context_builder = context_builder
self._skills_loader = skills_loader
self._skill_assembler = skill_assembler
def load(self) -> EngineLoadResult:
"""装配当前主链需要的最小 runtime 对象。"""
workspace = self.workspace
session_manager = self._session_manager or SessionManager(workspace)
curated_root = workspace / "memory" / "curated"
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
memory_service.initialize()
tool_registry = self._tool_registry or ToolRegistry()
skills_loader = self._skills_loader or SkillsLoader(workspace)
if self._tool_registry is None:
# 这里先注册最小工具集,满足主链的 tool loop。
tool_registry.register_many(
[
ObjectBackedTool(EchoTool()),
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
ObjectBackedTool(SessionSearchTool(db=session_manager)),
]
)
context_builder = self._context_builder or ContextBuilder()
tool_executor = ToolExecutor(tool_registry)
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
result = EngineLoadResult(
workspace=workspace,
tools=[spec.name for spec in tool_registry.list_specs()],
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
memory_stores=["curated"],
permissions=[],
session_manager=session_manager,
curated_memory_store=memory_service.get_store(),
memory_service=memory_service,
tool_registry=tool_registry,
tool_executor=tool_executor,
context_builder=context_builder,
skills_loader=skills_loader,
skill_assembler=skill_assembler,
)
if self._session_manager is None:
result.register_closeable("session_manager", session_manager.close)
return result

View File

@ -0,0 +1,689 @@
"""Unified agent loop used by all Beaver agents."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4
from beaver.engine.context import ContextBuildInput, SessionContext
from beaver.engine.providers import ProviderBundle, make_provider_bundle
from beaver.tools import ToolContext
from .loader import EngineLoader, EngineLoadResult
@dataclass(slots=True)
class AgentProfile:
"""Runtime profile for a Beaver agent instance."""
name: str = "default"
system_prompt: str = ""
default_model: str = "gpt-4.1-mini"
max_tokens: int = 4096
temperature: float = 0.2
max_tool_iterations: int = 8
@dataclass(slots=True)
class AgentRunResult:
"""一次 direct run 的最小结果结构。"""
session_id: str
run_id: str
output_text: str
finish_reason: str
tool_iterations: int
provider_name: str | None = None
model: str | None = None
usage: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class _DirectRunRequest:
"""运行循环中的单个 direct task。"""
task: str
kwargs: dict[str, Any]
future: asyncio.Future[AgentRunResult]
class AgentLoop:
"""Single execution kernel shared by root agents and delegated agents."""
def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None:
self.profile = profile or AgentProfile()
self.loader = loader or EngineLoader()
self.loaded: EngineLoadResult | None = None
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
self._running = False
self._stop_requested = False
def boot(self) -> EngineLoadResult:
"""Load shared runtime capabilities once for this agent instance."""
if self.loaded is None:
self.loaded = self.loader.load()
return self.loaded
@property
def is_running(self) -> bool:
return self._running
async def run(self) -> None:
"""启动最小运行循环,顺序消费提交进来的 direct tasks。
第一版故意保持克制:
1. 只做单消费者串行消费
2. 真正执行仍复用 `process_direct()`
3. 不引入 bus / worker / priority / retry
"""
if self._running:
raise RuntimeError("AgentLoop.run() is already active")
self.boot()
self._run_queue = asyncio.Queue()
self._running = True
self._stop_requested = False
try:
while True:
item = await self._run_queue.get()
if item is None:
if self._stop_requested:
break
continue
if item.future.cancelled():
continue
try:
result = await self._process_direct_impl(item.task, **item.kwargs)
except asyncio.CancelledError:
if not item.future.done():
item.future.cancel()
raise
except Exception as exc: # pragma: no cover - defensive queue path
if not item.future.done():
item.future.set_exception(exc)
else:
if not item.future.done():
item.future.set_result(result)
finally:
if self._run_queue is not None:
while True:
try:
pending = self._run_queue.get_nowait()
except asyncio.QueueEmpty:
break
if isinstance(pending, _DirectRunRequest) and not pending.future.done():
pending.future.set_exception(
RuntimeError("AgentLoop.run() stopped before processing the queued task")
)
self._running = False
self._stop_requested = False
self._run_queue = None
async def stop(self) -> None:
"""停止运行循环。
第一版语义:
- 不再接收新任务
- 当前已经取出的任务允许收尾
- 不自动 close runtime
"""
if not self._running or self._run_queue is None:
return
self._stop_requested = True
await self._run_queue.put(None)
async def submit_direct(
self,
task: str,
**kwargs: Any,
) -> AgentRunResult:
"""向运行中的 loop 提交一个 direct task并等待结果。"""
if not self._running or self._run_queue is None:
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
if self._stop_requested:
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
return await future
def close(self) -> None:
"""关闭当前 loop 持有的 runtime。
第 6 阶段先把生命周期最小骨架立住:
- `boot()` 负责建立 runtime
- `close()` 负责释放由 runtime 持有的资源
- 之后再在此基础上扩 `run()/stop()/shutdown hooks`
"""
if self._running:
raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first")
if self.loaded is None:
return
try:
self.loaded.close()
finally:
self.loaded = None
async def process_direct(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""跑通最小 direct run 主链。
当前主链刻意保持克制,只解决这些事情:
1. 确保 session 存在
2. 用 frozen memory + history 组 prompt
3. 调 provider
4. 若有 tool calls则进入最小 tool loop
5. 把 user/assistant/tool 消息和 usage 写回 session
"""
if self._running:
raise RuntimeError(
"AgentLoop.process_direct() is disabled while run() is active; "
"submit tasks via submit_direct() instead."
)
return await self._process_direct_impl(
task,
session_id=session_id,
source=source,
user_id=user_id,
title=title,
execution_context=execution_context,
model=model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model,
max_tokens=max_tokens,
temperature=temperature,
max_tool_iterations=max_tool_iterations,
provider_bundle=provider_bundle,
)
async def _process_direct_impl(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""真正执行一轮 direct run 的内部实现。
规则:
- 外部直接调用时走 `process_direct()`
- 运行循环内部消费时走 `_process_direct_impl()`
- 这样才能保证 run 模式下外部不能绕过队列直接执行
"""
loaded = self.boot()
session_manager = self._require_loaded("session_manager")
memory_service = self._require_loaded("memory_service")
context_builder = self._require_loaded("context_builder")
tool_registry = self._require_loaded("tool_registry")
tool_executor = self._require_loaded("tool_executor")
skill_assembler = self._require_loaded("skill_assembler")
resolved_session_id = session_id or uuid4().hex
resolved_run_id = uuid4().hex
resolved_model = model or self.profile.default_model
resolved_max_tokens = max_tokens or self.profile.max_tokens
resolved_temperature = self.profile.temperature if temperature is None else temperature
resolved_max_tool_iterations = (
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
)
# 每次新运行开始前都通过 MemoryService 刷新 live state。
# 这样 memory policy 会收口在 service而不是散在 loop 里。
memory_service.reload_for_new_run()
session_manager.ensure_session(
resolved_session_id,
source=source,
model=resolved_model,
title=title,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_started",
event_payload={
"source": source,
"model": resolved_model,
"agent_name": self.profile.name,
},
content=task,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = False
iterations = 0
final_usage: dict[str, Any] = {}
final_provider_name: str | None = provider_name
final_model: str | None = resolved_model
try:
bundle = provider_bundle or make_provider_bundle(
model=resolved_model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model or "text-embedding-v4",
)
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
skill_selector_model = (
bundle.auxiliary_runtime.model
if bundle.auxiliary_runtime is not None
else bundle.main_runtime.model
)
assembled_skills = await skill_assembler.assemble(
task_description=task,
provider=skill_selector_provider,
model=skill_selector_model,
embedding_runtime=bundle.embedding_runtime,
)
skill_activation_messages = context_builder.build_skill_activation_messages(
assembled_skills.activated_skills
)
if skill_activation_messages:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="skill_activation_snapshotted",
event_payload={
"activation_messages": skill_activation_messages,
},
content="\n\n".join(message["content"] for message in skill_activation_messages) or None,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
build_input = ContextBuildInput(
base_system_prompt=self.profile.system_prompt,
history=session_manager.get_history(resolved_session_id),
current_user_input=task,
memory_snapshot=memory_service.get_snapshot(),
activated_skills=assembled_skills.activated_skills,
session_context=SessionContext(
session_id=resolved_session_id,
source=source,
model=resolved_model,
user_id=user_id,
),
execution_context=execution_context,
)
context_result = context_builder.build_messages(build_input)
session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="system_prompt_snapshotted",
event_payload={
"source": source,
"model": resolved_model,
"system_prompt_length": len(context_result.system_prompt),
},
content=context_result.system_prompt,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = True
provider = bundle.main_provider
messages = list(context_result.messages)
tool_schemas = tool_registry.export_provider_schemas()
tool_context = ToolContext(
workspace=str(loaded.workspace),
session_id=resolved_session_id,
user_id=user_id,
services={
"session_manager": session_manager,
"memory_service": memory_service,
"memory_store": memory_service.get_store(),
"tool_registry": tool_registry,
},
metadata={
"source": source,
"agent_name": self.profile.name,
},
)
final_text = ""
final_finish_reason = "stop"
final_provider_name = bundle.main_runtime.provider_name
final_model = bundle.main_runtime.model
while True:
response = await provider.chat(
messages=messages,
tools=tool_schemas,
model=final_model,
max_tokens=resolved_max_tokens,
temperature=resolved_temperature,
)
final_provider_name = response.provider_name or final_provider_name
final_model = response.model or final_model
final_usage = self._merge_usage(final_usage, response.usage or {})
self._record_usage(session_manager, resolved_session_id, response.usage or {})
assistant_tool_calls = self._serialize_tool_calls(response.tool_calls)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=response.content,
tool_calls=assistant_tool_calls or None,
finish_reason=response.finish_reason,
reasoning=response.reasoning_content,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=response.content,
tool_calls=assistant_tool_calls or None,
reasoning_content=response.reasoning_content,
)
if not response.has_tool_calls:
final_text = response.content or ""
final_finish_reason = response.finish_reason or "stop"
break
if iterations >= resolved_max_tool_iterations:
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
final_finish_reason = "max_tool_iterations"
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=final_text,
finish_reason=final_finish_reason,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=final_text,
)
break
iterations += 1
for tool_call in response.tool_calls:
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="tool",
event_type="tool_result_recorded",
event_payload={
"success": result.success,
"error": result.error,
},
content=result.content,
tool_name=result.tool_name,
tool_call_id=tool_call.id,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_tool_result(
messages,
tool_call_id=tool_call.id,
tool_name=result.tool_name,
result=result.content,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_completed",
event_payload={
"finish_reason": final_finish_reason,
"tool_iterations": iterations,
},
content=final_text,
finish_reason=final_finish_reason,
context_visible=False,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
return AgentRunResult(
session_id=resolved_session_id,
run_id=resolved_run_id,
output_text=final_text,
finish_reason=final_finish_reason,
tool_iterations=iterations,
provider_name=final_provider_name,
model=final_model,
usage=final_usage,
)
except Exception as exc:
if not user_message_recorded:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
return self._build_error_result(
session_manager=session_manager,
session_id=resolved_session_id,
run_id=resolved_run_id,
source=source,
title=title,
user_id=user_id,
model=final_model or resolved_model,
message=f"Run failed before completion: {exc}",
tool_iterations=iterations,
provider_name=final_provider_name,
usage=final_usage,
)
def _require_loaded(self, field_name: str) -> Any:
loaded = self.boot()
value = getattr(loaded, field_name)
if value is None:
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
return value
@staticmethod
def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
payload: list[dict[str, Any]] = []
for tool_call in tool_calls:
payload.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
)
return payload
@staticmethod
def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None:
"""把 provider usage 映射到 session usage 字段。
这里先做最常见字段的最小映射:
- prompt_tokens -> input_tokens
- completion_tokens -> output_tokens
后面如果 provider 层补了更细的 cache/reasoning/cost再往这里扩。
"""
if not usage:
return
session_manager.update_usage(
session_id,
input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0),
output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0),
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
)
@staticmethod
def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]:
"""把多轮 provider usage 合并成一次 run 的累计 usage。"""
merged = dict(total)
for key, value in delta.items():
if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)):
merged[key] = merged.get(key, 0) + value
else:
merged[key] = value
return merged
@staticmethod
def _build_error_result(
*,
session_manager: Any,
session_id: str,
run_id: str,
source: str,
title: str | None,
user_id: str | None,
model: str | None,
message: str,
tool_iterations: int,
provider_name: str | None,
usage: dict[str, Any],
) -> AgentRunResult:
"""把主链中的未处理异常收口成可追踪的 assistant error turn。"""
session_manager.append_message(
session_id,
run_id=run_id,
role="assistant",
event_type="assistant_message_added",
content=message,
finish_reason="error",
source=source,
title=title,
model=model,
user_id=user_id,
)
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="run_failed",
event_payload={
"tool_iterations": tool_iterations,
"provider_name": provider_name,
},
content=message,
finish_reason="error",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
return AgentRunResult(
session_id=session_id,
run_id=run_id,
output_text=message,
finish_reason="error",
tool_iterations=tool_iterations,
provider_name=provider_name,
model=model,
usage=usage,
)

View File

@ -0,0 +1,33 @@
"""LLM provider adapters."""
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .chain import FallbackProviderChain
from .factory import (
ProviderBundle,
ProviderRoutingConfig,
ProviderRuntime,
ProviderTarget,
build_provider_runtime,
make_aux_provider,
make_fallback_provider,
make_main_provider,
make_provider_bundle,
make_provider_from_runtime,
)
__all__ = [
"FallbackProviderChain",
"LLMProvider",
"LLMResponse",
"ProviderBundle",
"ProviderRoutingConfig",
"ProviderRuntime",
"ProviderTarget",
"ToolCallRequest",
"build_provider_runtime",
"make_aux_provider",
"make_fallback_provider",
"make_main_provider",
"make_provider_bundle",
"make_provider_from_runtime",
]

View File

@ -0,0 +1,173 @@
"""Native Anthropic Messages API provider."""
from __future__ import annotations
import json
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import anthropic
except ModuleNotFoundError: # pragma: no cover
anthropic = None # type: ignore[assignment]
class AnthropicProvider(LLMProvider):
"""使用 Anthropic 原生 Messages API而不是强行走 OpenAI-compatible path。"""
def __init__(
self,
api_key: str | None = None,
default_model: str = "claude-sonnet-4-5",
api_base: str | None = None,
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self._client = None
def _client_or_raise(self):
if anthropic is None:
raise RuntimeError("anthropic package is not installed")
if self._client is None:
self._client = anthropic.AsyncAnthropic(
api_key=self.api_key,
base_url=self.api_base,
timeout=self.request_timeout_seconds,
)
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
try:
client = self._client_or_raise()
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
system_prompt, anthropic_messages = _convert_messages(messages)
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"system": system_prompt or "",
"messages": anthropic_messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if tools:
kwargs["tools"] = _convert_tools(tools)
try:
response = await client.messages.create(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
for block in response.content:
if block.type == "text":
content_parts.append(block.text)
elif block.type == "tool_use":
tool_calls.append(
ToolCallRequest(
id=block.id,
name=block.name,
arguments=block.input,
)
)
usage_payload = {}
if getattr(response, "usage", None):
usage_payload = {
"input_tokens": getattr(response.usage, "input_tokens", 0),
"output_tokens": getattr(response.usage, "output_tokens", 0),
}
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
usage=usage_payload,
provider_name="anthropic",
model=model or self.default_model,
)
def get_default_model(self) -> str:
return self.default_model
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
converted: list[dict[str, Any]] = []
for message in messages:
role = message.get("role")
if role == "system":
content = message.get("content")
system_prompt = content if isinstance(content, str) else ""
continue
if role == "tool":
converted.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message.get("tool_call_id"),
"content": message.get("content") or "",
}
],
}
)
continue
if role == "assistant" and message.get("tool_calls"):
content_blocks: list[dict[str, Any]] = []
if message.get("content"):
content_blocks.append({"type": "text", "text": message["content"]})
for tool_call in message.get("tool_calls", []):
function = tool_call.get("function", tool_call)
arguments = function.get("arguments")
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
content_blocks.append(
{
"type": "tool_use",
"id": tool_call.get("id"),
"name": function.get("name"),
"input": arguments or {},
}
)
converted.append({"role": "assistant", "content": content_blocks})
continue
content = message.get("content")
if isinstance(content, list):
blocks = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
blocks.append({"type": "text", "text": item.get("text", "")})
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
else:
converted.append({"role": role, "content": content or ""})
return system_prompt, converted
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
if not fn.get("name"):
continue
converted.append(
{
"name": fn["name"],
"description": fn.get("description") or "",
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
}
)
return converted

View File

@ -0,0 +1,98 @@
"""Beaver provider 子系统的统一契约。"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class ToolCallRequest:
"""模型返回的一次工具调用请求。"""
id: str
name: str
arguments: dict[str, Any]
@dataclass(slots=True)
class LLMResponse:
"""统一的模型响应结构。"""
content: str | None
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, Any] = field(default_factory=dict)
reasoning_content: str | None = None
provider_name: str | None = None
model: str | None = None
@property
def has_tool_calls(self) -> bool:
return bool(self.tool_calls)
class LLMProvider(ABC):
"""所有 provider 实现必须遵守的统一接口。"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
) -> None:
self.api_key = api_key
self.api_base = api_base
self.request_timeout_seconds = (
max(1.0, float(request_timeout_seconds))
if request_timeout_seconds is not None
else None
)
@staticmethod
def sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""清理 provider 普遍不接受的空 content。"""
result: list[dict[str, Any]] = []
for message in messages:
content = message.get("content")
if isinstance(content, str) and content == "":
clean = dict(message)
clean["content"] = None if (message.get("role") == "assistant" and message.get("tool_calls")) else "(empty)"
result.append(clean)
continue
if isinstance(content, list):
filtered = [
item
for item in content
if not (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
)
]
if len(filtered) != len(content):
clean = dict(message)
clean["content"] = filtered or "(empty)"
if message.get("role") == "assistant" and message.get("tool_calls") and not filtered:
clean["content"] = None
result.append(clean)
continue
result.append(message)
return result
@abstractmethod
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""统一聊天接口。"""
@abstractmethod
def get_default_model(self) -> str:
"""返回 provider 的默认模型名。"""

View File

@ -0,0 +1,145 @@
"""Provider chain helpers.
这里先实现最小可用的 fallback chain
- 每次调用都先尝试主 provider
- 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback
- fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路
这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。
"""
from __future__ import annotations
from .base import LLMProvider, LLMResponse
from .runtime import ProviderRuntime
class FallbackProviderChain(LLMProvider):
"""把 primary/fallback provider 封装成一个统一的 LLMProvider。"""
def __init__(
self,
primary_runtime: ProviderRuntime,
primary_provider: LLMProvider,
fallback_runtime: ProviderRuntime | None = None,
fallback_provider: LLMProvider | None = None,
) -> None:
super().__init__(
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
)
self.primary_runtime = primary_runtime
self.primary_provider = primary_provider
self.fallback_runtime = fallback_runtime
self.fallback_provider = fallback_provider
# 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。
# 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。
self._last_runtime = primary_runtime
self._last_provider = primary_provider
self._last_call_used_fallback = False
@property
def fallback_activated(self) -> bool:
"""最近一次 chat 是否实际用到了 fallback。"""
return self._last_call_used_fallback
@property
def active_runtime(self) -> ProviderRuntime:
"""最近一次 chat 实际使用的 runtime。"""
return self._last_runtime
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
self._last_provider = self.primary_provider
self._last_runtime = self.primary_runtime
self._last_call_used_fallback = False
response = await self._safe_chat(
self.primary_provider,
self.primary_runtime,
messages=messages,
tools=tools,
model=model or self.primary_runtime.model,
max_tokens=max_tokens,
temperature=temperature,
)
response = self._decorate_response(response, self.primary_runtime)
if not self._should_activate_fallback(response):
return response
assert self.fallback_provider is not None
assert self.fallback_runtime is not None
self._last_provider = self.fallback_provider
self._last_runtime = self.fallback_runtime
self._last_call_used_fallback = True
response = await self._safe_chat(
self.fallback_provider,
self.fallback_runtime,
messages=messages,
tools=tools,
model=self.fallback_runtime.model,
max_tokens=max_tokens,
temperature=temperature,
)
return self._decorate_response(response, self.fallback_runtime)
def get_default_model(self) -> str:
return self.primary_runtime.model
def _should_activate_fallback(self, response: LLMResponse) -> bool:
return (
self.fallback_provider is not None
and self.fallback_runtime is not None
and response.finish_reason == "error"
)
@staticmethod
async def _safe_chat(
provider: LLMProvider,
runtime: ProviderRuntime,
*,
messages: list[dict],
tools: list[dict] | None,
model: str,
max_tokens: int,
temperature: float,
) -> LLMResponse:
"""把 provider 抛出的异常也收敛成统一 error response。
这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。
"""
try:
return await provider.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
)
except Exception as exc:
return LLMResponse(
content=f"Error: {exc}",
finish_reason="error",
provider_name=runtime.provider_name,
model=runtime.model,
)
@staticmethod
def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse:
if response.provider_name is None:
response.provider_name = runtime.provider_name
if response.model is None:
response.model = runtime.model
return response

View File

@ -0,0 +1,274 @@
"""OpenAI Codex Responses provider."""
from __future__ import annotations
import asyncio
import hashlib
import json
from typing import Any, AsyncGenerator
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import httpx
except ModuleNotFoundError: # pragma: no cover
httpx = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
from oauth_cli_kit import get_token as get_codex_token
except ModuleNotFoundError: # pragma: no cover
get_codex_token = None # type: ignore[assignment]
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "beaver"
class OpenAICodexProvider(LLMProvider):
"""使用 Codex OAuth 调用 Responses API。"""
def __init__(
self,
default_model: str = "openai-codex/gpt-5.1-codex",
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
if httpx is None or get_codex_token is None:
return LLMResponse(content="Error: codex dependencies are not installed", finish_reason="error", provider_name="openai_codex")
resolved_model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access)
body: dict[str, Any] = {
"model": _strip_model_prefix(resolved_model),
"store": False,
"stream": True,
"instructions": system_prompt,
"input": input_items,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
"tool_choice": "auto",
"parallel_tool_calls": True,
}
if tools:
body["tools"] = _convert_tools(tools)
try:
content, tool_calls, finish_reason = await _request_codex(
DEFAULT_CODEX_URL,
headers,
body,
verify=True,
timeout_seconds=self.request_timeout_seconds or 600.0,
)
except Exception as exc:
return LLMResponse(content=f"Error calling Codex: {exc}", finish_reason="error", provider_name="openai_codex")
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason=finish_reason,
provider_name="openai_codex",
model=resolved_model,
)
def get_default_model(self) -> str:
return self.default_model
def _strip_model_prefix(model: str) -> str:
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
return model.split("/", 1)[1]
return model
def _build_headers(account_id: str, token: str) -> dict[str, str]:
return {
"Authorization": f"Bearer {token}",
"chatgpt-account-id": account_id,
"OpenAI-Beta": "responses=experimental",
"originator": DEFAULT_ORIGINATOR,
"User-Agent": "beaver (python)",
"accept": "text/event-stream",
"content-type": "application/json",
}
async def _request_codex(
url: str,
headers: dict[str, str],
body: dict[str, Any],
verify: bool,
timeout_seconds: float,
) -> tuple[str, list[ToolCallRequest], str]:
async with httpx.AsyncClient(timeout=timeout_seconds, verify=verify) as client:
async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append(
{
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
}
)
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for index, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed",
"id": f"msg_{index}",
}
)
for tool_call in message.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append(
{
"type": "function_call",
"id": item_id or f"fc_{index}",
"call_id": call_id or f"call_{index}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
}
)
continue
if role == "tool":
call_id, _ = _split_tool_call_id(message.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append(
{
"type": "function_call_output",
"call_id": call_id,
"output": output_text,
}
)
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: Any) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [item[5:].strip() for item in buffer if item.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(response: Any) -> tuple[str, list[ToolCallRequest], str]:
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_text.delta":
delta = event.get("delta") or ""
content_parts.append(delta)
elif event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
raw_arguments = item.get("arguments") or "{}"
try:
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
except json.JSONDecodeError:
arguments = {}
tool_calls.append(
ToolCallRequest(
id=f"{item.get('call_id', 'call')}|{item.get('id', '')}",
name=item.get("name", ""),
arguments=arguments,
)
)
elif event_type == "response.completed":
finish_reason = event.get("response", {}).get("status", "completed")
return "".join(content_parts) or None, tool_calls, finish_reason
def _friendly_error(status_code: int, body: str) -> str:
return f"Codex API error ({status_code}): {body[:400]}"

View File

@ -0,0 +1,106 @@
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
from __future__ import annotations
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
try: # pragma: no cover - optional dependency
import json_repair
except ModuleNotFoundError: # pragma: no cover
json_repair = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
from openai import AsyncOpenAI
except ModuleNotFoundError: # pragma: no cover
AsyncOpenAI = None # type: ignore[assignment]
class CustomProvider(LLMProvider):
"""直接连接任意 OpenAI-compatible endpoint。"""
def __init__(
self,
api_key: str = "no-key",
api_base: str = "http://localhost:8000/v1",
default_model: str = "default",
request_timeout_seconds: float | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self._client = None
def _client_or_raise(self):
if AsyncOpenAI is None:
raise RuntimeError("openai package is not installed")
if self._client is None:
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
timeout=self.request_timeout_seconds,
)
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
client = self._client_or_raise()
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self.sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if tools:
kwargs.update(tools=tools, tool_choice="auto")
try:
response = await client.chat.completions.create(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom")
choice = response.choices[0]
message = choice.message
parsed_tool_calls: list[ToolCallRequest] = []
for tool_call in message.tool_calls or []:
raw_arguments = tool_call.function.arguments
if isinstance(raw_arguments, str):
if json_repair is not None:
arguments = json_repair.loads(raw_arguments)
else:
import json
arguments = json.loads(raw_arguments)
else:
arguments = raw_arguments
parsed_tool_calls.append(
ToolCallRequest(
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
)
)
usage = getattr(response, "usage", None)
usage_payload = {}
if usage is not None:
usage_payload = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return LLMResponse(
content=message.content,
tool_calls=parsed_tool_calls,
finish_reason=choice.finish_reason or "stop",
usage=usage_payload,
reasoning_content=getattr(message, "reasoning_content", None),
provider_name="custom",
model=model or self.default_model,
)
def get_default_model(self) -> str:
return self.default_model

View File

@ -0,0 +1,235 @@
"""Provider runtime 的统一工厂入口。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .anthropic import AnthropicProvider
from .base import LLMProvider
from .chain import FallbackProviderChain
from .codex import OpenAICodexProvider
from .custom import CustomProvider
from .litellm import LiteLLMProvider
from .runtime import (
ProviderRoutingConfig,
ProviderRuntime,
ProviderTarget,
normalize_provider_target,
resolve_auxiliary_runtime,
resolve_embedding_runtime,
resolve_fallback_runtime,
resolve_provider_runtime,
)
@dataclass(slots=True)
class ProviderBundle:
"""一次运行所需的 provider 组合。
这里把三条常见链路收口到一起:
- `main`:主对话
- `fallback`:主链失败后的备用 provider
- `auxiliary`搜索摘要、压缩、memory flush 等辅助任务
"""
main_runtime: ProviderRuntime
main_provider: LLMProvider
fallback_runtime: ProviderRuntime | None = None
fallback_provider: LLMProvider | None = None
auxiliary_runtime: ProviderRuntime | None = None
auxiliary_provider: LLMProvider | None = None
embedding_runtime: ProviderRuntime | None = None
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
"""构建统一 provider runtime。"""
return resolve_provider_runtime(**kwargs)
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
"""根据 runtime 创建具体 provider 实例。"""
if runtime.spec.provider_impl == "custom":
return CustomProvider(
api_key=runtime.api_key or "no-key",
api_base=runtime.api_base or "http://localhost:8000/v1",
default_model=runtime.default_model or runtime.model,
request_timeout_seconds=runtime.request_timeout_seconds,
)
if runtime.spec.provider_impl == "codex":
return OpenAICodexProvider(
default_model=runtime.default_model or runtime.model,
request_timeout_seconds=runtime.request_timeout_seconds,
)
if runtime.spec.provider_impl == "anthropic":
return AnthropicProvider(
api_key=runtime.api_key,
default_model=runtime.default_model or runtime.model,
api_base=runtime.api_base,
request_timeout_seconds=runtime.request_timeout_seconds,
)
return LiteLLMProvider(
api_key=runtime.api_key,
api_base=runtime.api_base,
default_model=runtime.default_model or runtime.model,
provider_name=runtime.provider_name,
extra_headers=runtime.extra_headers,
request_timeout_seconds=runtime.request_timeout_seconds,
routing=runtime.routing,
)
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
"""构建主对话 provider。"""
fallback_target = kwargs.pop("fallback_target", None)
if fallback_target is None and "fallback_model" in kwargs:
fallback_target = kwargs.pop("fallback_model")
runtime = build_provider_runtime(
auxiliary=False,
fallback_target=fallback_target,
role="main",
source="main_config",
**kwargs,
)
provider = make_provider_from_runtime(runtime)
fallback_pair = make_fallback_provider(runtime, fallback_target)
if fallback_pair is None:
return runtime, provider
fallback_runtime, fallback_provider = fallback_pair
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
def make_fallback_provider(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
) -> tuple[ProviderRuntime, LLMProvider] | None:
"""构建 fallback provider。"""
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
if runtime is None:
return None
return runtime, make_provider_from_runtime(runtime)
def make_aux_provider(
main_runtime: ProviderRuntime | None = None,
*,
target: ProviderTarget | dict[str, Any] | None = None,
task_name: str = "auxiliary",
**kwargs: Any,
) -> tuple[ProviderRuntime, LLMProvider]:
"""构建辅助任务 provider。"""
if target is None and kwargs:
target = kwargs
if main_runtime is not None:
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
else:
normalized = normalize_provider_target(target)
if normalized is None or not normalized.model:
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
runtime = build_provider_runtime(
model=normalized.model,
provider_name=normalized.provider_name,
api_key=normalized.api_key,
api_base=normalized.api_base,
request_timeout_seconds=normalized.request_timeout_seconds,
extra_headers=normalized.extra_headers,
routing=normalized.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
return runtime, make_provider_from_runtime(runtime)
def make_embedding_runtime(
main_runtime: ProviderRuntime,
*,
target: ProviderTarget | dict[str, Any] | None = None,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""构建 embedding 专用 runtime。"""
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
def make_provider_bundle(
*,
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary_task_name: str = "auxiliary",
embedding_target: ProviderTarget | dict[str, Any] | None = None,
embedding_model: str = "text-embedding-v4",
**kwargs: Any,
) -> ProviderBundle:
"""一次性构建 main/fallback/aux 三条 provider 链。"""
runtime_kwargs = dict(kwargs)
fallback_target = runtime_kwargs.pop("fallback_target", None)
if fallback_target is None and "fallback_model" in kwargs:
fallback_target = runtime_kwargs.pop("fallback_model")
main_runtime = build_provider_runtime(
auxiliary=False,
fallback_target=fallback_target,
role="main",
source="main_config",
**runtime_kwargs,
)
primary_provider = make_provider_from_runtime(main_runtime)
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
if fallback_pair is None:
main_provider: LLMProvider = primary_provider
fallback_runtime = None
fallback_provider = None
else:
fallback_runtime, fallback_provider = fallback_pair
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
auxiliary_runtime = None
auxiliary_provider = None
if auxiliary_target is not None:
auxiliary_runtime, auxiliary_provider = make_aux_provider(
main_runtime,
target=auxiliary_target,
task_name=auxiliary_task_name,
)
embedding_runtime = make_embedding_runtime(
main_runtime,
target=embedding_target,
default_model=embedding_model,
)
return ProviderBundle(
main_runtime=main_runtime,
main_provider=main_provider,
fallback_runtime=fallback_runtime,
fallback_provider=fallback_provider,
auxiliary_runtime=auxiliary_runtime,
auxiliary_provider=auxiliary_provider,
embedding_runtime=embedding_runtime,
)
__all__ = [
"ProviderBundle",
"ProviderRoutingConfig",
"ProviderRuntime",
"ProviderTarget",
"build_provider_runtime",
"make_aux_provider",
"make_embedding_runtime",
"make_fallback_provider",
"make_main_provider",
"make_provider_bundle",
"make_provider_from_runtime",
]

View File

@ -0,0 +1,230 @@
"""LiteLLM provider implementation for multi-provider support."""
from __future__ import annotations
from contextlib import contextmanager
import json
import os
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .registry import find_by_model, find_gateway
from .runtime import ProviderRoutingConfig
try: # pragma: no cover - optional dependency
import json_repair
except ModuleNotFoundError: # pragma: no cover
json_repair = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
import litellm
from litellm import acompletion
except ModuleNotFoundError: # pragma: no cover
litellm = None # type: ignore[assignment]
acompletion = None # type: ignore[assignment]
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class LiteLLMProvider(LLMProvider):
"""通过 LiteLLM 统一访问大多数 provider。"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
provider_name: str | None = None,
request_timeout_seconds: float | None = None,
routing: ProviderRoutingConfig | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self.extra_headers = extra_headers or {}
self.routing = routing
self.provider_name = provider_name
self._gateway = find_gateway(provider_name, api_key, api_base)
if litellm is not None:
litellm.suppress_debug_info = True
litellm.drop_params = True
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
"""
if not api_key:
return {}
spec = self._gateway or find_by_model(model)
if spec is None or not spec.env_key:
return {}
overrides: dict[str, str] = {spec.env_key: api_key}
effective_base = api_base or spec.default_api_base
for env_name, env_value in spec.env_extras:
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
overrides[env_name] = resolved
return overrides
@contextmanager
def _temporary_env(self, overrides: dict[str, str]):
"""只在当前请求期间注入 provider 需要的环境变量。"""
if not overrides:
yield
return
sentinel = object()
previous: dict[str, object] = {}
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
try:
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = str(old_value)
def _resolve_model(self, model: str) -> str:
if self._gateway:
prefix = self._gateway.litellm_prefix
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
spec = find_by_model(model)
if spec and spec.litellm_prefix:
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
model = f"{spec.litellm_prefix}/{model}"
return model
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
sanitized = []
for message in messages:
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
spec = find_by_model(original_model)
if spec is None:
return
model_lower = original_model.lower()
for pattern, overrides in spec.model_overrides:
if pattern in model_lower:
kwargs.update(overrides)
return
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
if self.provider_name != "openrouter" or self.routing is None:
return
provider_payload: dict[str, Any] = {}
if self.routing.sort:
provider_payload["sort"] = self.routing.sort
if self.routing.only:
provider_payload["only"] = self.routing.only
if self.routing.ignore:
provider_payload["ignore"] = self.routing.ignore
if self.routing.order:
provider_payload["order"] = self.routing.order
if self.routing.require_parameters:
provider_payload["require_parameters"] = True
if self.routing.data_collection:
provider_payload["data_collection"] = self.routing.data_collection
if provider_payload:
kwargs["provider"] = provider_payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
if acompletion is None:
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
original_model = model or self.default_model
resolved_model = self._resolve_model(original_model)
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
kwargs: dict[str, Any] = {
"model": resolved_model,
"messages": sanitized_messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
self._apply_model_overrides(original_model, kwargs)
self._apply_openrouter_routing(kwargs)
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
try:
with self._temporary_env(env_overrides):
response = await acompletion(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
choice = response.choices[0]
message = choice.message
tool_calls: list[ToolCallRequest] = []
for tool_call in message.tool_calls or []:
raw_arguments = tool_call.function.arguments
if isinstance(raw_arguments, str):
try:
if json_repair is not None:
arguments = json_repair.loads(raw_arguments)
else:
arguments = json.loads(raw_arguments)
except Exception as exc:
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
arguments = {
"__beaver_tool_argument_parse_error__": str(exc),
"__raw_arguments__": raw_arguments,
}
else:
arguments = raw_arguments
tool_calls.append(
ToolCallRequest(
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
)
)
usage = getattr(response, "usage", None)
usage_payload = {}
if usage is not None:
usage_payload = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return LLMResponse(
content=getattr(message, "content", None),
tool_calls=tool_calls,
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
usage=usage_payload,
reasoning_content=getattr(message, "reasoning_content", None),
provider_name=self.provider_name or "litellm",
model=resolved_model,
)
def get_default_model(self) -> str:
return self.default_model

View File

@ -0,0 +1,249 @@
"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class ProviderSpec:
"""单个 provider 的元数据定义。"""
name: str
keywords: tuple[str, ...]
env_key: str
display_name: str = ""
litellm_prefix: str = ""
skip_prefixes: tuple[str, ...] = ()
env_extras: tuple[tuple[str, str], ...] = ()
is_gateway: bool = False
is_local: bool = False
detect_by_key_prefix: str = ""
detect_by_base_keyword: str = ""
default_api_base: str = ""
strip_model_prefix: bool = False
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
is_oauth: bool = False
is_direct: bool = False
supports_prompt_caching: bool = False
api_mode: str = "chat_completions"
provider_impl: str = "litellm"
@property
def label(self) -> str:
return self.display_name or self.name.title()
PROVIDERS: tuple[ProviderSpec, ...] = (
ProviderSpec(
name="custom",
keywords=(),
env_key="",
display_name="Custom",
is_direct=True,
provider_impl="custom",
api_mode="chat_completions",
),
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter",
is_gateway=True,
detect_by_key_prefix="sk-or-",
detect_by_base_keyword="openrouter",
default_api_base="https://openrouter.ai/api/v1",
supports_prompt_caching=True,
),
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY",
display_name="AiHubMix",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True,
),
ProviderSpec(
name="siliconflow",
keywords=("siliconflow",),
env_key="OPENAI_API_KEY",
display_name="SiliconFlow",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="siliconflow",
default_api_base="https://api.siliconflow.cn/v1",
),
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
env_key="OPENAI_API_KEY",
display_name="VolcEngine",
litellm_prefix="volcengine",
is_gateway=True,
detect_by_base_keyword="volces",
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
),
ProviderSpec(
name="anthropic",
keywords=("anthropic", "claude"),
env_key="ANTHROPIC_API_KEY",
display_name="Anthropic",
supports_prompt_caching=True,
api_mode="anthropic_messages",
provider_impl="anthropic",
),
ProviderSpec(
name="openai",
keywords=("openai", "gpt"),
env_key="OPENAI_API_KEY",
display_name="OpenAI",
),
ProviderSpec(
name="openai_codex",
keywords=("openai-codex", "codex"),
env_key="",
display_name="OpenAI Codex",
is_oauth=True,
detect_by_base_keyword="codex",
default_api_base="https://chatgpt.com/backend-api",
api_mode="codex_responses",
provider_impl="codex",
),
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="",
display_name="Github Copilot",
litellm_prefix="github_copilot",
skip_prefixes=("github_copilot/",),
is_oauth=True,
),
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
litellm_prefix="deepseek",
skip_prefixes=("deepseek/",),
),
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini",
skip_prefixes=("gemini/",),
),
ProviderSpec(
name="zhipu",
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai",
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
),
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope",
skip_prefixes=("dashscope/", "openrouter/"),
),
ProviderSpec(
name="moonshot",
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot",
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
default_api_base="https://api.moonshot.ai/v1",
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
ProviderSpec(
name="minimax",
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax",
skip_prefixes=("minimax/", "openrouter/"),
default_api_base="https://api.minimax.io/v1",
),
ProviderSpec(
name="vllm",
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
litellm_prefix="hosted_vllm",
is_local=True,
),
ProviderSpec(
name="groq",
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq",
skip_prefixes=("groq/",),
),
)
def find_by_name(name: str) -> ProviderSpec | None:
for spec in PROVIDERS:
if spec.name == name:
return spec
return None
def find_by_model(model: str) -> ProviderSpec | None:
"""按模型名关键词匹配标准 provider。"""
model_lower = model.lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
normalized_prefix = model_prefix.replace("-", "_")
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
# 显式前缀优先级最高。
# 这里不能只看 standard provider
# - `openrouter/...` 应该直接命中 openrouter
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
# - `github_copilot/...codex` 也不应被误判成 openai_codex
for spec in PROVIDERS:
aliases = {spec.name}
if spec.litellm_prefix:
aliases.add(spec.litellm_prefix.replace("-", "_"))
if model_prefix and normalized_prefix in aliases:
return spec
for spec in standard_specs:
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
return spec
return None
def find_gateway(
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
) -> ProviderSpec | None:
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
if provider_name:
spec = find_by_name(provider_name)
if spec and (spec.is_gateway or spec.is_local):
return spec
for spec in PROVIDERS:
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
return spec
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
return spec
return None

View File

@ -0,0 +1,408 @@
"""Hermes 风格的 provider runtime resolution。"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Any
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
@dataclass(slots=True)
class ProviderRoutingConfig:
"""OpenRouter provider routing 配置。"""
sort: str | None = None
only: list[str] = field(default_factory=list)
ignore: list[str] = field(default_factory=list)
order: list[str] = field(default_factory=list)
require_parameters: bool = False
data_collection: str | None = None
@dataclass(slots=True)
class ProviderTarget:
"""一次 provider 选路请求的标准化配置。
这层不是具体 runtime而是“调用方想要什么”
- 用哪个 provider
- 跑哪个 model
- 是否指定自定义 base_url
- 是否带额外 headers / routing
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
"""
provider_name: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
routing: ProviderRoutingConfig | None = None
@dataclass(slots=True)
class ProviderRuntime:
"""运行时真正使用的 provider 解析结果。"""
spec: ProviderSpec
model: str
provider_name: str
api_mode: str
api_key: str | None = None
api_base: str | None = None
default_model: str | None = None
request_timeout_seconds: float | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
routing: ProviderRoutingConfig | None = None
fallback_target: ProviderTarget | None = None
auxiliary: bool = False
role: str = "main"
source: str = "runtime"
def resolve_provider_runtime(
*,
model: str,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
extra_headers: dict[str, str] | None = None,
routing: ProviderRoutingConfig | None = None,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary: bool = False,
role: str = "main",
source: str = "runtime",
) -> ProviderRuntime:
"""把调用侧传入的配置解析成统一 runtime。"""
gateway = find_gateway(provider_name, api_key, api_base)
if gateway is not None:
spec = gateway
elif provider_name:
spec = find_by_name(provider_name)
else:
spec = find_by_model(model)
if spec is None:
if api_base:
spec = find_by_name("custom")
else:
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
resolved_api_base = api_base or spec.default_api_base or None
return ProviderRuntime(
spec=spec,
model=resolved_model,
provider_name=spec.name,
api_mode=spec.api_mode,
api_key=api_key,
api_base=resolved_api_base,
default_model=resolved_model,
request_timeout_seconds=request_timeout_seconds,
extra_headers=extra_headers or {},
routing=routing,
fallback_target=normalize_provider_target(fallback_target),
auxiliary=auxiliary,
role=role,
source=source,
)
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
"""把 dict/对象形式的 provider 配置收敛成统一结构。
这里兼容几种常见写法,便于后续接 CLI / config / gateway
- `provider` 或 `provider_name`
- `base_url` 或 `api_base`
- `headers` 或 `extra_headers`
- `timeout` 或 `request_timeout_seconds`
"""
if target is None:
return None
if isinstance(target, ProviderTarget):
return target
provider_name = target.get("provider_name")
if provider_name is None:
provider_name = target.get("provider")
api_base = target.get("api_base")
if api_base is None:
api_base = target.get("base_url")
extra_headers = target.get("extra_headers")
if extra_headers is None:
extra_headers = target.get("headers")
request_timeout_seconds = target.get("request_timeout_seconds")
if request_timeout_seconds is None:
request_timeout_seconds = target.get("timeout")
routing = target.get("routing")
if isinstance(routing, dict):
routing = ProviderRoutingConfig(**routing)
return ProviderTarget(
provider_name=provider_name,
model=target.get("model"),
api_key=target.get("api_key"),
api_base=api_base,
extra_headers=dict(extra_headers or {}),
request_timeout_seconds=request_timeout_seconds,
routing=routing,
)
def resolve_fallback_runtime(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None,
) -> ProviderRuntime | None:
"""把 fallback 配置解析成独立 runtime。
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
"""
target = normalize_provider_target(fallback_target)
if target is None or not target.model:
return None
inferred_provider = target.provider_name
if inferred_provider in {None, "", "main"}:
inferred_provider = primary_runtime.provider_name
api_key = target.api_key
api_base = target.api_base
extra_headers = dict(target.extra_headers)
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
if inferred_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=target.model,
provider_name=inferred_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=target.routing,
auxiliary=False,
role="fallback",
source="fallback_config",
)
def resolve_auxiliary_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
task_name: str = "auxiliary",
) -> ProviderRuntime:
"""解析辅助任务专用 runtime。
支持三类输入:
- `None` / `provider=main`:直接复用主链 provider
- 显式 `provider + model`:走独立 provider
- 仅给 `model`:按模型名自动匹配 provider
"""
normalized = normalize_provider_target(target)
if normalized is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
)
provider_name = normalized.provider_name
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
routing=normalized.routing or primary_runtime.routing,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
)
if provider_name == "main":
return resolve_provider_runtime(
model=normalized.model or primary_runtime.model,
provider_name=primary_runtime.provider_name,
api_key=normalized.api_key or primary_runtime.api_key,
api_base=normalized.api_base or primary_runtime.api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="main_runtime",
)
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="auto->main",
)
resolved_model = normalized.model or primary_runtime.model
resolved_provider = normalized.provider_name
if resolved_provider in {"auto", "", None} and not normalized.api_base:
# `auto` 的第一阶段实现保持保守:
# - 有显式 model 时按 model 匹配 provider
# - 匹配不到则回退主链 provider
spec = find_by_model(resolved_model)
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if resolved_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
def resolve_embedding_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""解析 embedding 专用 runtime。
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
"""
normalized = normalize_provider_target(target)
if normalized is None:
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
if not _supports_openai_embeddings(primary_runtime):
return None
return resolve_provider_runtime(
model=default_model,
provider_name="openai",
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
extra_headers=dict(primary_runtime.extra_headers),
routing=primary_runtime.routing,
auxiliary=False,
role="embedding",
source="embedding_inherited",
)
resolved_model = normalized.model or default_model
resolved_provider = normalized.provider_name
if resolved_provider in {None, "", "main", "auto"}:
resolved_provider = "custom" if normalized.api_base else "openai"
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if not api_base and _supports_openai_embeddings(primary_runtime):
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
runtime = resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing,
auxiliary=False,
role="embedding",
source="embedding_config",
)
if not _supports_openai_embeddings(runtime):
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
return runtime
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
def _clone_runtime(
runtime: ProviderRuntime,
**changes: Any,
) -> ProviderRuntime:
"""基于现有 runtime 复制一个轻量变体。
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
"""
payload = {
"extra_headers": dict(runtime.extra_headers),
"routing": runtime.routing,
"fallback_target": runtime.fallback_target,
}
payload.update(changes)
return replace(runtime, **payload)
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
"""根据 registry 规则应用必要前缀。"""
resolved = model
if gateway_mode:
prefix = spec.litellm_prefix
if spec.strip_model_prefix:
resolved = resolved.split("/")[-1]
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if spec.litellm_prefix:
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
return resolved
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"

View File

@ -0,0 +1,2 @@
"""Runtime helper objects and execution context."""

View File

@ -0,0 +1,15 @@
"""Session state and persistence."""
from .manager import SessionManager
from .models import MessageRecord, SessionRecord, SessionUsage
from .search import SessionSearchService
from .store import SessionStore
__all__ = [
"MessageRecord",
"SessionManager",
"SessionRecord",
"SessionSearchService",
"SessionStore",
"SessionUsage",
]

View File

@ -0,0 +1,143 @@
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from .models import MessageRecord
from .search import SessionSearchService
from .store import SessionStore
class SessionManager:
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
self.workspace = Path(workspace)
self.sessions_dir = self.workspace / "sessions"
self.sessions_dir.mkdir(parents=True, exist_ok=True)
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
self.store = SessionStore(self.db_path)
self.search = SessionSearchService(self.store)
def close(self) -> None:
self.store.close()
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
return self.store.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
def get_session(self, session_id: str) -> dict[str, Any] | None:
record = self.store.get_session_record(session_id)
return record.to_dict() if record is not None else None
def get_or_create(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> dict[str, Any]:
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
session = self.get_session(session_id)
if session is None:
raise RuntimeError(f"Failed to create session {session_id!r}")
return session
def append_message(self, session_id: str, **kwargs: Any) -> int:
return self.store.append_message(session_id, **kwargs)
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
这里和 `get_messages_as_conversation()` 的区别很关键:
- `get_event_records()` 面向 runtime / replay / audit保留隐藏系统事件
- `get_messages_as_conversation()` 面向 prompt builder只暴露可进上下文的事件
第 6 阶段开始后session 已不再只是“聊天消息存储”,而是在逐步收敛成
“外部事件流 + 上层投影视图”。
"""
return self.store.get_event_records(session_id)
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 direct run / future bus run 对应的事件片段。"""
return self.store.get_run_event_records(session_id, run_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按出现顺序列出当前 session 的所有 run_id。"""
return self.store.list_run_ids(session_id)
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
return self.store.get_messages_as_conversation(session_id)
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""返回适合注入 prompt 的可见历史切片。
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
1. 只包含 `context_visible=True` 的事件
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
"""
history = self.get_messages_as_conversation(session_id)
sliced = history[-max_messages:]
for index, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[index:]
break
return sliced
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
"""兼容旧名称,实际返回可见历史切片。"""
return self.get_visible_history(session_id, max_messages=max_messages)
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
self.store.update_system_prompt(session_id, system_prompt)
def update_usage(self, session_id: str, **kwargs: Any) -> None:
self.store.update_usage(session_id, **kwargs)
def end_session(self, session_id: str, end_reason: str) -> None:
self.store.end_session(session_id, end_reason)
def reopen_session(self, session_id: str) -> None:
self.store.reopen_session(session_id)
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.list_sessions_rich(**kwargs)
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.search.search_messages(**kwargs)
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
return self.search.resolve_session_id(session_id_or_prefix)

View File

@ -0,0 +1,211 @@
"""Beaver session 子系统的数据模型。
这层只定义数据结构,不放数据库读写逻辑。目的是把:
1. SQLite 行结构
2. 运行时会话对象
3. 对外暴露的 conversation message
三件事分开,避免后续所有地方都直接和裸字典耦合。
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class SessionUsage:
"""会话维度的 usage/cost 统计。"""
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
estimated_cost_usd: float = 0.0
actual_cost_usd: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"actual_cost_usd": self.actual_cost_usd,
}
@dataclass(slots=True)
class MessageRecord:
"""单条会话事件的结构化表示。
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
1. 普通 user/assistant/tool 消息本身就是事件
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
"""
role: str
content: str | None = None
timestamp: float | None = None
message_id: int | None = None
run_id: str | None = None
event_type: str | None = None
event_payload: dict[str, Any] | None = None
context_visible: bool = True
tool_name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
finish_reason: str | None = None
reasoning: str | None = None
reasoning_details: Any | None = None
codex_reasoning_items: Any | None = None
def to_conversation_message(self) -> dict[str, Any]:
"""转成 provider / context builder 可直接消费的消息格式。"""
if not self.context_visible:
raise ValueError("Hidden session events cannot be converted into conversation messages")
payload: dict[str, Any] = {
"role": self.role,
"content": self.content,
}
if self.tool_name:
payload["tool_name"] = self.tool_name
if self.tool_calls:
payload["tool_calls"] = self.tool_calls
if self.tool_call_id:
payload["tool_call_id"] = self.tool_call_id
if self.finish_reason:
payload["finish_reason"] = self.finish_reason
if self.reasoning:
payload["reasoning"] = self.reasoning
if self.reasoning_details is not None:
payload["reasoning_details"] = self.reasoning_details
if self.codex_reasoning_items is not None:
payload["codex_reasoning_items"] = self.codex_reasoning_items
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
"""从 SQLite row/dict 恢复消息模型。"""
tool_calls = row.get("tool_calls")
if isinstance(tool_calls, str):
try:
tool_calls = json.loads(tool_calls)
except json.JSONDecodeError:
tool_calls = []
reasoning_details = row.get("reasoning_details")
if isinstance(reasoning_details, str):
try:
reasoning_details = json.loads(reasoning_details)
except json.JSONDecodeError:
reasoning_details = None
codex_reasoning_items = row.get("codex_reasoning_items")
if isinstance(codex_reasoning_items, str):
try:
codex_reasoning_items = json.loads(codex_reasoning_items)
except json.JSONDecodeError:
codex_reasoning_items = None
event_payload = row.get("event_payload")
if isinstance(event_payload, str):
try:
event_payload = json.loads(event_payload)
except json.JSONDecodeError:
event_payload = None
return cls(
message_id=row.get("id"),
run_id=row.get("run_id"),
role=row["role"],
content=row.get("content"),
event_type=row.get("event_type") or row.get("role"),
event_payload=event_payload,
context_visible=bool(row.get("context_visible", 1)),
tool_name=row.get("tool_name"),
tool_calls=tool_calls,
tool_call_id=row.get("tool_call_id"),
timestamp=row.get("timestamp"),
finish_reason=row.get("finish_reason"),
reasoning=row.get("reasoning"),
reasoning_details=reasoning_details,
codex_reasoning_items=codex_reasoning_items,
)
@dataclass(slots=True)
class SessionRecord:
"""单个 session 的结构化表示。"""
session_id: str
source: str
started_at: float
last_active: float
user_id: str | None = None
title: str | None = None
model: str | None = None
system_prompt: str | None = None
parent_session_id: str | None = None
ended_at: float | None = None
end_reason: str | None = None
message_count: int = 0
tool_call_count: int = 0
preview: str | None = None
usage: SessionUsage = field(default_factory=SessionUsage)
def to_dict(self) -> dict[str, Any]:
payload = {
"id": self.session_id,
"source": self.source,
"user_id": self.user_id,
"title": self.title,
"model": self.model,
"system_prompt": self.system_prompt,
"parent_session_id": self.parent_session_id,
"started_at": self.started_at,
"last_active": self.last_active,
"ended_at": self.ended_at,
"end_reason": self.end_reason,
"message_count": self.message_count,
"tool_call_count": self.tool_call_count,
"preview": self.preview,
}
payload.update(self.usage.to_dict())
return payload
@classmethod
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
return cls(
session_id=row["id"],
source=row["source"],
user_id=row.get("user_id"),
title=row.get("title"),
model=row.get("model"),
system_prompt=row.get("system_prompt"),
parent_session_id=row.get("parent_session_id"),
started_at=row["started_at"],
last_active=row["last_active"],
ended_at=row.get("ended_at"),
end_reason=row.get("end_reason"),
message_count=row.get("message_count", 0),
tool_call_count=row.get("tool_call_count", 0),
preview=row.get("preview"),
usage=SessionUsage(
input_tokens=row.get("input_tokens", 0),
output_tokens=row.get("output_tokens", 0),
cache_read_tokens=row.get("cache_read_tokens", 0),
cache_write_tokens=row.get("cache_write_tokens", 0),
reasoning_tokens=row.get("reasoning_tokens", 0),
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
actual_cost_usd=row.get("actual_cost_usd"),
),
)

View File

@ -0,0 +1,151 @@
"""Beaver session 子系统的检索能力。"""
from __future__ import annotations
import re
import sqlite3
from typing import Any
from .store import SessionStore
class SessionSearchService:
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
def __init__(self, store: SessionStore) -> None:
self.store = store
@staticmethod
def _sanitize_fts5_query(query: str) -> str:
quoted_parts: list[str] = []
def preserve(match: re.Match[str]) -> str:
quoted_parts.append(match.group(0))
return f"\x00Q{len(quoted_parts) - 1}\x00"
sanitized = re.sub(r'"[^"]*"', preserve, query)
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
sanitized = re.sub(r"\*+", "*", sanitized)
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
for index, quoted in enumerate(quoted_parts):
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
return sanitized.strip()
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
exact = self.store.get_session_record(session_id_or_prefix)
if exact is not None:
return exact.session_id
escaped = (
session_id_or_prefix
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
rows = self.store._fetchall(
"""
SELECT id
FROM sessions
WHERE id LIKE ? ESCAPE '\\'
ORDER BY started_at DESC
LIMIT 2
""",
(f"{escaped}%",),
)
if len(rows) == 1:
return rows[0]["id"]
return None
def list_sessions_rich(
self,
*,
limit: int = 20,
offset: int = 0,
include_children: bool = False,
source: str | None = None,
exclude_sources: list[str] | None = None,
) -> list[dict[str, Any]]:
"""列出最近活跃的 session 及其摘要元数据。"""
clauses: list[str] = []
params: list[Any] = []
if not include_children:
clauses.append("parent_session_id IS NULL")
if source:
clauses.append("source = ?")
params.append(source)
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"source NOT IN ({placeholders})")
params.extend(exclude_sources)
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
params.extend([limit, offset])
rows = self.store._fetchall(
f"""
SELECT *
FROM sessions
{where}
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
tuple(params),
)
return rows
def search_messages(
self,
*,
query: str,
role_filter: list[str] | None = None,
exclude_sources: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> list[dict[str, Any]]:
"""使用 FTS5 搜索 session transcript。"""
query = self._sanitize_fts5_query(query)
if not query:
return []
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
params: list[Any] = [query]
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"s.source NOT IN ({placeholders})")
params.extend(exclude_sources)
if role_filter:
placeholders = ",".join("?" for _ in role_filter)
clauses.append(f"m.role IN ({placeholders})")
params.extend(role_filter)
params.extend([limit, offset])
sql = f"""
SELECT
m.id,
m.session_id,
m.role,
s.source,
s.model,
s.started_at AS session_started,
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
FROM messages_fts
JOIN messages m ON m.id = messages_fts.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(clauses)}
ORDER BY rank
LIMIT ? OFFSET ?
"""
try:
return self.store._fetchall(sql, tuple(params))
except sqlite3.Error as exc:
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc

View File

@ -0,0 +1,467 @@
"""Beaver session 子系统的 SQLite 存储实现。
设计来源主要参考 Hermes-agent
1. SQLite 作为统一 session/transcript backend
2. WAL 模式支持多读单写
3. FTS5 支持跨 session 文本检索
4. `parent_session_id` 支持 lineage
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable, TypeVar
from .models import MessageRecord, SessionRecord
T = TypeVar("T")
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
user_id TEXT,
title TEXT,
model TEXT,
system_prompt TEXT,
parent_session_id TEXT,
started_at REAL NOT NULL,
last_active REAL NOT NULL,
ended_at REAL,
end_reason TEXT,
message_count INTEGER DEFAULT 0,
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
estimated_cost_usd REAL DEFAULT 0,
actual_cost_usd REAL,
preview TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id),
run_id TEXT,
role TEXT NOT NULL,
event_type TEXT,
event_payload TEXT,
context_visible INTEGER NOT NULL DEFAULT 1,
content TEXT,
tool_name TEXT,
tool_calls TEXT,
tool_call_id TEXT,
timestamp REAL NOT NULL,
finish_reason TEXT,
reasoning TEXT,
reasoning_details TEXT,
codex_reasoning_items TEXT
);
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
"""
FTS_TABLE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content,
content=messages,
content_rowid=id
);
"""
FTS_TRIGGER_SQL = """
DROP TRIGGER IF EXISTS messages_fts_insert;
DROP TRIGGER IF EXISTS messages_fts_delete;
DROP TRIGGER IF EXISTS messages_fts_update;
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', old.id, old.content
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
INSERT INTO messages_fts(rowid, content)
SELECT new.id, new.content
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
END;
"""
class SessionStore:
"""SQLite-backed session store."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._init_schema()
def _init_schema(self) -> None:
with self._lock:
self._conn.executescript(SCHEMA_SQL)
try:
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
except sqlite3.OperationalError:
self._conn.executescript(FTS_TABLE_SQL)
self._conn.executescript(FTS_TRIGGER_SQL)
# 旧版本可能把 hidden 事件也写进了 FTS初始化时顺手清掉这些噪声项。
self._conn.execute(
"""
INSERT INTO messages_fts(messages_fts, rowid, content)
SELECT 'delete', id, content
FROM messages
WHERE context_visible = 0 AND content IS NOT NULL
"""
)
self._conn.commit()
def close(self) -> None:
with self._lock:
self._conn.close()
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
result = fn(self._conn)
self._conn.commit()
return result
except BaseException:
self._conn.rollback()
raise
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
with self._lock:
row = self._conn.execute(sql, params).fetchone()
return dict(row) if row else None
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
with self._lock:
rows = self._conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
def ensure_session(
self,
session_id: str,
*,
source: str = "unknown",
model: str | None = None,
title: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> str:
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
now = time.time()
def _do(conn: sqlite3.Connection) -> str:
conn.execute(
"""
INSERT INTO sessions (
id, source, user_id, title, model, parent_session_id, started_at, last_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
source = CASE
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
ELSE sessions.source
END,
user_id = COALESCE(sessions.user_id, excluded.user_id),
title = COALESCE(sessions.title, excluded.title),
model = COALESCE(sessions.model, excluded.model),
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
""",
(session_id, source, user_id, title, model, parent_session_id, now, now),
)
return session_id
return self._execute_write(_do)
def get_session_record(self, session_id: str) -> SessionRecord | None:
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
return SessionRecord.from_row(row) if row else None
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""保存本 session 组装后的完整 system prompt snapshot。"""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET system_prompt = ?, last_active = ?
WHERE id = ?
""",
(system_prompt, time.time(), session_id),
)
self._execute_write(_do)
def update_usage(
self,
session_id: str,
*,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: float = 0.0,
actual_cost_usd: float | None = None,
absolute: bool = False,
) -> None:
"""更新会话 usage。默认按增量累加。"""
if absolute:
sql = """
UPDATE sessions
SET input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = ?,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
else:
sql = """
UPDATE sessions
SET input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = estimated_cost_usd + ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
last_active = ?
WHERE id = ?
"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
time.time(),
session_id,
)
def _do(conn: sqlite3.Connection) -> None:
conn.execute(sql, params)
self._execute_write(_do)
def append_message(
self,
session_id: str,
*,
run_id: str | None = None,
role: str,
event_type: str | None = None,
event_payload: dict[str, Any] | None = None,
context_visible: bool = True,
content: str | None = None,
tool_name: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
finish_reason: str | None = None,
reasoning: str | None = None,
reasoning_details: Any | None = None,
codex_reasoning_items: Any | None = None,
source: str = "unknown",
title: str | None = None,
model: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> int:
"""向指定 session 追加一条消息。"""
self.ensure_session(
session_id,
source=source,
model=model,
title=title,
user_id=user_id,
parent_session_id=parent_session_id,
)
now = time.time()
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
preview = (content or "")[:120] if role == "user" and content else None
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
def _do(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"""
INSERT INTO messages (
session_id, run_id, role, event_type, event_payload, context_visible, content,
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
reasoning_details, codex_reasoning_items
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
run_id,
role,
event_type or role,
event_payload_json,
1 if context_visible else 0,
content,
tool_name,
tool_calls_json,
tool_call_id,
now,
finish_reason,
reasoning,
reasoning_details_json,
codex_items_json,
),
)
conn.execute(
"""
UPDATE sessions
SET last_active = ?,
message_count = message_count + 1,
tool_call_count = tool_call_count + ?,
model = COALESCE(model, ?),
preview = CASE
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
ELSE preview
END
WHERE id = ?
""",
(now, tool_call_count, model, preview, preview, session_id),
)
return int(cursor.lastrowid)
return self._execute_write(_do)
def get_message_records(self, session_id: str) -> list[MessageRecord]:
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ?
ORDER BY timestamp, id
""",
(session_id,),
)
return [MessageRecord.from_row(row) for row in rows]
def get_event_records(self, session_id: str) -> list[MessageRecord]:
"""返回当前 session 的完整事件流。
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
"""
return self.get_message_records(session_id)
def list_run_ids(self, session_id: str) -> list[str]:
"""按时间顺序列出当前 session 中出现过的 run_id。"""
rows = self._fetchall(
"""
SELECT run_id
FROM messages
WHERE session_id = ? AND run_id IS NOT NULL
GROUP BY run_id
ORDER BY MIN(timestamp), MIN(id)
""",
(session_id,),
)
return [str(row["run_id"]) for row in rows if row.get("run_id")]
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
"""返回某一次 run 对应的事件片段。"""
rows = self._fetchall(
"""
SELECT *
FROM messages
WHERE session_id = ? AND run_id = ?
ORDER BY timestamp, id
""",
(session_id, run_id),
)
return [MessageRecord.from_row(row) for row in rows]
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for record in self.get_event_records(session_id):
if not record.context_visible:
continue
messages.append(record.to_conversation_message())
return messages
def end_session(self, session_id: str, end_reason: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = ?, end_reason = ?, last_active = ?
WHERE id = ?
""",
(time.time(), end_reason, time.time(), session_id),
)
self._execute_write(_do)
def reopen_session(self, session_id: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""
UPDATE sessions
SET ended_at = NULL, end_reason = NULL, last_active = ?
WHERE id = ?
""",
(time.time(), session_id),
)
self._execute_write(_do)

View File

@ -0,0 +1,2 @@
"""Foundation layer for shared Beaver primitives."""

View File

@ -0,0 +1,2 @@
"""Configuration models and loaders."""

View File

@ -0,0 +1,2 @@
"""Shared error types."""

View File

@ -0,0 +1,5 @@
"""Event contracts and dispatch helpers."""
from .message_bus import InboundMessage, MessageBus, OutboundMessage
__all__ = ["InboundMessage", "MessageBus", "OutboundMessage"]

View File

@ -0,0 +1,72 @@
"""Minimal message bus for gateway-style host integration."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from uuid import uuid4
@dataclass(slots=True)
class InboundMessage:
"""A minimal inbound message accepted by the gateway bridge."""
channel: str
content: str
session_id: str | None = None
user_id: str | None = None
title: str | None = None
execution_context: str | None = None
model: str | None = None
provider_name: str | None = None
embedding_model: str | None = None
message_id: str = field(default_factory=lambda: str(uuid4()))
metadata: dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass(slots=True)
class OutboundMessage:
"""A minimal outbound message produced by the gateway bridge."""
channel: str
content: str
session_id: str | None
finish_reason: str
message_id: str = field(default_factory=lambda: str(uuid4()))
run_id: str | None = None
provider_name: str | None = None
model: str | None = None
usage: dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
class MessageBus:
"""Minimal async message bus with inbound/outbound queues."""
def __init__(self) -> None:
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
async def publish_inbound(self, message: InboundMessage) -> None:
await self.inbound.put(message)
async def consume_inbound(self) -> InboundMessage:
return await self.inbound.get()
async def publish_outbound(self, message: OutboundMessage) -> None:
await self.outbound.put(message)
async def consume_outbound(self) -> OutboundMessage:
return await self.outbound.get()
@property
def inbound_size(self) -> int:
return self.inbound.qsize()
@property
def outbound_size(self) -> int:
return self.outbound.qsize()

View File

@ -0,0 +1,2 @@
"""Shared data models."""

View File

@ -0,0 +1,2 @@
"""Common utility helpers."""

View File

@ -0,0 +1,2 @@
"""External integrations."""

View File

@ -0,0 +1,2 @@
"""A2A integration."""

View File

@ -0,0 +1,2 @@
"""MCP integration."""

View File

@ -0,0 +1,2 @@
"""Outlook integration."""

View File

@ -0,0 +1,2 @@
"""Provider-specific integrations."""

View File

@ -0,0 +1,2 @@
"""WhatsApp integration."""

View File

@ -0,0 +1,2 @@
"""Thin interface layer for Beaver."""

View File

@ -0,0 +1,2 @@
"""Channel interfaces."""

View File

@ -0,0 +1,2 @@
"""CLI interface."""

View File

@ -0,0 +1,59 @@
"""CLI entry for Beaver."""
try:
import typer
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
class _FallbackTyper:
def __init__(self, *_args, **_kwargs) -> None:
pass
def command(self):
def decorator(func):
return func
return decorator
def __call__(self) -> None:
raise RuntimeError("typer is not installed")
@staticmethod
def echo(message: str) -> None:
print(message)
@staticmethod
def Option(default=None, *_args, **_kwargs):
return default
typer = _FallbackTyper() # type: ignore[assignment]
from beaver.services.agent_service import AgentService
app = typer.Typer(help="Beaver backend CLI") if hasattr(typer, "Typer") else typer
@app.command()
def run(
message: str | None = typer.Option(None, "--message", "-m", help="Run one direct Beaver request."),
workspace: str | None = typer.Option(None, "--workspace", help="Workspace root for this run."),
) -> None:
"""Thin CLI wrapper around AgentService.
CLI 现在不再自己维护执行逻辑,只负责:
1. 解析命令行参数
2. 调 AgentService
3. 打印结果
"""
service = AgentService(workspace=workspace)
if not message:
service.create_loop()
typer.echo("Beaver engine booted.")
return
result = service.run_direct(message, source="cli")
typer.echo(result.output_text)
def main() -> None:
"""Project script entrypoint."""
app()

View File

@ -0,0 +1,2 @@
"""Gateway interface."""

View File

@ -0,0 +1,189 @@
"""Gateway entrypoint for Beaver.
当前阶段先不扩 bus / channels adapter只做最小消息桥接
1. 启动时托管 `AgentService.start()`
2. 常驻消费 `MessageBus.inbound`
3. 调 `service.submit_direct(...)`
4. 将结果写回 `MessageBus.outbound`
5. 退出时走 `AgentService.shutdown()`
"""
from __future__ import annotations
import asyncio
from pathlib import Path
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
from beaver.services.agent_service import AgentService
async def _publish_bridge_error(
bus: MessageBus,
inbound: InboundMessage,
*,
detail: str,
finish_reason: str = "error",
) -> None:
"""把 bridge 处理失败转换成结构化 outbound 错误消息。"""
await bus.publish_outbound(
OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=inbound.session_id,
content=detail,
finish_reason=finish_reason,
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
)
)
async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
"""把尚未处理的 inbound 明确冲刷成 outbound 错误,而不是静默丢弃。"""
while True:
try:
pending = bus.inbound.get_nowait()
except asyncio.QueueEmpty:
break
await _publish_bridge_error(bus, pending, detail=reason, finish_reason="stopped")
async def _await_bridge_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
"""等待 bridge 退出;超时则取消,避免 shutdown 被桥接层反向卡死。"""
try:
await asyncio.wait_for(task, timeout=timeout_seconds)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _bridge_inbound_to_runtime(
service: AgentService,
bus: MessageBus,
stop_event: asyncio.Event,
) -> None:
"""Consume inbound messages, run the agent, and publish outbound results."""
while True:
if stop_event.is_set():
await _flush_pending_inbound(
bus,
reason="Gateway stopped before processing the inbound message",
)
break
try:
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=0.25)
except asyncio.TimeoutError:
continue
try:
result = await service.submit_direct(
inbound.content,
session_id=inbound.session_id,
source=f"gateway:{inbound.channel}",
user_id=inbound.user_id,
title=inbound.title,
execution_context=inbound.execution_context,
model=inbound.model,
provider_name=inbound.provider_name,
embedding_model=inbound.embedding_model,
)
except asyncio.CancelledError:
await _publish_bridge_error(
bus,
inbound,
detail="Gateway stopped before completing the inbound message",
finish_reason="cancelled",
)
raise
except Exception as exc: # pragma: no cover - defensive bridge path
await _publish_bridge_error(
bus,
inbound,
detail=str(exc),
)
else:
await bus.publish_outbound(
OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=result.session_id,
run_id=result.run_id,
content=result.output_text,
finish_reason=result.finish_reason,
provider_name=result.provider_name,
model=result.model,
usage=dict(result.usage),
metadata={"inbound_metadata": dict(inbound.metadata)},
)
)
async def run_gateway(
*,
workspace: str | Path | None = None,
service: AgentService | None = None,
bus: MessageBus | None = None,
manage_service_lifecycle: bool | None = None,
stop_event: asyncio.Event | None = None,
shutdown_timeout_seconds: float | None = 5.0,
shutdown_force: bool = True,
) -> None:
"""运行最小 gateway 宿主层与消息桥接。
默认 ownership 语义:
- 未传 `service`gateway 自己创建并接管其 lifecycle
- 传入外部 `service`:默认只使用,不自动 start/shutdown
"""
attached_service = service or AgentService(workspace=workspace)
attached_bus = bus or MessageBus()
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
owned_stop_event = stop_event or asyncio.Event()
started = False
if owns_service:
try:
await attached_service.start()
started = True
except Exception:
attached_service.close()
raise
if not attached_service.is_running:
raise RuntimeError(
"Gateway requires AgentService running mode; start the injected service first "
"or allow the gateway to manage its lifecycle."
)
bridge_task = asyncio.create_task(_bridge_inbound_to_runtime(attached_service, attached_bus, owned_stop_event))
try:
await owned_stop_event.wait()
finally:
owned_stop_event.set()
if owns_service and started:
try:
await attached_service.shutdown(
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
finally:
await _await_bridge_shutdown(bridge_task)
else:
await _await_bridge_shutdown(bridge_task)
def main() -> None:
"""同步 gateway 入口。"""
try:
asyncio.run(run_gateway())
except KeyboardInterrupt:
pass

View File

@ -0,0 +1,2 @@
"""MCP server entrypoints."""

View File

@ -0,0 +1,210 @@
"""Beaver memory MCP server.
这个 server 用最精简的方式把两个内部能力暴露成 streamable-http MCP tools
1. `memory`
2. `session_search`
运行方式:
1. 直接用 Python
`python -m beaver.interfaces.mcp.memory_server --host 127.0.0.1 --port 8001`
2. 或者用 FastMCP CLI
`fastmcp run beaver/interfaces/mcp/memory_server.py:mcp --transport http --port 8001`
默认 MCP 路径是 `/mcp`FastMCP 的 HTTP transport 就是 streamable HTTP。
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
from beaver.engine.session import SessionManager
from beaver.memory.curated.store import MemoryStore
from beaver.tools.builtins.memory import memory_tool
from beaver.tools.builtins.session_search import session_search as run_session_search
try: # pragma: no cover - import guard for environments without fastmcp
from fastmcp import Context, FastMCP
from fastmcp.server.lifespan import lifespan
except ModuleNotFoundError: # pragma: no cover - handled at runtime in main()
FastMCP = None # type: ignore[assignment]
Context = Any # type: ignore[assignment]
lifespan = None # type: ignore[assignment]
def _require_fastmcp() -> None:
if FastMCP is None or lifespan is None:
raise RuntimeError(
"fastmcp is not installed. Install it with `pip install fastmcp` "
"or via this project's dependencies."
)
def _resolve_workspace_path(workspace: str | Path | None = None) -> Path:
"""决定 memory server 使用的 workspace 根目录。"""
if workspace is not None:
return Path(workspace).expanduser().resolve()
env_workspace = os.getenv("BEAVER_WORKSPACE")
if env_workspace:
return Path(env_workspace).expanduser().resolve()
return Path.cwd()
def _resolve_memory_dir(workspace: Path) -> Path:
"""curated memory 的默认目录。"""
return workspace / "memory" / "curated"
def _resolve_session_db_path(workspace: Path) -> Path:
"""session store 的默认路径。"""
return workspace / "sessions" / "state.db"
def create_memory_server(
*,
workspace: str | Path | None = None,
memory_dir: str | Path | None = None,
session_db_path: str | Path | None = None,
):
"""创建并返回 FastMCP memory server 实例。"""
_require_fastmcp()
workspace_path = _resolve_workspace_path(workspace)
resolved_memory_dir = Path(memory_dir).expanduser().resolve() if memory_dir else _resolve_memory_dir(workspace_path)
resolved_session_db = (
Path(session_db_path).expanduser().resolve()
if session_db_path
else _resolve_session_db_path(workspace_path)
)
@lifespan
async def memory_server_lifespan(_server):
"""在 server 生命周期内初始化共享 store/db。"""
store = MemoryStore(resolved_memory_dir)
store.load_from_disk()
session_manager = SessionManager(workspace=workspace_path, db_path=resolved_session_db)
try:
yield {
"workspace_path": workspace_path,
"memory_dir": resolved_memory_dir,
"session_db_path": resolved_session_db,
"memory_store": store,
"session_manager": session_manager,
}
finally:
session_manager.close()
server = FastMCP(
name="Beaver Memory Server",
instructions=(
"Provides two MCP tools: `memory` for durable curated memory CRUD, "
"and `session_search` for cross-session recall from transcript storage."
),
lifespan=memory_server_lifespan,
)
@server.custom_route("/health", methods=["GET"])
async def health_check(_request):
"""最小 health check方便远程探活。"""
from starlette.responses import JSONResponse
return JSONResponse(
{
"ok": True,
"server": "beaver-memory",
"transport": "streamable-http",
"workspace": str(workspace_path),
"memory_dir": str(resolved_memory_dir),
"session_db_path": str(resolved_session_db),
}
)
@server.tool()
async def memory(
action: str,
target: str = "memory",
content: str | None = None,
old_text: str | None = None,
ctx: Context | None = None,
) -> dict[str, Any]:
"""CRUD for curated memory."""
if ctx is None:
raise RuntimeError("FastMCP context is required.")
raw_result = memory_tool(
action=action,
target=target,
content=content,
old_text=old_text,
store=ctx.lifespan_context["memory_store"],
)
return json.loads(raw_result)
@server.tool()
async def session_search(
query: str = "",
role_filter: str | None = None,
limit: int = 3,
ctx: Context | None = None,
) -> dict[str, Any]:
"""Search prior sessions or browse recent ones."""
if ctx is None:
raise RuntimeError("FastMCP context is required.")
raw_result = await run_session_search(
query=query,
role_filter=role_filter,
limit=limit,
db=ctx.lifespan_context["session_manager"],
current_session_id=getattr(ctx, "session_id", None),
)
return json.loads(raw_result)
return server
def build_arg_parser() -> argparse.ArgumentParser:
"""构建最小命令行参数解析器。"""
parser = argparse.ArgumentParser(description="Run Beaver memory MCP server over streamable HTTP.")
parser.add_argument("--workspace", default=None, help="Workspace root. Defaults to BEAVER_WORKSPACE or cwd.")
parser.add_argument("--memory-dir", default=None, help="Override curated memory directory.")
parser.add_argument("--session-db", default=None, help="Override session SQLite database path.")
parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host.")
parser.add_argument("--port", default=8001, type=int, help="HTTP bind port.")
parser.add_argument("--path", default="/mcp", help="MCP endpoint path.")
return parser
def main() -> None:
"""以 streamable HTTP 启动 memory server。"""
parser = build_arg_parser()
args = parser.parse_args()
server = create_memory_server(
workspace=args.workspace,
memory_dir=args.memory_dir,
session_db_path=args.session_db,
)
server.run(
transport="http",
host=args.host,
port=args.port,
path=args.path,
)
if FastMCP is not None:
mcp = create_memory_server()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,2 @@
"""Web interface."""

View File

@ -0,0 +1,198 @@
"""FastAPI app factory for Beaver."""
from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from beaver.services.agent_service import AgentService
from .deps import get_agent_service
from .schemas import WebChatRequest, WebChatResponse, WebErrorResponse, WebStatusResponse
try:
from fastapi import FastAPI, HTTPException, Request
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
class HTTPException(Exception):
"""Minimal fallback exception matching FastAPI's constructor shape."""
def __init__(self, status_code: int, detail: str) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
class Request: # type: ignore[override]
"""Fallback request shim used only for import-time compatibility."""
def __init__(self, app: Any) -> None:
self.app = app
class FastAPI: # type: ignore[override]
"""Small fallback shim so the package can import before dependencies are installed."""
def __init__(self, *, title: str, lifespan: Callable[..., Any] | None = None) -> None:
self.title = title
self.lifespan = lifespan
self.state = SimpleNamespace()
def get(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
return func
return decorator
def post(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
return func
return decorator
@asynccontextmanager
async def _app_lifespan(
app: FastAPI,
*,
workspace: str | Path | None,
service: AgentService | None,
manage_service_lifecycle: bool | None,
shutdown_timeout_seconds: float | None,
shutdown_force: bool,
) -> AsyncIterator[None]:
"""把 Web app 接到 AgentService lifecycle 上。"""
attached_service = service or AgentService(workspace=workspace)
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
app.state.agent_service = attached_service
started = False
if owns_service:
try:
await attached_service.start()
started = True
except Exception:
attached_service.close()
raise
try:
yield
finally:
if owns_service and started:
await attached_service.shutdown(
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
def create_app(
*,
workspace: str | Path | None = None,
service: AgentService | None = None,
manage_service_lifecycle: bool | None = None,
shutdown_timeout_seconds: float | None = 5.0,
shutdown_force: bool = True,
) -> FastAPI:
"""Create a Beaver web app hosted by AgentService running mode.
默认 ownership 语义:
- 未传 `service`app 自己创建并接管其 lifecycle
- 传入外部 `service`:默认只挂载,不自动 start/shutdown
如果确实需要覆盖默认行为,可以显式传 `manage_service_lifecycle=True/False`。
"""
app = FastAPI(
title="Beaver Backend",
lifespan=lambda fastapi_app: _app_lifespan(
fastapi_app,
workspace=workspace,
service=service,
manage_service_lifecycle=manage_service_lifecycle,
shutdown_timeout_seconds=shutdown_timeout_seconds,
shutdown_force=shutdown_force,
),
)
@app.get("/api/ping", response_model=WebStatusResponse)
async def ping(request: Request) -> WebStatusResponse:
agent_service = get_agent_service(request)
running = agent_service.is_running
return WebStatusResponse(
status="ok",
running=running,
mode="running" if running else ("direct" if agent_service.has_loop else "idle"),
)
@app.post(
"/api/chat",
response_model=WebChatResponse,
responses={
400: {"model": WebErrorResponse},
409: {"model": WebErrorResponse},
503: {"model": WebErrorResponse},
},
)
async def chat(request: Request, payload: WebChatRequest) -> WebChatResponse:
agent_service = get_agent_service(request)
message = payload.message.strip()
if not message:
raise HTTPException(status_code=400, detail="'message' is required")
fallback_target = _model_dump(payload.fallback_target)
auxiliary_target = _model_dump(payload.auxiliary_target)
embedding_target = _model_dump(payload.embedding_target)
try:
result = await agent_service.submit_direct(
message,
session_id=payload.session_id,
source="web",
user_id=payload.user_id,
title=payload.title,
execution_context=payload.execution_context,
model=payload.model,
provider_name=payload.provider_name,
embedding_model=payload.embedding_model,
temperature=payload.temperature,
max_tokens=payload.max_tokens,
max_tool_iterations=payload.max_tool_iterations,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except RuntimeError as exc:
detail = str(exc)
if "requires an active run() loop" in detail or "not ready" in detail:
status_code = 503
elif "submit_direct" in detail or "running" in detail:
status_code = 409
else:
status_code = 503
raise HTTPException(status_code=status_code, detail=detail) from exc
return WebChatResponse(
session_id=result.session_id,
run_id=result.run_id,
output_text=result.output_text,
finish_reason=result.finish_reason,
tool_iterations=result.tool_iterations,
provider_name=result.provider_name,
model=result.model,
usage=result.usage,
)
return app
def _model_dump(value: Any) -> dict[str, Any] | None:
"""兼容 Pydantic v1/v2 的最小导出辅助。"""
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump(exclude_none=True)
if hasattr(value, "dict"):
return value.dict(exclude_none=True)
return dict(value)

View File

@ -0,0 +1,27 @@
"""Web dependency wiring."""
from __future__ import annotations
from typing import Any
from beaver.services.agent_service import AgentService
try:
from fastapi import HTTPException
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
class HTTPException(Exception):
"""Minimal fallback exception matching FastAPI's constructor shape."""
def __init__(self, status_code: int, detail: str) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
def get_agent_service(request: Any) -> AgentService:
"""从 app state 里取当前宿主层托管的 AgentService。"""
service = getattr(request.app.state, "agent_service", None)
if not isinstance(service, AgentService):
raise HTTPException(status_code=503, detail="AgentService is not ready")
return service

View File

@ -0,0 +1,2 @@
"""Web routes."""

View File

@ -0,0 +1,11 @@
"""Web request and response schemas."""
from .chat import WebChatRequest, WebChatResponse, WebErrorResponse, WebProviderTarget, WebStatusResponse
__all__ = [
"WebChatRequest",
"WebChatResponse",
"WebErrorResponse",
"WebProviderTarget",
"WebStatusResponse",
]

View File

@ -0,0 +1,93 @@
"""Chat-related web schemas."""
from __future__ import annotations
from typing import Any
try:
from pydantic import BaseModel, Field
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
class BaseModel:
"""Very small fallback shim used only so imports work without pydantic."""
def __init__(self, **kwargs: Any) -> None:
annotations = getattr(self.__class__, "__annotations__", {})
for name in annotations:
default = getattr(self.__class__, name, None)
if name in kwargs:
value = kwargs[name]
else:
value = default
setattr(self, name, value)
def model_dump(self, *, exclude_none: bool = False) -> dict[str, Any]:
data = dict(self.__dict__)
if exclude_none:
data = {key: value for key, value in data.items() if value is not None}
return data
def Field(default: Any = None, **kwargs: Any) -> Any:
default_factory = kwargs.get("default_factory")
if default_factory is not None:
return default_factory()
return default
class WebProviderTarget(BaseModel):
"""Web-facing provider target shape.
先保持和 runtime 里的 `ProviderTarget` 接近,但只暴露 Web 当前需要的字段。
后面如果 provider 层扩字段,再由这里显式补齐。
"""
provider: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] | None = None
class WebChatRequest(BaseModel):
"""最小正式 chat 请求结构。"""
message: str = Field(min_length=1)
session_id: str | None = None
user_id: str | None = None
title: str | None = None
execution_context: str | None = None
model: str | None = None
provider_name: str | None = None
embedding_model: str | None = None
temperature: float | None = None
max_tokens: int | None = None
max_tool_iterations: int | None = None
fallback_target: WebProviderTarget | None = None
auxiliary_target: WebProviderTarget | None = None
embedding_target: WebProviderTarget | None = None
class WebChatResponse(BaseModel):
"""最小正式 chat 响应结构。"""
session_id: str
run_id: str
output_text: str
finish_reason: str
tool_iterations: int
provider_name: str | None = None
model: str | None = None
usage: dict[str, Any] = Field(default_factory=dict)
class WebStatusResponse(BaseModel):
"""Web 宿主层状态响应。"""
status: str
running: bool
mode: str
class WebErrorResponse(BaseModel):
"""统一错误响应结构。"""
detail: str

View File

@ -0,0 +1,2 @@
"""Memory and experience stores."""

View File

@ -0,0 +1,11 @@
"""Curated long-term memory primitives."""
from .snapshot import MemorySnapshot, capture_memory_snapshot
from .store import MemoryStore, scan_memory_content
__all__ = [
"MemorySnapshot",
"MemoryStore",
"capture_memory_snapshot",
"scan_memory_content",
]

View File

@ -0,0 +1,52 @@
"""curated memory 的冻结快照工具。
这个文件很小,但职责非常关键:它把“长期记忆的 live state”和“当前会话注入 prompt
时使用的 frozen snapshot”明确分开。
设计目的:
1. 让调用侧显式意识到system prompt 使用的是一份冻结视图
2. 避免后续 engine/context builder 直接偷读 live store破坏 frozen snapshot 语义
3. 给 prompt 组装层一个简单、稳定、可测试的数据结构
"""
from __future__ import annotations
from dataclasses import dataclass
from .store import MemoryStore
@dataclass(frozen=True, slots=True)
class MemorySnapshot:
"""当前 session 使用的冻结记忆快照。
这里不是 memory store 本体,而是“给 prompt builder 的只读投影”。
一旦 capture 完成,这个对象就代表本 session 的注入视图,不应在会话中途被修改。
"""
memory_block: str | None
user_block: str | None
def as_prompt_sections(self) -> list[str]:
"""按稳定顺序返回可直接拼接进 prompt 的 section 列表。
顺序固定为:
1. user profile
2. agent memory
这样后续 context builder 的输出更稳定,测试也更容易写。
"""
return [section for section in (self.user_block, self.memory_block) if section]
def capture_memory_snapshot(store: MemoryStore) -> MemorySnapshot:
"""从 `MemoryStore` 提取当前 session 的 frozen snapshot。
前提是 `store.load_from_disk()` 已经在 session 启动时调用过,否则拿到的只是空快照。
"""
return MemorySnapshot(
memory_block=store.format_for_system_prompt("memory"),
user_block=store.format_for_system_prompt("user"),
)

View File

@ -0,0 +1,463 @@
"""Beaver 的精炼长期记忆存储层。
这个文件实现的是以 Hermes-agent 为基线的 curated memory 模型,目标不是
“把所有历史都存下来”,而是只保存跨会话仍然值得保留的稳定事实。
核心设计:
1. 只保留两个持久化记忆桶:
- ``memory``: agent 自己对环境、项目、工具 quirks 的长期备注
- ``user``: 对用户偏好、习惯、身份信息的长期理解
2. ``replace`` / ``remove`` 不使用 UUID而是使用短语义片段做子串匹配。
这是为了适配 LLM 更擅长“记住一句话片段”而不是“追踪一个随机 ID”的现实。
3. 写入前先做安全扫描,避免把 prompt injection / secrets exfiltration
一类危险内容写入长期记忆,再在未来会话中反向污染 system prompt。
4. 写入协议严格遵守:
- scan
- lock
- reload
- validate
- atomic write
5. 本文件维护两份状态:
- live state: 当前内存中的真实条目tool 写入后立刻变化
- frozen snapshot: 会话开始时冻结的一份 prompt 注入快照
其中最重要的一点是:本会话中新增的记忆会立刻写盘,但不会反向修改本会话
已经冻结的 system prompt。这样可以保住 prefix cache也避免“会话中途 prompt
变了导致行为抖动”的问题。
"""
from __future__ import annotations
import os
import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any
try:
import fcntl
except ImportError: # pragma: no cover - Windows fallback
fcntl = None
try:
import msvcrt
except ImportError: # pragma: no cover - Unix platforms
msvcrt = None
ENTRY_DELIMITER = "\n§\n"
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
DEFAULT_USER_FILENAME = "USER.md"
_MEMORY_THREAT_PATTERNS: list[tuple[str, str]] = [
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
(r"you\s+are\s+now\s+", "role_hijack"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don't\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
(r"authorized_keys", "ssh_backdoor"),
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
(r"\$HOME/\.beaver/\.env|\~/\.beaver/\.env", "beaver_env"),
]
_INVISIBLE_CHARS = {
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
def scan_memory_content(content: str) -> str | None:
"""扫描待写入内容,拦截明显危险的记忆条目。
这里不是在做完备的安全审计,而是在做“进入长期记忆之前的最低限度闸门”。
因为长期记忆会在未来会话中重新注入 system prompt所以一旦把恶意文本写进去
风险远高于普通临时上下文。
"""
for char in _INVISIBLE_CHARS:
if char in content:
return (
f"Blocked: content contains invisible unicode character "
f"U+{ord(char):04X}."
)
for pattern, pattern_id in _MEMORY_THREAT_PATTERNS:
if re.search(pattern, content, re.IGNORECASE):
return (
f"Blocked: content matches threat pattern '{pattern_id}'. "
"Memory entries are injected into future system prompts."
)
return None
class MemoryStore:
"""带容量上限的长期记忆存储。
这个类负责:
1. 从磁盘加载 `MEMORY.md` / `USER.md`
2. 在 session 启动时冻结 prompt snapshot
3. 为 `add / replace / remove` 提供安全写接口
4. 维护 live state 与 frozen snapshot 的边界
它不负责:
1. 自动从对话里抽取要记住的内容
2. session transcript 检索
3. skills 的学习和发布
"""
def __init__(
self,
root: str | Path,
*,
memory_char_limit: int = 2200,
user_char_limit: int = 1375,
) -> None:
self.root = Path(root)
self.memory_char_limit = memory_char_limit
self.user_char_limit = user_char_limit
self.memory_entries: list[str] = []
self.user_entries: list[str] = []
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
def load_from_disk(self) -> None:
"""从磁盘加载 live state并冻结当前 session 的 prompt snapshot。
调用时机应该是“会话启动时”,而不是每次工具写入后。
如果在每次写入后都重新 load 并更新 system prompt就会破坏 frozen snapshot
这个设计,导致本轮会话 prompt 前缀发生变化。
"""
self.root.mkdir(parents=True, exist_ok=True)
self.memory_entries = list(dict.fromkeys(self._read_file(self._path_for("memory"))))
self.user_entries = list(dict.fromkeys(self._read_file(self._path_for("user"))))
self._system_prompt_snapshot = {
"memory": self._render_block("memory", self.memory_entries),
"user": self._render_block("user", self.user_entries),
}
@contextmanager
def _file_lock(self, path: Path):
"""对目标记忆文件加排他锁。
锁文件使用 sibling `.lock` 文件,而不是直接锁业务文件本身。
原因是业务文件使用的是“临时文件写入 + os.replace 原子替换”,如果直接锁目标
文件,替换时会让锁语义和文件句柄关系变得更脆弱。
"""
lock_path = path.with_suffix(path.suffix + ".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
if fcntl is None and msvcrt is None:
yield
return
if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0):
lock_path.write_text(" ", encoding="utf-8")
fd = open(lock_path, "r+" if msvcrt else "a+", encoding="utf-8")
try:
if fcntl is not None:
fcntl.flock(fd, fcntl.LOCK_EX)
elif msvcrt is not None: # pragma: no cover - Windows fallback
fd.seek(0)
msvcrt.locking(fd.fileno(), msvcrt.LK_LOCK, 1)
yield
finally:
if fcntl is not None:
fcntl.flock(fd, fcntl.LOCK_UN)
elif msvcrt is not None: # pragma: no cover - Windows fallback
try:
fd.seek(0)
msvcrt.locking(fd.fileno(), msvcrt.LK_UNLCK, 1)
except OSError:
pass
fd.close()
def _path_for(self, target: str) -> Path:
"""根据目标桶返回实际文件路径。"""
if target == "user":
return self.root / DEFAULT_USER_FILENAME
return self.root / DEFAULT_MEMORY_FILENAME
def _entries_for(self, target: str) -> list[str]:
"""读取某个目标桶当前的 live entries。"""
if target == "user":
return self.user_entries
return self.memory_entries
def _set_entries(self, target: str, entries: list[str]) -> None:
"""更新某个目标桶在内存中的 live entries。"""
if target == "user":
self.user_entries = entries
else:
self.memory_entries = entries
def _char_limit(self, target: str) -> int:
"""返回目标桶的字符预算。
这里使用字符数而不是 token 数,是因为字符预算更稳定,也不依赖具体模型。
"""
return self.user_char_limit if target == "user" else self.memory_char_limit
def _char_count(self, target: str) -> int:
"""返回目标桶当前 live state 的字符占用。"""
entries = self._entries_for(target)
return len(ENTRY_DELIMITER.join(entries)) if entries else 0
def _reload_target(self, target: str) -> None:
"""在持锁状态下重新从磁盘读取目标桶。
这是并发安全协议里最关键的一步之一。
必须在拿到锁之后 reload才能确保当前进程不会覆盖掉其他并发会话刚刚写入
的最新内容。
"""
fresh = list(dict.fromkeys(self._read_file(self._path_for(target))))
self._set_entries(target, fresh)
def save_to_disk(self, target: str) -> None:
"""把当前 live entries 持久化到磁盘。"""
self.root.mkdir(parents=True, exist_ok=True)
self._write_file(self._path_for(target), self._entries_for(target))
def add(self, target: str, content: str) -> dict[str, Any]:
"""追加一条新的长期记忆。
规则:
1. 空内容拒绝
2. 安全扫描不通过拒绝
3. 精确重复拒绝
4. 超出字符预算拒绝
5. 否则追加并立即写盘
"""
content = content.strip()
if not content:
return {"success": False, "error": "Content cannot be empty."}
scan_error = scan_memory_content(content)
if scan_error:
return {"success": False, "error": scan_error}
with self._file_lock(self._path_for(target)):
self._reload_target(target)
entries = self._entries_for(target)
if content in entries:
return self._success_response(target, "Entry already exists (skipped duplicate).")
new_entries = entries + [content]
new_total = len(ENTRY_DELIMITER.join(new_entries))
limit = self._char_limit(target)
if new_total > limit:
current = self._char_count(target)
return {
"success": False,
"error": (
f"Memory at {current:,}/{limit:,} chars. "
f"Adding this entry ({len(content)} chars) would exceed the limit."
),
"current_entries": list(entries),
"usage": f"{current:,}/{limit:,}",
}
entries.append(content)
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry added.")
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
"""用新的内容替换一条已有记忆。
这里按 `old_text in entry` 做子串匹配,而不是要求调用方提供完整条目或 UUID。
如果命中多条且它们内容不同,会要求调用方给出更精确的片段,避免误替换。
"""
old_text = old_text.strip()
new_content = new_content.strip()
if not old_text:
return {"success": False, "error": "old_text cannot be empty."}
if not new_content:
return {
"success": False,
"error": "new_content cannot be empty. Use remove to delete entries.",
}
scan_error = scan_memory_content(new_content)
if scan_error:
return {"success": False, "error": scan_error}
with self._file_lock(self._path_for(target)):
self._reload_target(target)
entries = self._entries_for(target)
matches = [(index, entry) for index, entry in enumerate(entries) if old_text in entry]
if not matches:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if len(matches) > 1:
unique_texts = {entry for _, entry in matches}
if len(unique_texts) > 1:
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": [
entry[:80] + ("..." if len(entry) > 80 else "")
for _, entry in matches
],
}
index = matches[0][0]
candidate_entries = list(entries)
candidate_entries[index] = new_content
new_total = len(ENTRY_DELIMITER.join(candidate_entries))
limit = self._char_limit(target)
if new_total > limit:
return {
"success": False,
"error": (
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
"Shorten the new content or remove other entries first."
),
}
entries[index] = new_content
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry replaced.")
def remove(self, target: str, old_text: str) -> dict[str, Any]:
"""删除一条已有记忆。
删除和替换共享同样的匹配策略:优先服务于 LLM 可操作性,而不是数据库式的强 ID。
"""
old_text = old_text.strip()
if not old_text:
return {"success": False, "error": "old_text cannot be empty."}
with self._file_lock(self._path_for(target)):
self._reload_target(target)
entries = self._entries_for(target)
matches = [(index, entry) for index, entry in enumerate(entries) if old_text in entry]
if not matches:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if len(matches) > 1:
unique_texts = {entry for _, entry in matches}
if len(unique_texts) > 1:
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": [
entry[:80] + ("..." if len(entry) > 80 else "")
for _, entry in matches
],
}
entries.pop(matches[0][0])
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry removed.")
def format_for_system_prompt(self, target: str) -> str | None:
"""返回 session 启动时冻结下来的 prompt block。
这里明确返回的是 frozen snapshot而不是 live state。
所以如果 session 中途调用 `add()` 写入了新记忆,这里不会立刻变化。
"""
block = self._system_prompt_snapshot.get(target, "")
return block or None
def _success_response(self, target: str, message: str | None = None) -> dict[str, Any]:
"""统一生成 memory tool 的成功响应。
响应里返回 live entries 和占用信息,目的是让模型能“看到自己刚写进去什么”,
即使 system prompt 仍然保持冻结不变。
"""
current = self._char_count(target)
limit = self._char_limit(target)
percent = min(100, int((current / limit) * 100)) if limit > 0 else 0
payload: dict[str, Any] = {
"success": True,
"target": target,
"entries": list(self._entries_for(target)),
"entry_count": len(self._entries_for(target)),
"usage": f"{percent}% — {current:,}/{limit:,} chars",
}
if message:
payload["message"] = message
return payload
def _render_block(self, target: str, entries: list[str]) -> str:
"""把条目渲染成适合注入 system prompt 的块。"""
if not entries:
return ""
current = len(ENTRY_DELIMITER.join(entries))
limit = self._char_limit(target)
percent = min(100, int((current / limit) * 100)) if limit > 0 else 0
if target == "user":
header = f"USER PROFILE (who the user is) [{percent}% — {current:,}/{limit:,} chars]"
else:
header = f"MEMORY (your personal notes) [{percent}% — {current:,}/{limit:,} chars]"
separator = "" * 46
return f"{separator}\n{header}\n{separator}\n{ENTRY_DELIMITER.join(entries)}"
@staticmethod
def _read_file(path: Path) -> list[str]:
"""读取记忆文件并按 entry delimiter 拆分。
这里不额外加读锁,因为写入采用的是原子替换:读者只会看到旧完整文件或新完整文件,
不会看到半写入状态。
"""
if not path.exists():
return []
try:
raw = path.read_text(encoding="utf-8")
except OSError:
return []
if not raw.strip():
return []
return [entry for entry in (item.strip() for item in raw.split(ENTRY_DELIMITER)) if entry]
@staticmethod
def _write_file(path: Path, entries: list[str]) -> None:
"""以原子方式写入记忆文件。
这里不能直接 `open(path, "w")`,因为那会先截断原文件,再写新内容。
如果恰好此时别的进程正在读,就可能读到空文件或半成品。
正确方式是:
1. 在同目录创建临时文件
2. 写入并 fsync
3. 使用 `os.replace()` 原子替换
"""
content = ENTRY_DELIMITER.join(entries) if entries else ""
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
try:
with os.fdopen(fd, "w", encoding="utf-8") as handle:
handle.write(content)
handle.flush()
os.fsync(handle.fileno())
os.replace(tmp_path, path)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise

View File

@ -0,0 +1,2 @@
"""Reusable procedures."""

View File

@ -0,0 +1,2 @@
"""Run records."""

View File

@ -0,0 +1,5 @@
"""Session transcript search storage."""
from .transcript_store import TranscriptStore
__all__ = ["TranscriptStore"]

View File

@ -0,0 +1,46 @@
"""兼容层:过渡期把旧 transcript store 导向新的 session 子系统。
真正的主实现现在在:
1. `beaver.engine.session.store`
2. `beaver.engine.session.search`
3. `beaver.engine.session.manager`
保留这个文件只是为了避免已经写好的 MCP server / tool 导入立刻断掉。
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from beaver.engine.session.manager import SessionManager
class TranscriptStore:
"""兼容旧接口的薄封装。"""
def __init__(self, db_path: str | Path) -> None:
path = Path(db_path)
workspace = path.parent.parent if path.parent.name == "sessions" else path.parent
self.manager = SessionManager(workspace=workspace, db_path=path)
def close(self) -> None:
self.manager.close()
def ensure_session(self, session_id: str, **kwargs: Any) -> str:
return self.manager.ensure_session(session_id, **kwargs)
def append_message(self, session_id: str, **kwargs: Any) -> int:
return self.manager.append_message(session_id, **kwargs)
def get_session(self, session_id: str) -> dict[str, Any] | None:
return self.manager.get_session(session_id)
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.manager.list_sessions_rich(**kwargs)
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
return self.manager.get_messages_as_conversation(session_id)
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
return self.manager.search_messages(**kwargs)

View File

@ -0,0 +1,2 @@
"""Memory related to skill evolution."""

View File

@ -0,0 +1,2 @@
"""Storage backends for memory."""

View File

@ -0,0 +1,2 @@
"""Permission and governance layer."""

View File

@ -0,0 +1,2 @@
"""Execution guards."""

View File

@ -0,0 +1,2 @@
"""Permission policies."""

View File

@ -0,0 +1,2 @@
"""Agent permission profiles."""

View File

@ -0,0 +1,2 @@
"""Plugin system for Beaver."""

View File

@ -0,0 +1,2 @@
"""Plugin extension hooks."""

View File

@ -0,0 +1,2 @@
"""Plugin loading hooks."""

View File

@ -0,0 +1,2 @@
"""Plugin registry."""

View File

@ -0,0 +1,6 @@
"""Application services for Beaver."""
from .agent_service import AgentService
from .memory_service import MemoryService
__all__ = ["AgentService", "MemoryService"]

View File

@ -0,0 +1,2 @@
"""Administrative application service."""

View File

@ -0,0 +1,212 @@
"""Application service for agent entry.
这层的职责是把“接口层如何调用 AgentLoop”统一收口。
接口层以后不应该各自做这些事情:
1. 自己 new `AgentLoop`
2. 自己决定何时 `boot()`
3. 自己处理 direct run 的同步/异步包装
统一放在 `AgentService` 后CLI / Web / Gateway 才能共享同一条运行主链。
"""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
class AgentService:
"""面向 interfaces 的统一 agent 运行入口。
这里明确区分两种调用模式:
1. direct mode
- 不启动后台运行循环
- 直接调用 `process_direct()` / `run_direct()`
2. running mode
- 先 `await start()`
- 之后所有外部任务都必须走 `submit_direct()`
- 不允许再直接调用 `process_direct()`
"""
def __init__(
self,
*,
workspace: str | Path | None = None,
profile: AgentProfile | None = None,
loader: EngineLoader | None = None,
) -> None:
self.profile = profile or AgentProfile()
self.loader = loader or EngineLoader(workspace=workspace)
self._loop: AgentLoop | None = None
self._run_task: asyncio.Task[None] | None = None
def create_loop(self) -> AgentLoop:
"""创建并缓存当前 service 使用的 AgentLoop。"""
if self._loop is None:
self._loop = AgentLoop(profile=self.profile, loader=self.loader)
self._loop.boot()
return self._loop
@property
def has_loop(self) -> bool:
"""当前 service 是否已经创建过 loop。"""
return self._loop is not None
@property
def is_running(self) -> bool:
"""当前 service 是否处于 running mode。"""
return self._run_task is not None and not self._run_task.done()
def close(self) -> None:
"""关闭当前 service 持有的 runtime。"""
if self._run_task is not None and not self._run_task.done():
raise RuntimeError("AgentService.close() requires stop() before closing a running loop")
self._run_task = None
if self._loop is None:
return
try:
self._loop.close()
finally:
self._loop = None
async def start(self) -> None:
"""启动后台运行循环,进入 running mode。
进入 running mode 后:
- 外部任务必须通过 `submit_direct()` 提交
- `process_direct()` 不再允许直接调用
"""
if self._run_task is not None and not self._run_task.done():
return
loop = self.create_loop()
self._run_task = asyncio.create_task(loop.run())
while not loop.is_running:
if self._run_task.done():
await self._run_task
break
await asyncio.sleep(0)
async def _stop_impl(
self,
*,
timeout_seconds: float | None = None,
force: bool = False,
) -> None:
"""内部停止实现,支持 graceful timeout 和可选 force cancel。"""
if self._run_task is None:
return
run_task = self._run_task
loop = self.create_loop()
try:
await loop.stop()
if timeout_seconds is None:
await run_task
else:
try:
await asyncio.wait_for(asyncio.shield(run_task), timeout=timeout_seconds)
except asyncio.TimeoutError as exc:
if force:
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
else:
raise TimeoutError(
f"AgentService.stop() timed out after {timeout_seconds} seconds while draining queued tasks"
) from exc
finally:
if run_task.done():
self._run_task = None
async def stop(
self,
*,
timeout_seconds: float | None = None,
force: bool = False,
) -> None:
"""停止后台运行循环并等待退出。
参数:
- `timeout_seconds`: graceful drain 的最长等待时间;`None` 表示一直等
- `force`: 超时后是否 cancel 掉运行循环 task
"""
await self._stop_impl(timeout_seconds=timeout_seconds, force=force)
async def shutdown(
self,
*,
timeout_seconds: float | None = None,
force: bool = False,
) -> None:
"""先停运行循环,再释放 runtime。"""
await self._stop_impl(timeout_seconds=timeout_seconds, force=force)
self.close()
async def process_direct(
self,
message: str,
**kwargs: Any,
) -> AgentRunResult:
"""异步 direct run 入口。
仅在 direct mode 下可用。
如果 service 已经 `start()` 进入 running mode
调用方必须改用 `submit_direct()`,不能绕过运行队列直接执行。
"""
if self._run_task is not None and not self._run_task.done():
raise RuntimeError(
"AgentService.process_direct() is unavailable while the service is running; "
"use 'await AgentService.submit_direct(...)' after start()."
)
loop = self.create_loop()
return await loop.process_direct(message, **kwargs)
async def submit_direct(
self,
message: str,
**kwargs: Any,
) -> AgentRunResult:
"""向 running mode 下的 loop 提交 direct task。
这是 `start()` 之后唯一合法的外部任务入口。
"""
loop = self.create_loop()
return await loop.submit_direct(message, **kwargs)
def run_direct(
self,
message: str,
**kwargs: Any,
) -> AgentRunResult:
"""同步 direct run 包装。
主要给当前 CLI 或简单脚本使用。真正的长期方向仍然是让 interfaces
在 direct mode 下直接走 `await process_direct(...)`。
"""
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"AgentService.run_direct() cannot be used inside an active event loop; "
"use 'await AgentService.process_direct(...)' instead."
)
return asyncio.run(self.process_direct(message, **kwargs))

View File

@ -0,0 +1,65 @@
"""Beaver memory 应用服务。
这层不是新的 memory 实现,而是对现有 `MemoryStore + MemorySnapshot` 的应用层包装。
目标只有三个:
1. 把“本轮运行前需要 refresh live state”这件事集中到一个地方
2. 把“给 context builder 的只能是 frozen snapshot”这条规则写死
3. 让 `AgentLoop` 不再直接操作 `MemoryStore` 细节
设计边界:
1. 记忆实际读写逻辑仍然在 `beaver.memory.curated.store.MemoryStore`
2. memory tool 仍然直接写 store
3. 本服务只负责 runtime 接入策略,不负责 CRUD 业务本身
"""
from __future__ import annotations
from pathlib import Path
from beaver.memory.curated.snapshot import MemorySnapshot, capture_memory_snapshot
from beaver.memory.curated.store import MemoryStore
class MemoryService:
"""统一封装 runtime 对 curated memory 的访问方式。"""
def __init__(
self,
root: str | Path,
*,
store: MemoryStore | None = None,
) -> None:
self.root = Path(root)
self.store = store or MemoryStore(self.root)
self._snapshot: MemorySnapshot | None = None
def initialize(self) -> None:
"""启动时加载一次磁盘内容,建立首份 frozen snapshot 基线。"""
self.store.load_from_disk()
self._snapshot = capture_memory_snapshot(self.store)
def reload_for_new_run(self) -> None:
"""每次新 run 开始前刷新 live state。
这是 Hermes 风格 memory policy 的关键点:
- 上一次会话中通过 tool 写入的持久记忆,下一次运行应该能看到
- 但同一次 run 中途写入的新记忆,不应反向修改当前 frozen snapshot
"""
self.store.load_from_disk()
self._snapshot = capture_memory_snapshot(self.store)
def get_snapshot(self) -> MemorySnapshot:
"""获取当前 run 应注入 system prompt 的 frozen snapshot。"""
if self._snapshot is None:
# 兜底场景:如果调用方绕过 initialize/reload首次读取时仍建立一份快照。
self._snapshot = capture_memory_snapshot(self.store)
return self._snapshot
def get_store(self) -> MemoryStore:
"""暴露底层 store 给需要直接调用 CRUD 的工具层。"""
return self.store

View File

@ -0,0 +1,2 @@
"""Application service for skills."""

View File

@ -0,0 +1,10 @@
"""Application service for coordinated team runs."""
class TeamService:
"""Placeholder service for multi-agent execution."""
def run(self, task: str) -> str:
"""Return a placeholder summary until real backends are migrated."""
return f"team run placeholder: {task}"

View File

@ -0,0 +1,12 @@
"""Skill system for Beaver."""
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
from .catalog import SkillRecord, SkillsLoader
__all__ = [
"SkillAssembler",
"SkillAssemblyResult",
"SkillEmbeddingRetriever",
"SkillRecord",
"SkillsLoader",
]

View File

@ -0,0 +1,6 @@
"""Skill assembly for Beaver."""
from .embedding_retriever import SkillEmbeddingRetriever
from .task_assembler import SkillAssemblyResult, SkillAssembler
__all__ = ["SkillAssemblyResult", "SkillAssembler", "SkillEmbeddingRetriever"]

View File

@ -0,0 +1,188 @@
"""Embedding-based skill candidate retrieval.
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
阿里云百炼 `text-embedding-v4` 做最小语义召回:
1. 复用当前 provider 的 `api_key/api_base`
2. 先用 embedding 相似度召回一小批候选
3. 再交给上层 LLM selector 做最终技能选择
"""
from __future__ import annotations
import asyncio
import math
import os
import json
from urllib import request
from typing import Any
class SkillEmbeddingRetriever:
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
def __init__(
self,
*,
api_key_env: str = "OPENAI_API_KEY",
api_base_env: str = "OPENAI_API_BASE",
model: str = "text-embedding-v4",
timeout_seconds: float = 20.0,
) -> None:
self.api_key_env = api_key_env
self.api_base_env = api_base_env
self.model = model
self.timeout_seconds = timeout_seconds
async def retrieve(
self,
*,
query: str,
candidates: list[dict[str, str]],
top_k: int = 12,
api_key: str | None = None,
api_base: str | None = None,
model: str | None = None,
) -> list[dict[str, str]]:
"""按 embedding 相似度召回 top-k 候选。
如果没有可用的 API Key / base URL或者 embedding 调用失败,
当前阶段先退回到“全部候选交给 LLM selector”。
"""
if not candidates:
return []
resolved_api_key = api_key or os.getenv(self.api_key_env)
resolved_api_base = api_base or os.getenv(self.api_base_env)
if not resolved_api_key or not resolved_api_base:
return candidates
try:
query_embedding = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[query],
model=model or self.model,
)
candidate_texts = [self._candidate_text(item) for item in candidates]
candidate_embeddings = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=candidate_texts,
model=model or self.model,
)
except Exception:
return candidates
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
return candidates
query_vector = query_embedding[0]
scored: list[tuple[float, dict[str, str]]] = []
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
if not vector:
continue
scored.append((self._cosine_similarity(query_vector, vector), candidate))
scored.sort(key=lambda item: item[0], reverse=True)
return [item[1] for item in scored[:top_k]]
async def _embed_texts(
self,
*,
api_key: str,
api_base: str,
texts: list[str],
model: str,
) -> list[list[float]]:
"""调用 OpenAI-compatible embeddings 接口。
当前对齐的是你们实际在用的网关配置:
- `POST {api_base}/embeddings`
- `model=text-embedding-v4`
- `encoding_format=float`
"""
all_vectors: list[list[float]] = []
endpoint = self._normalize_embeddings_endpoint(api_base)
for start in range(0, len(texts), 10):
batch = texts[start:start + 10]
payload = await self._post_embeddings(
endpoint=endpoint,
api_key=api_key,
model=model,
texts=batch,
)
embeddings = payload.get("data") or []
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
return all_vectors
async def _post_embeddings(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
return await asyncio.to_thread(
self._post_embeddings_sync,
endpoint=endpoint,
api_key=api_key,
model=model,
texts=texts,
)
def _post_embeddings_sync(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
body = json.dumps(
{
"model": model,
"input": texts if len(texts) > 1 else texts[0],
"encoding_format": "float",
}
).encode("utf-8")
req = request.Request(
endpoint,
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as response:
return json.loads(response.read().decode("utf-8"))
@staticmethod
def _candidate_text(candidate: dict[str, str]) -> str:
name = (candidate.get("name") or "").strip()
description = (candidate.get("description") or "").strip()
return f"{name}\n{description}".strip()
@staticmethod
def _normalize_embeddings_endpoint(api_base: str) -> str:
base = api_base.rstrip("/")
if base.endswith("/embeddings"):
return base
if base.endswith("/v1"):
return f"{base}/embeddings"
return f"{base}/v1/embeddings"
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right or len(left) != len(right):
return -1.0
dot = sum(a * b for a, b in zip(left, right, strict=False))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0 or right_norm == 0:
return -1.0
return dot / (left_norm * right_norm)

View File

@ -0,0 +1,168 @@
"""LLM-driven skill assembler.
这层现在不再自己做规则打分,而是直接把:
1. task description
2. embedding 召回后的候选 skill 摘要
交给一个模型来决定本轮要激活哪些 skill。
当前目标非常克制:
- 输入尽量简单
- 输出只要 skill 名称
- 没有命中就返回空 skills
"""
from __future__ import annotations
from dataclasses import dataclass, field
import json
from typing import Any
from beaver.engine.context import SkillContext
from beaver.engine.providers.base import LLMProvider
from beaver.engine.providers.runtime import ProviderRuntime
from beaver.skills.catalog.loader import SkillsLoader
from beaver.skills.catalog.utils import strip_frontmatter
from .embedding_retriever import SkillEmbeddingRetriever
@dataclass(slots=True)
class SkillAssemblyResult:
"""一次装配后真正要注入当前 run 的 skills。"""
activated_skills: list[SkillContext] = field(default_factory=list)
class SkillAssembler:
"""用 LLM 根据 task description 选择当前 run 的 skills。"""
def __init__(
self,
loader: SkillsLoader,
retriever: SkillEmbeddingRetriever | None = None,
) -> None:
self.loader = loader
self.retriever = retriever or SkillEmbeddingRetriever()
async def assemble(
self,
*,
task_description: str,
provider: LLMProvider,
model: str,
embedding_runtime: ProviderRuntime | None = None,
top_k: int = 12,
) -> SkillAssemblyResult:
candidates = self.loader.build_selection_candidates()
if not candidates:
return SkillAssemblyResult()
candidates = await self.retriever.retrieve(
query=task_description,
candidates=candidates,
top_k=top_k,
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
model=embedding_runtime.model if embedding_runtime is not None else None,
)
if not candidates:
return SkillAssemblyResult()
selected_names = await self._select_skill_names(
task_description=task_description,
candidates=candidates,
provider=provider,
model=model,
)
if not selected_names:
return SkillAssemblyResult()
activated_skills: list[SkillContext] = []
for name in selected_names:
raw_content = self.loader.load_skill(name)
content = strip_frontmatter(raw_content).strip() if raw_content else ""
if not content:
continue
activated_skills.append(SkillContext(name=name, content=content))
return SkillAssemblyResult(activated_skills=activated_skills)
async def _select_skill_names(
self,
*,
task_description: str,
candidates: list[dict[str, str]],
provider: LLMProvider,
model: str,
) -> list[str]:
candidate_summary = self._render_candidates(candidates)
candidate_names = {item["name"] for item in candidates}
messages = [
{
"role": "system",
"content": (
"You select Beaver skills for a single run. "
"Given a task description and candidate skill summaries, "
"return only a JSON array of skill names to activate. "
"Do not invent names. If nothing matches, return []."
),
},
{
"role": "user",
"content": (
f"Task description:\n{task_description}\n\n"
f"Candidate skills:\n{candidate_summary}\n\n"
"Return only JSON, for example: [\"skill-a\", \"skill-b\"]"
),
},
]
response = await provider.chat(
messages=messages,
tools=None,
model=model,
max_tokens=512,
temperature=0,
)
if response.finish_reason == "error" or not response.content:
return []
parsed = self._parse_selected_names(response.content)
if not parsed:
return []
# 只保留当前候选集中真实存在的 skill 名称,并维持模型输出顺序。
filtered: list[str] = []
for name in parsed:
if name in candidate_names and name not in filtered:
filtered.append(name)
return filtered
@staticmethod
def _render_candidates(candidates: list[dict[str, str]]) -> str:
lines: list[str] = []
for item in candidates:
lines.append(f"- {item['name']}: {item['description']}")
return "\n".join(lines)
@staticmethod
def _parse_selected_names(content: str) -> list[str]:
cleaned = content.strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
cleaned = "\n".join(lines[1:-1]).strip()
try:
payload: Any = json.loads(cleaned)
except json.JSONDecodeError:
return []
if isinstance(payload, dict):
for key in ("skills", "selected_skills", "activated_skills", "selected"):
value = payload.get(key)
if isinstance(value, list):
payload = value
break
if not isinstance(payload, list):
return []
return [item.strip() for item in payload if isinstance(item, str) and item.strip()]

View File

@ -0,0 +1,2 @@
"""Built-in skill payloads."""

View File

@ -0,0 +1,5 @@
"""Skill catalog and indexing."""
from .loader import SkillRecord, SkillsLoader
__all__ = ["SkillRecord", "SkillsLoader"]

View File

@ -0,0 +1,281 @@
"""Beaver skills catalog loader。
第一版目标非常明确:
1. 扫描技能目录
2. 读取 `SKILL.md`
3. 解析前置元数据
4. 生成可注入上下文的正文与索引
这层不负责:
1. 动态选择本轮应该启用哪些 skill
2. skill review / publishing
3. skill 自动学习
这些决策属于 resolver 或更高层工作流。
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .utils import (
check_requirements,
escape_xml,
get_missing_requirements,
parse_frontmatter,
parse_skill_metadata_blob,
strip_frontmatter,
)
@dataclass(slots=True)
class SkillRecord:
"""单个 skill 的目录级元数据。"""
name: str
path: Path
source: str
class SkillsLoader:
"""从 workspace/builtin 目录中发现并读取 skills。"""
def __init__(
self,
workspace: str | Path,
*,
builtin_skills_dir: str | Path | None = None,
extra_dirs: list[str | Path] | None = None,
) -> None:
self.workspace = Path(workspace)
self.workspace_skills = self.workspace / "skills"
self.builtin_skills = Path(builtin_skills_dir) if builtin_skills_dir is not None else Path(__file__).resolve().parent.parent / "builtin"
self.extra_dirs = [Path(item) for item in (extra_dirs or [])]
def list_skills(self, *, filter_unavailable: bool = True) -> list[SkillRecord]:
"""列出当前可见的 skills。
优先级:
1. workspace
2. extra/plugin 目录
3. builtin
重名 skill 只保留优先级更高的那一个。
"""
ordered_roots: list[tuple[str, Path]] = [
("workspace", self.workspace_skills),
*[("plugin", path) for path in self.extra_dirs],
("builtin", self.builtin_skills),
]
found: dict[str, SkillRecord] = {}
for source, root in ordered_roots:
if not root.exists():
continue
for skill_dir in root.iterdir():
skill_file = skill_dir / "SKILL.md"
if not skill_dir.is_dir() or not skill_file.exists():
continue
name = skill_dir.name
if name in found:
continue
record = SkillRecord(name=name, path=skill_file, source=source)
if filter_unavailable and not self._record_available(record):
continue
found[name] = record
return list(found.values())
def load_skill(self, name: str) -> str | None:
"""按名称加载 skill 原始内容。"""
record = self._find_record(name)
if record is None:
return None
return record.path.read_text(encoding="utf-8")
def get_skill_record(self, name: str) -> SkillRecord | None:
"""按名称返回 skill record。"""
return self._find_record(name)
def get_skill_metadata(self, name: str) -> dict[str, Any] | None:
"""读取 skill frontmatter 元数据。"""
content = self.load_skill(name)
if content is None:
return None
metadata, _ = parse_frontmatter(content)
return metadata
def load_skills_for_context(self, skill_names: list[str]) -> str:
"""加载指定 skills 的正文,并整理成上下文块。"""
sections: list[str] = []
for name in skill_names:
content = self.load_skill(name)
if not content:
continue
body = strip_frontmatter(content).strip()
if not body:
continue
sections.append(f"## {name}\n\n{body}")
return "\n\n".join(sections)
def build_skills_summary(self) -> str:
"""构建可注入 system prompt 的 skills index。
虽然函数名还沿用 `summary`,但当前语义已经更接近 Hermes 的 skills index
- 这里只告诉模型“系统里有哪些 skill 可用”
- 不负责把 skill 正文塞进 system prompt
- 真正激活的 skill 正文由 resolver/builder 走显式消息注入
"""
skills = self.list_skills(filter_unavailable=False)
if not skills:
return ""
lines = ["<skills>"]
for record in skills:
frontmatter = self.get_skill_metadata(record.name) or {}
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
available = check_requirements(meta_blob)
description = frontmatter.get("description") or record.name
load_hint = f'Use skill_view(name="{record.name}") to load the full skill.'
lines.append(f' <skill available="{str(available).lower()}">')
lines.append(f" <name>{escape_xml(record.name)}</name>")
lines.append(f" <description>{escape_xml(description)}</description>")
lines.append(f" <load_hint>{escape_xml(load_hint)}</load_hint>")
support_files = self.list_skill_supporting_files(record.name)
if support_files:
lines.append(" <supporting_files>")
for file_path in support_files[:12]:
lines.append(f" <file>{escape_xml(file_path)}</file>")
if len(support_files) > 12:
lines.append(" <file>...additional files omitted...</file>")
lines.append(" </supporting_files>")
if not available:
missing = get_missing_requirements(meta_blob)
if missing:
lines.append(f" <requires>{escape_xml(missing)}</requires>")
lines.append(" </skill>")
lines.append("</skills>")
return "\n".join(lines)
def build_selection_candidates(self) -> list[dict[str, str]]:
"""构建给 LLM selector 使用的候选 skill 摘要。
这里刻意保持精简,只给:
- `name`
- `description`
选择器的任务只是“从候选里挑名字”,不是直接阅读完整 skill 正文。
真正激活后的 skill 正文仍然在后续阶段按需加载。
"""
candidates: list[dict[str, str]] = []
for record in self.list_skills(filter_unavailable=True):
frontmatter = self.get_skill_metadata(record.name) or {}
description = str(frontmatter.get("description") or "").strip()
if not description:
raw_content = self.load_skill(record.name) or ""
body = strip_frontmatter(raw_content).strip()
if body:
description = " ".join(body.splitlines()[:3])[:240].strip()
candidates.append(
{
"name": record.name,
"description": description or record.name,
}
)
return candidates
def list_skill_supporting_files(self, name: str) -> list[str]:
"""列出 skill 目录下可按需查看的支持文件相对路径。"""
record = self._find_record(name)
if record is None:
return []
skill_dir = record.path.parent
results: list[str] = []
for subdir in ("references", "templates", "scripts", "assets"):
root = skill_dir / subdir
if not root.exists():
continue
for file in sorted(root.rglob("*")):
if file.is_file() and not file.is_symlink():
results.append(str(file.relative_to(skill_dir)))
return results
def view_skill(self, name: str, file_path: str | None = None) -> tuple[str, str] | None:
"""读取 skill 正文或其支持文件。
返回 `(display_name, content)`
- `display_name` 用于提示当前读取的是 skill 本体还是某个支持文件
- `content` 为实际文本内容
"""
record = self._find_record(name)
if record is None:
return None
if not self._record_available(record):
frontmatter = self.get_skill_metadata(name) or {}
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
missing = get_missing_requirements(meta_blob)
detail = f" Missing requirements: {missing}." if missing else ""
raise ValueError(f"Skill '{name}' is currently unavailable.{detail}")
skill_dir = record.path.parent
if not file_path:
return ("SKILL.md", self._read_text_file(record.path, display_name="SKILL.md"))
candidate = (skill_dir / file_path).resolve()
try:
candidate.relative_to(skill_dir.resolve())
except ValueError as exc:
raise ValueError("Requested skill file must stay within the skill directory") from exc
if not candidate.exists() or not candidate.is_file():
raise FileNotFoundError(f"Skill file '{file_path}' does not exist")
display_name = str(candidate.relative_to(skill_dir))
return (display_name, self._read_text_file(candidate, display_name=display_name))
def get_always_skills(self) -> list[str]:
"""返回标记为 always 的可用 skill 名称。"""
result: list[str] = []
for record in self.list_skills(filter_unavailable=True):
frontmatter = self.get_skill_metadata(record.name) or {}
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
if meta_blob.get("always") or str(frontmatter.get("always", "")).lower() == "true":
result.append(record.name)
return result
def _find_record(self, name: str) -> SkillRecord | None:
for record in self.list_skills(filter_unavailable=False):
if record.name == name:
return record
return None
def _record_available(self, record: SkillRecord) -> bool:
content = record.path.read_text(encoding="utf-8")
frontmatter, _ = parse_frontmatter(content)
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
return check_requirements(meta_blob)
@staticmethod
def _read_text_file(path: Path, *, display_name: str) -> str:
try:
return path.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
raise ValueError(
f"Skill file '{display_name}' is not UTF-8 text and cannot be viewed with skill_view."
) from exc
def _skill_available(self, name: str) -> bool:
record = self._find_record(name)
if record is None:
return False
return self._record_available(record)

View File

@ -0,0 +1,122 @@
"""Skills catalog 的公共辅助函数。
这里专门放“解析和校验 skill 文件”的纯函数,避免 `loader.py` 里同时承担:
1. 目录扫描
2. frontmatter 解析
3. requirements 校验
4. 文本裁剪/格式化
把这些细节拆出来之后skills catalog 的边界会更清楚,后面无论是 reviews、publisher
还是 runtime resolver都可以复用同一套元数据解析规则。
"""
from __future__ import annotations
import json
import os
import re
import shutil
from typing import Any
def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
"""解析 Markdown 文件顶部的极简 frontmatter。
当前先只支持最常见的:
```md
---
key: value
key2: value2
---
body...
```
这样足够支撑第一版 skills runtime不提前把 YAML 解析器引进来。
"""
if not content.startswith("---"):
return {}, content
match = re.match(r"^---\n(.*?)\n---\n?", content, re.DOTALL)
if match is None:
return {}, content
metadata: dict[str, str] = {}
for line in match.group(1).splitlines():
if ":" not in line:
continue
key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'')
body = content[match.end():].strip()
return metadata, body
def strip_frontmatter(content: str) -> str:
"""去掉 frontmatter只保留 skill 正文。"""
_, body = parse_frontmatter(content)
return body
def parse_skill_metadata_blob(raw: str) -> dict[str, Any]:
"""解析 metadata 字段里的 JSON 扩展配置。
为了兼容旧 nanobot 习惯,这里同时支持:
- `nanobot`
- `openclaw`
第一版主要关心的字段有:
- `always`
- `requires`
"""
try:
data = json.loads(raw)
except (json.JSONDecodeError, TypeError):
return {}
if not isinstance(data, dict):
return {}
nested = data.get("nanobot", data.get("openclaw", data))
return nested if isinstance(nested, dict) else {}
def check_requirements(metadata: dict[str, Any]) -> bool:
"""检查 skill 的最小 requirements 是否满足。"""
requires = metadata.get("requires", {})
if not isinstance(requires, dict):
return True
for binary in requires.get("bins", []):
if not shutil.which(str(binary)):
return False
for env_name in requires.get("env", []):
if not os.environ.get(str(env_name)):
return False
return True
def get_missing_requirements(metadata: dict[str, Any]) -> str:
"""返回缺失 requirements 的简短描述。"""
requires = metadata.get("requires", {})
if not isinstance(requires, dict):
return ""
missing: list[str] = []
for binary in requires.get("bins", []):
if not shutil.which(str(binary)):
missing.append(f"CLI: {binary}")
for env_name in requires.get("env", []):
if not os.environ.get(str(env_name)):
missing.append(f"ENV: {env_name}")
return ", ".join(missing)
def escape_xml(value: str) -> str:
"""给 skills summary 做最小 XML 转义。"""
return value.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

View File

@ -0,0 +1,2 @@
"""Draft skills generated before review."""

View File

@ -0,0 +1,2 @@
"""Skill publishing and version switching."""

View File

@ -0,0 +1,5 @@
"""Runtime skill resolution."""
from .runtime import ResolvedSkillSet, RuntimeSkillResolver
__all__ = ["ResolvedSkillSet", "RuntimeSkillResolver"]

View File

@ -0,0 +1,50 @@
"""Runtime skill resolver。
这层负责回答一个运行时问题:
“这一次调用,哪些 skill 要被激活,并以什么形式注入上下文?”
第一版保持保守,只综合三类来源:
1. `always` skills
不在这里做复杂的语义匹配或自动推荐。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from beaver.engine.context import SkillContext
from beaver.skills.catalog.loader import SkillsLoader
from beaver.skills.catalog.utils import strip_frontmatter
@dataclass(slots=True)
class ResolvedSkillSet:
"""一次运行最终解析出的 skills 结果。"""
activated_skills: list[SkillContext] = field(default_factory=list)
class RuntimeSkillResolver:
"""把 profile/request 转成当前轮次真正激活的 skill 集合。"""
def __init__(self, loader: SkillsLoader) -> None:
self.loader = loader
def resolve(
self,
) -> ResolvedSkillSet:
selected: list[str] = []
for name in self.loader.get_always_skills():
if name not in selected:
selected.append(name)
activated_skills: list[SkillContext] = []
for name in selected:
raw_content = self.loader.load_skill(name)
content = strip_frontmatter(raw_content).strip() if raw_content else ""
if not content:
continue
activated_skills.append(SkillContext(name=name, content=content))
return ResolvedSkillSet(activated_skills=activated_skills)

View File

@ -0,0 +1,2 @@
"""Skill review workflow."""

View File

@ -0,0 +1,2 @@
"""Built-in Beaver templates."""

View File

@ -0,0 +1,15 @@
"""Tool system for Beaver."""
from .base import BaseTool, ObjectBackedTool, ToolContext, ToolResult, ToolSpec
from .registry import ToolRegistry
from .runtime import ToolExecutor
__all__ = [
"BaseTool",
"ObjectBackedTool",
"ToolContext",
"ToolExecutor",
"ToolRegistry",
"ToolResult",
"ToolSpec",
]

View File

@ -0,0 +1,175 @@
"""Beaver 工具系统的统一契约。
这一层的目标不是实现具体工具,而是把 runtime 真正依赖的最小接口定死。
我们需要统一回答 4 个问题:
1. 一个工具长什么样
2. tool schema 怎么导出给 provider
3. 工具执行结果长什么样
4. tool loop 执行时,可以把哪些运行时依赖传给工具
这层故意保持很薄:
- 不绑定 MCP
- 不绑定 memory/session
- 不绑定具体 provider
这样内建工具、MCP 工具、未来插件工具都可以收敛到同一套契约上。
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import json
from typing import Any
@dataclass(slots=True)
class ToolSpec:
"""单个工具对外暴露的描述信息。
这份信息主要服务两个场景:
1. 导出给 provider 的 function schema
2. 在 registry 中做列出、查找、调试
"""
name: str
description: str
input_schema: dict[str, Any]
def to_provider_schema(self) -> dict[str, Any]:
"""导出为 OpenAI-compatible function tool schema。"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.input_schema,
},
}
@dataclass(slots=True)
class ToolContext:
"""一次工具执行时可用的运行时上下文。
这不是“所有系统对象的大杂烩”,而是当前工具执行阶段最常用的公共入口。
后面主链接进来时,可以把 session manager / memory store / workspace 等从这里传入。
"""
workspace: str | None = None
session_id: str | None = None
user_id: str | None = None
services: dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
def get(self, key: str, default: Any = None) -> Any:
"""优先从 services 中取依赖,方便工具侧少写样板代码。"""
return self.services.get(key, default)
@dataclass(slots=True)
class ToolResult:
"""标准化工具执行结果。
统一返回结构的意义是:
1. tool loop 更容易记录日志和失败信息
2. provider 回灌时可以稳定地拿到字符串内容
3. 后面要做工具审计时,数据结构已经固定
"""
success: bool
content: str
tool_name: str
error: str | None = None
raw_output: Any | None = None
class BaseTool(ABC):
"""所有工具实现都应遵守的抽象基类。"""
@property
@abstractmethod
def spec(self) -> ToolSpec:
"""返回工具元数据。"""
@abstractmethod
async def invoke(self, arguments: dict[str, Any], context: ToolContext) -> ToolResult:
"""执行工具调用。"""
class ObjectBackedTool(BaseTool):
"""把现有“轻量对象工具”适配到统一 BaseTool 契约。
目前 `MemoryTool` / `SessionSearchTool` 已经存在,但它们还不是统一的 BaseTool。
这个适配器的作用就是避免重写业务逻辑,只做接口收口。
"""
def __init__(self, backend: Any) -> None:
self.backend = backend
self._spec = ToolSpec(
name=str(getattr(backend, "name")),
description=str(getattr(backend, "description", "")),
input_schema=dict(getattr(backend, "parameters", {"type": "object", "properties": {}})),
)
@property
def spec(self) -> ToolSpec:
return self._spec
async def invoke(self, arguments: dict[str, Any], context: ToolContext) -> ToolResult:
try:
call_arguments = dict(arguments)
self._inject_runtime_context(call_arguments, context)
content = await self.backend.execute(**call_arguments)
result = self._normalize_output(content)
return ToolResult(
success=result["success"],
content=result["content"],
tool_name=self.spec.name,
error=result.get("error"),
raw_output=content,
)
except Exception as exc:
return ToolResult(
success=False,
content=f"Tool {self.spec.name} failed: {exc}",
tool_name=self.spec.name,
error=str(exc),
)
def _inject_runtime_context(self, arguments: dict[str, Any], context: ToolContext) -> None:
"""把少量 runtime 上下文注入到后端工具参数中。
当前只做最小注入:
- 只有当 backend 明确暴露对应字段时才注入
- 避免把 ToolContext 整个对象直接塞给现有 builtin 工具
"""
if "current_session_id" not in arguments and hasattr(self.backend, "current_session_id"):
arguments["current_session_id"] = context.session_id
@staticmethod
def _normalize_output(content: Any) -> dict[str, Any]:
"""把后端工具返回值转成统一 success/content/error 语义。
对现有 builtin 工具最关键的是:
- 若返回的是 JSON 字符串,且包含 `success` 字段,就尊重它
- 否则默认视为普通成功文本
"""
if isinstance(content, str):
try:
parsed = json.loads(content)
except json.JSONDecodeError:
return {"success": True, "content": content}
if isinstance(parsed, dict) and "success" in parsed:
return {
"success": bool(parsed.get("success")),
"content": content,
"error": parsed.get("error"),
}
return {"success": True, "content": content}
return {"success": True, "content": str(content)}

View File

@ -0,0 +1,17 @@
"""Built-in Beaver tools."""
from .echo import EchoTool, echo_tool
from .memory import MemoryTool, memory_tool
from .skill_view import SkillViewTool, skill_view
from .session_search import SessionSearchTool, session_search
__all__ = [
"EchoTool",
"MemoryTool",
"SkillViewTool",
"SessionSearchTool",
"echo_tool",
"memory_tool",
"skill_view",
"session_search",
]

View File

@ -0,0 +1,43 @@
"""最小调试工具:把输入原样回显。
它的价值不是业务能力,而是运行时验证:
当你只想确认 tool loop 是否能走通时,`echo` 是最便宜、最确定的测试工具。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
ECHO_TOOL_DESCRIPTION = "Echo the provided text back to the agent. Useful for verifying tool calling."
ECHO_TOOL_PARAMETERS: dict[str, Any] = {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to echo back.",
}
},
"required": ["text"],
}
def echo_tool(*, text: str) -> str:
return text
@dataclass(slots=True)
class EchoTool:
"""面向 runtime 的最小内建工具。"""
name: str = "echo"
description: str = ECHO_TOOL_DESCRIPTION
parameters: dict[str, Any] = field(default_factory=lambda: dict(ECHO_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:
text = kwargs.get("text")
if not isinstance(text, str):
raise ValueError("echo tool requires a string field 'text'")
return echo_tool(text=text)

Some files were not shown because too many files have changed in this diff Show More