修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
@ -1,13 +0,0 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info
|
||||
dist/
|
||||
build/
|
||||
.git
|
||||
.env
|
||||
.assets
|
||||
node_modules/
|
||||
bridge/dist/
|
||||
workspace/
|
||||
201
app-instance/backend/.gitignore
vendored
201
app-instance/backend/.gitignore
vendored
@ -1,201 +0,0 @@
|
||||
<<<<<<< HEAD
|
||||
.assets
|
||||
.env
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
docs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
tests/
|
||||
=======
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
>>>>>>> origin/main
|
||||
@ -1,753 +0,0 @@
|
||||
# A2A Multi-Agent 改造方案
|
||||
|
||||
## 1. 需求目标
|
||||
|
||||
当前 `spawn`/`sub-agent` 只有一种执行方式: 创建一个本地后台 subagent 去完成任务。
|
||||
|
||||
这次需求要改成:
|
||||
|
||||
1. 调用 `sub-agent` 时,不一定新建本地 subagent。
|
||||
2. 先从“已添加的 Agent”里找可用目标。
|
||||
3. 再从 skills 中声明的 `agent cards` 里找可用目标。
|
||||
4. 通过 A2A 协议把任务发给对应 agent。
|
||||
5. 支持一个任务发给多个 agent,形成 `agent group`,最后回到主 agent 汇总。
|
||||
6. 保持现有 `spawn(task, label)` 兼容,不破坏已有行为。
|
||||
|
||||
结论先说:
|
||||
|
||||
- 最合适的做法不是继续把能力堆进 `SubagentManager`。
|
||||
- 应该把“本地 subagent 执行”升级为“统一委派层”。
|
||||
- `spawn` 工具继续保留,但语义从“创建 subagent”扩展为“委派给合适的 agent / agent group”。
|
||||
|
||||
## 2. 当前代码现状
|
||||
|
||||
### 2.1 当前触发链路
|
||||
|
||||
现有链路很单一:
|
||||
|
||||
1. `AgentLoop` 初始化 `SubagentManager`
|
||||
- 位置: `nanobot/agent/loop.py:88-114`
|
||||
2. `AgentLoop._register_default_tools()` 注册 `SpawnTool`
|
||||
- 位置: `nanobot/agent/loop.py:116-138`
|
||||
3. LLM 调用 `spawn(task, label)`
|
||||
4. `SpawnTool.execute()` 直接转发给 `SubagentManager.spawn()`
|
||||
- 位置: `nanobot/agent/tools/spawn.py:67-76`
|
||||
5. `SubagentManager.spawn()` 创建本地 asyncio 后台任务
|
||||
- 位置: `nanobot/agent/subagent.py:64-93`
|
||||
6. `_run_subagent()` 用一个受限工具集运行本地子代理
|
||||
- 位置: `nanobot/agent/subagent.py:95-195`
|
||||
7. `_announce_result()` 把结果包装成 `channel="system"` 的消息回投主消息总线
|
||||
- 位置: `nanobot/agent/subagent.py:197-230`
|
||||
8. `AgentLoop._process_message()` 接到 `system` 消息,再整理成用户可见回复
|
||||
- 位置: `nanobot/agent/loop.py:331-347`
|
||||
|
||||
### 2.2 当前已经有但没接入调度链路的能力
|
||||
|
||||
仓库里已经有两类“候选 agent 信息”,但没有进入实际调度:
|
||||
|
||||
1. Plugin agents
|
||||
- `PluginLoader.find_agent()` 已能找 agent
|
||||
- 位置: `nanobot/agent/plugins.py:83-91`
|
||||
- `build_agents_summary()` 也已能汇总 agent 信息
|
||||
- 位置: `nanobot/agent/plugins.py:100-121`
|
||||
- 但当前 `AgentLoop` / `ContextBuilder` 并没有用它做调度
|
||||
|
||||
2. Skills
|
||||
- `SkillsLoader` 已能枚举 / 读取 skill
|
||||
- 位置: `nanobot/agent/skills.py:32-249`
|
||||
- 但 skill 目前只被当作 prompt 资源,不会暴露成“可路由 agent”
|
||||
|
||||
### 2.3 当前缺口
|
||||
|
||||
当前缺少这几层:
|
||||
|
||||
1. 统一的 `Agent Registry`
|
||||
2. A2A `agent card` 发现与缓存
|
||||
3. A2A client 调用层
|
||||
4. 统一的委派器,负责在“本地 subagent / plugin agent / skill agent card / agent group”之间做路由
|
||||
5. group 级别的状态管理和结果聚合
|
||||
|
||||
## 3. 推荐总方案
|
||||
|
||||
推荐采用“保留 `spawn` 工具名,重构内部执行层”的方案。
|
||||
|
||||
### 3.1 核心思路
|
||||
|
||||
把当前:
|
||||
|
||||
- `SpawnTool -> SubagentManager -> 本地 subagent`
|
||||
|
||||
改成:
|
||||
|
||||
- `SpawnTool -> DelegationManager -> AgentResolver -> Executor(local/plugin/a2a/group)`
|
||||
|
||||
也就是:
|
||||
|
||||
1. `spawn` 不再等价于“必须创建 subagent”。
|
||||
2. `spawn` 变成“委派任务”。
|
||||
3. 真正执行方式由委派层动态决定。
|
||||
|
||||
### 3.2 为什么这样最合适
|
||||
|
||||
如果直接继续扩 `SubagentManager`,很快会出现这些问题:
|
||||
|
||||
1. 一个类同时负责本地 LLM 运行、A2A 网络调用、agent card 发现、group 并发、结果聚合。
|
||||
2. 后续要支持 plugin agent、本地 named agent、A2A streaming 时会越来越乱。
|
||||
3. 当前 `SubagentManager` 的职责本来就已经比较明确: “本地后台 subagent 执行器”。
|
||||
|
||||
所以更合理的拆法是:
|
||||
|
||||
1. `SubagentManager` 保留或下沉为 `LocalSubagentExecutor`
|
||||
2. 新增 `DelegationManager` 作为统一入口
|
||||
3. 新增 `AgentRegistry` / `AgentResolver`
|
||||
4. 新增 `A2AClient`
|
||||
|
||||
## 4. 推荐模块拆分
|
||||
|
||||
### 4.1 新增 `DelegationManager`
|
||||
|
||||
建议新文件:
|
||||
|
||||
- `nanobot/agent/delegation.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 接收 `spawn` 请求
|
||||
2. 根据参数和任务内容选择目标 agent
|
||||
3. 决定执行方式
|
||||
4. 对 group 做并发调度
|
||||
5. 统一把结果回投主消息总线
|
||||
|
||||
建议接口:
|
||||
|
||||
```python
|
||||
class DelegationManager:
|
||||
async def dispatch(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
target: str | None = None,
|
||||
targets: list[str] | None = None,
|
||||
strategy: str = "auto",
|
||||
origin_channel: str = "cli",
|
||||
origin_chat_id: str = "direct",
|
||||
) -> str: ...
|
||||
```
|
||||
|
||||
### 4.2 保留本地执行器
|
||||
|
||||
当前 `nanobot/agent/subagent.py` 的 `_run_subagent()` 逻辑可以保留,但角色改为:
|
||||
|
||||
- `LocalSubagentExecutor`
|
||||
|
||||
也可以第一版不重命名文件,只把里面逻辑拆成:
|
||||
|
||||
1. `spawn_local()`
|
||||
2. `_run_local_subagent()`
|
||||
3. `_announce_local_result()`
|
||||
|
||||
这样可以最小改动落地。
|
||||
|
||||
### 4.3 新增 `AgentRegistry`
|
||||
|
||||
建议新文件:
|
||||
|
||||
- `nanobot/agent/agent_registry.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 汇总所有可调度 agent
|
||||
2. 统一输出规范化 descriptor
|
||||
3. 维护优先级和去重逻辑
|
||||
|
||||
统一后的 agent 来源:
|
||||
|
||||
1. workspace 中“已添加的 agent”
|
||||
2. plugin agents
|
||||
3. skill frontmatter 里声明的 `agent_cards`
|
||||
4. 必要时 fallback 到本地 `local-subagent`
|
||||
|
||||
建议统一 descriptor:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentDescriptor:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
source: str # workspace | plugin | skill | builtin
|
||||
kind: str # local_prompt | a2a_remote | local_fallback
|
||||
protocol: str | None # a2a | None
|
||||
plugin_name: str | None = None
|
||||
skill_name: str | None = None
|
||||
model: str | None = None
|
||||
endpoint: str | None = None
|
||||
card_url: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
```
|
||||
|
||||
### 4.4 新增 A2A client 层
|
||||
|
||||
建议新目录:
|
||||
|
||||
- `nanobot/a2a/client.py`
|
||||
- `nanobot/a2a/cards.py`
|
||||
- `nanobot/a2a/models.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 获取 agent card
|
||||
2. 解析 card 能力
|
||||
3. 对远端 agent 发 JSON-RPC 请求
|
||||
4. 处理同步返回 / task 轮询 / streaming 兼容
|
||||
|
||||
## 5. 代码插入点
|
||||
|
||||
## 5.1 `nanobot/agent/loop.py`
|
||||
|
||||
### 插入点 A: `__init__`
|
||||
|
||||
当前:
|
||||
|
||||
- `self.subagents = SubagentManager(...)`
|
||||
- 位置: `nanobot/agent/loop.py:88-102`
|
||||
|
||||
建议改成:
|
||||
|
||||
1. 初始化 `PluginLoader`
|
||||
2. 初始化 `AgentRegistry`
|
||||
3. 初始化 `DelegationManager`
|
||||
4. `DelegationManager` 内部持有 `LocalSubagentExecutor` / `A2AExecutor`
|
||||
|
||||
推荐形态:
|
||||
|
||||
```python
|
||||
self.plugins = PluginLoader(workspace)
|
||||
self.agent_registry = AgentRegistry(workspace, plugins=self.plugins, ...)
|
||||
self.delegation = DelegationManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
registry=self.agent_registry,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### 插入点 B: `_register_default_tools`
|
||||
|
||||
当前:
|
||||
|
||||
- 注册 `SpawnTool(manager=self.subagents)`
|
||||
- 位置: `nanobot/agent/loop.py:130-134`
|
||||
|
||||
建议改成:
|
||||
|
||||
```python
|
||||
self.tools.register(SpawnTool(manager=self.delegation))
|
||||
```
|
||||
|
||||
### 插入点 C: `_set_tool_context`
|
||||
|
||||
当前会给 `spawn` 工具写 origin context:
|
||||
|
||||
- 位置: `nanobot/agent/loop.py:165-192`
|
||||
|
||||
这里逻辑可以继续保留,不需要大改,因为 A2A / group 结果最终也要回到原会话。
|
||||
|
||||
## 5.2 `nanobot/agent/tools/spawn.py`
|
||||
|
||||
当前 `SpawnTool` 参数只有:
|
||||
|
||||
- `task`
|
||||
- `label`
|
||||
|
||||
位置:
|
||||
|
||||
- schema: `nanobot/agent/tools/spawn.py:49-65`
|
||||
- execute: `nanobot/agent/tools/spawn.py:67-76`
|
||||
|
||||
建议扩成:
|
||||
|
||||
```python
|
||||
{
|
||||
"task": "string",
|
||||
"label": "string?",
|
||||
"target": "string?",
|
||||
"targets": "string[]?",
|
||||
"strategy": "auto|local|plugin|a2a|group"
|
||||
}
|
||||
```
|
||||
|
||||
兼容规则:
|
||||
|
||||
1. 老调用只传 `task/label` 时,等价于 `strategy="auto"`
|
||||
2. `target` 表示单目标
|
||||
3. `targets` 表示 group
|
||||
4. `strategy="local"` 强制走本地 subagent
|
||||
5. `strategy="a2a"` 强制只找 A2A 目标
|
||||
|
||||
## 5.3 `nanobot/agent/context.py`
|
||||
|
||||
当前 prompt 中只注入:
|
||||
|
||||
1. bootstrap
|
||||
2. memory
|
||||
3. skills summary
|
||||
|
||||
位置:
|
||||
|
||||
- `build_system_prompt()`: `nanobot/agent/context.py:38-76`
|
||||
|
||||
建议新增一段:
|
||||
|
||||
- `# Available Agents`
|
||||
|
||||
由 `AgentRegistry.build_agents_summary()` 生成,内容只放:
|
||||
|
||||
1. agent id / name
|
||||
2. 简短 description
|
||||
3. source
|
||||
4. protocol
|
||||
5. 是否支持 group / streaming
|
||||
|
||||
目标是让主 agent 知道:
|
||||
|
||||
1. 当前有哪些现成 agent 可用
|
||||
2. 什么时候应该 `spawn(target=...)`
|
||||
3. 哪些是 skill 暴露出来的 A2A agent
|
||||
|
||||
## 5.4 `nanobot/agent/skills.py`
|
||||
|
||||
这是 skill agent cards 的关键入口。
|
||||
|
||||
当前 skill frontmatter 已支持 `metadata` 字段,并会解析其中的 JSON:
|
||||
|
||||
- `_parse_nanobot_metadata()`: `nanobot/agent/skills.py:190-196`
|
||||
- `_get_skill_meta()`: `nanobot/agent/skills.py:209-212`
|
||||
|
||||
最推荐的做法不是去扫 `SKILL.md` 正文里的自由文本,而是约定 skill frontmatter 的 `metadata.nanobot.agent_cards`。
|
||||
|
||||
建议新增:
|
||||
|
||||
```python
|
||||
def list_skill_agent_cards(self) -> list[dict[str, Any]]: ...
|
||||
```
|
||||
|
||||
推荐 skill 写法:
|
||||
|
||||
```md
|
||||
---
|
||||
name: github-research
|
||||
description: GitHub research helper
|
||||
metadata: '{"nanobot":{"agent_cards":[{"id":"repo-analyst","url":"https://example.com/.well-known/agent-card","tags":["github","research"],"auth_env":"REPO_AGENT_TOKEN"}]}}'
|
||||
---
|
||||
```
|
||||
|
||||
为什么推荐这样做:
|
||||
|
||||
1. 当前 frontmatter 解析已经存在
|
||||
2. 不需要引入新的 skill 文件格式
|
||||
3. 不需要解析自由文本
|
||||
4. skill 打包/上传链路也不需要大改
|
||||
|
||||
## 5.5 `nanobot/agent/plugins.py`
|
||||
|
||||
当前 plugin agents 已能加载:
|
||||
|
||||
- `find_agent()`: `nanobot/agent/plugins.py:83-91`
|
||||
- `_load_agents()`: `nanobot/agent/plugins.py:210-229`
|
||||
|
||||
建议:
|
||||
|
||||
1. `AgentRegistry` 直接复用 `PluginLoader`
|
||||
2. plugin agent 作为“本地可执行 agent”来源之一
|
||||
|
||||
这里不建议把 plugin agent 强行转成 A2A。
|
||||
|
||||
更合理的处理是:
|
||||
|
||||
1. plugin agent 本地执行
|
||||
2. skill agent cards 远程 A2A 调用
|
||||
3. workspace 手动添加的 agent 也可走 A2A
|
||||
|
||||
## 5.6 `nanobot/config/schema.py`
|
||||
|
||||
当前 `ToolsConfig` 只有:
|
||||
|
||||
- `web`
|
||||
- `exec`
|
||||
- `restrict_to_workspace`
|
||||
- `mcp_servers`
|
||||
|
||||
位置:
|
||||
|
||||
- `nanobot/config/schema.py:337-347`
|
||||
|
||||
建议新增:
|
||||
|
||||
```python
|
||||
class A2AConfig(Base):
|
||||
enabled: bool = True
|
||||
timeout_seconds: int = 30
|
||||
poll_interval_seconds: int = 2
|
||||
card_cache_ttl_seconds: int = 300
|
||||
max_parallel_agents: int = 4
|
||||
allow_skill_cards: bool = True
|
||||
allow_workspace_agents: bool = True
|
||||
allowed_hosts: list[str] = Field(default_factory=list)
|
||||
```
|
||||
|
||||
然后挂到:
|
||||
|
||||
```python
|
||||
class ToolsConfig(Base):
|
||||
...
|
||||
a2a: A2AConfig = Field(default_factory=A2AConfig)
|
||||
```
|
||||
|
||||
## 5.7 `nanobot/web/server.py`
|
||||
|
||||
当前 web API 有:
|
||||
|
||||
- `/api/skills`
|
||||
- `/api/plugins`
|
||||
|
||||
位置:
|
||||
|
||||
- skills: `nanobot/web/server.py:702-843`
|
||||
- plugins: `nanobot/web/server.py:1000-1037`
|
||||
|
||||
建议新增:
|
||||
|
||||
1. `GET /api/agents`
|
||||
- 返回统一后的 agent registry
|
||||
2. `POST /api/agents`
|
||||
- 添加 workspace agent card
|
||||
3. `DELETE /api/agents/{id}`
|
||||
- 删除 workspace agent
|
||||
4. `POST /api/agents/refresh`
|
||||
- 刷新 card cache
|
||||
|
||||
这样“已添加的 Agent”才有明确的持久化来源。
|
||||
|
||||
## 6. 推荐的数据来源优先级
|
||||
|
||||
为了行为稳定,推荐 resolver 按以下优先级匹配:
|
||||
|
||||
1. workspace 手动添加的 agent
|
||||
2. plugin agents
|
||||
3. skill metadata 里的 agent cards
|
||||
4. fallback 到本地 subagent
|
||||
|
||||
原因:
|
||||
|
||||
1. workspace 手动添加通常是用户明确希望接入的 agent
|
||||
2. plugin agent 是本地稳定能力
|
||||
3. skill card 往往是外部资源,可信度和可用性最弱
|
||||
4. 本地 subagent 最后兜底,保证老行为不失效
|
||||
|
||||
## 7. A2A 协议接入建议
|
||||
|
||||
## 7.1 Agent Card 发现
|
||||
|
||||
建议支持 3 种入口:
|
||||
|
||||
1. 显式 `card_url`
|
||||
2. `base_url + /.well-known/agent-card`
|
||||
3. fallback `base_url + /.well-known/agent.json`
|
||||
|
||||
这样做的原因是:
|
||||
|
||||
1. 当前 A2A 文档和样例在 card 路径上存在新旧写法并存
|
||||
2. 兼容性会更好
|
||||
|
||||
## 7.2 RPC 调用兼容层
|
||||
|
||||
建议客户端优先尝试:
|
||||
|
||||
1. `tasks/send`
|
||||
2. 不支持时 fallback `message/send`
|
||||
|
||||
后续可选支持:
|
||||
|
||||
1. `tasks/sendSubscribe`
|
||||
2. `message/sendStream`
|
||||
3. `tasks/get`
|
||||
4. `tasks/cancel`
|
||||
|
||||
推荐第一期先做:
|
||||
|
||||
1. 非流式发任务
|
||||
2. 如果返回 `Task` 状态不是最终态,就轮询 `tasks/get`
|
||||
|
||||
这样能最小代价先打通。
|
||||
|
||||
## 7.3 发送给远端 agent 的上下文范围
|
||||
|
||||
不要把主会话完整 history 直接发给远端 agent。
|
||||
|
||||
建议第一版只发送:
|
||||
|
||||
1. 任务目标
|
||||
2. 必要的结构化说明
|
||||
3. 主 agent 整理好的最小上下文
|
||||
|
||||
原因:
|
||||
|
||||
1. 当前本地 subagent 也不共享主会话历史
|
||||
2. 外部 A2A agent 不可信时,最小化数据泄漏面
|
||||
3. 避免 token 膨胀
|
||||
|
||||
## 8. agent group 设计
|
||||
|
||||
## 8.1 什么时候触发 group
|
||||
|
||||
建议第一版只支持两种触发:
|
||||
|
||||
1. 用户明确指定多个 agent
|
||||
2. LLM 在工具调用里显式传 `targets=[...]`
|
||||
|
||||
不建议第一版做“自动拆成多个 agent 并行”的强自动化。
|
||||
|
||||
原因:
|
||||
|
||||
1. 容易失控
|
||||
2. 很难解释为什么调了这些 agent
|
||||
3. 对成本和网络调用不可控
|
||||
|
||||
## 8.2 group 执行链路
|
||||
|
||||
推荐链路:
|
||||
|
||||
1. `SpawnTool.execute()` 收到 `targets`
|
||||
2. `DelegationManager.dispatch()` 创建 `group_run_id`
|
||||
3. `AgentRegistry` 解析出每个 target 的 descriptor
|
||||
4. 按 executor 类型并发执行
|
||||
5. `asyncio.gather(..., return_exceptions=True)` 收集结果
|
||||
6. 统一做 group aggregation
|
||||
7. `_announce_group_result()` 回投主消息总线
|
||||
8. 主 agent 再生成最终用户回复
|
||||
|
||||
## 8.3 group 结果聚合
|
||||
|
||||
建议 group 执行器输出结构化结果:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentRunResult:
|
||||
agent_id: str
|
||||
status: str # ok | error | timeout | cancelled
|
||||
summary: str
|
||||
raw: dict[str, Any] | None = None
|
||||
```
|
||||
|
||||
group 最终回投内容建议类似:
|
||||
|
||||
```text
|
||||
[Agent group 'repo-check' completed]
|
||||
|
||||
Members:
|
||||
- researcher: ok
|
||||
- reviewer: ok
|
||||
- planner: error
|
||||
|
||||
Results:
|
||||
...
|
||||
|
||||
Summarize this naturally for the user. Mention disagreements if any.
|
||||
```
|
||||
|
||||
这样能继续复用当前 `system -> main agent -> user` 的输出模式。
|
||||
|
||||
## 9. 推荐触发方式
|
||||
|
||||
## 9.1 用户显式触发
|
||||
|
||||
用户说法示例:
|
||||
|
||||
1. “把这个任务交给 `github-reviewer`”
|
||||
2. “让 `researcher` 和 `reviewer` 一起处理”
|
||||
3. “如果有现成 agent 就不要新建 subagent”
|
||||
|
||||
这时主 agent 应调用:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "...",
|
||||
"target": "github-reviewer"
|
||||
}
|
||||
```
|
||||
|
||||
或者:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "...",
|
||||
"targets": ["researcher", "reviewer"],
|
||||
"strategy": "group"
|
||||
}
|
||||
```
|
||||
|
||||
## 9.2 模型自主触发
|
||||
|
||||
当主 agent 判断:
|
||||
|
||||
1. 任务独立可并行
|
||||
2. 已有 agent 专长明显更匹配
|
||||
3. 任务耗时长,适合后台执行
|
||||
|
||||
则调用 `spawn`,但不再默认认为一定是“新建本地 subagent”。
|
||||
|
||||
## 9.3 自动回退
|
||||
|
||||
如果没有找到匹配 agent:
|
||||
|
||||
1. `strategy=auto` -> fallback 本地 subagent
|
||||
2. `strategy=a2a` -> 直接返回未找到
|
||||
3. `strategy=group` 且部分目标不存在 -> 明确报错或只跑已解析目标,建议第一版严格报错
|
||||
|
||||
## 10. workspace 中“已添加 agent”的建议存储
|
||||
|
||||
建议新增:
|
||||
|
||||
- `workspace/agents/registry.json`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "github-reviewer",
|
||||
"name": "GitHub Reviewer",
|
||||
"description": "Review GitHub repository changes",
|
||||
"protocol": "a2a",
|
||||
"base_url": "https://reviewer.example.com/a2a",
|
||||
"card_url": "https://reviewer.example.com/.well-known/agent-card",
|
||||
"auth_env": "GITHUB_REVIEWER_TOKEN",
|
||||
"enabled": true,
|
||||
"tags": ["github", "review"]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
为什么不用直接塞进 `config.json`:
|
||||
|
||||
1. 这是 workspace 维度资源,不是全局运行参数
|
||||
2. web API 做增删改查更方便
|
||||
3. 不要求用户每次改 agent 都改配置再重启
|
||||
|
||||
## 11. 推荐实施顺序
|
||||
|
||||
### Phase 1: 打通单 agent 路由
|
||||
|
||||
目标:
|
||||
|
||||
1. 引入 `AgentRegistry`
|
||||
2. `spawn` 支持 `target`
|
||||
3. 支持 workspace agent 和 skill agent card
|
||||
4. 支持 A2A 单点调用
|
||||
5. 找不到时 fallback 本地 subagent
|
||||
|
||||
### Phase 2: 接入 plugin agent 本地执行
|
||||
|
||||
目标:
|
||||
|
||||
1. plugin agent 进入统一 registry
|
||||
2. plugin agent 可作为 `target`
|
||||
3. 本地 prompt-based agent 与 A2A remote agent 共存
|
||||
|
||||
### Phase 3: group 并发和聚合
|
||||
|
||||
目标:
|
||||
|
||||
1. `targets=[...]`
|
||||
2. 并发执行
|
||||
3. group 级状态跟踪
|
||||
4. 聚合后回投主 agent
|
||||
|
||||
### Phase 4: web 管理接口
|
||||
|
||||
目标:
|
||||
|
||||
1. `/api/agents`
|
||||
2. 添加 / 删除 / 刷新 agent
|
||||
3. 前端展示 unified registry
|
||||
|
||||
## 12. 兼容性要求
|
||||
|
||||
这次改造一定要保留以下兼容性:
|
||||
|
||||
1. 旧的 `spawn(task, label)` 调用仍然可用
|
||||
2. 没有 A2A agent 时,行为和现在一致
|
||||
3. skill 没写 `agent_cards` 时,skill 仍只是普通 skill
|
||||
4. plugin agent 不参与调度时,现有 plugin 机制不受影响
|
||||
|
||||
## 13. 风险点
|
||||
|
||||
### 13.1 A2A 规范新旧写法并存
|
||||
|
||||
从当前公开文档和样例看,存在这些并行写法:
|
||||
|
||||
1. card 路径: `/.well-known/agent-card` 和 `/.well-known/agent.json`
|
||||
2. RPC 方法: `tasks/send` 和 `message/send`
|
||||
|
||||
所以客户端必须做兼容适配,不能写死一种。
|
||||
|
||||
### 13.2 外部 agent 的安全边界
|
||||
|
||||
需要限制:
|
||||
|
||||
1. 白名单 host
|
||||
2. 超时
|
||||
3. card cache TTL
|
||||
4. skill card 是否允许自动启用
|
||||
|
||||
### 13.3 远端 agent 无法直接访问本地 workspace
|
||||
|
||||
这意味着:
|
||||
|
||||
1. 不能把“去读本地文件然后处理”原样发给远端 A2A agent
|
||||
2. 主 agent 需要先整理出必要上下文
|
||||
3. 第一版最好只做文本级委派
|
||||
|
||||
## 14. 我建议的落地结论
|
||||
|
||||
如果要控制改动面,又要保证后续可扩展,推荐最终采用下面这个结构:
|
||||
|
||||
```text
|
||||
AgentLoop
|
||||
-> SpawnTool
|
||||
-> DelegationManager
|
||||
-> AgentRegistry / AgentResolver
|
||||
-> LocalSubagentExecutor
|
||||
-> PluginAgentExecutor
|
||||
-> A2AExecutor
|
||||
-> AgentGroupExecutor
|
||||
-> announce_result() -> MessageBus(system) -> AgentLoop -> user
|
||||
```
|
||||
|
||||
也就是说:
|
||||
|
||||
1. `spawn` 工具保留
|
||||
2. `SubagentManager` 不再是唯一执行器
|
||||
3. `DelegationManager` 成为真正总入口
|
||||
4. skills 里的 `agent_cards` 用 frontmatter metadata 承载
|
||||
5. workspace agent 单独持久化
|
||||
6. group 通过并发 executor + 汇总消息实现
|
||||
|
||||
这是当前仓库里最稳妥、最符合现有架构的改法。
|
||||
|
||||
## 15. 外部参考
|
||||
|
||||
以下是我写这个方案时核对的 A2A 资料:
|
||||
|
||||
1. A2A Protocol Development Guide: https://a2aprotocol.ai/docs/guide/a2a-typescript-guide.html
|
||||
2. Python A2A Tutorial: https://a2aprotocol.ai/docs/guide/python-a2a-tutorial-20250513
|
||||
|
||||
注意:
|
||||
|
||||
1. 当前公开文档里既能看到 `tasks/send`,也能看到 `message/send`
|
||||
2. agent card 路径也能看到 `agent-card` 与 `agent.json` 两种写法
|
||||
3. 所以实现时建议做兼容层,不要只押一种命名
|
||||
@ -1,5 +0,0 @@
|
||||
We provide QR codes for joining the HKUDS discussion groups on **WeChat** and **Feishu**.
|
||||
|
||||
You can join by scanning the QR codes below:
|
||||
|
||||
<img src="https://github.com/HKUDS/.github/blob/main/profile/QR.png" alt="WeChat QR Code" width="400"/>
|
||||
@ -1,43 +0,0 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
apt-get purge -y gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies first (cached layer)
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
RUN mkdir -p nanobot bridge && touch nanobot/__init__.py && \
|
||||
uv pip install --system --no-cache . && \
|
||||
rm -rf nanobot bridge
|
||||
|
||||
# Copy the full source and install
|
||||
COPY nanobot/ nanobot/
|
||||
COPY bridge/ bridge/
|
||||
COPY third_party/swarms/ third_party/swarms/
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Build the WhatsApp bridge
|
||||
WORKDIR /app/bridge
|
||||
RUN git config --global url."https://github.com/".insteadOf "ssh://git@github.com/" && \
|
||||
git config --global url."https://github.com/".insteadOf "git@github.com:" && \
|
||||
npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@ -1,470 +1,34 @@
|
||||
# Boardware Genius Backend
|
||||
# Beaver Backend
|
||||
|
||||
这是 `Boardware Genius` 的后端服务仓库;当前技术命令和包名仍沿用 `nanobot`,但产品品牌按 `Boardware Genius` 表述:
|
||||
这是新的 `Beaver` 后端代码骨架。
|
||||
|
||||
- `nanobot web`:单用户 FastAPI 后端,供独立前端或 `/docs` 调试使用
|
||||
- `nanobot gateway`:常驻 worker,负责渠道接入、cron、heartbeat
|
||||
- MCP 动态工具接入
|
||||
- Outlook 集成:通过外部 `BW_Outlook_Mcp` 服务接入 Microsoft Graph / Exchange EWS
|
||||
- 工作区文件、技能、插件、代理、MCP 管理等 Web API
|
||||
旧实现已保留在 [backend-old](/home/ivan/xuan/nano_project/app-instance/backend-old),新目录用于按 [change.md](/home/ivan/xuan/nano_project/app-instance/backend/change.md) 的蓝图逐步重建后端。
|
||||
|
||||
如果你后续要把它打包成 Docker 丢到服务器,这份 README 就是给开发和部署同事看的执行文档。
|
||||
当前阶段目标:
|
||||
|
||||
## 这套仓库现在是什么
|
||||
1. 先建立新的目录边界和包结构。
|
||||
2. 明确 `beaver` 作为统一命名。
|
||||
3. 以统一 `engine` 为核心,后续让所有 agent 共享同一套运行内核。
|
||||
|
||||
这不是一个自带前端静态页面的全栈仓库,而是后端仓库:
|
||||
## 当前结构
|
||||
|
||||
- Web 模式启动的是 FastAPI API 服务
|
||||
- Gateway 模式启动的是常驻 agent / channel / cron 进程
|
||||
- WhatsApp 相关逻辑依赖 `bridge/` 里的 Node 20 bridge
|
||||
- Outlook 不是仓库内置模块,而是通过外部 `BW_Outlook_Mcp` 仓库接进来
|
||||
- `beaver/foundation`:底层公共设施
|
||||
- `beaver/engine`:统一 agent 内核
|
||||
- `beaver/coordinator`:多 agent 协调层
|
||||
- `beaver/tools`:工具系统
|
||||
- `beaver/skills`:技能系统
|
||||
- `beaver/memory`:记忆与经验沉淀
|
||||
- `beaver/permissions`:权限与治理
|
||||
- `beaver/services`:应用服务层
|
||||
- `beaver/interfaces`:CLI / Web / Gateway / Channels 薄入口
|
||||
- `beaver/integrations`:外部系统与协议集成
|
||||
|
||||
更细的执行链路可以看 [workflow.md](./workflow.md)。
|
||||
## 说明
|
||||
|
||||
## 目录结构
|
||||
这个目录当前还是第一版骨架,不等于完成迁移。
|
||||
|
||||
```text
|
||||
.
|
||||
├── nanobot/ # Python 主体:CLI、agent、web、channels、config、MCP
|
||||
├── bridge/ # WhatsApp bridge(Node 20)
|
||||
├── tests/ # 测试
|
||||
├── Dockerfile # 当前镜像构建文件
|
||||
├── docker-compose.yml # 当前自带 compose 示例(偏 gateway / CLI)
|
||||
└── workflow.md # 运行链路说明
|
||||
```
|
||||
后续迁移原则:
|
||||
|
||||
## 运行模式
|
||||
|
||||
| 命令 | 用途 | 默认端口 | 适合谁 |
|
||||
| --- | --- | --- | --- |
|
||||
| `nanobot agent` | 本地单轮 / 交互调试 | 无 | 开发排查 |
|
||||
| `nanobot web` | 启动 FastAPI 后端 | `18080` | 独立前端、接口调试、单用户使用 |
|
||||
| `nanobot gateway` | 启动常驻 worker | 无固定 HTTP 入口 | Telegram/Slack/Email/cron/heartbeat |
|
||||
| `nanobot status` | 查看配置和 provider 状态 | 无 | 开发、运维 |
|
||||
|
||||
注意:
|
||||
|
||||
- 如果你是给 Web 前端提供后端,请启动 `nanobot web`,不要误用 `gateway`
|
||||
- `gateway` 当前不是对外 Web API 服务
|
||||
- `web` 和 `gateway` 都会碰到同一份 workspace / cron / MCP 状态,通常不要在同一份数据目录上无脑同时跑两套
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python `>=3.11`
|
||||
- 推荐使用 `uv`
|
||||
- 如果要构建 WhatsApp bridge 或使用仓库自带 Dockerfile,需要 Node.js `20`
|
||||
|
||||
本地开发最省事的方式:
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
```
|
||||
|
||||
如果你不用 `uv`,也可以:
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
. .venv/bin/activate
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 本地快速启动
|
||||
|
||||
### 1. 初始化配置
|
||||
|
||||
```bash
|
||||
nanobot onboard
|
||||
```
|
||||
|
||||
初始化后默认会生成:
|
||||
|
||||
- 配置文件:`~/.nanobot/config.json`
|
||||
- 工作区:`~/.nanobot/workspace`
|
||||
|
||||
### 2. 填最小配置
|
||||
|
||||
下面是一份适合服务器环境的最小示例,重点是:
|
||||
|
||||
- 用绝对路径的 workspace
|
||||
- 建议打开 `restrictToWorkspace`
|
||||
- 先用 API Key provider,少踩 OAuth 交互坑
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "/root/.nanobot/workspace",
|
||||
"model": "openai/gpt-5"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openai": {
|
||||
"apiKey": "sk-xxxx"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"restrictToWorkspace": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
如果你不是跑在容器里,把 `/root/.nanobot/workspace` 换成你自己的绝对路径。
|
||||
|
||||
### 3. 检查配置
|
||||
|
||||
```bash
|
||||
nanobot status
|
||||
```
|
||||
|
||||
### 4. 本地调试 agent
|
||||
|
||||
```bash
|
||||
nanobot agent -m "你好"
|
||||
```
|
||||
|
||||
### 5. 启动 Web 后端
|
||||
|
||||
```bash
|
||||
nanobot web --host 0.0.0.0 --port 18080
|
||||
```
|
||||
|
||||
启动后可直接访问:
|
||||
|
||||
- `http://127.0.0.1:18080/docs`
|
||||
- `http://127.0.0.1:18080/api/ping`
|
||||
|
||||
## Web API 能力概览
|
||||
|
||||
当前 `nanobot web` 提供的 API 大致包括:
|
||||
|
||||
- 聊天与流式输出
|
||||
- 会话管理
|
||||
- cron 任务管理
|
||||
- skills / plugins / agents 管理
|
||||
- 工作区文件浏览、上传、下载、删除
|
||||
- MCP server 管理与测试
|
||||
- Outlook 集成状态、连接测试、连接/断开、Overview、Message Detail
|
||||
|
||||
如果你有独立前端,这个后端就是给前端接的;如果没有前端,也可以直接走 `/docs` 调试。
|
||||
|
||||
## Outlook MCP 集成
|
||||
|
||||
这是当前仓库里最容易部署时踩坑的一块。
|
||||
|
||||
### 关系先说清楚
|
||||
|
||||
当前后端不会自己实现 Outlook 协议,它依赖外部仓库 `BW_Outlook_Mcp`:
|
||||
|
||||
- 后端代码位置:`nanobot/web/outlook.py`
|
||||
- 默认查找逻辑:
|
||||
1. 先看环境变量 `NANOBOT_OUTLOOK_MCP_ROOT`
|
||||
2. 再看与本仓库同级目录的 `../BW_Outlook_Mcp`
|
||||
3. 如果以上都没有,就尝试直接执行 PATH 里的 `bw-outlook-mcp`
|
||||
|
||||
也就是说,部署同事必须额外把 `BW_Outlook_Mcp` 这个仓库准备好,或者把它直接安装进镜像。
|
||||
|
||||
### 推荐的两种接法
|
||||
|
||||
#### 方案 A:把 `BW_Outlook_Mcp` 安装进同一个 Python 环境
|
||||
|
||||
这是生产环境更稳的方案。
|
||||
|
||||
部署同事需要:
|
||||
|
||||
```bash
|
||||
git clone <你们的 BW_Outlook_Mcp 仓库地址> /srv/BW_Outlook_Mcp
|
||||
cd /srv/BW_Outlook_Mcp
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
安装完成后,容器或宿主机里能直接执行:
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp --help
|
||||
```
|
||||
|
||||
这样 Boardware Genius 就会直接用 PATH 里的 `bw-outlook-mcp`,不依赖额外挂载路径。
|
||||
|
||||
#### 方案 B:把 `BW_Outlook_Mcp` 作为外部目录挂进来
|
||||
|
||||
这是开发或临时部署更方便的方案。
|
||||
|
||||
部署同事需要至少做到两件事:
|
||||
|
||||
1. 把 `BW_Outlook_Mcp` 仓库拉到服务器
|
||||
2. 让这个目录里存在一个可执行的 `bw-outlook-mcp`
|
||||
|
||||
最简单的约定是:
|
||||
|
||||
```bash
|
||||
git clone <你们的 BW_Outlook_Mcp 仓库地址> /srv/BW_Outlook_Mcp
|
||||
cd /srv/BW_Outlook_Mcp
|
||||
python3 -m venv .venv
|
||||
. .venv/bin/activate
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
然后给 Boardware Genius 设置:
|
||||
|
||||
```bash
|
||||
export NANOBOT_OUTLOOK_MCP_ROOT=/srv/BW_Outlook_Mcp
|
||||
```
|
||||
|
||||
因为当前后端会优先寻找:
|
||||
|
||||
```text
|
||||
$NANOBOT_OUTLOOK_MCP_ROOT/.venv/bin/bw-outlook-mcp
|
||||
```
|
||||
|
||||
如果你挂了仓库目录但里面没有 `.venv/bin/bw-outlook-mcp`,那就必须确保 `bw-outlook-mcp` 已经在容器 PATH 里。
|
||||
|
||||
### Outlook 的认证和配置
|
||||
|
||||
`BW_Outlook_Mcp` 本身支持两套后端:
|
||||
|
||||
- `graph`:Microsoft 365 / Exchange Online
|
||||
- `ews`:本地或回迁后的 Exchange Server
|
||||
|
||||
#### Graph 登录
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth login-graph \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--client-id YOUR_CLIENT_ID \
|
||||
--tenant-id YOUR_TENANT_ID
|
||||
```
|
||||
|
||||
#### EWS 配置
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth setup-ews \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--email you@example.com \
|
||||
--username your_username \
|
||||
--domain example.com \
|
||||
--server mail.example.com
|
||||
```
|
||||
|
||||
如果你已经有固定 EWS URL,也可以改用:
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth setup-ews \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--email you@example.com \
|
||||
--username your_username \
|
||||
--service-endpoint https://mail.example.com/EWS/Exchange.asmx
|
||||
```
|
||||
|
||||
#### 查看状态
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth status --workspace /root/.nanobot/workspace
|
||||
```
|
||||
|
||||
### Outlook 状态文件会落在哪里
|
||||
|
||||
所有 Outlook 相关状态默认都落在 workspace 下:
|
||||
|
||||
```text
|
||||
<workspace>/state/bw_outlook_mcp/
|
||||
├── config.json
|
||||
├── secrets.json
|
||||
├── graph_token_cache.bin
|
||||
├── delta_store.json
|
||||
└── idempotency.sqlite3
|
||||
```
|
||||
|
||||
所以 Docker 部署时,不要只挂配置文件;要把整份 `~/.nanobot` 或至少 workspace 做持久化。
|
||||
|
||||
### Nanobot 里如何注册 Outlook MCP
|
||||
|
||||
如果你通过 Web 接口完成 Outlook 连接,后端会自动把 MCP server 注册到配置里。
|
||||
|
||||
手工写配置时,结构类似这样:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcpServers": {
|
||||
"outlook": {
|
||||
"command": "bw-outlook-mcp",
|
||||
"args": ["serve", "--workspace", "/root/.nanobot/workspace"],
|
||||
"sensitive": true,
|
||||
"toolTimeout": 60
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
这里一定要用绝对路径,不要写 `~/.nanobot/workspace`。
|
||||
|
||||
### 可选的 Outlook 环境变量
|
||||
|
||||
| 变量 | 作用 |
|
||||
| --- | --- |
|
||||
| `NANOBOT_OUTLOOK_MCP_ROOT` | 指向外部 `BW_Outlook_Mcp` 仓库目录 |
|
||||
| `NANOBOT_OUTLOOK_MCP_COMMAND` | 强制指定 `bw-outlook-mcp` 可执行文件 |
|
||||
| `NANOBOT_OUTLOOK_MCP_EXTRA_ARGS` | 给 `bw-outlook-mcp serve` 追加参数 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_DOMAIN` | Web 连接表单的默认域名 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_EWS_URL` | Web 连接表单默认 EWS 地址 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_EWS_SERVER` | Web 连接表单默认 Exchange 主机 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_TIMEZONE` | Web 连接表单默认时区 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_AUTODISCOVER` | Web 连接表单默认是否启用 autodiscover |
|
||||
|
||||
## Docker 部署
|
||||
|
||||
### 先说结论
|
||||
|
||||
服务器部署时,最重要的是持久化这份目录:
|
||||
|
||||
```text
|
||||
/root/.nanobot
|
||||
```
|
||||
|
||||
因为它里面不只是 `config.json`,还包括:
|
||||
|
||||
- workspace
|
||||
- sessions
|
||||
- cron 状态
|
||||
- Web 登录信息
|
||||
- Outlook 状态与 token 缓存
|
||||
|
||||
### 构建镜像
|
||||
|
||||
```bash
|
||||
docker build -t nanobot-backend:latest .
|
||||
```
|
||||
|
||||
### 首次初始化
|
||||
|
||||
第一次跑容器时,先执行一次:
|
||||
|
||||
```bash
|
||||
docker run --rm \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
onboard
|
||||
```
|
||||
|
||||
然后去编辑宿主机上的:
|
||||
|
||||
```text
|
||||
/srv/nanobot/data/config.json
|
||||
```
|
||||
|
||||
或者先进去执行:
|
||||
|
||||
```bash
|
||||
docker run --rm -it \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
status
|
||||
```
|
||||
|
||||
### 作为 Web 后端启动
|
||||
|
||||
如果你是给前端项目配后端,推荐这样跑:
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--name nanobot-web \
|
||||
-p 18080:18080 \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
-e NANOBOT_OUTLOOK_MCP_ROOT=/opt/BW_Outlook_Mcp \
|
||||
-v /srv/BW_Outlook_Mcp:/opt/BW_Outlook_Mcp \
|
||||
nanobot-backend:latest \
|
||||
web --host 0.0.0.0 --port 18080
|
||||
```
|
||||
|
||||
如果你已经把 `bw-outlook-mcp` 安装进镜像了,就不需要挂 `/srv/BW_Outlook_Mcp`,也不需要 `NANOBOT_OUTLOOK_MCP_ROOT`。
|
||||
|
||||
### 作为 Gateway/Worker 启动
|
||||
|
||||
如果你要接 Telegram / Slack / Email / cron 之类的常驻能力,再跑 gateway:
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--name nanobot-gateway \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
gateway
|
||||
```
|
||||
|
||||
### 推荐的服务器 compose 片段
|
||||
|
||||
仓库自带的 [docker-compose.yml](./docker-compose.yml) 更偏本地 gateway/CLI 示例。
|
||||
如果你是部署 Web 后端到服务器,更建议单独写成这样:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
nanobot-web:
|
||||
image: nanobot-backend:latest
|
||||
container_name: nanobot-web
|
||||
command: ["web", "--host", "0.0.0.0", "--port", "18080"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18080:18080"
|
||||
volumes:
|
||||
- /srv/nanobot/data:/root/.nanobot
|
||||
- /srv/BW_Outlook_Mcp:/opt/BW_Outlook_Mcp
|
||||
environment:
|
||||
NANOBOT_OUTLOOK_MCP_ROOT: /opt/BW_Outlook_Mcp
|
||||
```
|
||||
|
||||
如果你想把 Outlook 依赖做得更稳,推荐直接把 `BW_Outlook_Mcp` 安装进镜像,而不是运行时挂载仓库。
|
||||
|
||||
## 部署给同事时,至少要交代这几件事
|
||||
|
||||
1. 这是后端仓库,不带前端静态页面,前端请单独部署
|
||||
2. Web API 用 `nanobot web` 启动,不是 `gateway`
|
||||
3. 数据目录必须持久化到 `/root/.nanobot`
|
||||
4. 如果要 Outlook,必须额外拉取 `BW_Outlook_Mcp`
|
||||
5. Outlook 有两种接法:装进镜像,或者挂外部仓库并设置 `NANOBOT_OUTLOOK_MCP_ROOT`
|
||||
6. Outlook 的状态文件也在 workspace 里,删容器不挂卷就会丢
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
nanobot onboard
|
||||
nanobot status
|
||||
nanobot agent -m "你好"
|
||||
nanobot web --host 0.0.0.0 --port 18080
|
||||
nanobot gateway
|
||||
nanobot provider login openai-codex
|
||||
```
|
||||
|
||||
## 开发备注
|
||||
|
||||
- `workflow.md` 记录了当前代码实际运行链路,和旧版 README 更接近“真实代码”
|
||||
- `nanobot/web/outlook.py` 是当前 Outlook 集成入口
|
||||
- `tests/` 里有 Web API、Email、Docker 相关测试
|
||||
- 如果要上服务器,建议在配置里显式打开 `tools.restrictToWorkspace=true`
|
||||
|
||||
## 排错
|
||||
|
||||
### Web 启动了,但 Outlook 相关接口报错
|
||||
|
||||
优先检查:
|
||||
|
||||
- `bw-outlook-mcp` 是否能在当前容器里执行
|
||||
- `NANOBOT_OUTLOOK_MCP_ROOT` 是否指向正确目录
|
||||
- 如果走目录挂载模式,目录里是否真的有 `.venv/bin/bw-outlook-mcp`
|
||||
|
||||
### MCP 注册了,但工具没有出现
|
||||
|
||||
检查:
|
||||
|
||||
- `config.json` 里的 `tools.mcpServers`
|
||||
- `nanobot web` 或 `nanobot agent` 启动时是否用了同一份 `~/.nanobot`
|
||||
- Outlook MCP 是否能单独执行 `bw-outlook-mcp auth status --workspace ...`
|
||||
|
||||
### Docker 里配置改了没生效
|
||||
|
||||
优先检查你挂载的是不是整份:
|
||||
|
||||
```text
|
||||
/srv/nanobot/data:/root/.nanobot
|
||||
```
|
||||
|
||||
不是只挂了某一个文件。
|
||||
1. 不再新增 `nanobot` 命名。
|
||||
2. 不在新目录中保留 `third_party/`。
|
||||
3. 所有 agent 最终都复用 `beaver.engine`。
|
||||
|
||||
@ -1,264 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in Boardware Genius, please report it by:
|
||||
|
||||
1. **DO NOT** open a public GitHub issue
|
||||
2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
|
||||
3. Include:
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Potential impact
|
||||
- Suggested fix (if any)
|
||||
|
||||
We aim to respond to security reports within 48 hours.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. API Key Management
|
||||
|
||||
**CRITICAL**: Never commit API keys to version control.
|
||||
|
||||
```bash
|
||||
# ✅ Good: Store in config file with restricted permissions
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
|
||||
# ❌ Bad: Hardcoding keys in code or committing them
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- Store API keys in `~/.nanobot/config.json` with file permissions set to `0600`
|
||||
- Consider using environment variables for sensitive keys
|
||||
- Use OS keyring/credential manager for production deployments
|
||||
- Rotate API keys regularly
|
||||
- Use separate API keys for development and production
|
||||
|
||||
### 2. Channel Access Control
|
||||
|
||||
**IMPORTANT**: Always configure `allowFrom` lists for production use.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["123456789", "987654321"]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"allowFrom": ["+1234567890"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use)
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
|
||||
### 3. Shell Command Execution
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
- ✅ Never run Boardware Genius as root
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
- Filesystem formatting (`mkfs.*`)
|
||||
- Raw disk writes
|
||||
- Other destructive operations
|
||||
|
||||
### 4. File System Access
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Run Boardware Genius with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
- ❌ Don't give unrestricted access to sensitive files
|
||||
|
||||
### 5. Network Security
|
||||
|
||||
**API Calls:**
|
||||
- All external API calls use HTTPS by default
|
||||
- Timeouts are configured to prevent hanging requests
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
|
||||
**Critical**: Keep dependencies updated!
|
||||
|
||||
```bash
|
||||
# Check for vulnerable dependencies
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# Update to latest secure versions
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
For Node.js dependencies (WhatsApp bridge):
|
||||
```bash
|
||||
cd bridge
|
||||
npm audit
|
||||
npm audit fix
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Keep `litellm` updated to the latest version for security fixes
|
||||
- We've updated `ws` to `>=8.17.1` to fix DoS vulnerability
|
||||
- Run `pip-audit` or `npm audit` regularly
|
||||
- Subscribe to security advisories for Boardware Genius and its dependencies
|
||||
|
||||
### 7. Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Isolate the Environment**
|
||||
```bash
|
||||
# Run in a container or VM
|
||||
docker run --rm -it python:3.11
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
2. **Use a Dedicated User**
|
||||
```bash
|
||||
sudo useradd -m -s /bin/bash nanobot
|
||||
sudo -u nanobot nanobot gateway
|
||||
```
|
||||
|
||||
3. **Set Proper Permissions**
|
||||
```bash
|
||||
chmod 700 ~/.nanobot
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
chmod 700 ~/.nanobot/whatsapp-auth
|
||||
```
|
||||
|
||||
4. **Enable Logging**
|
||||
```bash
|
||||
# Configure log monitoring
|
||||
tail -f ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
|
||||
5. **Use Rate Limiting**
|
||||
- Configure rate limits on your API providers
|
||||
- Monitor usage for anomalies
|
||||
- Set spending limits on LLM APIs
|
||||
|
||||
6. **Regular Updates**
|
||||
```bash
|
||||
# Check for updates weekly
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
### 8. Development vs Production
|
||||
|
||||
**Development:**
|
||||
- Use separate API keys
|
||||
- Test with non-sensitive data
|
||||
- Enable verbose logging
|
||||
- Use a test Telegram bot
|
||||
|
||||
**Production:**
|
||||
- Use dedicated API keys with spending limits
|
||||
- Restrict file system access
|
||||
- Enable audit logging
|
||||
- Regular security reviews
|
||||
- Monitor for unusual activity
|
||||
|
||||
### 9. Data Privacy
|
||||
|
||||
- **Logs may contain sensitive information** - secure log files appropriately
|
||||
- **LLM providers see your prompts** - review their privacy policies
|
||||
- **Chat history is stored locally** - protect the `~/.nanobot` directory
|
||||
- **API keys are in plain text** - use OS keyring for production
|
||||
|
||||
### 10. Incident Response
|
||||
|
||||
If you suspect a security breach:
|
||||
|
||||
1. **Immediately revoke compromised API keys**
|
||||
2. **Review logs for unauthorized access**
|
||||
```bash
|
||||
grep "Access denied" ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
3. **Check for unexpected file modifications**
|
||||
4. **Rotate all credentials**
|
||||
5. **Update to latest version**
|
||||
6. **Report the incident** to maintainers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Built-in Security Controls
|
||||
|
||||
✅ **Input Validation**
|
||||
- Path traversal protection on file operations
|
||||
- Dangerous command pattern detection
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control
|
||||
- Failed authentication attempt logging
|
||||
- Open by default (configure allowFrom for production use)
|
||||
|
||||
✅ **Resource Protection**
|
||||
- Command execution timeouts (60s default)
|
||||
- Output truncation (10KB limit)
|
||||
- HTTP request timeouts (10-30s)
|
||||
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
⚠️ **Current Security Limitations:**
|
||||
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
|
||||
Before deploying Boardware Genius:
|
||||
|
||||
- [ ] API keys stored securely (not in code)
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
- [ ] Rate limits configured on API providers
|
||||
- [ ] Backup and disaster recovery plan in place
|
||||
- [ ] Security review of custom skills/tools
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
- Release Notes: https://github.com/HKUDS/nanobot/releases
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
6
app-instance/backend/beaver/__init__.py
Normal file
6
app-instance/backend/beaver/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Beaver backend package."""
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
2
app-instance/backend/beaver/coordinator/__init__.py
Normal file
2
app-instance/backend/beaver/coordinator/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Multi-agent coordination layer."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Pluggable multi-agent backends."""
|
||||
|
||||
20
app-instance/backend/beaver/coordinator/backends/base.py
Normal file
20
app-instance/backend/beaver/coordinator/backends/base.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Backend interfaces for multi-agent execution."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BackendResult:
|
||||
"""Normalized result returned by a coordination backend."""
|
||||
|
||||
success: bool
|
||||
summary: str
|
||||
|
||||
|
||||
class CoordinationBackend(Protocol):
|
||||
"""Protocol implemented by pluggable coordination backends."""
|
||||
|
||||
def run(self, task: str) -> BackendResult:
|
||||
"""Execute a team task and return a normalized result."""
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
"""Swarms backend wrapper for Beaver.
|
||||
|
||||
This package is intentionally local to Beaver's coordinator layer.
|
||||
There is no `third_party/` directory in the new backend layout.
|
||||
"""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Delegation orchestration."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Execution control, retry, and aggregation."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Team planning and execution-plan generation."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Agent registry and descriptors."""
|
||||
|
||||
2
app-instance/backend/beaver/coordinator/team/__init__.py
Normal file
2
app-instance/backend/beaver/coordinator/team/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Team models and orchestration objects."""
|
||||
|
||||
31
app-instance/backend/beaver/engine/__init__.py
Normal file
31
app-instance/backend/beaver/engine/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Unified Beaver agent engine.
|
||||
|
||||
这里不做顶层 eager import,避免子模块导入时触发循环依赖。
|
||||
对外仍然保留同样的导出名称,但改成按需加载。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AgentLoop", "AgentProfile", "AgentRunResult", "EngineLoader", "EngineLoadResult"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "EngineLoader":
|
||||
from .loader import EngineLoader
|
||||
|
||||
return EngineLoader
|
||||
if name == "EngineLoadResult":
|
||||
from .loader import EngineLoadResult
|
||||
|
||||
return EngineLoadResult
|
||||
if name in {"AgentLoop", "AgentProfile", "AgentRunResult"}:
|
||||
from .loop import AgentLoop, AgentProfile, AgentRunResult
|
||||
|
||||
return {
|
||||
"AgentLoop": AgentLoop,
|
||||
"AgentProfile": AgentProfile,
|
||||
"AgentRunResult": AgentRunResult,
|
||||
}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
17
app-instance/backend/beaver/engine/context/__init__.py
Normal file
17
app-instance/backend/beaver/engine/context/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Context assembly for agent runs."""
|
||||
|
||||
from .builder import (
|
||||
ContextBuildInput,
|
||||
ContextBuildResult,
|
||||
ContextBuilder,
|
||||
SessionContext,
|
||||
SkillContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextBuildInput",
|
||||
"ContextBuildResult",
|
||||
"ContextBuilder",
|
||||
"SessionContext",
|
||||
"SkillContext",
|
||||
]
|
||||
331
app-instance/backend/beaver/engine/context/builder.py
Normal file
331
app-instance/backend/beaver/engine/context/builder.py
Normal file
@ -0,0 +1,331 @@
|
||||
"""Beaver 运行时上下文装配器。
|
||||
|
||||
这个模块是 `session` 和 `provider` 之间的中间层,职责非常明确:
|
||||
|
||||
1. 把运行前已经准备好的静态/半静态上下文拼成一份稳定的 system prompt
|
||||
2. 把从 session 事件流里裁剪出的“可见历史”和当前用户输入整理成 provider 可直接消费的 messages
|
||||
3. 在 tool loop 中,持续把 assistant/tool 消息按统一格式追加回消息数组
|
||||
|
||||
为什么这层必须单独存在:
|
||||
|
||||
1. `AgentLoop` 不应该自己拼 prompt,否则很快又会长成一个大文件
|
||||
2. `memory`、`skills`、`session` 的注入顺序需要固定,否则模型行为会漂移
|
||||
3. tool loop 前后追加消息的格式必须统一,否则不同 provider 很容易出兼容问题
|
||||
|
||||
这一版 builder 的设计目标是“最小但稳定”:
|
||||
|
||||
1. 先服务单 agent 主链
|
||||
2. 先支持 frozen curated memory,而不是 live memory
|
||||
3. skills 按 Hermes 风格支持“显式激活消息注入”,不在这里做磁盘扫描
|
||||
4. 为后续 channel / gateway / team metadata 预留注入位,但不提前做复杂逻辑
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SkillContext:
|
||||
"""单个已激活 skill 的最小表示。
|
||||
|
||||
这里故意不把 skill 设计成复杂对象,只保留 builder 真正关心的两部分:
|
||||
|
||||
- `name`:用于生成激活提示
|
||||
- `content`:skill 的完整正文
|
||||
|
||||
注意:按当前 Hermes 风格实现,skill 正文不再塞进 system prompt,而是转成显式消息注入。
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionContext:
|
||||
"""当前运行轮次的会话元数据。
|
||||
|
||||
这不是 session store 里的完整 record,而是 prompt builder 关心的那一小部分:
|
||||
- 哪个 session
|
||||
- 来源是什么
|
||||
- 当前使用什么 model
|
||||
- 是否有 channel/chat/user 这类运行路由信息
|
||||
|
||||
把它单独抽出来的原因是:
|
||||
1. builder 不应该知道 SQLite row 长什么样
|
||||
2. 不同入口(CLI/Web/Gateway)都可以把自己的 metadata 收敛成同一种结构
|
||||
"""
|
||||
|
||||
session_id: str | None = None
|
||||
source: str | None = None
|
||||
model: str | None = None
|
||||
user_id: str | None = None
|
||||
channel: str | None = None
|
||||
chat_id: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContextBuildInput:
|
||||
"""一次上下文构建所需的全部输入。
|
||||
|
||||
这个对象的作用不是“炫技式封装”,而是把主链里零散的数据显式收口。
|
||||
这样一来,后面 `AgentLoop.process_direct()` 在组装参数时会更清晰,也更容易测试。
|
||||
|
||||
字段分组:
|
||||
- 身份/基础段:`base_system_prompt`
|
||||
- 会话可见历史:`history`
|
||||
- 当前输入:`current_user_input`
|
||||
- 冻结记忆:`memory_snapshot`
|
||||
- 技能:`activated_skills`
|
||||
- 运行元数据:`session_context` / `execution_context`
|
||||
- 额外扩展:`extra_sections`
|
||||
"""
|
||||
|
||||
base_system_prompt: str = ""
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
current_user_input: str | list[dict[str, Any]] | None = None
|
||||
memory_snapshot: MemorySnapshot | None = None
|
||||
activated_skills: list[SkillContext] = field(default_factory=list)
|
||||
session_context: SessionContext | None = None
|
||||
execution_context: str | None = None
|
||||
extra_sections: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContextBuildResult:
|
||||
"""一次上下文构建后的结果。
|
||||
|
||||
保留 `system_prompt` 的原因:
|
||||
1. `SessionManager.update_system_prompt()` 需要把最终注入的 prompt snapshot 落盘
|
||||
2. 调试时经常需要区分“system prompt 长什么样”和“messages 长什么样”
|
||||
3. 后面如果做 prompt audit / replay,也会直接复用这个结果
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""负责把运行时输入装配成稳定上下文。
|
||||
|
||||
这一层故意保持“无 IO、无数据库、无网络”:
|
||||
- 不直接读 session store
|
||||
- 不直接读 memory store
|
||||
- 不直接扫描 skills 目录
|
||||
|
||||
这样 builder 的行为只由输入决定,便于单测,也便于后面并到真正的 AgentLoop 主链里。
|
||||
"""
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
build_input: ContextBuildInput,
|
||||
) -> str:
|
||||
"""构建 system prompt。
|
||||
|
||||
顺序固定非常重要,当前约定是:
|
||||
|
||||
1. base system prompt
|
||||
2. session metadata
|
||||
3. execution context
|
||||
4. frozen memory snapshot
|
||||
5. extra sections
|
||||
|
||||
这样设计的原因:
|
||||
- 身份与总规则要最靠前
|
||||
- session/execution 是本轮运行语境,优先级高于长期记忆
|
||||
- memory 必须是 frozen snapshot,避免中途写 memory 后 prompt 失真
|
||||
- activated skill 正文按 Hermes 风格放到显式消息里,避免 system prompt 持续膨胀
|
||||
"""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
base_system_prompt = (build_input.base_system_prompt or "").strip()
|
||||
if base_system_prompt:
|
||||
sections.append(base_system_prompt)
|
||||
|
||||
session_section = self._render_session_section(build_input.session_context)
|
||||
if session_section:
|
||||
sections.append(session_section)
|
||||
|
||||
execution_context = (build_input.execution_context or "").strip()
|
||||
if execution_context:
|
||||
sections.append(f"# Execution Context\n\n{execution_context}")
|
||||
|
||||
if build_input.memory_snapshot is not None:
|
||||
# 这里明确只读 frozen snapshot,而不是去读 live memory store。
|
||||
# 否则一旦当前会话中途写 memory,system prompt 语义就会和会话开头不一致。
|
||||
snapshot_sections = build_input.memory_snapshot.as_prompt_sections()
|
||||
if snapshot_sections:
|
||||
sections.extend(snapshot_sections)
|
||||
|
||||
for extra in build_input.extra_sections:
|
||||
cleaned = (extra or "").strip()
|
||||
if cleaned:
|
||||
sections.append(cleaned)
|
||||
|
||||
return "\n\n---\n\n".join(sections)
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
build_input: ContextBuildInput,
|
||||
) -> ContextBuildResult:
|
||||
"""构建一次模型调用的完整 messages。
|
||||
|
||||
这里做三件事:
|
||||
1. 先生成最终 system prompt
|
||||
2. 按 Hermes 风格,把已激活 skill 的完整正文作为显式消息注入
|
||||
3. 把历史消息按原顺序接到后面
|
||||
4. 如果存在当前用户输入,则把本轮输入追加为最后一条 user message
|
||||
|
||||
注意:
|
||||
- `history` 默认被视为“已经由 session/context 上游从完整事件流中裁剪好的可见结构”
|
||||
- builder 不负责裁剪历史窗口,这件事应由 session/loop 上层决定
|
||||
- builder 只做最小格式统一
|
||||
"""
|
||||
|
||||
system_prompt = self.build_system_prompt(build_input)
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
messages.extend(self.build_skill_activation_messages(build_input.activated_skills))
|
||||
|
||||
for message in build_input.history:
|
||||
# 当前 builder 自己负责生成唯一的 system prompt。
|
||||
# 如果上游 history 已经混入 system 消息,这里要主动跳过,避免双 system。
|
||||
if message.get("role") == "system":
|
||||
continue
|
||||
messages.append(dict(message))
|
||||
|
||||
if build_input.current_user_input is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_input.current_user_input,
|
||||
}
|
||||
)
|
||||
|
||||
return ContextBuildResult(
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""向消息数组追加一条 tool result。
|
||||
|
||||
为什么这个函数放在 builder,而不是塞回 `AgentLoop`:
|
||||
- tool message 的结构必须和 provider 兼容
|
||||
- 统一在这里追加,可以避免不同执行路径拼出不同字段名
|
||||
- 后面如果要兼容更多 provider 差异,也只改这一层
|
||||
|
||||
这里返回原 list 本身,保持旧项目的“可链式追加”习惯。
|
||||
"""
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""向消息数组追加 assistant 消息。
|
||||
|
||||
这里有两个实现细节非常重要:
|
||||
|
||||
1. 无论 `content` 是否为空,都显式写入 `content` 键
|
||||
原因是部分 provider 在 assistant 带 `tool_calls` 时仍要求消息里存在 `content`
|
||||
|
||||
2. `reasoning_content` 只有在非空时才附带
|
||||
因为这属于思考模型扩展字段,不应污染普通 provider 路径
|
||||
"""
|
||||
|
||||
message: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
message["reasoning_content"] = reasoning_content
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def _render_session_section(self, session_context: SessionContext | None) -> str | None:
|
||||
"""把运行时 session metadata 渲染成一个可读 section。
|
||||
|
||||
这一段的目标不是让模型“记住所有数据库字段”,而是给它足够的当前运行语境。
|
||||
常见用途包括:
|
||||
- 知道当前来自 CLI 还是 Web/Gateway
|
||||
- 知道当前使用什么 model
|
||||
- 知道当前 channel/chat_id,便于后续多渠道行为约束
|
||||
"""
|
||||
|
||||
if session_context is None:
|
||||
return None
|
||||
|
||||
rows: list[str] = []
|
||||
if session_context.session_id:
|
||||
rows.append(f"Session ID: {session_context.session_id}")
|
||||
if session_context.source:
|
||||
rows.append(f"Source: {session_context.source}")
|
||||
if session_context.model:
|
||||
rows.append(f"Model: {session_context.model}")
|
||||
if session_context.user_id:
|
||||
rows.append(f"User ID: {session_context.user_id}")
|
||||
if session_context.channel:
|
||||
rows.append(f"Channel: {session_context.channel}")
|
||||
if session_context.chat_id:
|
||||
rows.append(f"Chat ID: {session_context.chat_id}")
|
||||
if session_context.parent_session_id:
|
||||
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
return "# Current Session\n\n" + "\n".join(rows)
|
||||
|
||||
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
|
||||
"""按 Hermes 风格把已激活 skill 转成显式消息。
|
||||
|
||||
关键区别:
|
||||
- system prompt 只保留轻量 skills index
|
||||
- 真正生效的 skill 正文通过额外消息块显式加载
|
||||
|
||||
这样模型不需要“从摘要里猜怎么读到正文”,而是直接拿到完整指导内容。
|
||||
"""
|
||||
|
||||
messages: list[dict[str, str]] = []
|
||||
for skill in activated_skills:
|
||||
content = (skill.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f'[SYSTEM: The "{skill.name}" skill is active for this run. '
|
||||
"Follow its instructions as active guidance unless the user overrides them.]\n\n"
|
||||
f"{content}"
|
||||
),
|
||||
}
|
||||
)
|
||||
return messages
|
||||
154
app-instance/backend/beaver/engine/loader.py
Normal file
154
app-instance/backend/beaver/engine/loader.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Centralized runtime loading for Beaver agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from beaver.engine.context import ContextBuilder
|
||||
from beaver.engine.session import SessionManager
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.skills import SkillAssembler, SkillsLoader
|
||||
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
|
||||
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EngineLoadResult:
|
||||
"""描述当前 agent runtime 已经装好的依赖。
|
||||
|
||||
这里同时保留两类字段:
|
||||
1. `tools/skills/memory_stores/permissions`
|
||||
- 便于做状态展示、调试、轻量测试
|
||||
2. `session_manager/tool_registry/...`
|
||||
- 供真正的运行时主链直接使用
|
||||
"""
|
||||
|
||||
workspace: Path
|
||||
tools: list[str] = field(default_factory=list)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
memory_stores: list[str] = field(default_factory=list)
|
||||
permissions: list[str] = field(default_factory=list)
|
||||
session_manager: SessionManager | None = None
|
||||
curated_memory_store: MemoryStore | None = None
|
||||
memory_service: MemoryService | None = None
|
||||
tool_registry: ToolRegistry | None = None
|
||||
tool_executor: ToolExecutor | None = None
|
||||
context_builder: ContextBuilder | None = None
|
||||
skills_loader: SkillsLoader | None = None
|
||||
skill_assembler: SkillAssembler | None = None
|
||||
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
|
||||
closed: bool = False
|
||||
|
||||
def register_closeable(self, name: str, close_fn: Callable[[], None]) -> None:
|
||||
"""登记一个由 runtime 统一关闭的资源。"""
|
||||
|
||||
self.closeables.append((name, close_fn))
|
||||
|
||||
def close(self) -> None:
|
||||
"""按后进先出顺序关闭 runtime 资源。
|
||||
|
||||
这一步先保持同步、最小、可组合:
|
||||
1. 只管理已经明确需要关闭的资源
|
||||
2. 暂不引入 async shutdown 协议
|
||||
3. 为后续 Web/Gateway lifespan 留统一入口
|
||||
"""
|
||||
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
errors: list[tuple[str, BaseException]] = []
|
||||
for name, close_fn in reversed(self.closeables):
|
||||
try:
|
||||
close_fn()
|
||||
except BaseException as exc: # pragma: no cover - defensive cleanup path
|
||||
errors.append((name, exc))
|
||||
self.closed = True
|
||||
|
||||
if errors:
|
||||
parts = ", ".join(f"{name}: {exc}" for name, exc in errors)
|
||||
raise RuntimeError(f"Runtime shutdown failed for {parts}")
|
||||
|
||||
|
||||
class EngineLoader:
|
||||
"""为任意 Beaver agent 装载共享 runtime 能力。
|
||||
|
||||
当前先做“最小可运行主链”需要的装配:
|
||||
- session manager
|
||||
- curated memory store
|
||||
- context builder
|
||||
- built-in tools
|
||||
- tool executor
|
||||
|
||||
等主链跑稳后,再把 skills、权限、MCP、delegation 逐步加进来。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
session_manager: SessionManager | None = None,
|
||||
curated_memory_store: MemoryStore | None = None,
|
||||
memory_service: MemoryService | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
context_builder: ContextBuilder | None = None,
|
||||
skills_loader: SkillsLoader | None = None,
|
||||
skill_assembler: SkillAssembler | None = None,
|
||||
) -> None:
|
||||
self.workspace = Path(workspace or Path.cwd())
|
||||
self._session_manager = session_manager
|
||||
self._curated_memory_store = curated_memory_store
|
||||
self._memory_service = memory_service
|
||||
self._tool_registry = tool_registry
|
||||
self._context_builder = context_builder
|
||||
self._skills_loader = skills_loader
|
||||
self._skill_assembler = skill_assembler
|
||||
|
||||
def load(self) -> EngineLoadResult:
|
||||
"""装配当前主链需要的最小 runtime 对象。"""
|
||||
|
||||
workspace = self.workspace
|
||||
session_manager = self._session_manager or SessionManager(workspace)
|
||||
|
||||
curated_root = workspace / "memory" / "curated"
|
||||
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
|
||||
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
|
||||
memory_service.initialize()
|
||||
|
||||
tool_registry = self._tool_registry or ToolRegistry()
|
||||
skills_loader = self._skills_loader or SkillsLoader(workspace)
|
||||
if self._tool_registry is None:
|
||||
# 这里先注册最小工具集,满足主链的 tool loop。
|
||||
tool_registry.register_many(
|
||||
[
|
||||
ObjectBackedTool(EchoTool()),
|
||||
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
|
||||
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
|
||||
ObjectBackedTool(SessionSearchTool(db=session_manager)),
|
||||
]
|
||||
)
|
||||
|
||||
context_builder = self._context_builder or ContextBuilder()
|
||||
tool_executor = ToolExecutor(tool_registry)
|
||||
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
|
||||
|
||||
result = EngineLoadResult(
|
||||
workspace=workspace,
|
||||
tools=[spec.name for spec in tool_registry.list_specs()],
|
||||
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
|
||||
memory_stores=["curated"],
|
||||
permissions=[],
|
||||
session_manager=session_manager,
|
||||
curated_memory_store=memory_service.get_store(),
|
||||
memory_service=memory_service,
|
||||
tool_registry=tool_registry,
|
||||
tool_executor=tool_executor,
|
||||
context_builder=context_builder,
|
||||
skills_loader=skills_loader,
|
||||
skill_assembler=skill_assembler,
|
||||
)
|
||||
if self._session_manager is None:
|
||||
result.register_closeable("session_manager", session_manager.close)
|
||||
return result
|
||||
689
app-instance/backend/beaver/engine/loop.py
Normal file
689
app-instance/backend/beaver/engine/loop.py
Normal file
@ -0,0 +1,689 @@
|
||||
"""Unified agent loop used by all Beaver agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from beaver.engine.context import ContextBuildInput, SessionContext
|
||||
from beaver.engine.providers import ProviderBundle, make_provider_bundle
|
||||
from beaver.tools import ToolContext
|
||||
|
||||
from .loader import EngineLoader, EngineLoadResult
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentProfile:
|
||||
"""Runtime profile for a Beaver agent instance."""
|
||||
|
||||
name: str = "default"
|
||||
system_prompt: str = ""
|
||||
default_model: str = "gpt-4.1-mini"
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.2
|
||||
max_tool_iterations: int = 8
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunResult:
|
||||
"""一次 direct run 的最小结果结构。"""
|
||||
|
||||
session_id: str
|
||||
run_id: str
|
||||
output_text: str
|
||||
finish_reason: str
|
||||
tool_iterations: int
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _DirectRunRequest:
|
||||
"""运行循环中的单个 direct task。"""
|
||||
|
||||
task: str
|
||||
kwargs: dict[str, Any]
|
||||
future: asyncio.Future[AgentRunResult]
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Single execution kernel shared by root agents and delegated agents."""
|
||||
|
||||
def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None:
|
||||
self.profile = profile or AgentProfile()
|
||||
self.loader = loader or EngineLoader()
|
||||
self.loaded: EngineLoadResult | None = None
|
||||
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
|
||||
self._running = False
|
||||
self._stop_requested = False
|
||||
|
||||
def boot(self) -> EngineLoadResult:
|
||||
"""Load shared runtime capabilities once for this agent instance."""
|
||||
if self.loaded is None:
|
||||
self.loaded = self.loader.load()
|
||||
return self.loaded
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
async def run(self) -> None:
|
||||
"""启动最小运行循环,顺序消费提交进来的 direct tasks。
|
||||
|
||||
第一版故意保持克制:
|
||||
1. 只做单消费者串行消费
|
||||
2. 真正执行仍复用 `process_direct()`
|
||||
3. 不引入 bus / worker / priority / retry
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError("AgentLoop.run() is already active")
|
||||
|
||||
self.boot()
|
||||
self._run_queue = asyncio.Queue()
|
||||
self._running = True
|
||||
self._stop_requested = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await self._run_queue.get()
|
||||
if item is None:
|
||||
if self._stop_requested:
|
||||
break
|
||||
continue
|
||||
|
||||
if item.future.cancelled():
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await self._process_direct_impl(item.task, **item.kwargs)
|
||||
except asyncio.CancelledError:
|
||||
if not item.future.done():
|
||||
item.future.cancel()
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive queue path
|
||||
if not item.future.done():
|
||||
item.future.set_exception(exc)
|
||||
else:
|
||||
if not item.future.done():
|
||||
item.future.set_result(result)
|
||||
finally:
|
||||
if self._run_queue is not None:
|
||||
while True:
|
||||
try:
|
||||
pending = self._run_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if isinstance(pending, _DirectRunRequest) and not pending.future.done():
|
||||
pending.future.set_exception(
|
||||
RuntimeError("AgentLoop.run() stopped before processing the queued task")
|
||||
)
|
||||
self._running = False
|
||||
self._stop_requested = False
|
||||
self._run_queue = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止运行循环。
|
||||
|
||||
第一版语义:
|
||||
- 不再接收新任务
|
||||
- 当前已经取出的任务允许收尾
|
||||
- 不自动 close runtime
|
||||
"""
|
||||
|
||||
if not self._running or self._run_queue is None:
|
||||
return
|
||||
self._stop_requested = True
|
||||
await self._run_queue.put(None)
|
||||
|
||||
async def submit_direct(
|
||||
self,
|
||||
task: str,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResult:
|
||||
"""向运行中的 loop 提交一个 direct task,并等待结果。"""
|
||||
|
||||
if not self._running or self._run_queue is None:
|
||||
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
|
||||
if self._stop_requested:
|
||||
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
|
||||
|
||||
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
|
||||
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
|
||||
return await future
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭当前 loop 持有的 runtime。
|
||||
|
||||
第 6 阶段先把生命周期最小骨架立住:
|
||||
- `boot()` 负责建立 runtime
|
||||
- `close()` 负责释放由 runtime 持有的资源
|
||||
- 之后再在此基础上扩 `run()/stop()/shutdown hooks`
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first")
|
||||
if self.loaded is None:
|
||||
return
|
||||
try:
|
||||
self.loaded.close()
|
||||
finally:
|
||||
self.loaded = None
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: Any = None,
|
||||
fallback_target: dict[str, Any] | None = None,
|
||||
auxiliary_target: dict[str, Any] | None = None,
|
||||
embedding_target: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tool_iterations: int | None = None,
|
||||
provider_bundle: ProviderBundle | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""跑通最小 direct run 主链。
|
||||
|
||||
当前主链刻意保持克制,只解决这些事情:
|
||||
1. 确保 session 存在
|
||||
2. 用 frozen memory + history 组 prompt
|
||||
3. 调 provider
|
||||
4. 若有 tool calls,则进入最小 tool loop
|
||||
5. 把 user/assistant/tool 消息和 usage 写回 session
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError(
|
||||
"AgentLoop.process_direct() is disabled while run() is active; "
|
||||
"submit tasks via submit_direct() instead."
|
||||
)
|
||||
return await self._process_direct_impl(
|
||||
task,
|
||||
session_id=session_id,
|
||||
source=source,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
execution_context=execution_context,
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
routing=routing,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
embedding_model=embedding_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
max_tool_iterations=max_tool_iterations,
|
||||
provider_bundle=provider_bundle,
|
||||
)
|
||||
|
||||
async def _process_direct_impl(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: Any = None,
|
||||
fallback_target: dict[str, Any] | None = None,
|
||||
auxiliary_target: dict[str, Any] | None = None,
|
||||
embedding_target: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tool_iterations: int | None = None,
|
||||
provider_bundle: ProviderBundle | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""真正执行一轮 direct run 的内部实现。
|
||||
|
||||
规则:
|
||||
- 外部直接调用时走 `process_direct()`
|
||||
- 运行循环内部消费时走 `_process_direct_impl()`
|
||||
- 这样才能保证 run 模式下外部不能绕过队列直接执行
|
||||
"""
|
||||
|
||||
loaded = self.boot()
|
||||
session_manager = self._require_loaded("session_manager")
|
||||
memory_service = self._require_loaded("memory_service")
|
||||
context_builder = self._require_loaded("context_builder")
|
||||
tool_registry = self._require_loaded("tool_registry")
|
||||
tool_executor = self._require_loaded("tool_executor")
|
||||
skill_assembler = self._require_loaded("skill_assembler")
|
||||
|
||||
resolved_session_id = session_id or uuid4().hex
|
||||
resolved_run_id = uuid4().hex
|
||||
resolved_model = model or self.profile.default_model
|
||||
resolved_max_tokens = max_tokens or self.profile.max_tokens
|
||||
resolved_temperature = self.profile.temperature if temperature is None else temperature
|
||||
resolved_max_tool_iterations = (
|
||||
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
||||
)
|
||||
|
||||
# 每次新运行开始前都通过 MemoryService 刷新 live state。
|
||||
# 这样 memory policy 会收口在 service,而不是散在 loop 里。
|
||||
memory_service.reload_for_new_run()
|
||||
|
||||
session_manager.ensure_session(
|
||||
resolved_session_id,
|
||||
source=source,
|
||||
model=resolved_model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="run_started",
|
||||
event_payload={
|
||||
"source": source,
|
||||
"model": resolved_model,
|
||||
"agent_name": self.profile.name,
|
||||
},
|
||||
content=task,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
user_message_recorded = False
|
||||
iterations = 0
|
||||
final_usage: dict[str, Any] = {}
|
||||
final_provider_name: str | None = provider_name
|
||||
final_model: str | None = resolved_model
|
||||
try:
|
||||
bundle = provider_bundle or make_provider_bundle(
|
||||
model=resolved_model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
routing=routing,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
embedding_model=embedding_model or "text-embedding-v4",
|
||||
)
|
||||
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
|
||||
skill_selector_model = (
|
||||
bundle.auxiliary_runtime.model
|
||||
if bundle.auxiliary_runtime is not None
|
||||
else bundle.main_runtime.model
|
||||
)
|
||||
assembled_skills = await skill_assembler.assemble(
|
||||
task_description=task,
|
||||
provider=skill_selector_provider,
|
||||
model=skill_selector_model,
|
||||
embedding_runtime=bundle.embedding_runtime,
|
||||
)
|
||||
skill_activation_messages = context_builder.build_skill_activation_messages(
|
||||
assembled_skills.activated_skills
|
||||
)
|
||||
|
||||
if skill_activation_messages:
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="skill_activation_snapshotted",
|
||||
event_payload={
|
||||
"activation_messages": skill_activation_messages,
|
||||
},
|
||||
content="\n\n".join(message["content"] for message in skill_activation_messages) or None,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
build_input = ContextBuildInput(
|
||||
base_system_prompt=self.profile.system_prompt,
|
||||
history=session_manager.get_history(resolved_session_id),
|
||||
current_user_input=task,
|
||||
memory_snapshot=memory_service.get_snapshot(),
|
||||
activated_skills=assembled_skills.activated_skills,
|
||||
session_context=SessionContext(
|
||||
session_id=resolved_session_id,
|
||||
source=source,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
context_result = context_builder.build_messages(build_input)
|
||||
session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="system_prompt_snapshotted",
|
||||
event_payload={
|
||||
"source": source,
|
||||
"model": resolved_model,
|
||||
"system_prompt_length": len(context_result.system_prompt),
|
||||
},
|
||||
content=context_result.system_prompt,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="user",
|
||||
event_type="user_message_added",
|
||||
content=task,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
user_message_recorded = True
|
||||
|
||||
provider = bundle.main_provider
|
||||
messages = list(context_result.messages)
|
||||
tool_schemas = tool_registry.export_provider_schemas()
|
||||
tool_context = ToolContext(
|
||||
workspace=str(loaded.workspace),
|
||||
session_id=resolved_session_id,
|
||||
user_id=user_id,
|
||||
services={
|
||||
"session_manager": session_manager,
|
||||
"memory_service": memory_service,
|
||||
"memory_store": memory_service.get_store(),
|
||||
"tool_registry": tool_registry,
|
||||
},
|
||||
metadata={
|
||||
"source": source,
|
||||
"agent_name": self.profile.name,
|
||||
},
|
||||
)
|
||||
|
||||
final_text = ""
|
||||
final_finish_reason = "stop"
|
||||
final_provider_name = bundle.main_runtime.provider_name
|
||||
final_model = bundle.main_runtime.model
|
||||
|
||||
while True:
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
model=final_model,
|
||||
max_tokens=resolved_max_tokens,
|
||||
temperature=resolved_temperature,
|
||||
)
|
||||
final_provider_name = response.provider_name or final_provider_name
|
||||
final_model = response.model or final_model
|
||||
final_usage = self._merge_usage(final_usage, response.usage or {})
|
||||
self._record_usage(session_manager, resolved_session_id, response.usage or {})
|
||||
|
||||
assistant_tool_calls = self._serialize_tool_calls(response.tool_calls)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=response.content,
|
||||
tool_calls=assistant_tool_calls or None,
|
||||
finish_reason=response.finish_reason,
|
||||
reasoning=response.reasoning_content,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_assistant_message(
|
||||
messages,
|
||||
content=response.content,
|
||||
tool_calls=assistant_tool_calls or None,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
final_text = response.content or ""
|
||||
final_finish_reason = response.finish_reason or "stop"
|
||||
break
|
||||
|
||||
if iterations >= resolved_max_tool_iterations:
|
||||
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
|
||||
final_finish_reason = "max_tool_iterations"
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_assistant_message(
|
||||
messages,
|
||||
content=final_text,
|
||||
)
|
||||
break
|
||||
|
||||
iterations += 1
|
||||
for tool_call in response.tool_calls:
|
||||
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="tool",
|
||||
event_type="tool_result_recorded",
|
||||
event_payload={
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
},
|
||||
content=result.content,
|
||||
tool_name=result.tool_name,
|
||||
tool_call_id=tool_call.id,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_tool_result(
|
||||
messages,
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=result.tool_name,
|
||||
result=result.content,
|
||||
)
|
||||
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="run_completed",
|
||||
event_payload={
|
||||
"finish_reason": final_finish_reason,
|
||||
"tool_iterations": iterations,
|
||||
},
|
||||
content=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return AgentRunResult(
|
||||
session_id=resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
output_text=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
tool_iterations=iterations,
|
||||
provider_name=final_provider_name,
|
||||
model=final_model,
|
||||
usage=final_usage,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not user_message_recorded:
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="user",
|
||||
event_type="user_message_added",
|
||||
content=task,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return self._build_error_result(
|
||||
session_manager=session_manager,
|
||||
session_id=resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
source=source,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
model=final_model or resolved_model,
|
||||
message=f"Run failed before completion: {exc}",
|
||||
tool_iterations=iterations,
|
||||
provider_name=final_provider_name,
|
||||
usage=final_usage,
|
||||
)
|
||||
|
||||
def _require_loaded(self, field_name: str) -> Any:
|
||||
loaded = self.boot()
|
||||
value = getattr(loaded, field_name)
|
||||
if value is None:
|
||||
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
|
||||
payload: list[dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
payload.append(
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.name,
|
||||
"arguments": tool_call.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None:
|
||||
"""把 provider usage 映射到 session usage 字段。
|
||||
|
||||
这里先做最常见字段的最小映射:
|
||||
- prompt_tokens -> input_tokens
|
||||
- completion_tokens -> output_tokens
|
||||
|
||||
后面如果 provider 层补了更细的 cache/reasoning/cost,再往这里扩。
|
||||
"""
|
||||
|
||||
if not usage:
|
||||
return
|
||||
session_manager.update_usage(
|
||||
session_id,
|
||||
input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0),
|
||||
output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0),
|
||||
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]:
|
||||
"""把多轮 provider usage 合并成一次 run 的累计 usage。"""
|
||||
|
||||
merged = dict(total)
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)):
|
||||
merged[key] = merged.get(key, 0) + value
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _build_error_result(
|
||||
*,
|
||||
session_manager: Any,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
source: str,
|
||||
title: str | None,
|
||||
user_id: str | None,
|
||||
model: str | None,
|
||||
message: str,
|
||||
tool_iterations: int,
|
||||
provider_name: str | None,
|
||||
usage: dict[str, Any],
|
||||
) -> AgentRunResult:
|
||||
"""把主链中的未处理异常收口成可追踪的 assistant error turn。"""
|
||||
|
||||
session_manager.append_message(
|
||||
session_id,
|
||||
run_id=run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=message,
|
||||
finish_reason="error",
|
||||
source=source,
|
||||
title=title,
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
session_id,
|
||||
run_id=run_id,
|
||||
role="system",
|
||||
event_type="run_failed",
|
||||
event_payload={
|
||||
"tool_iterations": tool_iterations,
|
||||
"provider_name": provider_name,
|
||||
},
|
||||
content=message,
|
||||
finish_reason="error",
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return AgentRunResult(
|
||||
session_id=session_id,
|
||||
run_id=run_id,
|
||||
output_text=message,
|
||||
finish_reason="error",
|
||||
tool_iterations=tool_iterations,
|
||||
provider_name=provider_name,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""LLM provider adapters."""
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .chain import FallbackProviderChain
|
||||
from .factory import (
|
||||
ProviderBundle,
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
build_provider_runtime,
|
||||
make_aux_provider,
|
||||
make_fallback_provider,
|
||||
make_main_provider,
|
||||
make_provider_bundle,
|
||||
make_provider_from_runtime,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FallbackProviderChain",
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"ToolCallRequest",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Native Anthropic Messages API provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import anthropic
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
anthropic = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""使用 Anthropic 原生 Messages API,而不是强行走 OpenAI-compatible path。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-5",
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if anthropic is None:
|
||||
raise RuntimeError("anthropic package is not installed")
|
||||
if self._client is None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
try:
|
||||
client = self._client_or_raise()
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
system_prompt, anthropic_messages = _convert_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"system": system_prompt or "",
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = _convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await client.messages.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_parts.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
arguments=block.input,
|
||||
)
|
||||
)
|
||||
usage_payload = {}
|
||||
if getattr(response, "usage", None):
|
||||
usage_payload = {
|
||||
"input_tokens": getattr(response.usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(response.usage, "output_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
provider_name="anthropic",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
if role == "system":
|
||||
content = message.get("content")
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
if role == "tool":
|
||||
converted.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.get("tool_call_id"),
|
||||
"content": message.get("content") or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "assistant" and message.get("tool_calls"):
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if message.get("content"):
|
||||
content_blocks.append({"type": "text", "text": message["content"]})
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
function = tool_call.get("function", tool_call)
|
||||
arguments = function.get("arguments")
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call.get("id"),
|
||||
"name": function.get("name"),
|
||||
"input": arguments or {},
|
||||
}
|
||||
)
|
||||
converted.append({"role": "assistant", "content": content_blocks})
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
blocks = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
blocks.append({"type": "text", "text": item.get("text", "")})
|
||||
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
|
||||
else:
|
||||
converted.append({"role": role, "content": content or ""})
|
||||
return system_prompt, converted
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
if not fn.get("name"):
|
||||
continue
|
||||
converted.append(
|
||||
{
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description") or "",
|
||||
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Beaver provider 子系统的统一契约。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallRequest:
|
||||
"""模型返回的一次工具调用请求。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMResponse:
|
||||
"""统一的模型响应结构。"""
|
||||
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""所有 provider 实现必须遵守的统一接口。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.request_timeout_seconds = (
|
||||
max(1.0, float(request_timeout_seconds))
|
||||
if request_timeout_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""清理 provider 普遍不接受的空 content。"""
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str) and content == "":
|
||||
clean = dict(message)
|
||||
clean["content"] = None if (message.get("role") == "assistant" and message.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item
|
||||
for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(message)
|
||||
clean["content"] = filtered or "(empty)"
|
||||
if message.get("role") == "assistant" and message.get("tool_calls") and not filtered:
|
||||
clean["content"] = None
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""统一聊天接口。"""
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""返回 provider 的默认模型名。"""
|
||||
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Provider chain helpers.
|
||||
|
||||
这里先实现最小可用的 fallback chain:
|
||||
- 每次调用都先尝试主 provider
|
||||
- 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback
|
||||
- fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路
|
||||
|
||||
这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import LLMProvider, LLMResponse
|
||||
from .runtime import ProviderRuntime
|
||||
|
||||
|
||||
class FallbackProviderChain(LLMProvider):
|
||||
"""把 primary/fallback provider 封装成一个统一的 LLMProvider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary_runtime: ProviderRuntime,
|
||||
primary_provider: LLMProvider,
|
||||
fallback_runtime: ProviderRuntime | None = None,
|
||||
fallback_provider: LLMProvider | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
self.primary_runtime = primary_runtime
|
||||
self.primary_provider = primary_provider
|
||||
self.fallback_runtime = fallback_runtime
|
||||
self.fallback_provider = fallback_provider
|
||||
# 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。
|
||||
# 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。
|
||||
self._last_runtime = primary_runtime
|
||||
self._last_provider = primary_provider
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
@property
|
||||
def fallback_activated(self) -> bool:
|
||||
"""最近一次 chat 是否实际用到了 fallback。"""
|
||||
|
||||
return self._last_call_used_fallback
|
||||
|
||||
@property
|
||||
def active_runtime(self) -> ProviderRuntime:
|
||||
"""最近一次 chat 实际使用的 runtime。"""
|
||||
|
||||
return self._last_runtime
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
self._last_provider = self.primary_provider
|
||||
self._last_runtime = self.primary_runtime
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.primary_provider,
|
||||
self.primary_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model or self.primary_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
response = self._decorate_response(response, self.primary_runtime)
|
||||
if not self._should_activate_fallback(response):
|
||||
return response
|
||||
|
||||
assert self.fallback_provider is not None
|
||||
assert self.fallback_runtime is not None
|
||||
|
||||
self._last_provider = self.fallback_provider
|
||||
self._last_runtime = self.fallback_runtime
|
||||
self._last_call_used_fallback = True
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.fallback_provider,
|
||||
self.fallback_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=self.fallback_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return self._decorate_response(response, self.fallback_runtime)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.primary_runtime.model
|
||||
|
||||
def _should_activate_fallback(self, response: LLMResponse) -> bool:
|
||||
return (
|
||||
self.fallback_provider is not None
|
||||
and self.fallback_runtime is not None
|
||||
and response.finish_reason == "error"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _safe_chat(
|
||||
provider: LLMProvider,
|
||||
runtime: ProviderRuntime,
|
||||
*,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> LLMResponse:
|
||||
"""把 provider 抛出的异常也收敛成统一 error response。
|
||||
|
||||
这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。
|
||||
"""
|
||||
|
||||
try:
|
||||
return await provider.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
provider_name=runtime.provider_name,
|
||||
model=runtime.model,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse:
|
||||
if response.provider_name is None:
|
||||
response.provider_name = runtime.provider_name
|
||||
if response.model is None:
|
||||
response.model = runtime.model
|
||||
return response
|
||||
@ -1,4 +1,4 @@
|
||||
"""OpenAI Codex Responses Provider."""
|
||||
"""OpenAI Codex Responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -7,24 +7,30 @@ import hashlib
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
try: # pragma: no cover - optional dependency
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
get_codex_token = None # type: ignore[assignment]
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
DEFAULT_ORIGINATOR = "beaver"
|
||||
|
||||
|
||||
class OpenAICodexProvider(LLMProvider):
|
||||
"""Use Codex OAuth to call the Responses API."""
|
||||
"""使用 Codex OAuth 调用 Responses API。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_model: str = "openai-codex/gpt-5.1-codex",
|
||||
request_timeout_seconds: float | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
|
||||
@ -36,14 +42,15 @@ class OpenAICodexProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
if httpx is None or get_codex_token is None:
|
||||
return LLMResponse(content="Error: codex dependencies are not installed", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
resolved_model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": _strip_model_prefix(model),
|
||||
"model": _strip_model_prefix(resolved_model),
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"instructions": system_prompt,
|
||||
@ -54,42 +61,27 @@ class OpenAICodexProvider(LLMProvider):
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
url = DEFAULT_CODEX_URL
|
||||
|
||||
try:
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
verify=True,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
except Exception as e:
|
||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||
raise
|
||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
verify=False,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Codex: {str(e)}",
|
||||
finish_reason="error",
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL,
|
||||
headers,
|
||||
body,
|
||||
verify=True,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling Codex: {exc}", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
provider_name="openai_codex",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -107,7 +99,7 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": DEFAULT_ORIGINATOR,
|
||||
"User-Agent": "nanobot (python)",
|
||||
"User-Agent": "beaver (python)",
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
@ -129,7 +121,6 @@ async def _request_codex(
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
@ -137,33 +128,30 @@ def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
for index, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Handle text first.
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
@ -171,28 +159,24 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{idx}",
|
||||
"id": f"msg_{index}",
|
||||
}
|
||||
)
|
||||
# Then handle tool calls.
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
for tool_call in message.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
call_id = call_id or f"call_{idx}"
|
||||
item_id = item_id or f"fc_{idx}"
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"id": item_id or f"fc_{index}",
|
||||
"call_id": call_id or f"call_{index}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
call_id, _ = _split_tool_call_id(message.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
@ -201,8 +185,6 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
@ -239,12 +221,12 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
async def _iter_sse(response: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
data_lines = [item[5:].strip() for item in buffer if item.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
@ -259,71 +241,34 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
async def _consume_sse(response: Any) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
content_parts.append(delta)
|
||||
elif event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
content += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
raw_arguments = item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
id=f"{item.get('call_id', 'call')}|{item.get('id', '')}",
|
||||
name=item.get("name", ""),
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
finish_reason = event.get("response", {}).get("status", "completed")
|
||||
return "".join(content_parts) or None, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
return f"HTTP {status_code}: {raw}"
|
||||
def _friendly_error(status_code: int, body: str) -> str:
|
||||
return f"Codex API error ({status_code}): {body[:400]}"
|
||||
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from openai import AsyncOpenAI
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
AsyncOpenAI = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
"""直接连接任意 OpenAI-compatible endpoint。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "no-key",
|
||||
api_base: str = "http://localhost:8000/v1",
|
||||
default_model: str = "default",
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if AsyncOpenAI is None:
|
||||
raise RuntimeError("openai package is not installed")
|
||||
if self._client is None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
client = self._client_or_raise()
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self.sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
try:
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
parsed_tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
import json
|
||||
arguments = json.loads(raw_arguments)
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
parsed_tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name="custom",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""Provider runtime 的统一工厂入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .anthropic import AnthropicProvider
|
||||
from .base import LLMProvider
|
||||
from .chain import FallbackProviderChain
|
||||
from .codex import OpenAICodexProvider
|
||||
from .custom import CustomProvider
|
||||
from .litellm import LiteLLMProvider
|
||||
from .runtime import (
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
normalize_provider_target,
|
||||
resolve_auxiliary_runtime,
|
||||
resolve_embedding_runtime,
|
||||
resolve_fallback_runtime,
|
||||
resolve_provider_runtime,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderBundle:
|
||||
"""一次运行所需的 provider 组合。
|
||||
|
||||
这里把三条常见链路收口到一起:
|
||||
- `main`:主对话
|
||||
- `fallback`:主链失败后的备用 provider
|
||||
- `auxiliary`:搜索摘要、压缩、memory flush 等辅助任务
|
||||
"""
|
||||
|
||||
main_runtime: ProviderRuntime
|
||||
main_provider: LLMProvider
|
||||
fallback_runtime: ProviderRuntime | None = None
|
||||
fallback_provider: LLMProvider | None = None
|
||||
auxiliary_runtime: ProviderRuntime | None = None
|
||||
auxiliary_provider: LLMProvider | None = None
|
||||
embedding_runtime: ProviderRuntime | None = None
|
||||
|
||||
|
||||
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
|
||||
"""构建统一 provider runtime。"""
|
||||
|
||||
return resolve_provider_runtime(**kwargs)
|
||||
|
||||
|
||||
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
|
||||
"""根据 runtime 创建具体 provider 实例。"""
|
||||
|
||||
if runtime.spec.provider_impl == "custom":
|
||||
return CustomProvider(
|
||||
api_key=runtime.api_key or "no-key",
|
||||
api_base=runtime.api_base or "http://localhost:8000/v1",
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "codex":
|
||||
return OpenAICodexProvider(
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "anthropic":
|
||||
return AnthropicProvider(
|
||||
api_key=runtime.api_key,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
api_base=runtime.api_base,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=runtime.api_key,
|
||||
api_base=runtime.api_base,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
provider_name=runtime.provider_name,
|
||||
extra_headers=runtime.extra_headers,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
routing=runtime.routing,
|
||||
)
|
||||
|
||||
|
||||
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建主对话 provider。"""
|
||||
|
||||
fallback_target = kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = kwargs.pop("fallback_model")
|
||||
|
||||
runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**kwargs,
|
||||
)
|
||||
provider = make_provider_from_runtime(runtime)
|
||||
fallback_pair = make_fallback_provider(runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
return runtime, provider
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
|
||||
|
||||
|
||||
def make_fallback_provider(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
) -> tuple[ProviderRuntime, LLMProvider] | None:
|
||||
"""构建 fallback provider。"""
|
||||
|
||||
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
|
||||
if runtime is None:
|
||||
return None
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_aux_provider(
|
||||
main_runtime: ProviderRuntime | None = None,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
task_name: str = "auxiliary",
|
||||
**kwargs: Any,
|
||||
) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建辅助任务 provider。"""
|
||||
|
||||
if target is None and kwargs:
|
||||
target = kwargs
|
||||
|
||||
if main_runtime is not None:
|
||||
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
|
||||
else:
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None or not normalized.model:
|
||||
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
|
||||
runtime = build_provider_runtime(
|
||||
model=normalized.model,
|
||||
provider_name=normalized.provider_name,
|
||||
api_key=normalized.api_key,
|
||||
api_base=normalized.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_embedding_runtime(
|
||||
main_runtime: ProviderRuntime,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""构建 embedding 专用 runtime。"""
|
||||
|
||||
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
|
||||
|
||||
|
||||
def make_provider_bundle(
|
||||
*,
|
||||
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary_task_name: str = "auxiliary",
|
||||
embedding_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
embedding_model: str = "text-embedding-v4",
|
||||
**kwargs: Any,
|
||||
) -> ProviderBundle:
|
||||
"""一次性构建 main/fallback/aux 三条 provider 链。"""
|
||||
|
||||
runtime_kwargs = dict(kwargs)
|
||||
fallback_target = runtime_kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = runtime_kwargs.pop("fallback_model")
|
||||
|
||||
main_runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**runtime_kwargs,
|
||||
)
|
||||
primary_provider = make_provider_from_runtime(main_runtime)
|
||||
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
main_provider: LLMProvider = primary_provider
|
||||
fallback_runtime = None
|
||||
fallback_provider = None
|
||||
else:
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
|
||||
|
||||
auxiliary_runtime = None
|
||||
auxiliary_provider = None
|
||||
if auxiliary_target is not None:
|
||||
auxiliary_runtime, auxiliary_provider = make_aux_provider(
|
||||
main_runtime,
|
||||
target=auxiliary_target,
|
||||
task_name=auxiliary_task_name,
|
||||
)
|
||||
|
||||
embedding_runtime = make_embedding_runtime(
|
||||
main_runtime,
|
||||
target=embedding_target,
|
||||
default_model=embedding_model,
|
||||
)
|
||||
|
||||
return ProviderBundle(
|
||||
main_runtime=main_runtime,
|
||||
main_provider=main_provider,
|
||||
fallback_runtime=fallback_runtime,
|
||||
fallback_provider=fallback_provider,
|
||||
auxiliary_runtime=auxiliary_runtime,
|
||||
auxiliary_provider=auxiliary_provider,
|
||||
embedding_runtime=embedding_runtime,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_embedding_runtime",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_gateway
|
||||
from .runtime import ProviderRoutingConfig
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
litellm = None # type: ignore[assignment]
|
||||
acompletion = None # type: ignore[assignment]
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.routing = routing
|
||||
self.provider_name = provider_name
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if litellm is not None:
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.drop_params = True
|
||||
|
||||
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
|
||||
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
|
||||
|
||||
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
|
||||
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
|
||||
"""
|
||||
|
||||
if not api_key:
|
||||
return {}
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if spec is None or not spec.env_key:
|
||||
return {}
|
||||
overrides: dict[str, str] = {spec.env_key: api_key}
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_value in spec.env_extras:
|
||||
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
overrides[env_name] = resolved
|
||||
return overrides
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(self, overrides: dict[str, str]):
|
||||
"""只在当前请求期间注入 provider 需要的环境变量。"""
|
||||
|
||||
if not overrides:
|
||||
yield
|
||||
return
|
||||
|
||||
sentinel = object()
|
||||
previous: dict[str, object] = {}
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = str(old_value)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
if self._gateway:
|
||||
prefix = self._gateway.litellm_prefix
|
||||
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
sanitized = []
|
||||
for message in messages:
|
||||
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
|
||||
spec = find_by_model(original_model)
|
||||
if spec is None:
|
||||
return
|
||||
model_lower = original_model.lower()
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
|
||||
if self.provider_name != "openrouter" or self.routing is None:
|
||||
return
|
||||
provider_payload: dict[str, Any] = {}
|
||||
if self.routing.sort:
|
||||
provider_payload["sort"] = self.routing.sort
|
||||
if self.routing.only:
|
||||
provider_payload["only"] = self.routing.only
|
||||
if self.routing.ignore:
|
||||
provider_payload["ignore"] = self.routing.ignore
|
||||
if self.routing.order:
|
||||
provider_payload["order"] = self.routing.order
|
||||
if self.routing.require_parameters:
|
||||
provider_payload["require_parameters"] = True
|
||||
if self.routing.data_collection:
|
||||
provider_payload["data_collection"] = self.routing.data_collection
|
||||
if provider_payload:
|
||||
kwargs["provider"] = provider_payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if acompletion is None:
|
||||
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
|
||||
|
||||
original_model = model or self.default_model
|
||||
resolved_model = self._resolve_model(original_model)
|
||||
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved_model,
|
||||
"messages": sanitized_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
self._apply_model_overrides(original_model, kwargs)
|
||||
self._apply_openrouter_routing(kwargs)
|
||||
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
|
||||
|
||||
try:
|
||||
with self._temporary_env(env_overrides):
|
||||
response = await acompletion(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
try:
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except Exception as exc:
|
||||
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
|
||||
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
|
||||
arguments = {
|
||||
"__beaver_tool_argument_parse_error__": str(exc),
|
||||
"__raw_arguments__": raw_arguments,
|
||||
}
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=getattr(message, "content", None),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name=self.provider_name or "litellm",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderSpec:
|
||||
"""单个 provider 的元数据定义。"""
|
||||
|
||||
name: str
|
||||
keywords: tuple[str, ...]
|
||||
env_key: str
|
||||
display_name: str = ""
|
||||
litellm_prefix: str = ""
|
||||
skip_prefixes: tuple[str, ...] = ()
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
is_gateway: bool = False
|
||||
is_local: bool = False
|
||||
detect_by_key_prefix: str = ""
|
||||
detect_by_base_keyword: str = ""
|
||||
default_api_base: str = ""
|
||||
strip_model_prefix: bool = False
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
is_oauth: bool = False
|
||||
is_direct: bool = False
|
||||
supports_prompt_caching: bool = False
|
||||
api_mode: str = "chat_completions"
|
||||
provider_impl: str = "litellm"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
is_direct=True,
|
||||
provider_impl="custom",
|
||||
api_mode="chat_completions",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter",
|
||||
is_gateway=True,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
supports_prompt_caching=True,
|
||||
api_mode="anthropic_messages",
|
||||
provider_impl="anthropic",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex", "codex"),
|
||||
env_key="",
|
||||
display_name="OpenAI Codex",
|
||||
is_oauth=True,
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
api_mode="codex_responses",
|
||||
provider_impl="codex",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot",
|
||||
skip_prefixes=("github_copilot/",),
|
||||
is_oauth=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek",
|
||||
skip_prefixes=("deepseek/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini",
|
||||
skip_prefixes=("gemini/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai",
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope",
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot",
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
default_api_base="https://api.moonshot.ai/v1",
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax",
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm",
|
||||
is_local=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq",
|
||||
skip_prefixes=("groq/",),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""按模型名关键词匹配标准 provider。"""
|
||||
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
|
||||
|
||||
# 显式前缀优先级最高。
|
||||
# 这里不能只看 standard provider:
|
||||
# - `openrouter/...` 应该直接命中 openrouter
|
||||
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
|
||||
# - `github_copilot/...codex` 也不应被误判成 openai_codex
|
||||
for spec in PROVIDERS:
|
||||
aliases = {spec.name}
|
||||
if spec.litellm_prefix:
|
||||
aliases.add(spec.litellm_prefix.replace("-", "_"))
|
||||
if model_prefix and normalized_prefix in aliases:
|
||||
return spec
|
||||
|
||||
for spec in standard_specs:
|
||||
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
|
||||
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
return None
|
||||
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""Hermes 风格的 provider runtime resolution。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
|
||||
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRoutingConfig:
|
||||
"""OpenRouter provider routing 配置。"""
|
||||
|
||||
sort: str | None = None
|
||||
only: list[str] = field(default_factory=list)
|
||||
ignore: list[str] = field(default_factory=list)
|
||||
order: list[str] = field(default_factory=list)
|
||||
require_parameters: bool = False
|
||||
data_collection: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderTarget:
|
||||
"""一次 provider 选路请求的标准化配置。
|
||||
|
||||
这层不是具体 runtime,而是“调用方想要什么”:
|
||||
- 用哪个 provider
|
||||
- 跑哪个 model
|
||||
- 是否指定自定义 base_url
|
||||
- 是否带额外 headers / routing
|
||||
|
||||
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
|
||||
"""
|
||||
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRuntime:
|
||||
"""运行时真正使用的 provider 解析结果。"""
|
||||
|
||||
spec: ProviderSpec
|
||||
model: str
|
||||
provider_name: str
|
||||
api_mode: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
default_model: str | None = None
|
||||
request_timeout_seconds: float | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
fallback_target: ProviderTarget | None = None
|
||||
auxiliary: bool = False
|
||||
role: str = "main"
|
||||
source: str = "runtime"
|
||||
|
||||
|
||||
def resolve_provider_runtime(
|
||||
*,
|
||||
model: str,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary: bool = False,
|
||||
role: str = "main",
|
||||
source: str = "runtime",
|
||||
) -> ProviderRuntime:
|
||||
"""把调用侧传入的配置解析成统一 runtime。"""
|
||||
|
||||
gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if gateway is not None:
|
||||
spec = gateway
|
||||
elif provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
else:
|
||||
spec = find_by_model(model)
|
||||
|
||||
if spec is None:
|
||||
if api_base:
|
||||
spec = find_by_name("custom")
|
||||
else:
|
||||
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
|
||||
|
||||
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
|
||||
resolved_api_base = api_base or spec.default_api_base or None
|
||||
|
||||
return ProviderRuntime(
|
||||
spec=spec,
|
||||
model=resolved_model,
|
||||
provider_name=spec.name,
|
||||
api_mode=spec.api_mode,
|
||||
api_key=api_key,
|
||||
api_base=resolved_api_base,
|
||||
default_model=resolved_model,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
extra_headers=extra_headers or {},
|
||||
routing=routing,
|
||||
fallback_target=normalize_provider_target(fallback_target),
|
||||
auxiliary=auxiliary,
|
||||
role=role,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
|
||||
"""把 dict/对象形式的 provider 配置收敛成统一结构。
|
||||
|
||||
这里兼容几种常见写法,便于后续接 CLI / config / gateway:
|
||||
- `provider` 或 `provider_name`
|
||||
- `base_url` 或 `api_base`
|
||||
- `headers` 或 `extra_headers`
|
||||
- `timeout` 或 `request_timeout_seconds`
|
||||
"""
|
||||
|
||||
if target is None:
|
||||
return None
|
||||
if isinstance(target, ProviderTarget):
|
||||
return target
|
||||
|
||||
provider_name = target.get("provider_name")
|
||||
if provider_name is None:
|
||||
provider_name = target.get("provider")
|
||||
|
||||
api_base = target.get("api_base")
|
||||
if api_base is None:
|
||||
api_base = target.get("base_url")
|
||||
|
||||
extra_headers = target.get("extra_headers")
|
||||
if extra_headers is None:
|
||||
extra_headers = target.get("headers")
|
||||
|
||||
request_timeout_seconds = target.get("request_timeout_seconds")
|
||||
if request_timeout_seconds is None:
|
||||
request_timeout_seconds = target.get("timeout")
|
||||
|
||||
routing = target.get("routing")
|
||||
if isinstance(routing, dict):
|
||||
routing = ProviderRoutingConfig(**routing)
|
||||
|
||||
return ProviderTarget(
|
||||
provider_name=provider_name,
|
||||
model=target.get("model"),
|
||||
api_key=target.get("api_key"),
|
||||
api_base=api_base,
|
||||
extra_headers=dict(extra_headers or {}),
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
routing=routing,
|
||||
)
|
||||
|
||||
|
||||
def resolve_fallback_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None,
|
||||
) -> ProviderRuntime | None:
|
||||
"""把 fallback 配置解析成独立 runtime。
|
||||
|
||||
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
|
||||
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
|
||||
"""
|
||||
|
||||
target = normalize_provider_target(fallback_target)
|
||||
if target is None or not target.model:
|
||||
return None
|
||||
|
||||
inferred_provider = target.provider_name
|
||||
if inferred_provider in {None, "", "main"}:
|
||||
inferred_provider = primary_runtime.provider_name
|
||||
|
||||
api_key = target.api_key
|
||||
api_base = target.api_base
|
||||
extra_headers = dict(target.extra_headers)
|
||||
|
||||
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
|
||||
if inferred_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=target.model,
|
||||
provider_name=inferred_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=target.routing,
|
||||
auxiliary=False,
|
||||
role="fallback",
|
||||
source="fallback_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_auxiliary_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
task_name: str = "auxiliary",
|
||||
) -> ProviderRuntime:
|
||||
"""解析辅助任务专用 runtime。
|
||||
|
||||
支持三类输入:
|
||||
- `None` / `provider=main`:直接复用主链 provider
|
||||
- 显式 `provider + model`:走独立 provider
|
||||
- 仅给 `model`:按模型名自动匹配 provider
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
provider_name = normalized.provider_name
|
||||
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_name == "main":
|
||||
return resolve_provider_runtime(
|
||||
model=normalized.model or primary_runtime.model,
|
||||
provider_name=primary_runtime.provider_name,
|
||||
api_key=normalized.api_key or primary_runtime.api_key,
|
||||
api_base=normalized.api_base or primary_runtime.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auto->main",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or primary_runtime.model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {"auto", "", None} and not normalized.api_base:
|
||||
# `auto` 的第一阶段实现保持保守:
|
||||
# - 有显式 model 时按 model 匹配 provider
|
||||
# - 匹配不到则回退主链 provider
|
||||
spec = find_by_model(resolved_model)
|
||||
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if resolved_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_embedding_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""解析 embedding 专用 runtime。
|
||||
|
||||
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
|
||||
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
|
||||
if normalized is None:
|
||||
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
|
||||
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
|
||||
if not _supports_openai_embeddings(primary_runtime):
|
||||
return None
|
||||
return resolve_provider_runtime(
|
||||
model=default_model,
|
||||
provider_name="openai",
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
extra_headers=dict(primary_runtime.extra_headers),
|
||||
routing=primary_runtime.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_inherited",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or default_model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {None, "", "main", "auto"}:
|
||||
resolved_provider = "custom" if normalized.api_base else "openai"
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if not api_base and _supports_openai_embeddings(primary_runtime):
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
runtime = resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_config",
|
||||
)
|
||||
if not _supports_openai_embeddings(runtime):
|
||||
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
|
||||
return runtime
|
||||
|
||||
|
||||
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
|
||||
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
|
||||
|
||||
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
|
||||
|
||||
|
||||
def _clone_runtime(
|
||||
runtime: ProviderRuntime,
|
||||
**changes: Any,
|
||||
) -> ProviderRuntime:
|
||||
"""基于现有 runtime 复制一个轻量变体。
|
||||
|
||||
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"extra_headers": dict(runtime.extra_headers),
|
||||
"routing": runtime.routing,
|
||||
"fallback_target": runtime.fallback_target,
|
||||
}
|
||||
payload.update(changes)
|
||||
return replace(runtime, **payload)
|
||||
|
||||
|
||||
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
|
||||
"""根据 registry 规则应用必要前缀。"""
|
||||
|
||||
resolved = model
|
||||
if gateway_mode:
|
||||
prefix = spec.litellm_prefix
|
||||
if spec.strip_model_prefix:
|
||||
resolved = resolved.split("/")[-1]
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
if spec.litellm_prefix:
|
||||
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
|
||||
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
|
||||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
2
app-instance/backend/beaver/engine/runtime/__init__.py
Normal file
2
app-instance/backend/beaver/engine/runtime/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Runtime helper objects and execution context."""
|
||||
|
||||
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Session state and persistence."""
|
||||
|
||||
from .manager import SessionManager
|
||||
from .models import MessageRecord, SessionRecord, SessionUsage
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
__all__ = [
|
||||
"MessageRecord",
|
||||
"SessionManager",
|
||||
"SessionRecord",
|
||||
"SessionSearchService",
|
||||
"SessionStore",
|
||||
"SessionUsage",
|
||||
]
|
||||
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .models import MessageRecord
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
|
||||
|
||||
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
|
||||
self.workspace = Path(workspace)
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
|
||||
self.store = SessionStore(self.db_path)
|
||||
self.search = SessionSearchService(self.store)
|
||||
|
||||
def close(self) -> None:
|
||||
self.store.close()
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
return self.store.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
record = self.store.get_session_record(session_id)
|
||||
return record.to_dict() if record is not None else None
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
session = self.get_session(session_id)
|
||||
if session is None:
|
||||
raise RuntimeError(f"Failed to create session {session_id!r}")
|
||||
return session
|
||||
|
||||
def append_message(self, session_id: str, **kwargs: Any) -> int:
|
||||
return self.store.append_message(session_id, **kwargs)
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
这里和 `get_messages_as_conversation()` 的区别很关键:
|
||||
- `get_event_records()` 面向 runtime / replay / audit,保留隐藏系统事件
|
||||
- `get_messages_as_conversation()` 面向 prompt builder,只暴露可进上下文的事件
|
||||
|
||||
第 6 阶段开始后,session 已不再只是“聊天消息存储”,而是在逐步收敛成
|
||||
“外部事件流 + 上层投影视图”。
|
||||
"""
|
||||
|
||||
return self.store.get_event_records(session_id)
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 direct run / future bus run 对应的事件片段。"""
|
||||
|
||||
return self.store.get_run_event_records(session_id, run_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按出现顺序列出当前 session 的所有 run_id。"""
|
||||
|
||||
return self.store.list_run_ids(session_id)
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
return self.store.get_messages_as_conversation(session_id)
|
||||
|
||||
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""返回适合注入 prompt 的可见历史切片。
|
||||
|
||||
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
|
||||
1. 只包含 `context_visible=True` 的事件
|
||||
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
|
||||
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
|
||||
"""
|
||||
|
||||
history = self.get_messages_as_conversation(session_id)
|
||||
sliced = history[-max_messages:]
|
||||
for index, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[index:]
|
||||
break
|
||||
return sliced
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""兼容旧名称,实际返回可见历史切片。"""
|
||||
|
||||
return self.get_visible_history(session_id, max_messages=max_messages)
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
self.store.update_system_prompt(session_id, system_prompt)
|
||||
|
||||
def update_usage(self, session_id: str, **kwargs: Any) -> None:
|
||||
self.store.update_usage(session_id, **kwargs)
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
self.store.end_session(session_id, end_reason)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
self.store.reopen_session(session_id)
|
||||
|
||||
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.list_sessions_rich(**kwargs)
|
||||
|
||||
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.search_messages(**kwargs)
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
return self.search.resolve_session_id(session_id_or_prefix)
|
||||
211
app-instance/backend/beaver/engine/session/models.py
Normal file
211
app-instance/backend/beaver/engine/session/models.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""Beaver session 子系统的数据模型。
|
||||
|
||||
这层只定义数据结构,不放数据库读写逻辑。目的是把:
|
||||
1. SQLite 行结构
|
||||
2. 运行时会话对象
|
||||
3. 对外暴露的 conversation message
|
||||
|
||||
三件事分开,避免后续所有地方都直接和裸字典耦合。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionUsage:
|
||||
"""会话维度的 usage/cost 统计。"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
actual_cost_usd: float | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"actual_cost_usd": self.actual_cost_usd,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MessageRecord:
|
||||
"""单条会话事件的结构化表示。
|
||||
|
||||
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
|
||||
1. 普通 user/assistant/tool 消息本身就是事件
|
||||
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
|
||||
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str | None = None
|
||||
timestamp: float | None = None
|
||||
message_id: int | None = None
|
||||
run_id: str | None = None
|
||||
event_type: str | None = None
|
||||
event_payload: dict[str, Any] | None = None
|
||||
context_visible: bool = True
|
||||
tool_name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
finish_reason: str | None = None
|
||||
reasoning: str | None = None
|
||||
reasoning_details: Any | None = None
|
||||
codex_reasoning_items: Any | None = None
|
||||
|
||||
def to_conversation_message(self) -> dict[str, Any]:
|
||||
"""转成 provider / context builder 可直接消费的消息格式。"""
|
||||
|
||||
if not self.context_visible:
|
||||
raise ValueError("Hidden session events cannot be converted into conversation messages")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.tool_name:
|
||||
payload["tool_name"] = self.tool_name
|
||||
if self.tool_calls:
|
||||
payload["tool_calls"] = self.tool_calls
|
||||
if self.tool_call_id:
|
||||
payload["tool_call_id"] = self.tool_call_id
|
||||
if self.finish_reason:
|
||||
payload["finish_reason"] = self.finish_reason
|
||||
if self.reasoning:
|
||||
payload["reasoning"] = self.reasoning
|
||||
if self.reasoning_details is not None:
|
||||
payload["reasoning_details"] = self.reasoning_details
|
||||
if self.codex_reasoning_items is not None:
|
||||
payload["codex_reasoning_items"] = self.codex_reasoning_items
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
|
||||
"""从 SQLite row/dict 恢复消息模型。"""
|
||||
|
||||
tool_calls = row.get("tool_calls")
|
||||
if isinstance(tool_calls, str):
|
||||
try:
|
||||
tool_calls = json.loads(tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
tool_calls = []
|
||||
|
||||
reasoning_details = row.get("reasoning_details")
|
||||
if isinstance(reasoning_details, str):
|
||||
try:
|
||||
reasoning_details = json.loads(reasoning_details)
|
||||
except json.JSONDecodeError:
|
||||
reasoning_details = None
|
||||
|
||||
codex_reasoning_items = row.get("codex_reasoning_items")
|
||||
if isinstance(codex_reasoning_items, str):
|
||||
try:
|
||||
codex_reasoning_items = json.loads(codex_reasoning_items)
|
||||
except json.JSONDecodeError:
|
||||
codex_reasoning_items = None
|
||||
|
||||
event_payload = row.get("event_payload")
|
||||
if isinstance(event_payload, str):
|
||||
try:
|
||||
event_payload = json.loads(event_payload)
|
||||
except json.JSONDecodeError:
|
||||
event_payload = None
|
||||
|
||||
return cls(
|
||||
message_id=row.get("id"),
|
||||
run_id=row.get("run_id"),
|
||||
role=row["role"],
|
||||
content=row.get("content"),
|
||||
event_type=row.get("event_type") or row.get("role"),
|
||||
event_payload=event_payload,
|
||||
context_visible=bool(row.get("context_visible", 1)),
|
||||
tool_name=row.get("tool_name"),
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=row.get("tool_call_id"),
|
||||
timestamp=row.get("timestamp"),
|
||||
finish_reason=row.get("finish_reason"),
|
||||
reasoning=row.get("reasoning"),
|
||||
reasoning_details=reasoning_details,
|
||||
codex_reasoning_items=codex_reasoning_items,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionRecord:
|
||||
"""单个 session 的结构化表示。"""
|
||||
|
||||
session_id: str
|
||||
source: str
|
||||
started_at: float
|
||||
last_active: float
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
model: str | None = None
|
||||
system_prompt: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
ended_at: float | None = None
|
||||
end_reason: str | None = None
|
||||
message_count: int = 0
|
||||
tool_call_count: int = 0
|
||||
preview: str | None = None
|
||||
usage: SessionUsage = field(default_factory=SessionUsage)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = {
|
||||
"id": self.session_id,
|
||||
"source": self.source,
|
||||
"user_id": self.user_id,
|
||||
"title": self.title,
|
||||
"model": self.model,
|
||||
"system_prompt": self.system_prompt,
|
||||
"parent_session_id": self.parent_session_id,
|
||||
"started_at": self.started_at,
|
||||
"last_active": self.last_active,
|
||||
"ended_at": self.ended_at,
|
||||
"end_reason": self.end_reason,
|
||||
"message_count": self.message_count,
|
||||
"tool_call_count": self.tool_call_count,
|
||||
"preview": self.preview,
|
||||
}
|
||||
payload.update(self.usage.to_dict())
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
|
||||
return cls(
|
||||
session_id=row["id"],
|
||||
source=row["source"],
|
||||
user_id=row.get("user_id"),
|
||||
title=row.get("title"),
|
||||
model=row.get("model"),
|
||||
system_prompt=row.get("system_prompt"),
|
||||
parent_session_id=row.get("parent_session_id"),
|
||||
started_at=row["started_at"],
|
||||
last_active=row["last_active"],
|
||||
ended_at=row.get("ended_at"),
|
||||
end_reason=row.get("end_reason"),
|
||||
message_count=row.get("message_count", 0),
|
||||
tool_call_count=row.get("tool_call_count", 0),
|
||||
preview=row.get("preview"),
|
||||
usage=SessionUsage(
|
||||
input_tokens=row.get("input_tokens", 0),
|
||||
output_tokens=row.get("output_tokens", 0),
|
||||
cache_read_tokens=row.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=row.get("cache_write_tokens", 0),
|
||||
reasoning_tokens=row.get("reasoning_tokens", 0),
|
||||
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
|
||||
actual_cost_usd=row.get("actual_cost_usd"),
|
||||
),
|
||||
)
|
||||
151
app-instance/backend/beaver/engine/session/search.py
Normal file
151
app-instance/backend/beaver/engine/session/search.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Beaver session 子系统的检索能力。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionSearchService:
|
||||
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
|
||||
|
||||
def __init__(self, store: SessionStore) -> None:
|
||||
self.store = store
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_fts5_query(query: str) -> str:
|
||||
quoted_parts: list[str] = []
|
||||
|
||||
def preserve(match: re.Match[str]) -> str:
|
||||
quoted_parts.append(match.group(0))
|
||||
return f"\x00Q{len(quoted_parts) - 1}\x00"
|
||||
|
||||
sanitized = re.sub(r'"[^"]*"', preserve, query)
|
||||
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
|
||||
sanitized = re.sub(r"\*+", "*", sanitized)
|
||||
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
|
||||
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
|
||||
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
|
||||
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
|
||||
|
||||
for index, quoted in enumerate(quoted_parts):
|
||||
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
|
||||
return sanitized.strip()
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
|
||||
|
||||
exact = self.store.get_session_record(session_id_or_prefix)
|
||||
if exact is not None:
|
||||
return exact.session_id
|
||||
|
||||
escaped = (
|
||||
session_id_or_prefix
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
rows = self.store._fetchall(
|
||||
"""
|
||||
SELECT id
|
||||
FROM sessions
|
||||
WHERE id LIKE ? ESCAPE '\\'
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 2
|
||||
""",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
if len(rows) == 1:
|
||||
return rows[0]["id"]
|
||||
return None
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
*,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
include_children: bool = False,
|
||||
source: str | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""列出最近活跃的 session 及其摘要元数据。"""
|
||||
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if not include_children:
|
||||
clauses.append("parent_session_id IS NULL")
|
||||
if source:
|
||||
clauses.append("source = ?")
|
||||
params.append(source)
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
|
||||
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
||||
params.extend([limit, offset])
|
||||
rows = self.store._fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM sessions
|
||||
{where}
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
return rows
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
role_filter: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""使用 FTS5 搜索 session transcript。"""
|
||||
|
||||
query = self._sanitize_fts5_query(query)
|
||||
if not query:
|
||||
return []
|
||||
|
||||
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
|
||||
params: list[Any] = [query]
|
||||
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"s.source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
if role_filter:
|
||||
placeholders = ",".join("?" for _ in role_filter)
|
||||
clauses.append(f"m.role IN ({placeholders})")
|
||||
params.extend(role_filter)
|
||||
|
||||
params.extend([limit, offset])
|
||||
sql = f"""
|
||||
SELECT
|
||||
m.id,
|
||||
m.session_id,
|
||||
m.role,
|
||||
s.source,
|
||||
s.model,
|
||||
s.started_at AS session_started,
|
||||
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
|
||||
FROM messages_fts
|
||||
JOIN messages m ON m.id = messages_fts.rowid
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(clauses)}
|
||||
ORDER BY rank
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.store._fetchall(sql, tuple(params))
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc
|
||||
467
app-instance/backend/beaver/engine/session/store.py
Normal file
467
app-instance/backend/beaver/engine/session/store.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Beaver session 子系统的 SQLite 存储实现。
|
||||
|
||||
设计来源主要参考 Hermes-agent:
|
||||
1. SQLite 作为统一 session/transcript backend
|
||||
2. WAL 模式支持多读单写
|
||||
3. FTS5 支持跨 session 文本检索
|
||||
4. `parent_session_id` 支持 lineage
|
||||
|
||||
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from .models import MessageRecord, SessionRecord
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
title TEXT,
|
||||
model TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
last_active REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
estimated_cost_usd REAL DEFAULT 0,
|
||||
actual_cost_usd REAL,
|
||||
preview TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||||
run_id TEXT,
|
||||
role TEXT NOT NULL,
|
||||
event_type TEXT,
|
||||
event_payload TEXT,
|
||||
context_visible INTEGER NOT NULL DEFAULT 1,
|
||||
content TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_call_id TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
finish_reason TEXT,
|
||||
reasoning TEXT,
|
||||
reasoning_details TEXT,
|
||||
codex_reasoning_items TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
|
||||
"""
|
||||
|
||||
FTS_TABLE_SQL = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||
content,
|
||||
content=messages,
|
||||
content_rowid=id
|
||||
);
|
||||
"""
|
||||
|
||||
FTS_TRIGGER_SQL = """
|
||||
DROP TRIGGER IF EXISTS messages_fts_insert;
|
||||
DROP TRIGGER IF EXISTS messages_fts_delete;
|
||||
DROP TRIGGER IF EXISTS messages_fts_update;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
"""
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""SQLite-backed session store."""
|
||||
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.executescript(SCHEMA_SQL)
|
||||
try:
|
||||
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
|
||||
except sqlite3.OperationalError:
|
||||
self._conn.executescript(FTS_TABLE_SQL)
|
||||
self._conn.executescript(FTS_TRIGGER_SQL)
|
||||
# 旧版本可能把 hidden 事件也写进了 FTS;初始化时顺手清掉这些噪声项。
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', id, content
|
||||
FROM messages
|
||||
WHERE context_visible = 0 AND content IS NOT NULL
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.close()
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
return result
|
||||
except BaseException:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
row = self._conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
rows = self._conn.execute(sql, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> str:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (
|
||||
id, source, user_id, title, model, parent_session_id, started_at, last_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
source = CASE
|
||||
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
|
||||
ELSE sessions.source
|
||||
END,
|
||||
user_id = COALESCE(sessions.user_id, excluded.user_id),
|
||||
title = COALESCE(sessions.title, excluded.title),
|
||||
model = COALESCE(sessions.model, excluded.model),
|
||||
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
|
||||
""",
|
||||
(session_id, source, user_id, title, model, parent_session_id, now, now),
|
||||
)
|
||||
return session_id
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_session_record(self, session_id: str) -> SessionRecord | None:
|
||||
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
return SessionRecord.from_row(row) if row else None
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""保存本 session 组装后的完整 system prompt snapshot。"""
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET system_prompt = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(system_prompt, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: float = 0.0,
|
||||
actual_cost_usd: float | None = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""更新会话 usage。默认按增量累加。"""
|
||||
|
||||
if absolute:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = ?,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
cache_write_tokens = cache_write_tokens + ?,
|
||||
reasoning_tokens = reasoning_tokens + ?,
|
||||
estimated_cost_usd = estimated_cost_usd + ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||||
END,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(sql, params)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
run_id: str | None = None,
|
||||
role: str,
|
||||
event_type: str | None = None,
|
||||
event_payload: dict[str, Any] | None = None,
|
||||
context_visible: bool = True,
|
||||
content: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
finish_reason: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
reasoning_details: Any | None = None,
|
||||
codex_reasoning_items: Any | None = None,
|
||||
source: str = "unknown",
|
||||
title: str | None = None,
|
||||
model: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> int:
|
||||
"""向指定 session 追加一条消息。"""
|
||||
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
now = time.time()
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
|
||||
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
|
||||
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
|
||||
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
|
||||
preview = (content or "")[:120] if role == "user" and content else None
|
||||
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> int:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
session_id, run_id, role, event_type, event_payload, context_visible, content,
|
||||
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
|
||||
reasoning_details, codex_reasoning_items
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
session_id,
|
||||
run_id,
|
||||
role,
|
||||
event_type or role,
|
||||
event_payload_json,
|
||||
1 if context_visible else 0,
|
||||
content,
|
||||
tool_name,
|
||||
tool_calls_json,
|
||||
tool_call_id,
|
||||
now,
|
||||
finish_reason,
|
||||
reasoning,
|
||||
reasoning_details_json,
|
||||
codex_items_json,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET last_active = ?,
|
||||
message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ?,
|
||||
model = COALESCE(model, ?),
|
||||
preview = CASE
|
||||
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
|
||||
ELSE preview
|
||||
END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now, tool_call_count, model, preview, preview, session_id),
|
||||
)
|
||||
return int(cursor.lastrowid)
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_message_records(self, session_id: str) -> list[MessageRecord]:
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
|
||||
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
|
||||
"""
|
||||
|
||||
return self.get_message_records(session_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按时间顺序列出当前 session 中出现过的 run_id。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT run_id
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id IS NOT NULL
|
||||
GROUP BY run_id
|
||||
ORDER BY MIN(timestamp), MIN(id)
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [str(row["run_id"]) for row in rows if row.get("run_id")]
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 run 对应的事件片段。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id, run_id),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
for record in self.get_event_records(session_id):
|
||||
if not record.context_visible:
|
||||
continue
|
||||
messages.append(record.to_conversation_message())
|
||||
return messages
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = ?, end_reason = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), end_reason, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = NULL, end_reason = NULL, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
2
app-instance/backend/beaver/foundation/__init__.py
Normal file
2
app-instance/backend/beaver/foundation/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Foundation layer for shared Beaver primitives."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Configuration models and loaders."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Shared error types."""
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
"""Event contracts and dispatch helpers."""
|
||||
|
||||
from .message_bus import InboundMessage, MessageBus, OutboundMessage
|
||||
|
||||
__all__ = ["InboundMessage", "MessageBus", "OutboundMessage"]
|
||||
72
app-instance/backend/beaver/foundation/events/message_bus.py
Normal file
72
app-instance/backend/beaver/foundation/events/message_bus.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Minimal message bus for gateway-style host integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InboundMessage:
|
||||
"""A minimal inbound message accepted by the gateway bridge."""
|
||||
|
||||
channel: str
|
||||
content: str
|
||||
session_id: str | None = None
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
execution_context: str | None = None
|
||||
model: str | None = None
|
||||
provider_name: str | None = None
|
||||
embedding_model: str | None = None
|
||||
message_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OutboundMessage:
|
||||
"""A minimal outbound message produced by the gateway bridge."""
|
||||
|
||||
channel: str
|
||||
content: str
|
||||
session_id: str | None
|
||||
finish_reason: str
|
||||
message_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
run_id: str | None = None
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""Minimal async message bus with inbound/outbound queues."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||
|
||||
async def publish_inbound(self, message: InboundMessage) -> None:
|
||||
await self.inbound.put(message)
|
||||
|
||||
async def consume_inbound(self) -> InboundMessage:
|
||||
return await self.inbound.get()
|
||||
|
||||
async def publish_outbound(self, message: OutboundMessage) -> None:
|
||||
await self.outbound.put(message)
|
||||
|
||||
async def consume_outbound(self) -> OutboundMessage:
|
||||
return await self.outbound.get()
|
||||
|
||||
@property
|
||||
def inbound_size(self) -> int:
|
||||
return self.inbound.qsize()
|
||||
|
||||
@property
|
||||
def outbound_size(self) -> int:
|
||||
return self.outbound.qsize()
|
||||
@ -0,0 +1,2 @@
|
||||
"""Shared data models."""
|
||||
|
||||
2
app-instance/backend/beaver/foundation/utils/__init__.py
Normal file
2
app-instance/backend/beaver/foundation/utils/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Common utility helpers."""
|
||||
|
||||
2
app-instance/backend/beaver/integrations/__init__.py
Normal file
2
app-instance/backend/beaver/integrations/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""External integrations."""
|
||||
|
||||
2
app-instance/backend/beaver/integrations/a2a/__init__.py
Normal file
2
app-instance/backend/beaver/integrations/a2a/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""A2A integration."""
|
||||
|
||||
2
app-instance/backend/beaver/integrations/mcp/__init__.py
Normal file
2
app-instance/backend/beaver/integrations/mcp/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""MCP integration."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Outlook integration."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Provider-specific integrations."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""WhatsApp integration."""
|
||||
|
||||
2
app-instance/backend/beaver/interfaces/__init__.py
Normal file
2
app-instance/backend/beaver/interfaces/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Thin interface layer for Beaver."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Channel interfaces."""
|
||||
|
||||
2
app-instance/backend/beaver/interfaces/cli/__init__.py
Normal file
2
app-instance/backend/beaver/interfaces/cli/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""CLI interface."""
|
||||
|
||||
59
app-instance/backend/beaver/interfaces/cli/main.py
Normal file
59
app-instance/backend/beaver/interfaces/cli/main.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""CLI entry for Beaver."""
|
||||
|
||||
try:
|
||||
import typer
|
||||
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
|
||||
class _FallbackTyper:
|
||||
def __init__(self, *_args, **_kwargs) -> None:
|
||||
pass
|
||||
|
||||
def command(self):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def __call__(self) -> None:
|
||||
raise RuntimeError("typer is not installed")
|
||||
|
||||
@staticmethod
|
||||
def echo(message: str) -> None:
|
||||
print(message)
|
||||
|
||||
@staticmethod
|
||||
def Option(default=None, *_args, **_kwargs):
|
||||
return default
|
||||
|
||||
typer = _FallbackTyper() # type: ignore[assignment]
|
||||
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
app = typer.Typer(help="Beaver backend CLI") if hasattr(typer, "Typer") else typer
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
message: str | None = typer.Option(None, "--message", "-m", help="Run one direct Beaver request."),
|
||||
workspace: str | None = typer.Option(None, "--workspace", help="Workspace root for this run."),
|
||||
) -> None:
|
||||
"""Thin CLI wrapper around AgentService.
|
||||
|
||||
CLI 现在不再自己维护执行逻辑,只负责:
|
||||
1. 解析命令行参数
|
||||
2. 调 AgentService
|
||||
3. 打印结果
|
||||
"""
|
||||
|
||||
service = AgentService(workspace=workspace)
|
||||
if not message:
|
||||
service.create_loop()
|
||||
typer.echo("Beaver engine booted.")
|
||||
return
|
||||
|
||||
result = service.run_direct(message, source="cli")
|
||||
typer.echo(result.output_text)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Project script entrypoint."""
|
||||
app()
|
||||
@ -0,0 +1,2 @@
|
||||
"""Gateway interface."""
|
||||
|
||||
189
app-instance/backend/beaver/interfaces/gateway/main.py
Normal file
189
app-instance/backend/beaver/interfaces/gateway/main.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Gateway entrypoint for Beaver.
|
||||
|
||||
当前阶段先不扩 bus / channels adapter,只做最小消息桥接:
|
||||
1. 启动时托管 `AgentService.start()`
|
||||
2. 常驻消费 `MessageBus.inbound`
|
||||
3. 调 `service.submit_direct(...)`
|
||||
4. 将结果写回 `MessageBus.outbound`
|
||||
5. 退出时走 `AgentService.shutdown()`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
async def _publish_bridge_error(
|
||||
bus: MessageBus,
|
||||
inbound: InboundMessage,
|
||||
*,
|
||||
detail: str,
|
||||
finish_reason: str = "error",
|
||||
) -> None:
|
||||
"""把 bridge 处理失败转换成结构化 outbound 错误消息。"""
|
||||
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=inbound.session_id,
|
||||
content=detail,
|
||||
finish_reason=finish_reason,
|
||||
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
|
||||
"""把尚未处理的 inbound 明确冲刷成 outbound 错误,而不是静默丢弃。"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
pending = bus.inbound.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await _publish_bridge_error(bus, pending, detail=reason, finish_reason="stopped")
|
||||
|
||||
|
||||
async def _await_bridge_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
|
||||
"""等待 bridge 退出;超时则取消,避免 shutdown 被桥接层反向卡死。"""
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=timeout_seconds)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _bridge_inbound_to_runtime(
|
||||
service: AgentService,
|
||||
bus: MessageBus,
|
||||
stop_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Consume inbound messages, run the agent, and publish outbound results."""
|
||||
|
||||
while True:
|
||||
if stop_event.is_set():
|
||||
await _flush_pending_inbound(
|
||||
bus,
|
||||
reason="Gateway stopped before processing the inbound message",
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=0.25)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await service.submit_direct(
|
||||
inbound.content,
|
||||
session_id=inbound.session_id,
|
||||
source=f"gateway:{inbound.channel}",
|
||||
user_id=inbound.user_id,
|
||||
title=inbound.title,
|
||||
execution_context=inbound.execution_context,
|
||||
model=inbound.model,
|
||||
provider_name=inbound.provider_name,
|
||||
embedding_model=inbound.embedding_model,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await _publish_bridge_error(
|
||||
bus,
|
||||
inbound,
|
||||
detail="Gateway stopped before completing the inbound message",
|
||||
finish_reason="cancelled",
|
||||
)
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive bridge path
|
||||
await _publish_bridge_error(
|
||||
bus,
|
||||
inbound,
|
||||
detail=str(exc),
|
||||
)
|
||||
else:
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=result.session_id,
|
||||
run_id=result.run_id,
|
||||
content=result.output_text,
|
||||
finish_reason=result.finish_reason,
|
||||
provider_name=result.provider_name,
|
||||
model=result.model,
|
||||
usage=dict(result.usage),
|
||||
metadata={"inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def run_gateway(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
service: AgentService | None = None,
|
||||
bus: MessageBus | None = None,
|
||||
manage_service_lifecycle: bool | None = None,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
shutdown_timeout_seconds: float | None = 5.0,
|
||||
shutdown_force: bool = True,
|
||||
) -> None:
|
||||
"""运行最小 gateway 宿主层与消息桥接。
|
||||
|
||||
默认 ownership 语义:
|
||||
- 未传 `service`:gateway 自己创建并接管其 lifecycle
|
||||
- 传入外部 `service`:默认只使用,不自动 start/shutdown
|
||||
"""
|
||||
|
||||
attached_service = service or AgentService(workspace=workspace)
|
||||
attached_bus = bus or MessageBus()
|
||||
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
|
||||
owned_stop_event = stop_event or asyncio.Event()
|
||||
started = False
|
||||
if owns_service:
|
||||
try:
|
||||
await attached_service.start()
|
||||
started = True
|
||||
except Exception:
|
||||
attached_service.close()
|
||||
raise
|
||||
|
||||
if not attached_service.is_running:
|
||||
raise RuntimeError(
|
||||
"Gateway requires AgentService running mode; start the injected service first "
|
||||
"or allow the gateway to manage its lifecycle."
|
||||
)
|
||||
|
||||
bridge_task = asyncio.create_task(_bridge_inbound_to_runtime(attached_service, attached_bus, owned_stop_event))
|
||||
try:
|
||||
await owned_stop_event.wait()
|
||||
finally:
|
||||
owned_stop_event.set()
|
||||
if owns_service and started:
|
||||
try:
|
||||
await attached_service.shutdown(
|
||||
timeout_seconds=shutdown_timeout_seconds,
|
||||
force=shutdown_force,
|
||||
)
|
||||
finally:
|
||||
await _await_bridge_shutdown(bridge_task)
|
||||
else:
|
||||
await _await_bridge_shutdown(bridge_task)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""同步 gateway 入口。"""
|
||||
|
||||
try:
|
||||
asyncio.run(run_gateway())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
2
app-instance/backend/beaver/interfaces/mcp/__init__.py
Normal file
2
app-instance/backend/beaver/interfaces/mcp/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""MCP server entrypoints."""
|
||||
|
||||
210
app-instance/backend/beaver/interfaces/mcp/memory_server.py
Normal file
210
app-instance/backend/beaver/interfaces/mcp/memory_server.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Beaver memory MCP server.
|
||||
|
||||
这个 server 用最精简的方式把两个内部能力暴露成 streamable-http MCP tools:
|
||||
1. `memory`
|
||||
2. `session_search`
|
||||
|
||||
运行方式:
|
||||
1. 直接用 Python:
|
||||
`python -m beaver.interfaces.mcp.memory_server --host 127.0.0.1 --port 8001`
|
||||
2. 或者用 FastMCP CLI:
|
||||
`fastmcp run beaver/interfaces/mcp/memory_server.py:mcp --transport http --port 8001`
|
||||
|
||||
默认 MCP 路径是 `/mcp`,FastMCP 的 HTTP transport 就是 streamable HTTP。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine.session import SessionManager
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
from beaver.tools.builtins.memory import memory_tool
|
||||
from beaver.tools.builtins.session_search import session_search as run_session_search
|
||||
|
||||
try: # pragma: no cover - import guard for environments without fastmcp
|
||||
from fastmcp import Context, FastMCP
|
||||
from fastmcp.server.lifespan import lifespan
|
||||
except ModuleNotFoundError: # pragma: no cover - handled at runtime in main()
|
||||
FastMCP = None # type: ignore[assignment]
|
||||
Context = Any # type: ignore[assignment]
|
||||
lifespan = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _require_fastmcp() -> None:
|
||||
if FastMCP is None or lifespan is None:
|
||||
raise RuntimeError(
|
||||
"fastmcp is not installed. Install it with `pip install fastmcp` "
|
||||
"or via this project's dependencies."
|
||||
)
|
||||
|
||||
|
||||
def _resolve_workspace_path(workspace: str | Path | None = None) -> Path:
|
||||
"""决定 memory server 使用的 workspace 根目录。"""
|
||||
|
||||
if workspace is not None:
|
||||
return Path(workspace).expanduser().resolve()
|
||||
env_workspace = os.getenv("BEAVER_WORKSPACE")
|
||||
if env_workspace:
|
||||
return Path(env_workspace).expanduser().resolve()
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
def _resolve_memory_dir(workspace: Path) -> Path:
|
||||
"""curated memory 的默认目录。"""
|
||||
|
||||
return workspace / "memory" / "curated"
|
||||
|
||||
|
||||
def _resolve_session_db_path(workspace: Path) -> Path:
|
||||
"""session store 的默认路径。"""
|
||||
|
||||
return workspace / "sessions" / "state.db"
|
||||
|
||||
|
||||
def create_memory_server(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
memory_dir: str | Path | None = None,
|
||||
session_db_path: str | Path | None = None,
|
||||
):
|
||||
"""创建并返回 FastMCP memory server 实例。"""
|
||||
|
||||
_require_fastmcp()
|
||||
workspace_path = _resolve_workspace_path(workspace)
|
||||
resolved_memory_dir = Path(memory_dir).expanduser().resolve() if memory_dir else _resolve_memory_dir(workspace_path)
|
||||
resolved_session_db = (
|
||||
Path(session_db_path).expanduser().resolve()
|
||||
if session_db_path
|
||||
else _resolve_session_db_path(workspace_path)
|
||||
)
|
||||
|
||||
@lifespan
|
||||
async def memory_server_lifespan(_server):
|
||||
"""在 server 生命周期内初始化共享 store/db。"""
|
||||
|
||||
store = MemoryStore(resolved_memory_dir)
|
||||
store.load_from_disk()
|
||||
session_manager = SessionManager(workspace=workspace_path, db_path=resolved_session_db)
|
||||
try:
|
||||
yield {
|
||||
"workspace_path": workspace_path,
|
||||
"memory_dir": resolved_memory_dir,
|
||||
"session_db_path": resolved_session_db,
|
||||
"memory_store": store,
|
||||
"session_manager": session_manager,
|
||||
}
|
||||
finally:
|
||||
session_manager.close()
|
||||
|
||||
server = FastMCP(
|
||||
name="Beaver Memory Server",
|
||||
instructions=(
|
||||
"Provides two MCP tools: `memory` for durable curated memory CRUD, "
|
||||
"and `session_search` for cross-session recall from transcript storage."
|
||||
),
|
||||
lifespan=memory_server_lifespan,
|
||||
)
|
||||
|
||||
@server.custom_route("/health", methods=["GET"])
|
||||
async def health_check(_request):
|
||||
"""最小 health check,方便远程探活。"""
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": True,
|
||||
"server": "beaver-memory",
|
||||
"transport": "streamable-http",
|
||||
"workspace": str(workspace_path),
|
||||
"memory_dir": str(resolved_memory_dir),
|
||||
"session_db_path": str(resolved_session_db),
|
||||
}
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def memory(
|
||||
action: str,
|
||||
target: str = "memory",
|
||||
content: str | None = None,
|
||||
old_text: str | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""CRUD for curated memory."""
|
||||
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required.")
|
||||
raw_result = memory_tool(
|
||||
action=action,
|
||||
target=target,
|
||||
content=content,
|
||||
old_text=old_text,
|
||||
store=ctx.lifespan_context["memory_store"],
|
||||
)
|
||||
return json.loads(raw_result)
|
||||
|
||||
@server.tool()
|
||||
async def session_search(
|
||||
query: str = "",
|
||||
role_filter: str | None = None,
|
||||
limit: int = 3,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Search prior sessions or browse recent ones."""
|
||||
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required.")
|
||||
raw_result = await run_session_search(
|
||||
query=query,
|
||||
role_filter=role_filter,
|
||||
limit=limit,
|
||||
db=ctx.lifespan_context["session_manager"],
|
||||
current_session_id=getattr(ctx, "session_id", None),
|
||||
)
|
||||
return json.loads(raw_result)
|
||||
|
||||
return server
|
||||
|
||||
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
"""构建最小命令行参数解析器。"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run Beaver memory MCP server over streamable HTTP.")
|
||||
parser.add_argument("--workspace", default=None, help="Workspace root. Defaults to BEAVER_WORKSPACE or cwd.")
|
||||
parser.add_argument("--memory-dir", default=None, help="Override curated memory directory.")
|
||||
parser.add_argument("--session-db", default=None, help="Override session SQLite database path.")
|
||||
parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host.")
|
||||
parser.add_argument("--port", default=8001, type=int, help="HTTP bind port.")
|
||||
parser.add_argument("--path", default="/mcp", help="MCP endpoint path.")
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""以 streamable HTTP 启动 memory server。"""
|
||||
|
||||
parser = build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
server = create_memory_server(
|
||||
workspace=args.workspace,
|
||||
memory_dir=args.memory_dir,
|
||||
session_db_path=args.session_db,
|
||||
)
|
||||
server.run(
|
||||
transport="http",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
path=args.path,
|
||||
)
|
||||
|
||||
|
||||
if FastMCP is not None:
|
||||
mcp = create_memory_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
app-instance/backend/beaver/interfaces/web/__init__.py
Normal file
2
app-instance/backend/beaver/interfaces/web/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Web interface."""
|
||||
|
||||
198
app-instance/backend/beaver/interfaces/web/app.py
Normal file
198
app-instance/backend/beaver/interfaces/web/app.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""FastAPI app factory for Beaver."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
from .deps import get_agent_service
|
||||
from .schemas import WebChatRequest, WebChatResponse, WebErrorResponse, WebStatusResponse
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
|
||||
class HTTPException(Exception):
|
||||
"""Minimal fallback exception matching FastAPI's constructor shape."""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
super().__init__(detail)
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
class Request: # type: ignore[override]
|
||||
"""Fallback request shim used only for import-time compatibility."""
|
||||
|
||||
def __init__(self, app: Any) -> None:
|
||||
self.app = app
|
||||
|
||||
class FastAPI: # type: ignore[override]
|
||||
"""Small fallback shim so the package can import before dependencies are installed."""
|
||||
|
||||
def __init__(self, *, title: str, lifespan: Callable[..., Any] | None = None) -> None:
|
||||
self.title = title
|
||||
self.lifespan = lifespan
|
||||
self.state = SimpleNamespace()
|
||||
|
||||
def get(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _app_lifespan(
|
||||
app: FastAPI,
|
||||
*,
|
||||
workspace: str | Path | None,
|
||||
service: AgentService | None,
|
||||
manage_service_lifecycle: bool | None,
|
||||
shutdown_timeout_seconds: float | None,
|
||||
shutdown_force: bool,
|
||||
) -> AsyncIterator[None]:
|
||||
"""把 Web app 接到 AgentService lifecycle 上。"""
|
||||
|
||||
attached_service = service or AgentService(workspace=workspace)
|
||||
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
|
||||
app.state.agent_service = attached_service
|
||||
started = False
|
||||
if owns_service:
|
||||
try:
|
||||
await attached_service.start()
|
||||
started = True
|
||||
except Exception:
|
||||
attached_service.close()
|
||||
raise
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if owns_service and started:
|
||||
await attached_service.shutdown(
|
||||
timeout_seconds=shutdown_timeout_seconds,
|
||||
force=shutdown_force,
|
||||
)
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
service: AgentService | None = None,
|
||||
manage_service_lifecycle: bool | None = None,
|
||||
shutdown_timeout_seconds: float | None = 5.0,
|
||||
shutdown_force: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Create a Beaver web app hosted by AgentService running mode.
|
||||
|
||||
默认 ownership 语义:
|
||||
- 未传 `service`:app 自己创建并接管其 lifecycle
|
||||
- 传入外部 `service`:默认只挂载,不自动 start/shutdown
|
||||
|
||||
如果确实需要覆盖默认行为,可以显式传 `manage_service_lifecycle=True/False`。
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="Beaver Backend",
|
||||
lifespan=lambda fastapi_app: _app_lifespan(
|
||||
fastapi_app,
|
||||
workspace=workspace,
|
||||
service=service,
|
||||
manage_service_lifecycle=manage_service_lifecycle,
|
||||
shutdown_timeout_seconds=shutdown_timeout_seconds,
|
||||
shutdown_force=shutdown_force,
|
||||
),
|
||||
)
|
||||
|
||||
@app.get("/api/ping", response_model=WebStatusResponse)
|
||||
async def ping(request: Request) -> WebStatusResponse:
|
||||
agent_service = get_agent_service(request)
|
||||
running = agent_service.is_running
|
||||
return WebStatusResponse(
|
||||
status="ok",
|
||||
running=running,
|
||||
mode="running" if running else ("direct" if agent_service.has_loop else "idle"),
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/api/chat",
|
||||
response_model=WebChatResponse,
|
||||
responses={
|
||||
400: {"model": WebErrorResponse},
|
||||
409: {"model": WebErrorResponse},
|
||||
503: {"model": WebErrorResponse},
|
||||
},
|
||||
)
|
||||
async def chat(request: Request, payload: WebChatRequest) -> WebChatResponse:
|
||||
agent_service = get_agent_service(request)
|
||||
message = payload.message.strip()
|
||||
if not message:
|
||||
raise HTTPException(status_code=400, detail="'message' is required")
|
||||
|
||||
fallback_target = _model_dump(payload.fallback_target)
|
||||
auxiliary_target = _model_dump(payload.auxiliary_target)
|
||||
embedding_target = _model_dump(payload.embedding_target)
|
||||
|
||||
try:
|
||||
result = await agent_service.submit_direct(
|
||||
message,
|
||||
session_id=payload.session_id,
|
||||
source="web",
|
||||
user_id=payload.user_id,
|
||||
title=payload.title,
|
||||
execution_context=payload.execution_context,
|
||||
model=payload.model,
|
||||
provider_name=payload.provider_name,
|
||||
embedding_model=payload.embedding_model,
|
||||
temperature=payload.temperature,
|
||||
max_tokens=payload.max_tokens,
|
||||
max_tool_iterations=payload.max_tool_iterations,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except RuntimeError as exc:
|
||||
detail = str(exc)
|
||||
if "requires an active run() loop" in detail or "not ready" in detail:
|
||||
status_code = 503
|
||||
elif "submit_direct" in detail or "running" in detail:
|
||||
status_code = 409
|
||||
else:
|
||||
status_code = 503
|
||||
raise HTTPException(status_code=status_code, detail=detail) from exc
|
||||
|
||||
return WebChatResponse(
|
||||
session_id=result.session_id,
|
||||
run_id=result.run_id,
|
||||
output_text=result.output_text,
|
||||
finish_reason=result.finish_reason,
|
||||
tool_iterations=result.tool_iterations,
|
||||
provider_name=result.provider_name,
|
||||
model=result.model,
|
||||
usage=result.usage,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _model_dump(value: Any) -> dict[str, Any] | None:
|
||||
"""兼容 Pydantic v1/v2 的最小导出辅助。"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump(exclude_none=True)
|
||||
if hasattr(value, "dict"):
|
||||
return value.dict(exclude_none=True)
|
||||
return dict(value)
|
||||
27
app-instance/backend/beaver/interfaces/web/deps.py
Normal file
27
app-instance/backend/beaver/interfaces/web/deps.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Web dependency wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException
|
||||
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
|
||||
class HTTPException(Exception):
|
||||
"""Minimal fallback exception matching FastAPI's constructor shape."""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
super().__init__(detail)
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
|
||||
def get_agent_service(request: Any) -> AgentService:
|
||||
"""从 app state 里取当前宿主层托管的 AgentService。"""
|
||||
|
||||
service = getattr(request.app.state, "agent_service", None)
|
||||
if not isinstance(service, AgentService):
|
||||
raise HTTPException(status_code=503, detail="AgentService is not ready")
|
||||
return service
|
||||
@ -0,0 +1,2 @@
|
||||
"""Web routes."""
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
"""Web request and response schemas."""
|
||||
|
||||
from .chat import WebChatRequest, WebChatResponse, WebErrorResponse, WebProviderTarget, WebStatusResponse
|
||||
|
||||
__all__ = [
|
||||
"WebChatRequest",
|
||||
"WebChatResponse",
|
||||
"WebErrorResponse",
|
||||
"WebProviderTarget",
|
||||
"WebStatusResponse",
|
||||
]
|
||||
93
app-instance/backend/beaver/interfaces/web/schemas/chat.py
Normal file
93
app-instance/backend/beaver/interfaces/web/schemas/chat.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Chat-related web schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, Field
|
||||
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
|
||||
class BaseModel:
|
||||
"""Very small fallback shim used only so imports work without pydantic."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
annotations = getattr(self.__class__, "__annotations__", {})
|
||||
for name in annotations:
|
||||
default = getattr(self.__class__, name, None)
|
||||
if name in kwargs:
|
||||
value = kwargs[name]
|
||||
else:
|
||||
value = default
|
||||
setattr(self, name, value)
|
||||
|
||||
def model_dump(self, *, exclude_none: bool = False) -> dict[str, Any]:
|
||||
data = dict(self.__dict__)
|
||||
if exclude_none:
|
||||
data = {key: value for key, value in data.items() if value is not None}
|
||||
return data
|
||||
|
||||
def Field(default: Any = None, **kwargs: Any) -> Any:
|
||||
default_factory = kwargs.get("default_factory")
|
||||
if default_factory is not None:
|
||||
return default_factory()
|
||||
return default
|
||||
|
||||
|
||||
class WebProviderTarget(BaseModel):
|
||||
"""Web-facing provider target shape.
|
||||
|
||||
先保持和 runtime 里的 `ProviderTarget` 接近,但只暴露 Web 当前需要的字段。
|
||||
后面如果 provider 层扩字段,再由这里显式补齐。
|
||||
"""
|
||||
|
||||
provider: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class WebChatRequest(BaseModel):
|
||||
"""最小正式 chat 请求结构。"""
|
||||
|
||||
message: str = Field(min_length=1)
|
||||
session_id: str | None = None
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
execution_context: str | None = None
|
||||
model: str | None = None
|
||||
provider_name: str | None = None
|
||||
embedding_model: str | None = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
max_tool_iterations: int | None = None
|
||||
fallback_target: WebProviderTarget | None = None
|
||||
auxiliary_target: WebProviderTarget | None = None
|
||||
embedding_target: WebProviderTarget | None = None
|
||||
|
||||
|
||||
class WebChatResponse(BaseModel):
|
||||
"""最小正式 chat 响应结构。"""
|
||||
|
||||
session_id: str
|
||||
run_id: str
|
||||
output_text: str
|
||||
finish_reason: str
|
||||
tool_iterations: int
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
usage: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebStatusResponse(BaseModel):
|
||||
"""Web 宿主层状态响应。"""
|
||||
|
||||
status: str
|
||||
running: bool
|
||||
mode: str
|
||||
|
||||
|
||||
class WebErrorResponse(BaseModel):
|
||||
"""统一错误响应结构。"""
|
||||
|
||||
detail: str
|
||||
2
app-instance/backend/beaver/memory/__init__.py
Normal file
2
app-instance/backend/beaver/memory/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Memory and experience stores."""
|
||||
|
||||
11
app-instance/backend/beaver/memory/curated/__init__.py
Normal file
11
app-instance/backend/beaver/memory/curated/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Curated long-term memory primitives."""
|
||||
|
||||
from .snapshot import MemorySnapshot, capture_memory_snapshot
|
||||
from .store import MemoryStore, scan_memory_content
|
||||
|
||||
__all__ = [
|
||||
"MemorySnapshot",
|
||||
"MemoryStore",
|
||||
"capture_memory_snapshot",
|
||||
"scan_memory_content",
|
||||
]
|
||||
52
app-instance/backend/beaver/memory/curated/snapshot.py
Normal file
52
app-instance/backend/beaver/memory/curated/snapshot.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""curated memory 的冻结快照工具。
|
||||
|
||||
这个文件很小,但职责非常关键:它把“长期记忆的 live state”和“当前会话注入 prompt
|
||||
时使用的 frozen snapshot”明确分开。
|
||||
|
||||
设计目的:
|
||||
1. 让调用侧显式意识到:system prompt 使用的是一份冻结视图
|
||||
2. 避免后续 engine/context builder 直接偷读 live store,破坏 frozen snapshot 语义
|
||||
3. 给 prompt 组装层一个简单、稳定、可测试的数据结构
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .store import MemoryStore
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MemorySnapshot:
|
||||
"""当前 session 使用的冻结记忆快照。
|
||||
|
||||
这里不是 memory store 本体,而是“给 prompt builder 的只读投影”。
|
||||
一旦 capture 完成,这个对象就代表本 session 的注入视图,不应在会话中途被修改。
|
||||
"""
|
||||
|
||||
memory_block: str | None
|
||||
user_block: str | None
|
||||
|
||||
def as_prompt_sections(self) -> list[str]:
|
||||
"""按稳定顺序返回可直接拼接进 prompt 的 section 列表。
|
||||
|
||||
顺序固定为:
|
||||
1. user profile
|
||||
2. agent memory
|
||||
|
||||
这样后续 context builder 的输出更稳定,测试也更容易写。
|
||||
"""
|
||||
|
||||
return [section for section in (self.user_block, self.memory_block) if section]
|
||||
|
||||
|
||||
def capture_memory_snapshot(store: MemoryStore) -> MemorySnapshot:
|
||||
"""从 `MemoryStore` 提取当前 session 的 frozen snapshot。
|
||||
|
||||
前提是 `store.load_from_disk()` 已经在 session 启动时调用过,否则拿到的只是空快照。
|
||||
"""
|
||||
|
||||
return MemorySnapshot(
|
||||
memory_block=store.format_for_system_prompt("memory"),
|
||||
user_block=store.format_for_system_prompt("user"),
|
||||
)
|
||||
463
app-instance/backend/beaver/memory/curated/store.py
Normal file
463
app-instance/backend/beaver/memory/curated/store.py
Normal file
@ -0,0 +1,463 @@
|
||||
"""Beaver 的精炼长期记忆存储层。
|
||||
|
||||
这个文件实现的是以 Hermes-agent 为基线的 curated memory 模型,目标不是
|
||||
“把所有历史都存下来”,而是只保存跨会话仍然值得保留的稳定事实。
|
||||
|
||||
核心设计:
|
||||
1. 只保留两个持久化记忆桶:
|
||||
- ``memory``: agent 自己对环境、项目、工具 quirks 的长期备注
|
||||
- ``user``: 对用户偏好、习惯、身份信息的长期理解
|
||||
2. ``replace`` / ``remove`` 不使用 UUID,而是使用短语义片段做子串匹配。
|
||||
这是为了适配 LLM 更擅长“记住一句话片段”而不是“追踪一个随机 ID”的现实。
|
||||
3. 写入前先做安全扫描,避免把 prompt injection / secrets exfiltration
|
||||
一类危险内容写入长期记忆,再在未来会话中反向污染 system prompt。
|
||||
4. 写入协议严格遵守:
|
||||
- scan
|
||||
- lock
|
||||
- reload
|
||||
- validate
|
||||
- atomic write
|
||||
5. 本文件维护两份状态:
|
||||
- live state: 当前内存中的真实条目,tool 写入后立刻变化
|
||||
- frozen snapshot: 会话开始时冻结的一份 prompt 注入快照
|
||||
|
||||
其中最重要的一点是:本会话中新增的记忆会立刻写盘,但不会反向修改本会话
|
||||
已经冻结的 system prompt。这样可以保住 prefix cache,也避免“会话中途 prompt
|
||||
变了导致行为抖动”的问题。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError: # pragma: no cover - Windows fallback
|
||||
fcntl = None
|
||||
|
||||
try:
|
||||
import msvcrt
|
||||
except ImportError: # pragma: no cover - Unix platforms
|
||||
msvcrt = None
|
||||
|
||||
ENTRY_DELIMITER = "\n§\n"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
|
||||
_MEMORY_THREAT_PATTERNS: list[tuple[str, str]] = [
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"you\s+are\s+now\s+", "role_hijack"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don't\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
|
||||
(r"\$HOME/\.beaver/\.env|\~/\.beaver/\.env", "beaver_env"),
|
||||
]
|
||||
|
||||
_INVISIBLE_CHARS = {
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
def scan_memory_content(content: str) -> str | None:
|
||||
"""扫描待写入内容,拦截明显危险的记忆条目。
|
||||
|
||||
这里不是在做完备的安全审计,而是在做“进入长期记忆之前的最低限度闸门”。
|
||||
因为长期记忆会在未来会话中重新注入 system prompt,所以一旦把恶意文本写进去,
|
||||
风险远高于普通临时上下文。
|
||||
"""
|
||||
|
||||
for char in _INVISIBLE_CHARS:
|
||||
if char in content:
|
||||
return (
|
||||
f"Blocked: content contains invisible unicode character "
|
||||
f"U+{ord(char):04X}."
|
||||
)
|
||||
|
||||
for pattern, pattern_id in _MEMORY_THREAT_PATTERNS:
|
||||
if re.search(pattern, content, re.IGNORECASE):
|
||||
return (
|
||||
f"Blocked: content matches threat pattern '{pattern_id}'. "
|
||||
"Memory entries are injected into future system prompts."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""带容量上限的长期记忆存储。
|
||||
|
||||
这个类负责:
|
||||
1. 从磁盘加载 `MEMORY.md` / `USER.md`
|
||||
2. 在 session 启动时冻结 prompt snapshot
|
||||
3. 为 `add / replace / remove` 提供安全写接口
|
||||
4. 维护 live state 与 frozen snapshot 的边界
|
||||
|
||||
它不负责:
|
||||
1. 自动从对话里抽取要记住的内容
|
||||
2. session transcript 检索
|
||||
3. skills 的学习和发布
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str | Path,
|
||||
*,
|
||||
memory_char_limit: int = 2200,
|
||||
user_char_limit: int = 1375,
|
||||
) -> None:
|
||||
self.root = Path(root)
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
self.memory_entries: list[str] = []
|
||||
self.user_entries: list[str] = []
|
||||
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
|
||||
|
||||
def load_from_disk(self) -> None:
|
||||
"""从磁盘加载 live state,并冻结当前 session 的 prompt snapshot。
|
||||
|
||||
调用时机应该是“会话启动时”,而不是每次工具写入后。
|
||||
如果在每次写入后都重新 load 并更新 system prompt,就会破坏 frozen snapshot
|
||||
这个设计,导致本轮会话 prompt 前缀发生变化。
|
||||
"""
|
||||
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
self.memory_entries = list(dict.fromkeys(self._read_file(self._path_for("memory"))))
|
||||
self.user_entries = list(dict.fromkeys(self._read_file(self._path_for("user"))))
|
||||
self._system_prompt_snapshot = {
|
||||
"memory": self._render_block("memory", self.memory_entries),
|
||||
"user": self._render_block("user", self.user_entries),
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(self, path: Path):
|
||||
"""对目标记忆文件加排他锁。
|
||||
|
||||
锁文件使用 sibling `.lock` 文件,而不是直接锁业务文件本身。
|
||||
原因是业务文件使用的是“临时文件写入 + os.replace 原子替换”,如果直接锁目标
|
||||
文件,替换时会让锁语义和文件句柄关系变得更脆弱。
|
||||
"""
|
||||
|
||||
lock_path = path.with_suffix(path.suffix + ".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if fcntl is None and msvcrt is None:
|
||||
yield
|
||||
return
|
||||
|
||||
if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0):
|
||||
lock_path.write_text(" ", encoding="utf-8")
|
||||
|
||||
fd = open(lock_path, "r+" if msvcrt else "a+", encoding="utf-8")
|
||||
try:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
elif msvcrt is not None: # pragma: no cover - Windows fallback
|
||||
fd.seek(0)
|
||||
msvcrt.locking(fd.fileno(), msvcrt.LK_LOCK, 1)
|
||||
yield
|
||||
finally:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
elif msvcrt is not None: # pragma: no cover - Windows fallback
|
||||
try:
|
||||
fd.seek(0)
|
||||
msvcrt.locking(fd.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except OSError:
|
||||
pass
|
||||
fd.close()
|
||||
|
||||
def _path_for(self, target: str) -> Path:
|
||||
"""根据目标桶返回实际文件路径。"""
|
||||
if target == "user":
|
||||
return self.root / DEFAULT_USER_FILENAME
|
||||
return self.root / DEFAULT_MEMORY_FILENAME
|
||||
|
||||
def _entries_for(self, target: str) -> list[str]:
|
||||
"""读取某个目标桶当前的 live entries。"""
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: list[str]) -> None:
|
||||
"""更新某个目标桶在内存中的 live entries。"""
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
else:
|
||||
self.memory_entries = entries
|
||||
|
||||
def _char_limit(self, target: str) -> int:
|
||||
"""返回目标桶的字符预算。
|
||||
|
||||
这里使用字符数而不是 token 数,是因为字符预算更稳定,也不依赖具体模型。
|
||||
"""
|
||||
return self.user_char_limit if target == "user" else self.memory_char_limit
|
||||
|
||||
def _char_count(self, target: str) -> int:
|
||||
"""返回目标桶当前 live state 的字符占用。"""
|
||||
entries = self._entries_for(target)
|
||||
return len(ENTRY_DELIMITER.join(entries)) if entries else 0
|
||||
|
||||
def _reload_target(self, target: str) -> None:
|
||||
"""在持锁状态下重新从磁盘读取目标桶。
|
||||
|
||||
这是并发安全协议里最关键的一步之一。
|
||||
必须在拿到锁之后 reload,才能确保当前进程不会覆盖掉其他并发会话刚刚写入
|
||||
的最新内容。
|
||||
"""
|
||||
fresh = list(dict.fromkeys(self._read_file(self._path_for(target))))
|
||||
self._set_entries(target, fresh)
|
||||
|
||||
def save_to_disk(self, target: str) -> None:
|
||||
"""把当前 live entries 持久化到磁盘。"""
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
self._write_file(self._path_for(target), self._entries_for(target))
|
||||
|
||||
def add(self, target: str, content: str) -> dict[str, Any]:
|
||||
"""追加一条新的长期记忆。
|
||||
|
||||
规则:
|
||||
1. 空内容拒绝
|
||||
2. 安全扫描不通过拒绝
|
||||
3. 精确重复拒绝
|
||||
4. 超出字符预算拒绝
|
||||
5. 否则追加并立即写盘
|
||||
"""
|
||||
|
||||
content = content.strip()
|
||||
if not content:
|
||||
return {"success": False, "error": "Content cannot be empty."}
|
||||
|
||||
scan_error = scan_memory_content(content)
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (skipped duplicate).")
|
||||
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
limit = self._char_limit(target)
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit."
|
||||
),
|
||||
"current_entries": list(entries),
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
|
||||
"""用新的内容替换一条已有记忆。
|
||||
|
||||
这里按 `old_text in entry` 做子串匹配,而不是要求调用方提供完整条目或 UUID。
|
||||
如果命中多条且它们内容不同,会要求调用方给出更精确的片段,避免误替换。
|
||||
"""
|
||||
|
||||
old_text = old_text.strip()
|
||||
new_content = new_content.strip()
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
if not new_content:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "new_content cannot be empty. Use remove to delete entries.",
|
||||
}
|
||||
|
||||
scan_error = scan_memory_content(new_content)
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
matches = [(index, entry) for index, entry in enumerate(entries) if old_text in entry]
|
||||
if not matches:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
unique_texts = {entry for _, entry in matches}
|
||||
if len(unique_texts) > 1:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": [
|
||||
entry[:80] + ("..." if len(entry) > 80 else "")
|
||||
for _, entry in matches
|
||||
],
|
||||
}
|
||||
|
||||
index = matches[0][0]
|
||||
candidate_entries = list(entries)
|
||||
candidate_entries[index] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(candidate_entries))
|
||||
limit = self._char_limit(target)
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
|
||||
entries[index] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
def remove(self, target: str, old_text: str) -> dict[str, Any]:
|
||||
"""删除一条已有记忆。
|
||||
|
||||
删除和替换共享同样的匹配策略:优先服务于 LLM 可操作性,而不是数据库式的强 ID。
|
||||
"""
|
||||
|
||||
old_text = old_text.strip()
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
matches = [(index, entry) for index, entry in enumerate(entries) if old_text in entry]
|
||||
if not matches:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
unique_texts = {entry for _, entry in matches}
|
||||
if len(unique_texts) > 1:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": [
|
||||
entry[:80] + ("..." if len(entry) > 80 else "")
|
||||
for _, entry in matches
|
||||
],
|
||||
}
|
||||
|
||||
entries.pop(matches[0][0])
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
def format_for_system_prompt(self, target: str) -> str | None:
|
||||
"""返回 session 启动时冻结下来的 prompt block。
|
||||
|
||||
这里明确返回的是 frozen snapshot,而不是 live state。
|
||||
所以如果 session 中途调用 `add()` 写入了新记忆,这里不会立刻变化。
|
||||
"""
|
||||
|
||||
block = self._system_prompt_snapshot.get(target, "")
|
||||
return block or None
|
||||
|
||||
def _success_response(self, target: str, message: str | None = None) -> dict[str, Any]:
|
||||
"""统一生成 memory tool 的成功响应。
|
||||
|
||||
响应里返回 live entries 和占用信息,目的是让模型能“看到自己刚写进去什么”,
|
||||
即使 system prompt 仍然保持冻结不变。
|
||||
"""
|
||||
current = self._char_count(target)
|
||||
limit = self._char_limit(target)
|
||||
percent = min(100, int((current / limit) * 100)) if limit > 0 else 0
|
||||
payload: dict[str, Any] = {
|
||||
"success": True,
|
||||
"target": target,
|
||||
"entries": list(self._entries_for(target)),
|
||||
"entry_count": len(self._entries_for(target)),
|
||||
"usage": f"{percent}% — {current:,}/{limit:,} chars",
|
||||
}
|
||||
if message:
|
||||
payload["message"] = message
|
||||
return payload
|
||||
|
||||
def _render_block(self, target: str, entries: list[str]) -> str:
|
||||
"""把条目渲染成适合注入 system prompt 的块。"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
current = len(ENTRY_DELIMITER.join(entries))
|
||||
limit = self._char_limit(target)
|
||||
percent = min(100, int((current / limit) * 100)) if limit > 0 else 0
|
||||
if target == "user":
|
||||
header = f"USER PROFILE (who the user is) [{percent}% — {current:,}/{limit:,} chars]"
|
||||
else:
|
||||
header = f"MEMORY (your personal notes) [{percent}% — {current:,}/{limit:,} chars]"
|
||||
separator = "═" * 46
|
||||
return f"{separator}\n{header}\n{separator}\n{ENTRY_DELIMITER.join(entries)}"
|
||||
|
||||
@staticmethod
|
||||
def _read_file(path: Path) -> list[str]:
|
||||
"""读取记忆文件并按 entry delimiter 拆分。
|
||||
|
||||
这里不额外加读锁,因为写入采用的是原子替换:读者只会看到旧完整文件或新完整文件,
|
||||
不会看到半写入状态。
|
||||
"""
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
return []
|
||||
if not raw.strip():
|
||||
return []
|
||||
return [entry for entry in (item.strip() for item in raw.split(ENTRY_DELIMITER)) if entry]
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, entries: list[str]) -> None:
|
||||
"""以原子方式写入记忆文件。
|
||||
|
||||
这里不能直接 `open(path, "w")`,因为那会先截断原文件,再写新内容。
|
||||
如果恰好此时别的进程正在读,就可能读到空文件或半成品。
|
||||
|
||||
正确方式是:
|
||||
1. 在同目录创建临时文件
|
||||
2. 写入并 fsync
|
||||
3. 使用 `os.replace()` 原子替换
|
||||
"""
|
||||
content = ENTRY_DELIMITER.join(entries) if entries else ""
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
||||
handle.write(content)
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
@ -0,0 +1,2 @@
|
||||
"""Reusable procedures."""
|
||||
|
||||
2
app-instance/backend/beaver/memory/runs/__init__.py
Normal file
2
app-instance/backend/beaver/memory/runs/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Run records."""
|
||||
|
||||
5
app-instance/backend/beaver/memory/search/__init__.py
Normal file
5
app-instance/backend/beaver/memory/search/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Session transcript search storage."""
|
||||
|
||||
from .transcript_store import TranscriptStore
|
||||
|
||||
__all__ = ["TranscriptStore"]
|
||||
@ -0,0 +1,46 @@
|
||||
"""兼容层:过渡期把旧 transcript store 导向新的 session 子系统。
|
||||
|
||||
真正的主实现现在在:
|
||||
1. `beaver.engine.session.store`
|
||||
2. `beaver.engine.session.search`
|
||||
3. `beaver.engine.session.manager`
|
||||
|
||||
保留这个文件只是为了避免已经写好的 MCP server / tool 导入立刻断掉。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine.session.manager import SessionManager
|
||||
|
||||
|
||||
class TranscriptStore:
|
||||
"""兼容旧接口的薄封装。"""
|
||||
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
path = Path(db_path)
|
||||
workspace = path.parent.parent if path.parent.name == "sessions" else path.parent
|
||||
self.manager = SessionManager(workspace=workspace, db_path=path)
|
||||
|
||||
def close(self) -> None:
|
||||
self.manager.close()
|
||||
|
||||
def ensure_session(self, session_id: str, **kwargs: Any) -> str:
|
||||
return self.manager.ensure_session(session_id, **kwargs)
|
||||
|
||||
def append_message(self, session_id: str, **kwargs: Any) -> int:
|
||||
return self.manager.append_message(session_id, **kwargs)
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
return self.manager.get_session(session_id)
|
||||
|
||||
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.manager.list_sessions_rich(**kwargs)
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
return self.manager.get_messages_as_conversation(session_id)
|
||||
|
||||
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.manager.search_messages(**kwargs)
|
||||
2
app-instance/backend/beaver/memory/skills/__init__.py
Normal file
2
app-instance/backend/beaver/memory/skills/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Memory related to skill evolution."""
|
||||
|
||||
2
app-instance/backend/beaver/memory/stores/__init__.py
Normal file
2
app-instance/backend/beaver/memory/stores/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Storage backends for memory."""
|
||||
|
||||
2
app-instance/backend/beaver/permissions/__init__.py
Normal file
2
app-instance/backend/beaver/permissions/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Permission and governance layer."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Execution guards."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Permission policies."""
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
"""Agent permission profiles."""
|
||||
|
||||
2
app-instance/backend/beaver/plugins/__init__.py
Normal file
2
app-instance/backend/beaver/plugins/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Plugin system for Beaver."""
|
||||
|
||||
2
app-instance/backend/beaver/plugins/hooks.py
Normal file
2
app-instance/backend/beaver/plugins/hooks.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Plugin extension hooks."""
|
||||
|
||||
2
app-instance/backend/beaver/plugins/loader.py
Normal file
2
app-instance/backend/beaver/plugins/loader.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Plugin loading hooks."""
|
||||
|
||||
2
app-instance/backend/beaver/plugins/registry.py
Normal file
2
app-instance/backend/beaver/plugins/registry.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Plugin registry."""
|
||||
|
||||
6
app-instance/backend/beaver/services/__init__.py
Normal file
6
app-instance/backend/beaver/services/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Application services for Beaver."""
|
||||
|
||||
from .agent_service import AgentService
|
||||
from .memory_service import MemoryService
|
||||
|
||||
__all__ = ["AgentService", "MemoryService"]
|
||||
2
app-instance/backend/beaver/services/admin_service.py
Normal file
2
app-instance/backend/beaver/services/admin_service.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Administrative application service."""
|
||||
|
||||
212
app-instance/backend/beaver/services/agent_service.py
Normal file
212
app-instance/backend/beaver/services/agent_service.py
Normal file
@ -0,0 +1,212 @@
|
||||
"""Application service for agent entry.
|
||||
|
||||
这层的职责是把“接口层如何调用 AgentLoop”统一收口。
|
||||
|
||||
接口层以后不应该各自做这些事情:
|
||||
1. 自己 new `AgentLoop`
|
||||
2. 自己决定何时 `boot()`
|
||||
3. 自己处理 direct run 的同步/异步包装
|
||||
|
||||
统一放在 `AgentService` 后,CLI / Web / Gateway 才能共享同一条运行主链。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""面向 interfaces 的统一 agent 运行入口。
|
||||
|
||||
这里明确区分两种调用模式:
|
||||
1. direct mode
|
||||
- 不启动后台运行循环
|
||||
- 直接调用 `process_direct()` / `run_direct()`
|
||||
2. running mode
|
||||
- 先 `await start()`
|
||||
- 之后所有外部任务都必须走 `submit_direct()`
|
||||
- 不允许再直接调用 `process_direct()`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
profile: AgentProfile | None = None,
|
||||
loader: EngineLoader | None = None,
|
||||
) -> None:
|
||||
self.profile = profile or AgentProfile()
|
||||
self.loader = loader or EngineLoader(workspace=workspace)
|
||||
self._loop: AgentLoop | None = None
|
||||
self._run_task: asyncio.Task[None] | None = None
|
||||
|
||||
def create_loop(self) -> AgentLoop:
|
||||
"""创建并缓存当前 service 使用的 AgentLoop。"""
|
||||
|
||||
if self._loop is None:
|
||||
self._loop = AgentLoop(profile=self.profile, loader=self.loader)
|
||||
self._loop.boot()
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def has_loop(self) -> bool:
|
||||
"""当前 service 是否已经创建过 loop。"""
|
||||
|
||||
return self._loop is not None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""当前 service 是否处于 running mode。"""
|
||||
|
||||
return self._run_task is not None and not self._run_task.done()
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭当前 service 持有的 runtime。"""
|
||||
|
||||
if self._run_task is not None and not self._run_task.done():
|
||||
raise RuntimeError("AgentService.close() requires stop() before closing a running loop")
|
||||
self._run_task = None
|
||||
if self._loop is None:
|
||||
return
|
||||
try:
|
||||
self._loop.close()
|
||||
finally:
|
||||
self._loop = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动后台运行循环,进入 running mode。
|
||||
|
||||
进入 running mode 后:
|
||||
- 外部任务必须通过 `submit_direct()` 提交
|
||||
- `process_direct()` 不再允许直接调用
|
||||
"""
|
||||
|
||||
if self._run_task is not None and not self._run_task.done():
|
||||
return
|
||||
loop = self.create_loop()
|
||||
self._run_task = asyncio.create_task(loop.run())
|
||||
while not loop.is_running:
|
||||
if self._run_task.done():
|
||||
await self._run_task
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def _stop_impl(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: float | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""内部停止实现,支持 graceful timeout 和可选 force cancel。"""
|
||||
|
||||
if self._run_task is None:
|
||||
return
|
||||
run_task = self._run_task
|
||||
loop = self.create_loop()
|
||||
try:
|
||||
await loop.stop()
|
||||
if timeout_seconds is None:
|
||||
await run_task
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(run_task), timeout=timeout_seconds)
|
||||
except asyncio.TimeoutError as exc:
|
||||
if force:
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
raise TimeoutError(
|
||||
f"AgentService.stop() timed out after {timeout_seconds} seconds while draining queued tasks"
|
||||
) from exc
|
||||
finally:
|
||||
if run_task.done():
|
||||
self._run_task = None
|
||||
|
||||
async def stop(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: float | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""停止后台运行循环并等待退出。
|
||||
|
||||
参数:
|
||||
- `timeout_seconds`: graceful drain 的最长等待时间;`None` 表示一直等
|
||||
- `force`: 超时后是否 cancel 掉运行循环 task
|
||||
"""
|
||||
|
||||
await self._stop_impl(timeout_seconds=timeout_seconds, force=force)
|
||||
|
||||
async def shutdown(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: float | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""先停运行循环,再释放 runtime。"""
|
||||
|
||||
await self._stop_impl(timeout_seconds=timeout_seconds, force=force)
|
||||
self.close()
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
message: str,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResult:
|
||||
"""异步 direct run 入口。
|
||||
|
||||
仅在 direct mode 下可用。
|
||||
|
||||
如果 service 已经 `start()` 进入 running mode,
|
||||
调用方必须改用 `submit_direct()`,不能绕过运行队列直接执行。
|
||||
"""
|
||||
|
||||
if self._run_task is not None and not self._run_task.done():
|
||||
raise RuntimeError(
|
||||
"AgentService.process_direct() is unavailable while the service is running; "
|
||||
"use 'await AgentService.submit_direct(...)' after start()."
|
||||
)
|
||||
loop = self.create_loop()
|
||||
return await loop.process_direct(message, **kwargs)
|
||||
|
||||
async def submit_direct(
|
||||
self,
|
||||
message: str,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResult:
|
||||
"""向 running mode 下的 loop 提交 direct task。
|
||||
|
||||
这是 `start()` 之后唯一合法的外部任务入口。
|
||||
"""
|
||||
|
||||
loop = self.create_loop()
|
||||
return await loop.submit_direct(message, **kwargs)
|
||||
|
||||
def run_direct(
|
||||
self,
|
||||
message: str,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResult:
|
||||
"""同步 direct run 包装。
|
||||
|
||||
主要给当前 CLI 或简单脚本使用。真正的长期方向仍然是让 interfaces
|
||||
在 direct mode 下直接走 `await process_direct(...)`。
|
||||
"""
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"AgentService.run_direct() cannot be used inside an active event loop; "
|
||||
"use 'await AgentService.process_direct(...)' instead."
|
||||
)
|
||||
return asyncio.run(self.process_direct(message, **kwargs))
|
||||
65
app-instance/backend/beaver/services/memory_service.py
Normal file
65
app-instance/backend/beaver/services/memory_service.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Beaver memory 应用服务。
|
||||
|
||||
这层不是新的 memory 实现,而是对现有 `MemoryStore + MemorySnapshot` 的应用层包装。
|
||||
|
||||
目标只有三个:
|
||||
1. 把“本轮运行前需要 refresh live state”这件事集中到一个地方
|
||||
2. 把“给 context builder 的只能是 frozen snapshot”这条规则写死
|
||||
3. 让 `AgentLoop` 不再直接操作 `MemoryStore` 细节
|
||||
|
||||
设计边界:
|
||||
1. 记忆实际读写逻辑仍然在 `beaver.memory.curated.store.MemoryStore`
|
||||
2. memory tool 仍然直接写 store
|
||||
3. 本服务只负责 runtime 接入策略,不负责 CRUD 业务本身
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot, capture_memory_snapshot
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""统一封装 runtime 对 curated memory 的访问方式。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str | Path,
|
||||
*,
|
||||
store: MemoryStore | None = None,
|
||||
) -> None:
|
||||
self.root = Path(root)
|
||||
self.store = store or MemoryStore(self.root)
|
||||
self._snapshot: MemorySnapshot | None = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""启动时加载一次磁盘内容,建立首份 frozen snapshot 基线。"""
|
||||
|
||||
self.store.load_from_disk()
|
||||
self._snapshot = capture_memory_snapshot(self.store)
|
||||
|
||||
def reload_for_new_run(self) -> None:
|
||||
"""每次新 run 开始前刷新 live state。
|
||||
|
||||
这是 Hermes 风格 memory policy 的关键点:
|
||||
- 上一次会话中通过 tool 写入的持久记忆,下一次运行应该能看到
|
||||
- 但同一次 run 中途写入的新记忆,不应反向修改当前 frozen snapshot
|
||||
"""
|
||||
|
||||
self.store.load_from_disk()
|
||||
self._snapshot = capture_memory_snapshot(self.store)
|
||||
|
||||
def get_snapshot(self) -> MemorySnapshot:
|
||||
"""获取当前 run 应注入 system prompt 的 frozen snapshot。"""
|
||||
|
||||
if self._snapshot is None:
|
||||
# 兜底场景:如果调用方绕过 initialize/reload,首次读取时仍建立一份快照。
|
||||
self._snapshot = capture_memory_snapshot(self.store)
|
||||
return self._snapshot
|
||||
|
||||
def get_store(self) -> MemoryStore:
|
||||
"""暴露底层 store 给需要直接调用 CRUD 的工具层。"""
|
||||
|
||||
return self.store
|
||||
2
app-instance/backend/beaver/services/skill_service.py
Normal file
2
app-instance/backend/beaver/services/skill_service.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Application service for skills."""
|
||||
|
||||
10
app-instance/backend/beaver/services/team_service.py
Normal file
10
app-instance/backend/beaver/services/team_service.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Application service for coordinated team runs."""
|
||||
|
||||
|
||||
class TeamService:
|
||||
"""Placeholder service for multi-agent execution."""
|
||||
|
||||
def run(self, task: str) -> str:
|
||||
"""Return a placeholder summary until real backends are migrated."""
|
||||
return f"team run placeholder: {task}"
|
||||
|
||||
12
app-instance/backend/beaver/skills/__init__.py
Normal file
12
app-instance/backend/beaver/skills/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Skill system for Beaver."""
|
||||
|
||||
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
|
||||
from .catalog import SkillRecord, SkillsLoader
|
||||
|
||||
__all__ = [
|
||||
"SkillAssembler",
|
||||
"SkillAssemblyResult",
|
||||
"SkillEmbeddingRetriever",
|
||||
"SkillRecord",
|
||||
"SkillsLoader",
|
||||
]
|
||||
6
app-instance/backend/beaver/skills/assembler/__init__.py
Normal file
6
app-instance/backend/beaver/skills/assembler/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Skill assembly for Beaver."""
|
||||
|
||||
from .embedding_retriever import SkillEmbeddingRetriever
|
||||
from .task_assembler import SkillAssemblyResult, SkillAssembler
|
||||
|
||||
__all__ = ["SkillAssemblyResult", "SkillAssembler", "SkillEmbeddingRetriever"]
|
||||
@ -0,0 +1,188 @@
|
||||
"""Embedding-based skill candidate retrieval.
|
||||
|
||||
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
|
||||
阿里云百炼 `text-embedding-v4` 做最小语义召回:
|
||||
1. 复用当前 provider 的 `api_key/api_base`
|
||||
2. 先用 embedding 相似度召回一小批候选
|
||||
3. 再交给上层 LLM selector 做最终技能选择
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import json
|
||||
from urllib import request
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SkillEmbeddingRetriever:
|
||||
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key_env: str = "OPENAI_API_KEY",
|
||||
api_base_env: str = "OPENAI_API_BASE",
|
||||
model: str = "text-embedding-v4",
|
||||
timeout_seconds: float = 20.0,
|
||||
) -> None:
|
||||
self.api_key_env = api_key_env
|
||||
self.api_base_env = api_base_env
|
||||
self.model = model
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
candidates: list[dict[str, str]],
|
||||
top_k: int = 12,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""按 embedding 相似度召回 top-k 候选。
|
||||
|
||||
如果没有可用的 API Key / base URL,或者 embedding 调用失败,
|
||||
当前阶段先退回到“全部候选交给 LLM selector”。
|
||||
"""
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
resolved_api_key = api_key or os.getenv(self.api_key_env)
|
||||
resolved_api_base = api_base or os.getenv(self.api_base_env)
|
||||
if not resolved_api_key or not resolved_api_base:
|
||||
return candidates
|
||||
|
||||
try:
|
||||
query_embedding = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=[query],
|
||||
model=model or self.model,
|
||||
)
|
||||
candidate_texts = [self._candidate_text(item) for item in candidates]
|
||||
candidate_embeddings = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=candidate_texts,
|
||||
model=model or self.model,
|
||||
)
|
||||
except Exception:
|
||||
return candidates
|
||||
|
||||
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
|
||||
return candidates
|
||||
|
||||
query_vector = query_embedding[0]
|
||||
scored: list[tuple[float, dict[str, str]]] = []
|
||||
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
|
||||
if not vector:
|
||||
continue
|
||||
scored.append((self._cosine_similarity(query_vector, vector), candidate))
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
return [item[1] for item in scored[:top_k]]
|
||||
|
||||
async def _embed_texts(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
) -> list[list[float]]:
|
||||
"""调用 OpenAI-compatible embeddings 接口。
|
||||
|
||||
当前对齐的是你们实际在用的网关配置:
|
||||
- `POST {api_base}/embeddings`
|
||||
- `model=text-embedding-v4`
|
||||
- `encoding_format=float`
|
||||
"""
|
||||
|
||||
all_vectors: list[list[float]] = []
|
||||
endpoint = self._normalize_embeddings_endpoint(api_base)
|
||||
for start in range(0, len(texts), 10):
|
||||
batch = texts[start:start + 10]
|
||||
payload = await self._post_embeddings(
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=batch,
|
||||
)
|
||||
embeddings = payload.get("data") or []
|
||||
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
|
||||
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
|
||||
return all_vectors
|
||||
|
||||
async def _post_embeddings(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
) -> dict[str, Any]:
|
||||
return await asyncio.to_thread(
|
||||
self._post_embeddings_sync,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def _post_embeddings_sync(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
) -> dict[str, Any]:
|
||||
body = json.dumps(
|
||||
{
|
||||
"model": model,
|
||||
"input": texts if len(texts) > 1 else texts[0],
|
||||
"encoding_format": "float",
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = request.Request(
|
||||
endpoint,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=self.timeout_seconds) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def _candidate_text(candidate: dict[str, str]) -> str:
|
||||
name = (candidate.get("name") or "").strip()
|
||||
description = (candidate.get("description") or "").strip()
|
||||
return f"{name}\n{description}".strip()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_embeddings_endpoint(api_base: str) -> str:
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/embeddings"):
|
||||
return base
|
||||
if base.endswith("/v1"):
|
||||
return f"{base}/embeddings"
|
||||
return f"{base}/v1/embeddings"
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||||
if not left or not right or len(left) != len(right):
|
||||
return -1.0
|
||||
dot = sum(a * b for a, b in zip(left, right, strict=False))
|
||||
left_norm = math.sqrt(sum(a * a for a in left))
|
||||
right_norm = math.sqrt(sum(b * b for b in right))
|
||||
if left_norm == 0 or right_norm == 0:
|
||||
return -1.0
|
||||
return dot / (left_norm * right_norm)
|
||||
168
app-instance/backend/beaver/skills/assembler/task_assembler.py
Normal file
168
app-instance/backend/beaver/skills/assembler/task_assembler.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""LLM-driven skill assembler.
|
||||
|
||||
这层现在不再自己做规则打分,而是直接把:
|
||||
1. task description
|
||||
2. embedding 召回后的候选 skill 摘要
|
||||
|
||||
交给一个模型来决定本轮要激活哪些 skill。
|
||||
|
||||
当前目标非常克制:
|
||||
- 输入尽量简单
|
||||
- 输出只要 skill 名称
|
||||
- 没有命中就返回空 skills
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine.context import SkillContext
|
||||
from beaver.engine.providers.base import LLMProvider
|
||||
from beaver.engine.providers.runtime import ProviderRuntime
|
||||
from beaver.skills.catalog.loader import SkillsLoader
|
||||
from beaver.skills.catalog.utils import strip_frontmatter
|
||||
from .embedding_retriever import SkillEmbeddingRetriever
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SkillAssemblyResult:
|
||||
"""一次装配后真正要注入当前 run 的 skills。"""
|
||||
|
||||
activated_skills: list[SkillContext] = field(default_factory=list)
|
||||
|
||||
|
||||
class SkillAssembler:
|
||||
"""用 LLM 根据 task description 选择当前 run 的 skills。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader: SkillsLoader,
|
||||
retriever: SkillEmbeddingRetriever | None = None,
|
||||
) -> None:
|
||||
self.loader = loader
|
||||
self.retriever = retriever or SkillEmbeddingRetriever()
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
*,
|
||||
task_description: str,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
embedding_runtime: ProviderRuntime | None = None,
|
||||
top_k: int = 12,
|
||||
) -> SkillAssemblyResult:
|
||||
candidates = self.loader.build_selection_candidates()
|
||||
if not candidates:
|
||||
return SkillAssemblyResult()
|
||||
candidates = await self.retriever.retrieve(
|
||||
query=task_description,
|
||||
candidates=candidates,
|
||||
top_k=top_k,
|
||||
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
|
||||
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
|
||||
model=embedding_runtime.model if embedding_runtime is not None else None,
|
||||
)
|
||||
if not candidates:
|
||||
return SkillAssemblyResult()
|
||||
|
||||
selected_names = await self._select_skill_names(
|
||||
task_description=task_description,
|
||||
candidates=candidates,
|
||||
provider=provider,
|
||||
model=model,
|
||||
)
|
||||
if not selected_names:
|
||||
return SkillAssemblyResult()
|
||||
|
||||
activated_skills: list[SkillContext] = []
|
||||
for name in selected_names:
|
||||
raw_content = self.loader.load_skill(name)
|
||||
content = strip_frontmatter(raw_content).strip() if raw_content else ""
|
||||
if not content:
|
||||
continue
|
||||
activated_skills.append(SkillContext(name=name, content=content))
|
||||
|
||||
return SkillAssemblyResult(activated_skills=activated_skills)
|
||||
|
||||
async def _select_skill_names(
|
||||
self,
|
||||
*,
|
||||
task_description: str,
|
||||
candidates: list[dict[str, str]],
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
) -> list[str]:
|
||||
candidate_summary = self._render_candidates(candidates)
|
||||
candidate_names = {item["name"] for item in candidates}
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You select Beaver skills for a single run. "
|
||||
"Given a task description and candidate skill summaries, "
|
||||
"return only a JSON array of skill names to activate. "
|
||||
"Do not invent names. If nothing matches, return []."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Task description:\n{task_description}\n\n"
|
||||
f"Candidate skills:\n{candidate_summary}\n\n"
|
||||
"Return only JSON, for example: [\"skill-a\", \"skill-b\"]"
|
||||
),
|
||||
},
|
||||
]
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
tools=None,
|
||||
model=model,
|
||||
max_tokens=512,
|
||||
temperature=0,
|
||||
)
|
||||
if response.finish_reason == "error" or not response.content:
|
||||
return []
|
||||
|
||||
parsed = self._parse_selected_names(response.content)
|
||||
if not parsed:
|
||||
return []
|
||||
|
||||
# 只保留当前候选集中真实存在的 skill 名称,并维持模型输出顺序。
|
||||
filtered: list[str] = []
|
||||
for name in parsed:
|
||||
if name in candidate_names and name not in filtered:
|
||||
filtered.append(name)
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def _render_candidates(candidates: list[dict[str, str]]) -> str:
|
||||
lines: list[str] = []
|
||||
for item in candidates:
|
||||
lines.append(f"- {item['name']}: {item['description']}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selected_names(content: str) -> list[str]:
|
||||
cleaned = content.strip()
|
||||
if cleaned.startswith("```"):
|
||||
lines = cleaned.splitlines()
|
||||
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||
cleaned = "\n".join(lines[1:-1]).strip()
|
||||
|
||||
try:
|
||||
payload: Any = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if isinstance(payload, dict):
|
||||
for key in ("skills", "selected_skills", "activated_skills", "selected"):
|
||||
value = payload.get(key)
|
||||
if isinstance(value, list):
|
||||
payload = value
|
||||
break
|
||||
|
||||
if not isinstance(payload, list):
|
||||
return []
|
||||
return [item.strip() for item in payload if isinstance(item, str) and item.strip()]
|
||||
2
app-instance/backend/beaver/skills/builtin/__init__.py
Normal file
2
app-instance/backend/beaver/skills/builtin/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Built-in skill payloads."""
|
||||
|
||||
5
app-instance/backend/beaver/skills/catalog/__init__.py
Normal file
5
app-instance/backend/beaver/skills/catalog/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Skill catalog and indexing."""
|
||||
|
||||
from .loader import SkillRecord, SkillsLoader
|
||||
|
||||
__all__ = ["SkillRecord", "SkillsLoader"]
|
||||
281
app-instance/backend/beaver/skills/catalog/loader.py
Normal file
281
app-instance/backend/beaver/skills/catalog/loader.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""Beaver skills catalog loader。
|
||||
|
||||
第一版目标非常明确:
|
||||
|
||||
1. 扫描技能目录
|
||||
2. 读取 `SKILL.md`
|
||||
3. 解析前置元数据
|
||||
4. 生成可注入上下文的正文与索引
|
||||
|
||||
这层不负责:
|
||||
1. 动态选择本轮应该启用哪些 skill
|
||||
2. skill review / publishing
|
||||
3. skill 自动学习
|
||||
|
||||
这些决策属于 resolver 或更高层工作流。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
check_requirements,
|
||||
escape_xml,
|
||||
get_missing_requirements,
|
||||
parse_frontmatter,
|
||||
parse_skill_metadata_blob,
|
||||
strip_frontmatter,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SkillRecord:
|
||||
"""单个 skill 的目录级元数据。"""
|
||||
|
||||
name: str
|
||||
path: Path
|
||||
source: str
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""从 workspace/builtin 目录中发现并读取 skills。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str | Path,
|
||||
*,
|
||||
builtin_skills_dir: str | Path | None = None,
|
||||
extra_dirs: list[str | Path] | None = None,
|
||||
) -> None:
|
||||
self.workspace = Path(workspace)
|
||||
self.workspace_skills = self.workspace / "skills"
|
||||
self.builtin_skills = Path(builtin_skills_dir) if builtin_skills_dir is not None else Path(__file__).resolve().parent.parent / "builtin"
|
||||
self.extra_dirs = [Path(item) for item in (extra_dirs or [])]
|
||||
|
||||
def list_skills(self, *, filter_unavailable: bool = True) -> list[SkillRecord]:
|
||||
"""列出当前可见的 skills。
|
||||
|
||||
优先级:
|
||||
1. workspace
|
||||
2. extra/plugin 目录
|
||||
3. builtin
|
||||
|
||||
重名 skill 只保留优先级更高的那一个。
|
||||
"""
|
||||
|
||||
ordered_roots: list[tuple[str, Path]] = [
|
||||
("workspace", self.workspace_skills),
|
||||
*[("plugin", path) for path in self.extra_dirs],
|
||||
("builtin", self.builtin_skills),
|
||||
]
|
||||
found: dict[str, SkillRecord] = {}
|
||||
|
||||
for source, root in ordered_roots:
|
||||
if not root.exists():
|
||||
continue
|
||||
for skill_dir in root.iterdir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if not skill_dir.is_dir() or not skill_file.exists():
|
||||
continue
|
||||
name = skill_dir.name
|
||||
if name in found:
|
||||
continue
|
||||
record = SkillRecord(name=name, path=skill_file, source=source)
|
||||
if filter_unavailable and not self._record_available(record):
|
||||
continue
|
||||
found[name] = record
|
||||
return list(found.values())
|
||||
|
||||
def load_skill(self, name: str) -> str | None:
|
||||
"""按名称加载 skill 原始内容。"""
|
||||
|
||||
record = self._find_record(name)
|
||||
if record is None:
|
||||
return None
|
||||
return record.path.read_text(encoding="utf-8")
|
||||
|
||||
def get_skill_record(self, name: str) -> SkillRecord | None:
|
||||
"""按名称返回 skill record。"""
|
||||
|
||||
return self._find_record(name)
|
||||
|
||||
def get_skill_metadata(self, name: str) -> dict[str, Any] | None:
|
||||
"""读取 skill frontmatter 元数据。"""
|
||||
|
||||
content = self.load_skill(name)
|
||||
if content is None:
|
||||
return None
|
||||
metadata, _ = parse_frontmatter(content)
|
||||
return metadata
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
"""加载指定 skills 的正文,并整理成上下文块。"""
|
||||
|
||||
sections: list[str] = []
|
||||
for name in skill_names:
|
||||
content = self.load_skill(name)
|
||||
if not content:
|
||||
continue
|
||||
body = strip_frontmatter(content).strip()
|
||||
if not body:
|
||||
continue
|
||||
sections.append(f"## {name}\n\n{body}")
|
||||
return "\n\n".join(sections)
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
"""构建可注入 system prompt 的 skills index。
|
||||
|
||||
虽然函数名还沿用 `summary`,但当前语义已经更接近 Hermes 的 skills index:
|
||||
- 这里只告诉模型“系统里有哪些 skill 可用”
|
||||
- 不负责把 skill 正文塞进 system prompt
|
||||
- 真正激活的 skill 正文由 resolver/builder 走显式消息注入
|
||||
"""
|
||||
|
||||
skills = self.list_skills(filter_unavailable=False)
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
lines = ["<skills>"]
|
||||
for record in skills:
|
||||
frontmatter = self.get_skill_metadata(record.name) or {}
|
||||
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
|
||||
available = check_requirements(meta_blob)
|
||||
description = frontmatter.get("description") or record.name
|
||||
load_hint = f'Use skill_view(name="{record.name}") to load the full skill.'
|
||||
lines.append(f' <skill available="{str(available).lower()}">')
|
||||
lines.append(f" <name>{escape_xml(record.name)}</name>")
|
||||
lines.append(f" <description>{escape_xml(description)}</description>")
|
||||
lines.append(f" <load_hint>{escape_xml(load_hint)}</load_hint>")
|
||||
support_files = self.list_skill_supporting_files(record.name)
|
||||
if support_files:
|
||||
lines.append(" <supporting_files>")
|
||||
for file_path in support_files[:12]:
|
||||
lines.append(f" <file>{escape_xml(file_path)}</file>")
|
||||
if len(support_files) > 12:
|
||||
lines.append(" <file>...additional files omitted...</file>")
|
||||
lines.append(" </supporting_files>")
|
||||
if not available:
|
||||
missing = get_missing_requirements(meta_blob)
|
||||
if missing:
|
||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_selection_candidates(self) -> list[dict[str, str]]:
|
||||
"""构建给 LLM selector 使用的候选 skill 摘要。
|
||||
|
||||
这里刻意保持精简,只给:
|
||||
- `name`
|
||||
- `description`
|
||||
|
||||
选择器的任务只是“从候选里挑名字”,不是直接阅读完整 skill 正文。
|
||||
真正激活后的 skill 正文仍然在后续阶段按需加载。
|
||||
"""
|
||||
|
||||
candidates: list[dict[str, str]] = []
|
||||
for record in self.list_skills(filter_unavailable=True):
|
||||
frontmatter = self.get_skill_metadata(record.name) or {}
|
||||
description = str(frontmatter.get("description") or "").strip()
|
||||
if not description:
|
||||
raw_content = self.load_skill(record.name) or ""
|
||||
body = strip_frontmatter(raw_content).strip()
|
||||
if body:
|
||||
description = " ".join(body.splitlines()[:3])[:240].strip()
|
||||
candidates.append(
|
||||
{
|
||||
"name": record.name,
|
||||
"description": description or record.name,
|
||||
}
|
||||
)
|
||||
return candidates
|
||||
|
||||
def list_skill_supporting_files(self, name: str) -> list[str]:
|
||||
"""列出 skill 目录下可按需查看的支持文件相对路径。"""
|
||||
|
||||
record = self._find_record(name)
|
||||
if record is None:
|
||||
return []
|
||||
skill_dir = record.path.parent
|
||||
results: list[str] = []
|
||||
for subdir in ("references", "templates", "scripts", "assets"):
|
||||
root = skill_dir / subdir
|
||||
if not root.exists():
|
||||
continue
|
||||
for file in sorted(root.rglob("*")):
|
||||
if file.is_file() and not file.is_symlink():
|
||||
results.append(str(file.relative_to(skill_dir)))
|
||||
return results
|
||||
|
||||
def view_skill(self, name: str, file_path: str | None = None) -> tuple[str, str] | None:
|
||||
"""读取 skill 正文或其支持文件。
|
||||
|
||||
返回 `(display_name, content)`:
|
||||
- `display_name` 用于提示当前读取的是 skill 本体还是某个支持文件
|
||||
- `content` 为实际文本内容
|
||||
"""
|
||||
|
||||
record = self._find_record(name)
|
||||
if record is None:
|
||||
return None
|
||||
if not self._record_available(record):
|
||||
frontmatter = self.get_skill_metadata(name) or {}
|
||||
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
|
||||
missing = get_missing_requirements(meta_blob)
|
||||
detail = f" Missing requirements: {missing}." if missing else ""
|
||||
raise ValueError(f"Skill '{name}' is currently unavailable.{detail}")
|
||||
|
||||
skill_dir = record.path.parent
|
||||
if not file_path:
|
||||
return ("SKILL.md", self._read_text_file(record.path, display_name="SKILL.md"))
|
||||
|
||||
candidate = (skill_dir / file_path).resolve()
|
||||
try:
|
||||
candidate.relative_to(skill_dir.resolve())
|
||||
except ValueError as exc:
|
||||
raise ValueError("Requested skill file must stay within the skill directory") from exc
|
||||
if not candidate.exists() or not candidate.is_file():
|
||||
raise FileNotFoundError(f"Skill file '{file_path}' does not exist")
|
||||
display_name = str(candidate.relative_to(skill_dir))
|
||||
return (display_name, self._read_text_file(candidate, display_name=display_name))
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""返回标记为 always 的可用 skill 名称。"""
|
||||
|
||||
result: list[str] = []
|
||||
for record in self.list_skills(filter_unavailable=True):
|
||||
frontmatter = self.get_skill_metadata(record.name) or {}
|
||||
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
|
||||
if meta_blob.get("always") or str(frontmatter.get("always", "")).lower() == "true":
|
||||
result.append(record.name)
|
||||
return result
|
||||
|
||||
def _find_record(self, name: str) -> SkillRecord | None:
|
||||
for record in self.list_skills(filter_unavailable=False):
|
||||
if record.name == name:
|
||||
return record
|
||||
return None
|
||||
|
||||
def _record_available(self, record: SkillRecord) -> bool:
|
||||
content = record.path.read_text(encoding="utf-8")
|
||||
frontmatter, _ = parse_frontmatter(content)
|
||||
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
|
||||
return check_requirements(meta_blob)
|
||||
|
||||
@staticmethod
|
||||
def _read_text_file(path: Path, *, display_name: str) -> str:
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError(
|
||||
f"Skill file '{display_name}' is not UTF-8 text and cannot be viewed with skill_view."
|
||||
) from exc
|
||||
|
||||
def _skill_available(self, name: str) -> bool:
|
||||
record = self._find_record(name)
|
||||
if record is None:
|
||||
return False
|
||||
return self._record_available(record)
|
||||
122
app-instance/backend/beaver/skills/catalog/utils.py
Normal file
122
app-instance/backend/beaver/skills/catalog/utils.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Skills catalog 的公共辅助函数。
|
||||
|
||||
这里专门放“解析和校验 skill 文件”的纯函数,避免 `loader.py` 里同时承担:
|
||||
|
||||
1. 目录扫描
|
||||
2. frontmatter 解析
|
||||
3. requirements 校验
|
||||
4. 文本裁剪/格式化
|
||||
|
||||
把这些细节拆出来之后,skills catalog 的边界会更清楚,后面无论是 reviews、publisher
|
||||
还是 runtime resolver,都可以复用同一套元数据解析规则。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
|
||||
"""解析 Markdown 文件顶部的极简 frontmatter。
|
||||
|
||||
当前先只支持最常见的:
|
||||
|
||||
```md
|
||||
---
|
||||
key: value
|
||||
key2: value2
|
||||
---
|
||||
body...
|
||||
```
|
||||
|
||||
这样足够支撑第一版 skills runtime,不提前把 YAML 解析器引进来。
|
||||
"""
|
||||
|
||||
if not content.startswith("---"):
|
||||
return {}, content
|
||||
|
||||
match = re.match(r"^---\n(.*?)\n---\n?", content, re.DOTALL)
|
||||
if match is None:
|
||||
return {}, content
|
||||
|
||||
metadata: dict[str, str] = {}
|
||||
for line in match.group(1).splitlines():
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
body = content[match.end():].strip()
|
||||
return metadata, body
|
||||
|
||||
|
||||
def strip_frontmatter(content: str) -> str:
|
||||
"""去掉 frontmatter,只保留 skill 正文。"""
|
||||
|
||||
_, body = parse_frontmatter(content)
|
||||
return body
|
||||
|
||||
|
||||
def parse_skill_metadata_blob(raw: str) -> dict[str, Any]:
|
||||
"""解析 metadata 字段里的 JSON 扩展配置。
|
||||
|
||||
为了兼容旧 nanobot 习惯,这里同时支持:
|
||||
- `nanobot`
|
||||
- `openclaw`
|
||||
|
||||
第一版主要关心的字段有:
|
||||
- `always`
|
||||
- `requires`
|
||||
"""
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
nested = data.get("nanobot", data.get("openclaw", data))
|
||||
return nested if isinstance(nested, dict) else {}
|
||||
|
||||
|
||||
def check_requirements(metadata: dict[str, Any]) -> bool:
|
||||
"""检查 skill 的最小 requirements 是否满足。"""
|
||||
|
||||
requires = metadata.get("requires", {})
|
||||
if not isinstance(requires, dict):
|
||||
return True
|
||||
|
||||
for binary in requires.get("bins", []):
|
||||
if not shutil.which(str(binary)):
|
||||
return False
|
||||
for env_name in requires.get("env", []):
|
||||
if not os.environ.get(str(env_name)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(metadata: dict[str, Any]) -> str:
|
||||
"""返回缺失 requirements 的简短描述。"""
|
||||
|
||||
requires = metadata.get("requires", {})
|
||||
if not isinstance(requires, dict):
|
||||
return ""
|
||||
|
||||
missing: list[str] = []
|
||||
for binary in requires.get("bins", []):
|
||||
if not shutil.which(str(binary)):
|
||||
missing.append(f"CLI: {binary}")
|
||||
for env_name in requires.get("env", []):
|
||||
if not os.environ.get(str(env_name)):
|
||||
missing.append(f"ENV: {env_name}")
|
||||
return ", ".join(missing)
|
||||
|
||||
|
||||
def escape_xml(value: str) -> str:
|
||||
"""给 skills summary 做最小 XML 转义。"""
|
||||
|
||||
return value.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
2
app-instance/backend/beaver/skills/drafts/__init__.py
Normal file
2
app-instance/backend/beaver/skills/drafts/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Draft skills generated before review."""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user