feat: 移除backend-old目录中的废弃文件
移除以下文件: - .dockerignore 和 .gitignore 配置文件 - A2A_Multiagent_change.md 设计文档 - COMMUNICATION.md 通讯信息文档 - Dockerfile 构建配置 - LICENSE 许可证文件 这些文件属于旧版本后端代码,不再需要维护。
This commit is contained in:
@ -1,13 +0,0 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info
|
||||
dist/
|
||||
build/
|
||||
.git
|
||||
.env
|
||||
.assets
|
||||
node_modules/
|
||||
bridge/dist/
|
||||
workspace/
|
||||
201
app-instance/backend-old/.gitignore
vendored
201
app-instance/backend-old/.gitignore
vendored
@ -1,201 +0,0 @@
|
||||
<<<<<<< HEAD
|
||||
.assets
|
||||
.env
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
docs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
tests/
|
||||
=======
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
>>>>>>> origin/main
|
||||
@ -1,753 +0,0 @@
|
||||
# A2A Multi-Agent 改造方案
|
||||
|
||||
## 1. 需求目标
|
||||
|
||||
当前 `spawn`/`sub-agent` 只有一种执行方式: 创建一个本地后台 subagent 去完成任务。
|
||||
|
||||
这次需求要改成:
|
||||
|
||||
1. 调用 `sub-agent` 时,不一定新建本地 subagent。
|
||||
2. 先从“已添加的 Agent”里找可用目标。
|
||||
3. 再从 skills 中声明的 `agent cards` 里找可用目标。
|
||||
4. 通过 A2A 协议把任务发给对应 agent。
|
||||
5. 支持一个任务发给多个 agent,形成 `agent group`,最后回到主 agent 汇总。
|
||||
6. 保持现有 `spawn(task, label)` 兼容,不破坏已有行为。
|
||||
|
||||
结论先说:
|
||||
|
||||
- 最合适的做法不是继续把能力堆进 `SubagentManager`。
|
||||
- 应该把“本地 subagent 执行”升级为“统一委派层”。
|
||||
- `spawn` 工具继续保留,但语义从“创建 subagent”扩展为“委派给合适的 agent / agent group”。
|
||||
|
||||
## 2. 当前代码现状
|
||||
|
||||
### 2.1 当前触发链路
|
||||
|
||||
现有链路很单一:
|
||||
|
||||
1. `AgentLoop` 初始化 `SubagentManager`
|
||||
- 位置: `nanobot/agent/loop.py:88-114`
|
||||
2. `AgentLoop._register_default_tools()` 注册 `SpawnTool`
|
||||
- 位置: `nanobot/agent/loop.py:116-138`
|
||||
3. LLM 调用 `spawn(task, label)`
|
||||
4. `SpawnTool.execute()` 直接转发给 `SubagentManager.spawn()`
|
||||
- 位置: `nanobot/agent/tools/spawn.py:67-76`
|
||||
5. `SubagentManager.spawn()` 创建本地 asyncio 后台任务
|
||||
- 位置: `nanobot/agent/subagent.py:64-93`
|
||||
6. `_run_subagent()` 用一个受限工具集运行本地子代理
|
||||
- 位置: `nanobot/agent/subagent.py:95-195`
|
||||
7. `_announce_result()` 把结果包装成 `channel="system"` 的消息回投主消息总线
|
||||
- 位置: `nanobot/agent/subagent.py:197-230`
|
||||
8. `AgentLoop._process_message()` 接到 `system` 消息,再整理成用户可见回复
|
||||
- 位置: `nanobot/agent/loop.py:331-347`
|
||||
|
||||
### 2.2 当前已经有但没接入调度链路的能力
|
||||
|
||||
仓库里已经有两类“候选 agent 信息”,但没有进入实际调度:
|
||||
|
||||
1. Plugin agents
|
||||
- `PluginLoader.find_agent()` 已能找 agent
|
||||
- 位置: `nanobot/agent/plugins.py:83-91`
|
||||
- `build_agents_summary()` 也已能汇总 agent 信息
|
||||
- 位置: `nanobot/agent/plugins.py:100-121`
|
||||
- 但当前 `AgentLoop` / `ContextBuilder` 并没有用它做调度
|
||||
|
||||
2. Skills
|
||||
- `SkillsLoader` 已能枚举 / 读取 skill
|
||||
- 位置: `nanobot/agent/skills.py:32-249`
|
||||
- 但 skill 目前只被当作 prompt 资源,不会暴露成“可路由 agent”
|
||||
|
||||
### 2.3 当前缺口
|
||||
|
||||
当前缺少这几层:
|
||||
|
||||
1. 统一的 `Agent Registry`
|
||||
2. A2A `agent card` 发现与缓存
|
||||
3. A2A client 调用层
|
||||
4. 统一的委派器,负责在“本地 subagent / plugin agent / skill agent card / agent group”之间做路由
|
||||
5. group 级别的状态管理和结果聚合
|
||||
|
||||
## 3. 推荐总方案
|
||||
|
||||
推荐采用“保留 `spawn` 工具名,重构内部执行层”的方案。
|
||||
|
||||
### 3.1 核心思路
|
||||
|
||||
把当前:
|
||||
|
||||
- `SpawnTool -> SubagentManager -> 本地 subagent`
|
||||
|
||||
改成:
|
||||
|
||||
- `SpawnTool -> DelegationManager -> AgentResolver -> Executor(local/plugin/a2a/group)`
|
||||
|
||||
也就是:
|
||||
|
||||
1. `spawn` 不再等价于“必须创建 subagent”。
|
||||
2. `spawn` 变成“委派任务”。
|
||||
3. 真正执行方式由委派层动态决定。
|
||||
|
||||
### 3.2 为什么这样最合适
|
||||
|
||||
如果直接继续扩 `SubagentManager`,很快会出现这些问题:
|
||||
|
||||
1. 一个类同时负责本地 LLM 运行、A2A 网络调用、agent card 发现、group 并发、结果聚合。
|
||||
2. 后续要支持 plugin agent、本地 named agent、A2A streaming 时会越来越乱。
|
||||
3. 当前 `SubagentManager` 的职责本来就已经比较明确: “本地后台 subagent 执行器”。
|
||||
|
||||
所以更合理的拆法是:
|
||||
|
||||
1. `SubagentManager` 保留或下沉为 `LocalSubagentExecutor`
|
||||
2. 新增 `DelegationManager` 作为统一入口
|
||||
3. 新增 `AgentRegistry` / `AgentResolver`
|
||||
4. 新增 `A2AClient`
|
||||
|
||||
## 4. 推荐模块拆分
|
||||
|
||||
### 4.1 新增 `DelegationManager`
|
||||
|
||||
建议新文件:
|
||||
|
||||
- `nanobot/agent/delegation.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 接收 `spawn` 请求
|
||||
2. 根据参数和任务内容选择目标 agent
|
||||
3. 决定执行方式
|
||||
4. 对 group 做并发调度
|
||||
5. 统一把结果回投主消息总线
|
||||
|
||||
建议接口:
|
||||
|
||||
```python
|
||||
class DelegationManager:
|
||||
async def dispatch(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
target: str | None = None,
|
||||
targets: list[str] | None = None,
|
||||
strategy: str = "auto",
|
||||
origin_channel: str = "cli",
|
||||
origin_chat_id: str = "direct",
|
||||
) -> str: ...
|
||||
```
|
||||
|
||||
### 4.2 保留本地执行器
|
||||
|
||||
当前 `nanobot/agent/subagent.py` 的 `_run_subagent()` 逻辑可以保留,但角色改为:
|
||||
|
||||
- `LocalSubagentExecutor`
|
||||
|
||||
也可以第一版不重命名文件,只把里面逻辑拆成:
|
||||
|
||||
1. `spawn_local()`
|
||||
2. `_run_local_subagent()`
|
||||
3. `_announce_local_result()`
|
||||
|
||||
这样可以最小改动落地。
|
||||
|
||||
### 4.3 新增 `AgentRegistry`
|
||||
|
||||
建议新文件:
|
||||
|
||||
- `nanobot/agent/agent_registry.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 汇总所有可调度 agent
|
||||
2. 统一输出规范化 descriptor
|
||||
3. 维护优先级和去重逻辑
|
||||
|
||||
统一后的 agent 来源:
|
||||
|
||||
1. workspace 中“已添加的 agent”
|
||||
2. plugin agents
|
||||
3. skill frontmatter 里声明的 `agent_cards`
|
||||
4. 必要时 fallback 到本地 `local-subagent`
|
||||
|
||||
建议统一 descriptor:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentDescriptor:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
source: str # workspace | plugin | skill | builtin
|
||||
kind: str # local_prompt | a2a_remote | local_fallback
|
||||
protocol: str | None # a2a | None
|
||||
plugin_name: str | None = None
|
||||
skill_name: str | None = None
|
||||
model: str | None = None
|
||||
endpoint: str | None = None
|
||||
card_url: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
```
|
||||
|
||||
### 4.4 新增 A2A client 层
|
||||
|
||||
建议新目录:
|
||||
|
||||
- `nanobot/a2a/client.py`
|
||||
- `nanobot/a2a/cards.py`
|
||||
- `nanobot/a2a/models.py`
|
||||
|
||||
职责:
|
||||
|
||||
1. 获取 agent card
|
||||
2. 解析 card 能力
|
||||
3. 对远端 agent 发 JSON-RPC 请求
|
||||
4. 处理同步返回 / task 轮询 / streaming 兼容
|
||||
|
||||
## 5. 代码插入点
|
||||
|
||||
## 5.1 `nanobot/agent/loop.py`
|
||||
|
||||
### 插入点 A: `__init__`
|
||||
|
||||
当前:
|
||||
|
||||
- `self.subagents = SubagentManager(...)`
|
||||
- 位置: `nanobot/agent/loop.py:88-102`
|
||||
|
||||
建议改成:
|
||||
|
||||
1. 初始化 `PluginLoader`
|
||||
2. 初始化 `AgentRegistry`
|
||||
3. 初始化 `DelegationManager`
|
||||
4. `DelegationManager` 内部持有 `LocalSubagentExecutor` / `A2AExecutor`
|
||||
|
||||
推荐形态:
|
||||
|
||||
```python
|
||||
self.plugins = PluginLoader(workspace)
|
||||
self.agent_registry = AgentRegistry(workspace, plugins=self.plugins, ...)
|
||||
self.delegation = DelegationManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
registry=self.agent_registry,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### 插入点 B: `_register_default_tools`
|
||||
|
||||
当前:
|
||||
|
||||
- 注册 `SpawnTool(manager=self.subagents)`
|
||||
- 位置: `nanobot/agent/loop.py:130-134`
|
||||
|
||||
建议改成:
|
||||
|
||||
```python
|
||||
self.tools.register(SpawnTool(manager=self.delegation))
|
||||
```
|
||||
|
||||
### 插入点 C: `_set_tool_context`
|
||||
|
||||
当前会给 `spawn` 工具写 origin context:
|
||||
|
||||
- 位置: `nanobot/agent/loop.py:165-192`
|
||||
|
||||
这里逻辑可以继续保留,不需要大改,因为 A2A / group 结果最终也要回到原会话。
|
||||
|
||||
## 5.2 `nanobot/agent/tools/spawn.py`
|
||||
|
||||
当前 `SpawnTool` 参数只有:
|
||||
|
||||
- `task`
|
||||
- `label`
|
||||
|
||||
位置:
|
||||
|
||||
- schema: `nanobot/agent/tools/spawn.py:49-65`
|
||||
- execute: `nanobot/agent/tools/spawn.py:67-76`
|
||||
|
||||
建议扩成:
|
||||
|
||||
```python
|
||||
{
|
||||
"task": "string",
|
||||
"label": "string?",
|
||||
"target": "string?",
|
||||
"targets": "string[]?",
|
||||
"strategy": "auto|local|plugin|a2a|group"
|
||||
}
|
||||
```
|
||||
|
||||
兼容规则:
|
||||
|
||||
1. 老调用只传 `task/label` 时,等价于 `strategy="auto"`
|
||||
2. `target` 表示单目标
|
||||
3. `targets` 表示 group
|
||||
4. `strategy="local"` 强制走本地 subagent
|
||||
5. `strategy="a2a"` 强制只找 A2A 目标
|
||||
|
||||
## 5.3 `nanobot/agent/context.py`
|
||||
|
||||
当前 prompt 中只注入:
|
||||
|
||||
1. bootstrap
|
||||
2. memory
|
||||
3. skills summary
|
||||
|
||||
位置:
|
||||
|
||||
- `build_system_prompt()`: `nanobot/agent/context.py:38-76`
|
||||
|
||||
建议新增一段:
|
||||
|
||||
- `# Available Agents`
|
||||
|
||||
由 `AgentRegistry.build_agents_summary()` 生成,内容只放:
|
||||
|
||||
1. agent id / name
|
||||
2. 简短 description
|
||||
3. source
|
||||
4. protocol
|
||||
5. 是否支持 group / streaming
|
||||
|
||||
目标是让主 agent 知道:
|
||||
|
||||
1. 当前有哪些现成 agent 可用
|
||||
2. 什么时候应该 `spawn(target=...)`
|
||||
3. 哪些是 skill 暴露出来的 A2A agent
|
||||
|
||||
## 5.4 `nanobot/agent/skills.py`
|
||||
|
||||
这是 skill agent cards 的关键入口。
|
||||
|
||||
当前 skill frontmatter 已支持 `metadata` 字段,并会解析其中的 JSON:
|
||||
|
||||
- `_parse_nanobot_metadata()`: `nanobot/agent/skills.py:190-196`
|
||||
- `_get_skill_meta()`: `nanobot/agent/skills.py:209-212`
|
||||
|
||||
最推荐的做法不是去扫 `SKILL.md` 正文里的自由文本,而是约定 skill frontmatter 的 `metadata.nanobot.agent_cards`。
|
||||
|
||||
建议新增:
|
||||
|
||||
```python
|
||||
def list_skill_agent_cards(self) -> list[dict[str, Any]]: ...
|
||||
```
|
||||
|
||||
推荐 skill 写法:
|
||||
|
||||
```md
|
||||
---
|
||||
name: github-research
|
||||
description: GitHub research helper
|
||||
metadata: '{"nanobot":{"agent_cards":[{"id":"repo-analyst","url":"https://example.com/.well-known/agent-card","tags":["github","research"],"auth_env":"REPO_AGENT_TOKEN"}]}}'
|
||||
---
|
||||
```
|
||||
|
||||
为什么推荐这样做:
|
||||
|
||||
1. 当前 frontmatter 解析已经存在
|
||||
2. 不需要引入新的 skill 文件格式
|
||||
3. 不需要解析自由文本
|
||||
4. skill 打包/上传链路也不需要大改
|
||||
|
||||
## 5.5 `nanobot/agent/plugins.py`
|
||||
|
||||
当前 plugin agents 已能加载:
|
||||
|
||||
- `find_agent()`: `nanobot/agent/plugins.py:83-91`
|
||||
- `_load_agents()`: `nanobot/agent/plugins.py:210-229`
|
||||
|
||||
建议:
|
||||
|
||||
1. `AgentRegistry` 直接复用 `PluginLoader`
|
||||
2. plugin agent 作为“本地可执行 agent”来源之一
|
||||
|
||||
这里不建议把 plugin agent 强行转成 A2A。
|
||||
|
||||
更合理的处理是:
|
||||
|
||||
1. plugin agent 本地执行
|
||||
2. skill agent cards 远程 A2A 调用
|
||||
3. workspace 手动添加的 agent 也可走 A2A
|
||||
|
||||
## 5.6 `nanobot/config/schema.py`
|
||||
|
||||
当前 `ToolsConfig` 只有:
|
||||
|
||||
- `web`
|
||||
- `exec`
|
||||
- `restrict_to_workspace`
|
||||
- `mcp_servers`
|
||||
|
||||
位置:
|
||||
|
||||
- `nanobot/config/schema.py:337-347`
|
||||
|
||||
建议新增:
|
||||
|
||||
```python
|
||||
class A2AConfig(Base):
|
||||
enabled: bool = True
|
||||
timeout_seconds: int = 30
|
||||
poll_interval_seconds: int = 2
|
||||
card_cache_ttl_seconds: int = 300
|
||||
max_parallel_agents: int = 4
|
||||
allow_skill_cards: bool = True
|
||||
allow_workspace_agents: bool = True
|
||||
allowed_hosts: list[str] = Field(default_factory=list)
|
||||
```
|
||||
|
||||
然后挂到:
|
||||
|
||||
```python
|
||||
class ToolsConfig(Base):
|
||||
...
|
||||
a2a: A2AConfig = Field(default_factory=A2AConfig)
|
||||
```
|
||||
|
||||
## 5.7 `nanobot/web/server.py`
|
||||
|
||||
当前 web API 有:
|
||||
|
||||
- `/api/skills`
|
||||
- `/api/plugins`
|
||||
|
||||
位置:
|
||||
|
||||
- skills: `nanobot/web/server.py:702-843`
|
||||
- plugins: `nanobot/web/server.py:1000-1037`
|
||||
|
||||
建议新增:
|
||||
|
||||
1. `GET /api/agents`
|
||||
- 返回统一后的 agent registry
|
||||
2. `POST /api/agents`
|
||||
- 添加 workspace agent card
|
||||
3. `DELETE /api/agents/{id}`
|
||||
- 删除 workspace agent
|
||||
4. `POST /api/agents/refresh`
|
||||
- 刷新 card cache
|
||||
|
||||
这样“已添加的 Agent”才有明确的持久化来源。
|
||||
|
||||
## 6. 推荐的数据来源优先级
|
||||
|
||||
为了行为稳定,推荐 resolver 按以下优先级匹配:
|
||||
|
||||
1. workspace 手动添加的 agent
|
||||
2. plugin agents
|
||||
3. skill metadata 里的 agent cards
|
||||
4. fallback 到本地 subagent
|
||||
|
||||
原因:
|
||||
|
||||
1. workspace 手动添加通常是用户明确希望接入的 agent
|
||||
2. plugin agent 是本地稳定能力
|
||||
3. skill card 往往是外部资源,可信度和可用性最弱
|
||||
4. 本地 subagent 最后兜底,保证老行为不失效
|
||||
|
||||
## 7. A2A 协议接入建议
|
||||
|
||||
## 7.1 Agent Card 发现
|
||||
|
||||
建议支持 3 种入口:
|
||||
|
||||
1. 显式 `card_url`
|
||||
2. `base_url + /.well-known/agent-card`
|
||||
3. fallback `base_url + /.well-known/agent.json`
|
||||
|
||||
这样做的原因是:
|
||||
|
||||
1. 当前 A2A 文档和样例在 card 路径上存在新旧写法并存
|
||||
2. 兼容性会更好
|
||||
|
||||
## 7.2 RPC 调用兼容层
|
||||
|
||||
建议客户端优先尝试:
|
||||
|
||||
1. `tasks/send`
|
||||
2. 不支持时 fallback `message/send`
|
||||
|
||||
后续可选支持:
|
||||
|
||||
1. `tasks/sendSubscribe`
|
||||
2. `message/sendStream`
|
||||
3. `tasks/get`
|
||||
4. `tasks/cancel`
|
||||
|
||||
推荐第一期先做:
|
||||
|
||||
1. 非流式发任务
|
||||
2. 如果返回 `Task` 状态不是最终态,就轮询 `tasks/get`
|
||||
|
||||
这样能最小代价先打通。
|
||||
|
||||
## 7.3 发送给远端 agent 的上下文范围
|
||||
|
||||
不要把主会话完整 history 直接发给远端 agent。
|
||||
|
||||
建议第一版只发送:
|
||||
|
||||
1. 任务目标
|
||||
2. 必要的结构化说明
|
||||
3. 主 agent 整理好的最小上下文
|
||||
|
||||
原因:
|
||||
|
||||
1. 当前本地 subagent 也不共享主会话历史
|
||||
2. 外部 A2A agent 不可信时,最小化数据泄漏面
|
||||
3. 避免 token 膨胀
|
||||
|
||||
## 8. agent group 设计
|
||||
|
||||
## 8.1 什么时候触发 group
|
||||
|
||||
建议第一版只支持两种触发:
|
||||
|
||||
1. 用户明确指定多个 agent
|
||||
2. LLM 在工具调用里显式传 `targets=[...]`
|
||||
|
||||
不建议第一版做“自动拆成多个 agent 并行”的强自动化。
|
||||
|
||||
原因:
|
||||
|
||||
1. 容易失控
|
||||
2. 很难解释为什么调了这些 agent
|
||||
3. 对成本和网络调用不可控
|
||||
|
||||
## 8.2 group 执行链路
|
||||
|
||||
推荐链路:
|
||||
|
||||
1. `SpawnTool.execute()` 收到 `targets`
|
||||
2. `DelegationManager.dispatch()` 创建 `group_run_id`
|
||||
3. `AgentRegistry` 解析出每个 target 的 descriptor
|
||||
4. 按 executor 类型并发执行
|
||||
5. `asyncio.gather(..., return_exceptions=True)` 收集结果
|
||||
6. 统一做 group aggregation
|
||||
7. `_announce_group_result()` 回投主消息总线
|
||||
8. 主 agent 再生成最终用户回复
|
||||
|
||||
## 8.3 group 结果聚合
|
||||
|
||||
建议 group 执行器输出结构化结果:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentRunResult:
|
||||
agent_id: str
|
||||
status: str # ok | error | timeout | cancelled
|
||||
summary: str
|
||||
raw: dict[str, Any] | None = None
|
||||
```
|
||||
|
||||
group 最终回投内容建议类似:
|
||||
|
||||
```text
|
||||
[Agent group 'repo-check' completed]
|
||||
|
||||
Members:
|
||||
- researcher: ok
|
||||
- reviewer: ok
|
||||
- planner: error
|
||||
|
||||
Results:
|
||||
...
|
||||
|
||||
Summarize this naturally for the user. Mention disagreements if any.
|
||||
```
|
||||
|
||||
这样能继续复用当前 `system -> main agent -> user` 的输出模式。
|
||||
|
||||
## 9. 推荐触发方式
|
||||
|
||||
## 9.1 用户显式触发
|
||||
|
||||
用户说法示例:
|
||||
|
||||
1. “把这个任务交给 `github-reviewer`”
|
||||
2. “让 `researcher` 和 `reviewer` 一起处理”
|
||||
3. “如果有现成 agent 就不要新建 subagent”
|
||||
|
||||
这时主 agent 应调用:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "...",
|
||||
"target": "github-reviewer"
|
||||
}
|
||||
```
|
||||
|
||||
或者:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "...",
|
||||
"targets": ["researcher", "reviewer"],
|
||||
"strategy": "group"
|
||||
}
|
||||
```
|
||||
|
||||
## 9.2 模型自主触发
|
||||
|
||||
当主 agent 判断:
|
||||
|
||||
1. 任务独立可并行
|
||||
2. 已有 agent 专长明显更匹配
|
||||
3. 任务耗时长,适合后台执行
|
||||
|
||||
则调用 `spawn`,但不再默认认为一定是“新建本地 subagent”。
|
||||
|
||||
## 9.3 自动回退
|
||||
|
||||
如果没有找到匹配 agent:
|
||||
|
||||
1. `strategy=auto` -> fallback 本地 subagent
|
||||
2. `strategy=a2a` -> 直接返回未找到
|
||||
3. `strategy=group` 且部分目标不存在 -> 明确报错或只跑已解析目标,建议第一版严格报错
|
||||
|
||||
## 10. workspace 中“已添加 agent”的建议存储
|
||||
|
||||
建议新增:
|
||||
|
||||
- `workspace/agents/registry.json`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "github-reviewer",
|
||||
"name": "GitHub Reviewer",
|
||||
"description": "Review GitHub repository changes",
|
||||
"protocol": "a2a",
|
||||
"base_url": "https://reviewer.example.com/a2a",
|
||||
"card_url": "https://reviewer.example.com/.well-known/agent-card",
|
||||
"auth_env": "GITHUB_REVIEWER_TOKEN",
|
||||
"enabled": true,
|
||||
"tags": ["github", "review"]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
为什么不用直接塞进 `config.json`:
|
||||
|
||||
1. 这是 workspace 维度资源,不是全局运行参数
|
||||
2. web API 做增删改查更方便
|
||||
3. 不要求用户每次改 agent 都改配置再重启
|
||||
|
||||
## 11. 推荐实施顺序
|
||||
|
||||
### Phase 1: 打通单 agent 路由
|
||||
|
||||
目标:
|
||||
|
||||
1. 引入 `AgentRegistry`
|
||||
2. `spawn` 支持 `target`
|
||||
3. 支持 workspace agent 和 skill agent card
|
||||
4. 支持 A2A 单点调用
|
||||
5. 找不到时 fallback 本地 subagent
|
||||
|
||||
### Phase 2: 接入 plugin agent 本地执行
|
||||
|
||||
目标:
|
||||
|
||||
1. plugin agent 进入统一 registry
|
||||
2. plugin agent 可作为 `target`
|
||||
3. 本地 prompt-based agent 与 A2A remote agent 共存
|
||||
|
||||
### Phase 3: group 并发和聚合
|
||||
|
||||
目标:
|
||||
|
||||
1. `targets=[...]`
|
||||
2. 并发执行
|
||||
3. group 级状态跟踪
|
||||
4. 聚合后回投主 agent
|
||||
|
||||
### Phase 4: web 管理接口
|
||||
|
||||
目标:
|
||||
|
||||
1. `/api/agents`
|
||||
2. 添加 / 删除 / 刷新 agent
|
||||
3. 前端展示 unified registry
|
||||
|
||||
## 12. 兼容性要求
|
||||
|
||||
这次改造一定要保留以下兼容性:
|
||||
|
||||
1. 旧的 `spawn(task, label)` 调用仍然可用
|
||||
2. 没有 A2A agent 时,行为和现在一致
|
||||
3. skill 没写 `agent_cards` 时,skill 仍只是普通 skill
|
||||
4. plugin agent 不参与调度时,现有 plugin 机制不受影响
|
||||
|
||||
## 13. 风险点
|
||||
|
||||
### 13.1 A2A 规范新旧写法并存
|
||||
|
||||
从当前公开文档和样例看,存在这些并行写法:
|
||||
|
||||
1. card 路径: `/.well-known/agent-card` 和 `/.well-known/agent.json`
|
||||
2. RPC 方法: `tasks/send` 和 `message/send`
|
||||
|
||||
所以客户端必须做兼容适配,不能写死一种。
|
||||
|
||||
### 13.2 外部 agent 的安全边界
|
||||
|
||||
需要限制:
|
||||
|
||||
1. 白名单 host
|
||||
2. 超时
|
||||
3. card cache TTL
|
||||
4. skill card 是否允许自动启用
|
||||
|
||||
### 13.3 远端 agent 无法直接访问本地 workspace
|
||||
|
||||
这意味着:
|
||||
|
||||
1. 不能把“去读本地文件然后处理”原样发给远端 A2A agent
|
||||
2. 主 agent 需要先整理出必要上下文
|
||||
3. 第一版最好只做文本级委派
|
||||
|
||||
## 14. 我建议的落地结论
|
||||
|
||||
如果要控制改动面,又要保证后续可扩展,推荐最终采用下面这个结构:
|
||||
|
||||
```text
|
||||
AgentLoop
|
||||
-> SpawnTool
|
||||
-> DelegationManager
|
||||
-> AgentRegistry / AgentResolver
|
||||
-> LocalSubagentExecutor
|
||||
-> PluginAgentExecutor
|
||||
-> A2AExecutor
|
||||
-> AgentGroupExecutor
|
||||
-> announce_result() -> MessageBus(system) -> AgentLoop -> user
|
||||
```
|
||||
|
||||
也就是说:
|
||||
|
||||
1. `spawn` 工具保留
|
||||
2. `SubagentManager` 不再是唯一执行器
|
||||
3. `DelegationManager` 成为真正总入口
|
||||
4. skills 里的 `agent_cards` 用 frontmatter metadata 承载
|
||||
5. workspace agent 单独持久化
|
||||
6. group 通过并发 executor + 汇总消息实现
|
||||
|
||||
这是当前仓库里最稳妥、最符合现有架构的改法。
|
||||
|
||||
## 15. 外部参考
|
||||
|
||||
以下是我写这个方案时核对的 A2A 资料:
|
||||
|
||||
1. A2A Protocol Development Guide: https://a2aprotocol.ai/docs/guide/a2a-typescript-guide.html
|
||||
2. Python A2A Tutorial: https://a2aprotocol.ai/docs/guide/python-a2a-tutorial-20250513
|
||||
|
||||
注意:
|
||||
|
||||
1. 当前公开文档里既能看到 `tasks/send`,也能看到 `message/send`
|
||||
2. agent card 路径也能看到 `agent-card` 与 `agent.json` 两种写法
|
||||
3. 所以实现时建议做兼容层,不要只押一种命名
|
||||
@ -1,5 +0,0 @@
|
||||
We provide QR codes for joining the HKUDS discussion groups on **WeChat** and **Feishu**.
|
||||
|
||||
You can join by scanning the QR codes below:
|
||||
|
||||
<img src="https://github.com/HKUDS/.github/blob/main/profile/QR.png" alt="WeChat QR Code" width="400"/>
|
||||
@ -1,43 +0,0 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
apt-get purge -y gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies first (cached layer)
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
RUN mkdir -p nanobot bridge && touch nanobot/__init__.py && \
|
||||
uv pip install --system --no-cache . && \
|
||||
rm -rf nanobot bridge
|
||||
|
||||
# Copy the full source and install
|
||||
COPY nanobot/ nanobot/
|
||||
COPY bridge/ bridge/
|
||||
COPY third_party/swarms/ third_party/swarms/
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Build the WhatsApp bridge
|
||||
WORKDIR /app/bridge
|
||||
RUN git config --global url."https://github.com/".insteadOf "ssh://git@github.com/" && \
|
||||
git config --global url."https://github.com/".insteadOf "git@github.com:" && \
|
||||
npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@ -1,470 +0,0 @@
|
||||
# Boardware Genius Backend
|
||||
|
||||
这是 `Boardware Genius` 的后端服务仓库;当前技术命令和包名仍沿用 `nanobot`,但产品品牌按 `Boardware Genius` 表述:
|
||||
|
||||
- `nanobot web`:单用户 FastAPI 后端,供独立前端或 `/docs` 调试使用
|
||||
- `nanobot gateway`:常驻 worker,负责渠道接入、cron、heartbeat
|
||||
- MCP 动态工具接入
|
||||
- Outlook 集成:通过外部 `BW_Outlook_Mcp` 服务接入 Microsoft Graph / Exchange EWS
|
||||
- 工作区文件、技能、插件、代理、MCP 管理等 Web API
|
||||
|
||||
如果你后续要把它打包成 Docker 丢到服务器,这份 README 就是给开发和部署同事看的执行文档。
|
||||
|
||||
## 这套仓库现在是什么
|
||||
|
||||
这不是一个自带前端静态页面的全栈仓库,而是后端仓库:
|
||||
|
||||
- Web 模式启动的是 FastAPI API 服务
|
||||
- Gateway 模式启动的是常驻 agent / channel / cron 进程
|
||||
- WhatsApp 相关逻辑依赖 `bridge/` 里的 Node 20 bridge
|
||||
- Outlook 不是仓库内置模块,而是通过外部 `BW_Outlook_Mcp` 仓库接进来
|
||||
|
||||
更细的执行链路可以看 [workflow.md](./workflow.md)。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```text
|
||||
.
|
||||
├── nanobot/ # Python 主体:CLI、agent、web、channels、config、MCP
|
||||
├── bridge/ # WhatsApp bridge(Node 20)
|
||||
├── tests/ # 测试
|
||||
├── Dockerfile # 当前镜像构建文件
|
||||
├── docker-compose.yml # 当前自带 compose 示例(偏 gateway / CLI)
|
||||
└── workflow.md # 运行链路说明
|
||||
```
|
||||
|
||||
## 运行模式
|
||||
|
||||
| 命令 | 用途 | 默认端口 | 适合谁 |
|
||||
| --- | --- | --- | --- |
|
||||
| `nanobot agent` | 本地单轮 / 交互调试 | 无 | 开发排查 |
|
||||
| `nanobot web` | 启动 FastAPI 后端 | `18080` | 独立前端、接口调试、单用户使用 |
|
||||
| `nanobot gateway` | 启动常驻 worker | 无固定 HTTP 入口 | Telegram/Slack/Email/cron/heartbeat |
|
||||
| `nanobot status` | 查看配置和 provider 状态 | 无 | 开发、运维 |
|
||||
|
||||
注意:
|
||||
|
||||
- 如果你是给 Web 前端提供后端,请启动 `nanobot web`,不要误用 `gateway`
|
||||
- `gateway` 当前不是对外 Web API 服务
|
||||
- `web` 和 `gateway` 都会碰到同一份 workspace / cron / MCP 状态,通常不要在同一份数据目录上无脑同时跑两套
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python `>=3.11`
|
||||
- 推荐使用 `uv`
|
||||
- 如果要构建 WhatsApp bridge 或使用仓库自带 Dockerfile,需要 Node.js `20`
|
||||
|
||||
本地开发最省事的方式:
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
```
|
||||
|
||||
如果你不用 `uv`,也可以:
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
. .venv/bin/activate
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 本地快速启动
|
||||
|
||||
### 1. 初始化配置
|
||||
|
||||
```bash
|
||||
nanobot onboard
|
||||
```
|
||||
|
||||
初始化后默认会生成:
|
||||
|
||||
- 配置文件:`~/.nanobot/config.json`
|
||||
- 工作区:`~/.nanobot/workspace`
|
||||
|
||||
### 2. 填最小配置
|
||||
|
||||
下面是一份适合服务器环境的最小示例,重点是:
|
||||
|
||||
- 用绝对路径的 workspace
|
||||
- 建议打开 `restrictToWorkspace`
|
||||
- 先用 API Key provider,少踩 OAuth 交互坑
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "/root/.nanobot/workspace",
|
||||
"model": "openai/gpt-5"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openai": {
|
||||
"apiKey": "sk-xxxx"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"restrictToWorkspace": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
如果你不是跑在容器里,把 `/root/.nanobot/workspace` 换成你自己的绝对路径。
|
||||
|
||||
### 3. 检查配置
|
||||
|
||||
```bash
|
||||
nanobot status
|
||||
```
|
||||
|
||||
### 4. 本地调试 agent
|
||||
|
||||
```bash
|
||||
nanobot agent -m "你好"
|
||||
```
|
||||
|
||||
### 5. 启动 Web 后端
|
||||
|
||||
```bash
|
||||
nanobot web --host 0.0.0.0 --port 18080
|
||||
```
|
||||
|
||||
启动后可直接访问:
|
||||
|
||||
- `http://127.0.0.1:18080/docs`
|
||||
- `http://127.0.0.1:18080/api/ping`
|
||||
|
||||
## Web API 能力概览
|
||||
|
||||
当前 `nanobot web` 提供的 API 大致包括:
|
||||
|
||||
- 聊天与流式输出
|
||||
- 会话管理
|
||||
- cron 任务管理
|
||||
- skills / plugins / agents 管理
|
||||
- 工作区文件浏览、上传、下载、删除
|
||||
- MCP server 管理与测试
|
||||
- Outlook 集成状态、连接测试、连接/断开、Overview、Message Detail
|
||||
|
||||
如果你有独立前端,这个后端就是给前端接的;如果没有前端,也可以直接走 `/docs` 调试。
|
||||
|
||||
## Outlook MCP 集成
|
||||
|
||||
这是当前仓库里最容易部署时踩坑的一块。
|
||||
|
||||
### 关系先说清楚
|
||||
|
||||
当前后端不会自己实现 Outlook 协议,它依赖外部仓库 `BW_Outlook_Mcp`:
|
||||
|
||||
- 后端代码位置:`nanobot/web/outlook.py`
|
||||
- 默认查找逻辑:
|
||||
1. 先看环境变量 `NANOBOT_OUTLOOK_MCP_ROOT`
|
||||
2. 再看与本仓库同级目录的 `../BW_Outlook_Mcp`
|
||||
3. 如果以上都没有,就尝试直接执行 PATH 里的 `bw-outlook-mcp`
|
||||
|
||||
也就是说,部署同事必须额外把 `BW_Outlook_Mcp` 这个仓库准备好,或者把它直接安装进镜像。
|
||||
|
||||
### 推荐的两种接法
|
||||
|
||||
#### 方案 A:把 `BW_Outlook_Mcp` 安装进同一个 Python 环境
|
||||
|
||||
这是生产环境更稳的方案。
|
||||
|
||||
部署同事需要:
|
||||
|
||||
```bash
|
||||
git clone <你们的 BW_Outlook_Mcp 仓库地址> /srv/BW_Outlook_Mcp
|
||||
cd /srv/BW_Outlook_Mcp
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
安装完成后,容器或宿主机里能直接执行:
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp --help
|
||||
```
|
||||
|
||||
这样 Boardware Genius 就会直接用 PATH 里的 `bw-outlook-mcp`,不依赖额外挂载路径。
|
||||
|
||||
#### 方案 B:把 `BW_Outlook_Mcp` 作为外部目录挂进来
|
||||
|
||||
这是开发或临时部署更方便的方案。
|
||||
|
||||
部署同事需要至少做到两件事:
|
||||
|
||||
1. 把 `BW_Outlook_Mcp` 仓库拉到服务器
|
||||
2. 让这个目录里存在一个可执行的 `bw-outlook-mcp`
|
||||
|
||||
最简单的约定是:
|
||||
|
||||
```bash
|
||||
git clone <你们的 BW_Outlook_Mcp 仓库地址> /srv/BW_Outlook_Mcp
|
||||
cd /srv/BW_Outlook_Mcp
|
||||
python3 -m venv .venv
|
||||
. .venv/bin/activate
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
然后给 Boardware Genius 设置:
|
||||
|
||||
```bash
|
||||
export NANOBOT_OUTLOOK_MCP_ROOT=/srv/BW_Outlook_Mcp
|
||||
```
|
||||
|
||||
因为当前后端会优先寻找:
|
||||
|
||||
```text
|
||||
$NANOBOT_OUTLOOK_MCP_ROOT/.venv/bin/bw-outlook-mcp
|
||||
```
|
||||
|
||||
如果你挂了仓库目录但里面没有 `.venv/bin/bw-outlook-mcp`,那就必须确保 `bw-outlook-mcp` 已经在容器 PATH 里。
|
||||
|
||||
### Outlook 的认证和配置
|
||||
|
||||
`BW_Outlook_Mcp` 本身支持两套后端:
|
||||
|
||||
- `graph`:Microsoft 365 / Exchange Online
|
||||
- `ews`:本地或回迁后的 Exchange Server
|
||||
|
||||
#### Graph 登录
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth login-graph \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--client-id YOUR_CLIENT_ID \
|
||||
--tenant-id YOUR_TENANT_ID
|
||||
```
|
||||
|
||||
#### EWS 配置
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth setup-ews \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--email you@example.com \
|
||||
--username your_username \
|
||||
--domain example.com \
|
||||
--server mail.example.com
|
||||
```
|
||||
|
||||
如果你已经有固定 EWS URL,也可以改用:
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth setup-ews \
|
||||
--workspace /root/.nanobot/workspace \
|
||||
--email you@example.com \
|
||||
--username your_username \
|
||||
--service-endpoint https://mail.example.com/EWS/Exchange.asmx
|
||||
```
|
||||
|
||||
#### 查看状态
|
||||
|
||||
```bash
|
||||
bw-outlook-mcp auth status --workspace /root/.nanobot/workspace
|
||||
```
|
||||
|
||||
### Outlook 状态文件会落在哪里
|
||||
|
||||
所有 Outlook 相关状态默认都落在 workspace 下:
|
||||
|
||||
```text
|
||||
<workspace>/state/bw_outlook_mcp/
|
||||
├── config.json
|
||||
├── secrets.json
|
||||
├── graph_token_cache.bin
|
||||
├── delta_store.json
|
||||
└── idempotency.sqlite3
|
||||
```
|
||||
|
||||
所以 Docker 部署时,不要只挂配置文件;要把整份 `~/.nanobot` 或至少 workspace 做持久化。
|
||||
|
||||
### Nanobot 里如何注册 Outlook MCP
|
||||
|
||||
如果你通过 Web 接口完成 Outlook 连接,后端会自动把 MCP server 注册到配置里。
|
||||
|
||||
手工写配置时,结构类似这样:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcpServers": {
|
||||
"outlook": {
|
||||
"command": "bw-outlook-mcp",
|
||||
"args": ["serve", "--workspace", "/root/.nanobot/workspace"],
|
||||
"sensitive": true,
|
||||
"toolTimeout": 60
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
这里一定要用绝对路径,不要写 `~/.nanobot/workspace`。
|
||||
|
||||
### 可选的 Outlook 环境变量
|
||||
|
||||
| 变量 | 作用 |
|
||||
| --- | --- |
|
||||
| `NANOBOT_OUTLOOK_MCP_ROOT` | 指向外部 `BW_Outlook_Mcp` 仓库目录 |
|
||||
| `NANOBOT_OUTLOOK_MCP_COMMAND` | 强制指定 `bw-outlook-mcp` 可执行文件 |
|
||||
| `NANOBOT_OUTLOOK_MCP_EXTRA_ARGS` | 给 `bw-outlook-mcp serve` 追加参数 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_DOMAIN` | Web 连接表单的默认域名 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_EWS_URL` | Web 连接表单默认 EWS 地址 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_EWS_SERVER` | Web 连接表单默认 Exchange 主机 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_TIMEZONE` | Web 连接表单默认时区 |
|
||||
| `NANOBOT_OUTLOOK_DEFAULT_AUTODISCOVER` | Web 连接表单默认是否启用 autodiscover |
|
||||
|
||||
## Docker 部署
|
||||
|
||||
### 先说结论
|
||||
|
||||
服务器部署时,最重要的是持久化这份目录:
|
||||
|
||||
```text
|
||||
/root/.nanobot
|
||||
```
|
||||
|
||||
因为它里面不只是 `config.json`,还包括:
|
||||
|
||||
- workspace
|
||||
- sessions
|
||||
- cron 状态
|
||||
- Web 登录信息
|
||||
- Outlook 状态与 token 缓存
|
||||
|
||||
### 构建镜像
|
||||
|
||||
```bash
|
||||
docker build -t nanobot-backend:latest .
|
||||
```
|
||||
|
||||
### 首次初始化
|
||||
|
||||
第一次跑容器时,先执行一次:
|
||||
|
||||
```bash
|
||||
docker run --rm \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
onboard
|
||||
```
|
||||
|
||||
然后去编辑宿主机上的:
|
||||
|
||||
```text
|
||||
/srv/nanobot/data/config.json
|
||||
```
|
||||
|
||||
或者先进去执行:
|
||||
|
||||
```bash
|
||||
docker run --rm -it \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
status
|
||||
```
|
||||
|
||||
### 作为 Web 后端启动
|
||||
|
||||
如果你是给前端项目配后端,推荐这样跑:
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--name nanobot-web \
|
||||
-p 18080:18080 \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
-e NANOBOT_OUTLOOK_MCP_ROOT=/opt/BW_Outlook_Mcp \
|
||||
-v /srv/BW_Outlook_Mcp:/opt/BW_Outlook_Mcp \
|
||||
nanobot-backend:latest \
|
||||
web --host 0.0.0.0 --port 18080
|
||||
```
|
||||
|
||||
如果你已经把 `bw-outlook-mcp` 安装进镜像了,就不需要挂 `/srv/BW_Outlook_Mcp`,也不需要 `NANOBOT_OUTLOOK_MCP_ROOT`。
|
||||
|
||||
### 作为 Gateway/Worker 启动
|
||||
|
||||
如果你要接 Telegram / Slack / Email / cron 之类的常驻能力,再跑 gateway:
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--name nanobot-gateway \
|
||||
-v /srv/nanobot/data:/root/.nanobot \
|
||||
nanobot-backend:latest \
|
||||
gateway
|
||||
```
|
||||
|
||||
### 推荐的服务器 compose 片段
|
||||
|
||||
仓库自带的 [docker-compose.yml](./docker-compose.yml) 更偏本地 gateway/CLI 示例。
|
||||
如果你是部署 Web 后端到服务器,更建议单独写成这样:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
nanobot-web:
|
||||
image: nanobot-backend:latest
|
||||
container_name: nanobot-web
|
||||
command: ["web", "--host", "0.0.0.0", "--port", "18080"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18080:18080"
|
||||
volumes:
|
||||
- /srv/nanobot/data:/root/.nanobot
|
||||
- /srv/BW_Outlook_Mcp:/opt/BW_Outlook_Mcp
|
||||
environment:
|
||||
NANOBOT_OUTLOOK_MCP_ROOT: /opt/BW_Outlook_Mcp
|
||||
```
|
||||
|
||||
如果你想把 Outlook 依赖做得更稳,推荐直接把 `BW_Outlook_Mcp` 安装进镜像,而不是运行时挂载仓库。
|
||||
|
||||
## 部署给同事时,至少要交代这几件事
|
||||
|
||||
1. 这是后端仓库,不带前端静态页面,前端请单独部署
|
||||
2. Web API 用 `nanobot web` 启动,不是 `gateway`
|
||||
3. 数据目录必须持久化到 `/root/.nanobot`
|
||||
4. 如果要 Outlook,必须额外拉取 `BW_Outlook_Mcp`
|
||||
5. Outlook 有两种接法:装进镜像,或者挂外部仓库并设置 `NANOBOT_OUTLOOK_MCP_ROOT`
|
||||
6. Outlook 的状态文件也在 workspace 里,删容器不挂卷就会丢
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
nanobot onboard
|
||||
nanobot status
|
||||
nanobot agent -m "你好"
|
||||
nanobot web --host 0.0.0.0 --port 18080
|
||||
nanobot gateway
|
||||
nanobot provider login openai-codex
|
||||
```
|
||||
|
||||
## 开发备注
|
||||
|
||||
- `workflow.md` 记录了当前代码实际运行链路,和旧版 README 更接近“真实代码”
|
||||
- `nanobot/web/outlook.py` 是当前 Outlook 集成入口
|
||||
- `tests/` 里有 Web API、Email、Docker 相关测试
|
||||
- 如果要上服务器,建议在配置里显式打开 `tools.restrictToWorkspace=true`
|
||||
|
||||
## 排错
|
||||
|
||||
### Web 启动了,但 Outlook 相关接口报错
|
||||
|
||||
优先检查:
|
||||
|
||||
- `bw-outlook-mcp` 是否能在当前容器里执行
|
||||
- `NANOBOT_OUTLOOK_MCP_ROOT` 是否指向正确目录
|
||||
- 如果走目录挂载模式,目录里是否真的有 `.venv/bin/bw-outlook-mcp`
|
||||
|
||||
### MCP 注册了,但工具没有出现
|
||||
|
||||
检查:
|
||||
|
||||
- `config.json` 里的 `tools.mcpServers`
|
||||
- `nanobot web` 或 `nanobot agent` 启动时是否用了同一份 `~/.nanobot`
|
||||
- Outlook MCP 是否能单独执行 `bw-outlook-mcp auth status --workspace ...`
|
||||
|
||||
### Docker 里配置改了没生效
|
||||
|
||||
优先检查你挂载的是不是整份:
|
||||
|
||||
```text
|
||||
/srv/nanobot/data:/root/.nanobot
|
||||
```
|
||||
|
||||
不是只挂了某一个文件。
|
||||
@ -1,264 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in Boardware Genius, please report it by:
|
||||
|
||||
1. **DO NOT** open a public GitHub issue
|
||||
2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
|
||||
3. Include:
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Potential impact
|
||||
- Suggested fix (if any)
|
||||
|
||||
We aim to respond to security reports within 48 hours.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. API Key Management
|
||||
|
||||
**CRITICAL**: Never commit API keys to version control.
|
||||
|
||||
```bash
|
||||
# ✅ Good: Store in config file with restricted permissions
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
|
||||
# ❌ Bad: Hardcoding keys in code or committing them
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- Store API keys in `~/.nanobot/config.json` with file permissions set to `0600`
|
||||
- Consider using environment variables for sensitive keys
|
||||
- Use OS keyring/credential manager for production deployments
|
||||
- Rotate API keys regularly
|
||||
- Use separate API keys for development and production
|
||||
|
||||
### 2. Channel Access Control
|
||||
|
||||
**IMPORTANT**: Always configure `allowFrom` lists for production use.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["123456789", "987654321"]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"allowFrom": ["+1234567890"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use)
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
|
||||
### 3. Shell Command Execution
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
- ✅ Never run Boardware Genius as root
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
- Filesystem formatting (`mkfs.*`)
|
||||
- Raw disk writes
|
||||
- Other destructive operations
|
||||
|
||||
### 4. File System Access
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Run Boardware Genius with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
- ❌ Don't give unrestricted access to sensitive files
|
||||
|
||||
### 5. Network Security
|
||||
|
||||
**API Calls:**
|
||||
- All external API calls use HTTPS by default
|
||||
- Timeouts are configured to prevent hanging requests
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
|
||||
**Critical**: Keep dependencies updated!
|
||||
|
||||
```bash
|
||||
# Check for vulnerable dependencies
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# Update to latest secure versions
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
For Node.js dependencies (WhatsApp bridge):
|
||||
```bash
|
||||
cd bridge
|
||||
npm audit
|
||||
npm audit fix
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Keep `litellm` updated to the latest version for security fixes
|
||||
- We've updated `ws` to `>=8.17.1` to fix DoS vulnerability
|
||||
- Run `pip-audit` or `npm audit` regularly
|
||||
- Subscribe to security advisories for Boardware Genius and its dependencies
|
||||
|
||||
### 7. Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Isolate the Environment**
|
||||
```bash
|
||||
# Run in a container or VM
|
||||
docker run --rm -it python:3.11
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
2. **Use a Dedicated User**
|
||||
```bash
|
||||
sudo useradd -m -s /bin/bash nanobot
|
||||
sudo -u nanobot nanobot gateway
|
||||
```
|
||||
|
||||
3. **Set Proper Permissions**
|
||||
```bash
|
||||
chmod 700 ~/.nanobot
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
chmod 700 ~/.nanobot/whatsapp-auth
|
||||
```
|
||||
|
||||
4. **Enable Logging**
|
||||
```bash
|
||||
# Configure log monitoring
|
||||
tail -f ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
|
||||
5. **Use Rate Limiting**
|
||||
- Configure rate limits on your API providers
|
||||
- Monitor usage for anomalies
|
||||
- Set spending limits on LLM APIs
|
||||
|
||||
6. **Regular Updates**
|
||||
```bash
|
||||
# Check for updates weekly
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
### 8. Development vs Production
|
||||
|
||||
**Development:**
|
||||
- Use separate API keys
|
||||
- Test with non-sensitive data
|
||||
- Enable verbose logging
|
||||
- Use a test Telegram bot
|
||||
|
||||
**Production:**
|
||||
- Use dedicated API keys with spending limits
|
||||
- Restrict file system access
|
||||
- Enable audit logging
|
||||
- Regular security reviews
|
||||
- Monitor for unusual activity
|
||||
|
||||
### 9. Data Privacy
|
||||
|
||||
- **Logs may contain sensitive information** - secure log files appropriately
|
||||
- **LLM providers see your prompts** - review their privacy policies
|
||||
- **Chat history is stored locally** - protect the `~/.nanobot` directory
|
||||
- **API keys are in plain text** - use OS keyring for production
|
||||
|
||||
### 10. Incident Response
|
||||
|
||||
If you suspect a security breach:
|
||||
|
||||
1. **Immediately revoke compromised API keys**
|
||||
2. **Review logs for unauthorized access**
|
||||
```bash
|
||||
grep "Access denied" ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
3. **Check for unexpected file modifications**
|
||||
4. **Rotate all credentials**
|
||||
5. **Update to latest version**
|
||||
6. **Report the incident** to maintainers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Built-in Security Controls
|
||||
|
||||
✅ **Input Validation**
|
||||
- Path traversal protection on file operations
|
||||
- Dangerous command pattern detection
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control
|
||||
- Failed authentication attempt logging
|
||||
- Open by default (configure allowFrom for production use)
|
||||
|
||||
✅ **Resource Protection**
|
||||
- Command execution timeouts (60s default)
|
||||
- Output truncation (10KB limit)
|
||||
- HTTP request timeouts (10-30s)
|
||||
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
⚠️ **Current Security Limitations:**
|
||||
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
|
||||
Before deploying Boardware Genius:
|
||||
|
||||
- [ ] API keys stored securely (not in code)
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
- [ ] Rate limits configured on API providers
|
||||
- [ ] Backup and disaster recovery plan in place
|
||||
- [ ] Security review of custom skills/tools
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
- Release Notes: https://github.com/HKUDS/nanobot/releases
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
@ -1,26 +0,0 @@
|
||||
{
|
||||
"name": "nanobot-whatsapp-bridge",
|
||||
"version": "0.1.0",
|
||||
"description": "WhatsApp bridge for Boardware Genius using Baileys",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "tsc && node dist/index.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"ws": "^8.17.1",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.14.0",
|
||||
"@types/ws": "^8.5.10",
|
||||
"typescript": "^5.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Boardware Genius WhatsApp Bridge
|
||||
*
|
||||
* This bridge connects WhatsApp Web to the Boardware Genius Python backend
|
||||
* via WebSocket. It handles authentication, message forwarding,
|
||||
* and reconnection logic.
|
||||
*
|
||||
* Usage:
|
||||
* npm run build && npm start
|
||||
*
|
||||
* Or with custom settings:
|
||||
* BRIDGE_PORT=3001 AUTH_DIR=~/.nanobot/whatsapp npm start
|
||||
*/
|
||||
|
||||
// Polyfill crypto for Baileys in ESM
|
||||
import { webcrypto } from 'crypto';
|
||||
if (!globalThis.crypto) {
|
||||
(globalThis as any).crypto = webcrypto;
|
||||
}
|
||||
|
||||
import { BridgeServer } from './server.js';
|
||||
import { homedir } from 'os';
|
||||
import { join } from 'path';
|
||||
|
||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||
|
||||
console.log('Boardware Genius WhatsApp Bridge');
|
||||
console.log('========================\n');
|
||||
|
||||
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('\n\nShutting down...');
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
process.on('SIGTERM', async () => {
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start the server
|
||||
server.start().catch((error) => {
|
||||
console.error('Failed to start bridge:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
@ -1,129 +0,0 @@
|
||||
/**
|
||||
* WebSocket server for Python-Node.js bridge communication.
|
||||
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||
*/
|
||||
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
import { WhatsAppClient, InboundMessage } from './whatsapp.js';
|
||||
|
||||
interface SendCommand {
|
||||
type: 'send';
|
||||
to: string;
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface BridgeMessage {
|
||||
type: 'message' | 'status' | 'qr' | 'error';
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export class BridgeServer {
|
||||
private wss: WebSocketServer | null = null;
|
||||
private wa: WhatsAppClient | null = null;
|
||||
private clients: Set<WebSocket> = new Set();
|
||||
|
||||
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Bind to localhost only — never expose to external network
|
||||
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||
if (this.token) console.log('🔒 Token authentication enabled');
|
||||
|
||||
// Initialize WhatsApp client
|
||||
this.wa = new WhatsAppClient({
|
||||
authDir: this.authDir,
|
||||
onMessage: (msg) => this.broadcast({ type: 'message', ...msg }),
|
||||
onQR: (qr) => this.broadcast({ type: 'qr', qr }),
|
||||
onStatus: (status) => this.broadcast({ type: 'status', status }),
|
||||
});
|
||||
|
||||
// Handle WebSocket connections
|
||||
this.wss.on('connection', (ws) => {
|
||||
if (this.token) {
|
||||
// Require auth handshake as first message
|
||||
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||
ws.once('message', (data) => {
|
||||
clearTimeout(timeout);
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === 'auth' && msg.token === this.token) {
|
||||
console.log('🔗 Python client authenticated');
|
||||
this.setupClient(ws);
|
||||
} else {
|
||||
ws.close(4003, 'Invalid token');
|
||||
}
|
||||
} catch {
|
||||
ws.close(4003, 'Invalid auth message');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('🔗 Python client connected');
|
||||
this.setupClient(ws);
|
||||
}
|
||||
});
|
||||
|
||||
// Connect to WhatsApp
|
||||
await this.wa.connect();
|
||||
}
|
||||
|
||||
private setupClient(ws: WebSocket): void {
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||
if (cmd.type === 'send' && this.wa) {
|
||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||
}
|
||||
}
|
||||
|
||||
private broadcast(msg: BridgeMessage): void {
|
||||
const data = JSON.stringify(msg);
|
||||
for (const client of this.clients) {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
// Close all client connections
|
||||
for (const client of this.clients) {
|
||||
client.close();
|
||||
}
|
||||
this.clients.clear();
|
||||
|
||||
// Close WebSocket server
|
||||
if (this.wss) {
|
||||
this.wss.close();
|
||||
this.wss = null;
|
||||
}
|
||||
|
||||
// Disconnect WhatsApp
|
||||
if (this.wa) {
|
||||
await this.wa.disconnect();
|
||||
this.wa = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,3 +0,0 @@
|
||||
declare module 'qrcode-terminal' {
|
||||
export function generate(text: string, options?: { small?: boolean }): void;
|
||||
}
|
||||
@ -1,187 +0,0 @@
|
||||
/**
|
||||
* WhatsApp client wrapper using Baileys.
|
||||
* Based on OpenClaw's working implementation.
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import makeWASocket, {
|
||||
DisconnectReason,
|
||||
useMultiFileAuthState,
|
||||
fetchLatestBaileysVersion,
|
||||
makeCacheableSignalKeyStore,
|
||||
} from '@whiskeysockets/baileys';
|
||||
|
||||
import { Boom } from '@hapi/boom';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
import pino from 'pino';
|
||||
|
||||
const VERSION = '0.1.0';
|
||||
|
||||
export interface InboundMessage {
|
||||
id: string;
|
||||
sender: string;
|
||||
pn: string;
|
||||
content: string;
|
||||
timestamp: number;
|
||||
isGroup: boolean;
|
||||
}
|
||||
|
||||
export interface WhatsAppClientOptions {
|
||||
authDir: string;
|
||||
onMessage: (msg: InboundMessage) => void;
|
||||
onQR: (qr: string) => void;
|
||||
onStatus: (status: string) => void;
|
||||
}
|
||||
|
||||
export class WhatsAppClient {
|
||||
private sock: any = null;
|
||||
private options: WhatsAppClientOptions;
|
||||
private reconnecting = false;
|
||||
|
||||
constructor(options: WhatsAppClientOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
const logger = pino({ level: 'silent' });
|
||||
const { state, saveCreds } = await useMultiFileAuthState(this.options.authDir);
|
||||
const { version } = await fetchLatestBaileysVersion();
|
||||
|
||||
console.log(`Using Baileys version: ${version.join('.')}`);
|
||||
|
||||
// Create socket following OpenClaw's pattern
|
||||
this.sock = makeWASocket({
|
||||
auth: {
|
||||
creds: state.creds,
|
||||
keys: makeCacheableSignalKeyStore(state.keys, logger),
|
||||
},
|
||||
version,
|
||||
logger,
|
||||
printQRInTerminal: false,
|
||||
browser: ['nanobot', 'cli', VERSION],
|
||||
syncFullHistory: false,
|
||||
markOnlineOnConnect: false,
|
||||
});
|
||||
|
||||
// Handle WebSocket errors
|
||||
if (this.sock.ws && typeof this.sock.ws.on === 'function') {
|
||||
this.sock.ws.on('error', (err: Error) => {
|
||||
console.error('WebSocket error:', err.message);
|
||||
});
|
||||
}
|
||||
|
||||
// Handle connection updates
|
||||
this.sock.ev.on('connection.update', async (update: any) => {
|
||||
const { connection, lastDisconnect, qr } = update;
|
||||
|
||||
if (qr) {
|
||||
// Display QR code in terminal
|
||||
console.log('\n📱 Scan this QR code with WhatsApp (Linked Devices):\n');
|
||||
qrcode.generate(qr, { small: true });
|
||||
this.options.onQR(qr);
|
||||
}
|
||||
|
||||
if (connection === 'close') {
|
||||
const statusCode = (lastDisconnect?.error as Boom)?.output?.statusCode;
|
||||
const shouldReconnect = statusCode !== DisconnectReason.loggedOut;
|
||||
|
||||
console.log(`Connection closed. Status: ${statusCode}, Will reconnect: ${shouldReconnect}`);
|
||||
this.options.onStatus('disconnected');
|
||||
|
||||
if (shouldReconnect && !this.reconnecting) {
|
||||
this.reconnecting = true;
|
||||
console.log('Reconnecting in 5 seconds...');
|
||||
setTimeout(() => {
|
||||
this.reconnecting = false;
|
||||
this.connect();
|
||||
}, 5000);
|
||||
}
|
||||
} else if (connection === 'open') {
|
||||
console.log('✅ Connected to WhatsApp');
|
||||
this.options.onStatus('connected');
|
||||
}
|
||||
});
|
||||
|
||||
// Save credentials on update
|
||||
this.sock.ev.on('creds.update', saveCreds);
|
||||
|
||||
// Handle incoming messages
|
||||
this.sock.ev.on('messages.upsert', async ({ messages, type }: { messages: any[]; type: string }) => {
|
||||
if (type !== 'notify') return;
|
||||
|
||||
for (const msg of messages) {
|
||||
// Skip own messages
|
||||
if (msg.key.fromMe) continue;
|
||||
|
||||
// Skip status updates
|
||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||
|
||||
const content = this.extractMessageContent(msg);
|
||||
if (!content) continue;
|
||||
|
||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||
|
||||
this.options.onMessage({
|
||||
id: msg.key.id || '',
|
||||
sender: msg.key.remoteJid || '',
|
||||
pn: msg.key.remoteJidAlt || '',
|
||||
content,
|
||||
timestamp: msg.messageTimestamp as number,
|
||||
isGroup,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private extractMessageContent(msg: any): string | null {
|
||||
const message = msg.message;
|
||||
if (!message) return null;
|
||||
|
||||
// Text message
|
||||
if (message.conversation) {
|
||||
return message.conversation;
|
||||
}
|
||||
|
||||
// Extended text (reply, link preview)
|
||||
if (message.extendedTextMessage?.text) {
|
||||
return message.extendedTextMessage.text;
|
||||
}
|
||||
|
||||
// Image with caption
|
||||
if (message.imageMessage?.caption) {
|
||||
return `[Image] ${message.imageMessage.caption}`;
|
||||
}
|
||||
|
||||
// Video with caption
|
||||
if (message.videoMessage?.caption) {
|
||||
return `[Video] ${message.videoMessage.caption}`;
|
||||
}
|
||||
|
||||
// Document with caption
|
||||
if (message.documentMessage?.caption) {
|
||||
return `[Document] ${message.documentMessage.caption}`;
|
||||
}
|
||||
|
||||
// Voice/Audio message
|
||||
if (message.audioMessage) {
|
||||
return `[Voice Message]`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async sendMessage(to: string, text: string): Promise<void> {
|
||||
if (!this.sock) {
|
||||
throw new Error('Not connected');
|
||||
}
|
||||
|
||||
await this.sock.sendMessage(to, { text });
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
if (this.sock) {
|
||||
this.sock.end(undefined);
|
||||
this.sock = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,16 +0,0 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"declaration": true,
|
||||
"resolveJsonModule": true
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 12 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 5.6 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 6.8 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 6.0 MiB |
File diff suppressed because it is too large
Load Diff
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
for dir in agent agent/tools bus config cron heartbeat session utils; do
|
||||
count=$(find "nanobot/$dir" -maxdepth 1 -name "*.py" -exec cat {} + | wc -l)
|
||||
printf " %-16s %5s lines\n" "$dir/" "$count"
|
||||
done
|
||||
|
||||
root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, providers/)"
|
||||
@ -1,31 +0,0 @@
|
||||
x-common-config: &common-config
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ~/.nanobot:/root/.nanobot
|
||||
|
||||
services:
|
||||
nanobot-gateway:
|
||||
container_name: nanobot-gateway
|
||||
<<: *common-config
|
||||
command: ["gateway"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 18790:18790
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
nanobot-cli:
|
||||
<<: *common-config
|
||||
profiles:
|
||||
- cli
|
||||
command: ["status"]
|
||||
stdin_open: true
|
||||
tty: true
|
||||
@ -1,143 +0,0 @@
|
||||
# Boardware Genius 前后端分离启动指南(单用户直连)
|
||||
|
||||
本指南对应当前仓库:
|
||||
`/home/ivan/xuan/steven_project/nanobot`
|
||||
|
||||
## 1. 环境准备
|
||||
|
||||
- Python: `>=3.11`
|
||||
- Node.js: `>=18`
|
||||
- 包管理工具: `uv`、`npm`
|
||||
|
||||
在项目根目录执行:
|
||||
|
||||
```bash
|
||||
cd /home/ivan/xuan/steven_project/nanobot
|
||||
uv sync
|
||||
```
|
||||
|
||||
如果你第一次使用 Boardware Genius,需要先初始化:
|
||||
|
||||
```bash
|
||||
./.venv/bin/python -m nanobot onboard
|
||||
```
|
||||
|
||||
然后编辑配置文件(至少配置一个可用模型):
|
||||
|
||||
- `~/.nanobot/config.json`
|
||||
|
||||
## 2. 启动后端(Web API)
|
||||
|
||||
在项目根目录执行:
|
||||
|
||||
```bash
|
||||
cd /home/ivan/xuan/steven_project/nanobot
|
||||
./.venv/bin/python -m nanobot web --host 127.0.0.1 --port 10000
|
||||
```
|
||||
|
||||
启动成功后会看到类似日志:
|
||||
|
||||
- `Uvicorn running on http://127.0.0.1:10000`
|
||||
|
||||
可用接口示例:
|
||||
|
||||
- `GET http://127.0.0.1:10000/api/status`
|
||||
|
||||
### 2.1 准备登录账号 JSON(必需)
|
||||
|
||||
Web 登录会读取本地账号文件,默认路径:
|
||||
|
||||
- `/home/ivan/xuan/steven_project/nanobot/web_auth_users.json`
|
||||
|
||||
示例内容(任选一种格式):
|
||||
|
||||
```json
|
||||
{
|
||||
"users": [
|
||||
{ "username": "admin", "password": "123456" }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"admin": "123456",
|
||||
"alice": "alice_pwd"
|
||||
}
|
||||
```
|
||||
|
||||
也可通过环境变量指定自定义路径:
|
||||
|
||||
```bash
|
||||
export NANOBOT_AUTH_FILE=/your/path/users.json
|
||||
```
|
||||
|
||||
## 3. 启动前端(Next.js)
|
||||
|
||||
新开一个终端,执行:
|
||||
|
||||
```bash
|
||||
cd /home/ivan/xuan/steven_project/nanobot/frontend
|
||||
cp env_template .env.local
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
前端默认地址:
|
||||
|
||||
- `http://127.0.0.1:3080`
|
||||
|
||||
前端默认会请求:
|
||||
|
||||
- `NEXT_PUBLIC_API_URL=http://127.0.0.1:10000`
|
||||
|
||||
注意:如果你之前已经有 `frontend/.env.local`,请确认里面不是旧地址(例如 `localhost:8080`)。
|
||||
|
||||
如果你要改后端地址,修改:
|
||||
|
||||
- `frontend/.env.local`
|
||||
|
||||
## 4. 访问与验证
|
||||
|
||||
1. 打开 `http://127.0.0.1:3080`
|
||||
2. 首屏应进入登录页
|
||||
3. 使用 `web_auth_users.json` 中正确的账号密码登录
|
||||
4. 登录成功后进入对话页并可正常收发消息
|
||||
|
||||
## 5. 常见问题
|
||||
|
||||
### 5.1 前端显示“未连接/服务离线”
|
||||
|
||||
按顺序检查:
|
||||
|
||||
1. 后端是否在运行(终端是否有 `Uvicorn running ...`)
|
||||
2. 前端 `NEXT_PUBLIC_API_URL` 是否指向正确地址
|
||||
3. 端口是否被占用(`10000` / `3080`)
|
||||
|
||||
### 5.2 后端启动报 `No module named fastapi`
|
||||
|
||||
在项目根目录重新执行:
|
||||
|
||||
### 5.3 反向代理下登录后跳错前端域名
|
||||
|
||||
如果 API 域名和主前端域名不同,启动 backend 前显式设置主前端公开地址:
|
||||
|
||||
```bash
|
||||
export NANOBOT_FRONTEND_PUBLIC_BASE_URL=https://nanobot.bwgdi.com
|
||||
```
|
||||
|
||||
这样登录/注册成功后,backend 返回的 `frontend_base_url` 会固定为这个公开域名,而不是按 API 域名去拼 `:3080`。
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
### 5.3 需要开发测试工具(pytest/ruff)
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
```
|
||||
|
||||
## 6. 停止服务
|
||||
|
||||
- 在各自终端按 `Ctrl + C` 即可停止。
|
||||
@ -1,7 +0,0 @@
|
||||
"""
|
||||
Boardware Genius - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.4"
|
||||
__brand__ = "Boardware Genius"
|
||||
__logo__ = ""
|
||||
@ -1,8 +0,0 @@
|
||||
"""
|
||||
Entry point for running nanobot as a module: python -m nanobot
|
||||
"""
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@ -1,5 +0,0 @@
|
||||
"""A2A helpers."""
|
||||
|
||||
from nanobot.a2a.client import A2AClient
|
||||
|
||||
__all__ = ["A2AClient"]
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,35 +0,0 @@
|
||||
"""agent 核心模块导出入口。
|
||||
|
||||
这里刻意改成懒加载导出:
|
||||
1. 避免 `nanobot.agent` 被导入时立即拉起一整串重量级依赖;
|
||||
2. 降低循环导入概率,特别是 `loop/context/skills` 之间的交叉引用;
|
||||
3. 保持对外 API 不变,调用方仍然可以 `from nanobot.agent import AgentLoop`。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# 只有访问某个导出符号时才真正 import 对应模块,避免 import-time 副作用。
|
||||
if name == "AgentLoop":
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
return AgentLoop
|
||||
if name == "ContextBuilder":
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
return ContextBuilder
|
||||
if name == "MemoryStore":
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
return MemoryStore
|
||||
if name == "SkillsLoader":
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
return SkillsLoader
|
||||
# 交给 Python 默认语义处理不存在的导出名。
|
||||
raise AttributeError(name)
|
||||
@ -1,419 +0,0 @@
|
||||
"""统一 agent 注册表。
|
||||
|
||||
这个模块把当前工作区里“可被委派”的执行体统一抽象成 `AgentDescriptor`:
|
||||
1. workspace 手工登记的远端 A2A agent;
|
||||
2. plugin 提供的本地 prompt agent;
|
||||
3. skill 元数据里声明的 agent cards;
|
||||
4. 内置 local fallback agent。
|
||||
|
||||
上层委派逻辑只和 `AgentDescriptor` 打交道,不需要关心来源细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.plugins import PluginLoader
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
_TOKEN_RE = re.compile(r"[a-z0-9_-]+")
|
||||
_CJK_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDescriptor:
|
||||
"""委派层使用的统一 agent 描述对象。"""
|
||||
|
||||
# 稳定 ID,供路由、持久化和精确匹配使用。
|
||||
id: str
|
||||
# 面向 UI/日志的展示名。
|
||||
name: str
|
||||
# 简短说明,主要供模型和前端展示。
|
||||
description: str
|
||||
# 来源类型:builtin / plugin / skill / workspace。
|
||||
source: str
|
||||
# 运行方式:local_prompt / local_fallback / a2a_remote 等。
|
||||
kind: str
|
||||
# 底层协议,目前主要是 a2a 或 None。
|
||||
protocol: str | None = None
|
||||
plugin_name: str | None = None
|
||||
skill_name: str | None = None
|
||||
model: str | None = None
|
||||
system_prompt: str | None = None
|
||||
endpoint: str | None = None
|
||||
base_url: str | None = None
|
||||
card_url: str | None = None
|
||||
auth_env: str | None = None
|
||||
auth_mode: str = "none"
|
||||
auth_audience: str | None = None
|
||||
auth_scopes: list[str] = field(default_factory=list)
|
||||
enabled: bool = True
|
||||
tags: list[str] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
support_streaming: bool = False
|
||||
|
||||
def matches(self, target: str) -> bool:
|
||||
"""判断给定目标字符串是否命中当前 agent。"""
|
||||
probe = (target or "").strip().lower()
|
||||
if not probe:
|
||||
return False
|
||||
# 同时支持按 id / name / alias 命中,方便模型用自然语言近似引用。
|
||||
candidates = {self.id.lower(), self.name.lower()}
|
||||
candidates.update(alias.lower() for alias in self.aliases if alias)
|
||||
return probe in candidates
|
||||
|
||||
def searchable_text(self) -> str:
|
||||
"""构造一段用于简单相关性匹配的可搜索文本。"""
|
||||
fields = [
|
||||
self.id,
|
||||
self.name,
|
||||
self.description,
|
||||
" ".join(self.tags),
|
||||
" ".join(self.aliases),
|
||||
self.plugin_name or "",
|
||||
self.skill_name or "",
|
||||
]
|
||||
return " ".join(part for part in fields if part).lower()
|
||||
|
||||
def public_dict(self) -> dict[str, Any]:
|
||||
"""导出给前端使用的安全字典。"""
|
||||
data = asdict(self)
|
||||
# system_prompt 属于内部实现细节,不应默认暴露给前端。
|
||||
data.pop("system_prompt", None)
|
||||
return data
|
||||
|
||||
|
||||
class WorkspaceAgentStore:
|
||||
"""workspace 级 agent 存储。
|
||||
|
||||
这里保存的是用户在 Web UI 或本地配置里手工登记的 agent,
|
||||
文件位置固定为 `<workspace>/agents/registry.json`。
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
# 单独放到 `agents/` 目录,便于和 skills / memory / files 等目录职责分离。
|
||||
self.directory = workspace / "agents"
|
||||
self.path = self.directory / "registry.json"
|
||||
|
||||
def list_agents(self) -> list[dict[str, Any]]:
|
||||
"""读取并返回所有手工登记 agent。"""
|
||||
if not self.path.exists():
|
||||
return []
|
||||
try:
|
||||
raw = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError, ValueError):
|
||||
# 存储损坏时不抛异常拖垮主流程,直接视为空。
|
||||
return []
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
result: list[dict[str, Any]] = []
|
||||
for item in raw:
|
||||
# 仅接受带 id 的对象,保证后续 registry 至少有稳定主键。
|
||||
if isinstance(item, dict) and item.get("id"):
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
def save_agents(self, agents: list[dict[str, Any]]) -> None:
|
||||
"""将 agent 列表完整覆写到 registry 文件。"""
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
self.path.write_text(
|
||||
json.dumps(agents, indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def upsert_agent(self, agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""按 id 新增或更新一个 agent 记录。"""
|
||||
record = dict(agent)
|
||||
agent_id = str(record.get("id", "")).strip()
|
||||
if not agent_id:
|
||||
raise ValueError("Agent id is required")
|
||||
record["id"] = agent_id
|
||||
# 对基础展示字段做最小兜底,避免后续 UI 或提示词出现空值。
|
||||
record.setdefault("name", agent_id)
|
||||
record.setdefault("description", record["name"])
|
||||
record.setdefault("protocol", "a2a")
|
||||
record.setdefault("enabled", True)
|
||||
record.setdefault("tags", [])
|
||||
# 先剔除旧记录再 append,最后统一排序,保持存储文件稳定可读。
|
||||
agents = [a for a in self.list_agents() if a.get("id") != agent_id]
|
||||
agents.append(record)
|
||||
agents.sort(key=lambda item: item.get("id", "").lower())
|
||||
self.save_agents(agents)
|
||||
return record
|
||||
|
||||
def delete_agent(self, agent_id: str) -> bool:
|
||||
"""按 id 删除一个 agent,删除成功返回 True。"""
|
||||
target = agent_id.strip()
|
||||
if not target:
|
||||
return False
|
||||
agents = self.list_agents()
|
||||
filtered = [a for a in agents if a.get("id") != target]
|
||||
if len(filtered) == len(agents):
|
||||
return False
|
||||
self.save_agents(filtered)
|
||||
return True
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""构建并查询当前可委派 agent 集合。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
plugins: PluginLoader | None = None,
|
||||
skills: SkillsLoader | None = None,
|
||||
allow_skill_cards: bool = True,
|
||||
allow_workspace_agents: bool = True,
|
||||
include_local_fallback: bool = True,
|
||||
include_plugin_agents: bool = True,
|
||||
):
|
||||
self.workspace = workspace
|
||||
# 插件和技能加载器允许外部复用同一个实例,避免重复扫描磁盘。
|
||||
self.plugins = plugins or PluginLoader(workspace)
|
||||
self.skills = skills or SkillsLoader(workspace, extra_dirs=self.plugins.get_skill_dirs())
|
||||
self.allow_skill_cards = allow_skill_cards
|
||||
self.allow_workspace_agents = allow_workspace_agents
|
||||
self.include_local_fallback = include_local_fallback
|
||||
self.include_plugin_agents = include_plugin_agents
|
||||
self.workspace_store = WorkspaceAgentStore(workspace)
|
||||
|
||||
def list_agents(self, include_local_fallback: bool | None = None) -> list[AgentDescriptor]:
|
||||
"""按统一格式列出当前可见 agent。"""
|
||||
if include_local_fallback is None:
|
||||
include_local_fallback = self.include_local_fallback
|
||||
agents: list[AgentDescriptor] = []
|
||||
|
||||
if self.allow_workspace_agents:
|
||||
for record in self.workspace_store.list_agents():
|
||||
if not record.get("enabled", True):
|
||||
continue
|
||||
agent = self._workspace_record_to_descriptor(record)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
|
||||
# plugin agents 本质上是“带独立系统提示词的本地执行器”。
|
||||
if self.include_plugin_agents:
|
||||
for plugin in self.plugins.plugins.values():
|
||||
for agent in plugin.agents.values():
|
||||
agents.append(
|
||||
AgentDescriptor(
|
||||
id=f"plugin:{agent.name}",
|
||||
name=agent.name,
|
||||
description=agent.description or agent.name,
|
||||
source="plugin",
|
||||
kind="local_prompt",
|
||||
protocol=None,
|
||||
plugin_name=agent.plugin_name,
|
||||
model=agent.model,
|
||||
system_prompt=agent.system_prompt,
|
||||
aliases=[agent.name],
|
||||
metadata={"plugin_name": agent.plugin_name},
|
||||
)
|
||||
)
|
||||
|
||||
if self.allow_skill_cards:
|
||||
# skill 里声明的 card 视为远端 A2A agent 的静态入口。
|
||||
for card in self.skills.list_skill_agent_cards():
|
||||
agent = self._skill_card_to_descriptor(card)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
|
||||
if include_local_fallback:
|
||||
# 永远保留一个本地兜底执行器,确保自动路由时至少有可执行目标。
|
||||
agents.append(
|
||||
AgentDescriptor(
|
||||
id="local-subagent",
|
||||
name="Local Subagent",
|
||||
description="Local fallback agent that can use files, shell, and web tools.",
|
||||
source="builtin",
|
||||
kind="local_fallback",
|
||||
protocol=None,
|
||||
aliases=["subagent", "local"],
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[str] = set()
|
||||
result: list[AgentDescriptor] = []
|
||||
for agent in agents:
|
||||
# 去重规则按 id 小写匹配,优先保留先出现的来源。
|
||||
key = agent.id.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(agent)
|
||||
return result
|
||||
|
||||
def get_agent(self, target: str) -> AgentDescriptor | None:
|
||||
"""按 id / name / alias 获取单个 agent。"""
|
||||
probe = (target or "").strip()
|
||||
if not probe:
|
||||
return None
|
||||
for agent in self.list_agents():
|
||||
if agent.matches(probe):
|
||||
return agent
|
||||
return None
|
||||
|
||||
def suggest_agents(self, query: str, limit: int = 5) -> list[AgentDescriptor]:
|
||||
"""基于简单词项打分为一段任务文本推荐 agent。"""
|
||||
query_text = query or ""
|
||||
query_lower = query_text.lower()
|
||||
tokens = {token for token in _TOKEN_RE.findall(query_lower) if len(token) > 2}
|
||||
query_cjk_bigrams = self._cjk_bigrams(query_text)
|
||||
|
||||
scored: list[tuple[int, AgentDescriptor]] = []
|
||||
for agent in self.list_agents(include_local_fallback=False):
|
||||
haystack = agent.searchable_text()
|
||||
haystack_cjk_bigrams = self._cjk_bigrams(haystack)
|
||||
score = 0
|
||||
for token in tokens:
|
||||
# token 命中一次给基础分。
|
||||
if token in haystack:
|
||||
score += 2
|
||||
# 如果查询里直接出现了 agent 名或 id,再给更高权重。
|
||||
if agent.name.lower() in query_lower or agent.id.lower() in query_lower:
|
||||
score += 5
|
||||
for phrase in [agent.name, agent.id, *agent.tags, *agent.aliases]:
|
||||
phrase_text = str(phrase or "").strip()
|
||||
if not phrase_text:
|
||||
continue
|
||||
if phrase_text.lower() in query_lower or phrase_text in query_text:
|
||||
score += 3
|
||||
if query_cjk_bigrams and haystack_cjk_bigrams:
|
||||
# 中文任务没有空格分词,先用 bigram overlap 做粗粒度召回。
|
||||
score += min(6, len(query_cjk_bigrams & haystack_cjk_bigrams))
|
||||
if score > 0:
|
||||
scored.append((score, agent))
|
||||
|
||||
scored.sort(key=lambda item: (-item[0], item[1].name.lower()))
|
||||
return [agent for _, agent in scored[:limit]]
|
||||
|
||||
@staticmethod
|
||||
def _cjk_bigrams(text: str) -> set[str]:
|
||||
"""提取中文 bigram,用于中文任务的轻量召回。"""
|
||||
chunks = _CJK_RE.findall(str(text or ""))
|
||||
result: set[str] = set()
|
||||
for chunk in chunks:
|
||||
if len(chunk) == 1:
|
||||
result.add(chunk)
|
||||
continue
|
||||
for index in range(len(chunk) - 1):
|
||||
result.add(chunk[index:index + 2])
|
||||
return result
|
||||
|
||||
def build_agents_summary(self) -> str:
|
||||
"""把 agent 列表格式化成 prompt 可直接嵌入的 XML 片段。"""
|
||||
agents = self.list_agents()
|
||||
if not agents:
|
||||
return ""
|
||||
|
||||
def esc(value: str) -> str:
|
||||
# 这里手工转义最基础的 XML 特殊字符,避免描述文本破坏结构。
|
||||
return (
|
||||
value.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
|
||||
lines = ["<agents>"]
|
||||
for agent in agents:
|
||||
lines.append(" <agent>")
|
||||
lines.append(f" <id>{esc(agent.id)}</id>")
|
||||
lines.append(f" <name>{esc(agent.name)}</name>")
|
||||
lines.append(f" <source>{esc(agent.source)}</source>")
|
||||
lines.append(f" <kind>{esc(agent.kind)}</kind>")
|
||||
lines.append(f" <description>{esc(agent.description)}</description>")
|
||||
if agent.protocol:
|
||||
lines.append(f" <protocol>{esc(agent.protocol)}</protocol>")
|
||||
if agent.tags:
|
||||
lines.append(f" <tags>{esc(', '.join(agent.tags))}</tags>")
|
||||
lines.append(" </agent>")
|
||||
lines.append("</agents>")
|
||||
return "\n".join(lines)
|
||||
|
||||
def list_public_agents(self) -> list[dict[str, Any]]:
|
||||
"""列出脱敏后的 agent 结构,供 Web API 使用。"""
|
||||
return [agent.public_dict() for agent in self.list_agents()]
|
||||
|
||||
def _workspace_record_to_descriptor(self, record: dict[str, Any]) -> AgentDescriptor | None:
|
||||
"""把 workspace registry 里的原始记录转成统一描述对象。"""
|
||||
protocol = str(record.get("protocol") or "a2a").lower()
|
||||
if protocol != "a2a":
|
||||
# 当前仅支持把 workspace 记录解释成 A2A agent。
|
||||
return None
|
||||
agent_id = str(record.get("id", "")).strip()
|
||||
if not agent_id:
|
||||
return None
|
||||
name = str(record.get("name") or agent_id)
|
||||
return AgentDescriptor(
|
||||
id=agent_id,
|
||||
name=name,
|
||||
description=str(record.get("description") or name),
|
||||
source="workspace",
|
||||
kind="a2a_remote",
|
||||
protocol="a2a",
|
||||
endpoint=record.get("endpoint") or record.get("base_url"),
|
||||
base_url=record.get("base_url") or record.get("endpoint"),
|
||||
card_url=record.get("card_url"),
|
||||
auth_env=record.get("auth_env"),
|
||||
auth_mode=str(record.get("auth_mode") or "none").strip().lower() or "none",
|
||||
auth_audience=(str(record.get("auth_audience") or "").strip() or None),
|
||||
auth_scopes=[
|
||||
str(scope).strip()
|
||||
for scope in record.get("auth_scopes", [])
|
||||
if str(scope).strip()
|
||||
],
|
||||
enabled=bool(record.get("enabled", True)),
|
||||
tags=[str(tag) for tag in record.get("tags", []) if str(tag).strip()],
|
||||
aliases=[
|
||||
alias
|
||||
for alias in [record.get("name"), *record.get("aliases", [])]
|
||||
if isinstance(alias, str) and alias.strip()
|
||||
],
|
||||
capabilities=record.get("capabilities", {}) if isinstance(record.get("capabilities"), dict) else {},
|
||||
metadata=record.get("metadata", {}) if isinstance(record.get("metadata"), dict) else {},
|
||||
support_streaming=bool(record.get("support_streaming", False)),
|
||||
)
|
||||
|
||||
def _skill_card_to_descriptor(self, card: dict[str, Any]) -> AgentDescriptor | None:
|
||||
"""把 skill frontmatter 中的 agent card 转成统一描述对象。"""
|
||||
card_id = str(card.get("id") or "").strip()
|
||||
skill_name = str(card.get("skill_name") or "").strip()
|
||||
if not card_id:
|
||||
return None
|
||||
name = str(card.get("name") or card_id)
|
||||
return AgentDescriptor(
|
||||
id=card_id,
|
||||
name=name,
|
||||
description=str(card.get("description") or name),
|
||||
source="skill",
|
||||
kind="a2a_remote",
|
||||
protocol="a2a",
|
||||
skill_name=skill_name or None,
|
||||
endpoint=card.get("endpoint") or card.get("base_url"),
|
||||
base_url=card.get("base_url") or card.get("endpoint"),
|
||||
card_url=card.get("url") or card.get("card_url"),
|
||||
auth_env=card.get("auth_env"),
|
||||
auth_mode=str(card.get("auth_mode") or "none").strip().lower() or "none",
|
||||
auth_audience=(str(card.get("auth_audience") or "").strip() or None),
|
||||
auth_scopes=[
|
||||
str(scope).strip()
|
||||
for scope in card.get("auth_scopes", [])
|
||||
if str(scope).strip()
|
||||
],
|
||||
tags=[str(tag) for tag in card.get("tags", []) if str(tag).strip()],
|
||||
aliases=[
|
||||
alias
|
||||
for alias in [card.get("name"), *card.get("aliases", [])]
|
||||
if isinstance(alias, str) and alias.strip()
|
||||
],
|
||||
capabilities=card.get("capabilities", {}) if isinstance(card.get("capabilities"), dict) else {},
|
||||
metadata=card.get("metadata", {}) if isinstance(card.get("metadata"), dict) else {},
|
||||
support_streaming=bool(card.get("support_streaming", False)),
|
||||
)
|
||||
@ -1,257 +0,0 @@
|
||||
"""上下文构建器:负责为每次 LLM 调用组装完整消息上下文。
|
||||
|
||||
本模块主要做三件事:
|
||||
1. 生成 system prompt(身份、运行时信息、bootstrap 文件、记忆、技能摘要);
|
||||
2. 将历史消息与当前用户输入拼接成模型可消费的 messages;
|
||||
3. 在工具调用循环中追加 assistant/tool 消息,维持对话状态连续性。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.agent_registry import AgentRegistry
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""
|
||||
Agent 上下文装配器。
|
||||
|
||||
设计目标:
|
||||
- 把“静态配置”(AGENTS/USER/TOOLS 等)与“动态上下文”(时间、会话、历史)统一拼装;
|
||||
- 保持 prompt 结构稳定,降低模型行为波动;
|
||||
- 让工具调用前后的消息追加逻辑集中在一个位置,便于维护。
|
||||
"""
|
||||
|
||||
# bootstrap 文件按此顺序加载并拼接,顺序会影响最终提示词语义优先级。
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
skills_loader: SkillsLoader | None = None,
|
||||
agent_registry: AgentRegistry | None = None,
|
||||
):
|
||||
self.workspace = workspace
|
||||
# 记忆与技能都按 workspace 维度隔离,避免跨项目污染。
|
||||
self.memory = MemoryStore(workspace)
|
||||
# 若上层已构造好 SkillsLoader / AgentRegistry,则复用,避免重复扫描磁盘。
|
||||
self.skills = skills_loader or SkillsLoader(workspace)
|
||||
# agent_registry 可选:只有支持多 agent 委派时才会把可用 agent 摘要塞进 prompt。
|
||||
self.agent_registry = agent_registry
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
skill_names: list[str] | None = None,
|
||||
execution_context: str | None = None,
|
||||
) -> str:
|
||||
"""构建 system prompt(身份 + 配置 + 记忆 + 技能信息)。"""
|
||||
# skill_names 目前作为接口预留,便于未来按需只加载指定技能。
|
||||
parts = []
|
||||
|
||||
# 1) 核心身份段:包含当前时间、系统环境、工作区路径等动态信息。
|
||||
parts.append(self._get_identity())
|
||||
|
||||
# 2) workspace 里的 bootstrap 文件(若存在)按顺序拼接。
|
||||
bootstrap = self._load_bootstrap_files()
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
# 3) 长期记忆上下文(来自 memory/MEMORY.md 等)。
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory:
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
|
||||
# 4) 技能采用“渐进加载”策略。
|
||||
# 4.1 always 技能:直接把完整内容塞进当前 prompt。
|
||||
always_skills = self.skills.get_always_skills()
|
||||
if always_skills:
|
||||
always_content = self.skills.load_skills_for_context(always_skills)
|
||||
if always_content:
|
||||
parts.append(f"# Active Skills\n\n{always_content}")
|
||||
|
||||
# 4.2 可用技能:只放摘要,具体内容让 agent 运行时按需 read_file。
|
||||
# 这样可以控制 token 体积,避免把所有技能全文塞入上下文。
|
||||
skills_summary = self.skills.build_skills_summary()
|
||||
if skills_summary:
|
||||
parts.append(f"""# Skills
|
||||
|
||||
The following skills extend your capabilities. To use a skill, read its SKILL.md file using the read_file tool.
|
||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
||||
|
||||
{skills_summary}""")
|
||||
|
||||
if self.agent_registry:
|
||||
parts.append("""# Delegation Tools
|
||||
|
||||
Use `spawn_subagent` when the task should go to one delegated worker.
|
||||
Use `spawn_agent_team` when the task should be explored in parallel by multiple workers.
|
||||
At the top level, you do not need to choose concrete downstream agents.
|
||||
Use the `skills` argument when the delegated worker or team must follow specific skills.""")
|
||||
|
||||
if execution_context:
|
||||
# `execution_context` 用于 cron / system task 这类“不是普通用户消息”的额外运行说明。
|
||||
parts.append(f"# Execution Context\n\n{execution_context.strip()}")
|
||||
|
||||
# 各块之间用分隔线拼接,提升提示词可读性与结构稳定性。
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _get_identity(self) -> str:
|
||||
"""生成核心身份段。"""
|
||||
import time as _time
|
||||
from datetime import datetime
|
||||
# 时间与时区在 system prompt 中显式给出,减少模型对“当前时间”的猜测。
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = _time.strftime("%Z") or "UTC"
|
||||
# 固化绝对工作区路径,帮助模型生成更准确的文件操作指令。
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
# 运行时信息可帮助模型在跨平台命令选择时更稳健(如 macOS/Linux 差异)。
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
|
||||
return f"""# Boardware Genius
|
||||
|
||||
You are Boardware Genius, a helpful AI assistant.
|
||||
|
||||
## Current Time
|
||||
{now} ({tz})
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
- Long-term memory: {workspace_path}/memory/MEMORY.md
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.
|
||||
|
||||
## Tool Call Guidelines
|
||||
- Before calling tools, you may briefly state your intent (e.g. "Let me check that"), but NEVER predict or describe the expected result before receiving it.
|
||||
- Before modifying a file, read it first to confirm its current content.
|
||||
- Do not assume a file or directory exists — use list_dir or read_file to verify.
|
||||
- After writing or editing a file, re-read it if accuracy matters.
|
||||
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||
- Do not write directly into `{workspace_path}/skills`; new or updated skills must go through the review flow before activation.
|
||||
|
||||
## Delegation Policy
|
||||
- Solve simple tasks yourself when the work is short, direct, and does not benefit from delegation.
|
||||
- Delegate only when the task is complex, multi-step, time-consuming, or benefits from specialized/parallel work.
|
||||
- Use `spawn_subagent` for one focused delegated worker when only the final result matters.
|
||||
- Use `spawn_agent_team` when multiple agents should explore the task in parallel, compare findings, or work across separate areas.
|
||||
- Do not delegate by default if you can complete the task reliably in the current turn.
|
||||
- Do not create or modify persistent local sub-agents unless the user explicitly asks for a reusable long-lived worker.
|
||||
|
||||
## Memory
|
||||
- Remember important facts: write to {workspace_path}/memory/MEMORY.md
|
||||
- Recall past events: grep {workspace_path}/memory/HISTORY.md"""
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""从 workspace 读取 bootstrap 文件并拼接。"""
|
||||
parts = []
|
||||
|
||||
for filename in self.BOOTSTRAP_FILES:
|
||||
file_path = self.workspace / filename
|
||||
if file_path.exists():
|
||||
# 缺失文件时静默跳过,保持默认可用。
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
parts.append(f"## {filename}\n\n{content}")
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
skill_names: list[str] | None = None,
|
||||
execution_context: str | None = None,
|
||||
media: list[str] | None = None,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""构建一次 LLM 调用的完整 messages 数组。"""
|
||||
messages = []
|
||||
|
||||
# 第 1 条固定是 system prompt。
|
||||
system_prompt = self.build_system_prompt(skill_names, execution_context=execution_context)
|
||||
if channel and chat_id:
|
||||
# 把当前会话路由信息也写入系统提示,便于模型做跨渠道决策。
|
||||
system_prompt += f"\n\n## Current Session\nChannel: {channel}\nChat ID: {chat_id}"
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 追加历史消息(通常已由 SessionManager 做窗口与清洗)。
|
||||
messages.extend(history)
|
||||
|
||||
# 追加当前用户输入;若带图片则转换为多模态 content 结构。
|
||||
user_content = self._build_user_content(current_message, media)
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
return messages
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""构建 user content,支持文本或“文本+图片”多模态格式。"""
|
||||
# 无媒体时直接走纯文本,保持最简单路径。
|
||||
if not media:
|
||||
return text
|
||||
|
||||
images = []
|
||||
for path in media:
|
||||
p = Path(path)
|
||||
mime, _ = mimetypes.guess_type(path)
|
||||
# 仅接收本地图片文件,其他媒体类型暂不注入到模型内容。
|
||||
if not p.is_file() or not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
# 按 data URL 形式内联图片,兼容支持 image_url 的 provider 接口。
|
||||
b64 = base64.b64encode(p.read_bytes()).decode()
|
||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||
|
||||
# 没有合法图片时回退纯文本,避免传空数组导致模型侧解析异常。
|
||||
if not images:
|
||||
return text
|
||||
# 多模态结构中把图片放前、文本放后,便于模型先“看图”再读文字指令。
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""把工具执行结果追加到 messages。"""
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result
|
||||
})
|
||||
return messages
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""把 assistant 消息追加到 messages(可携带 tool_calls/reasoning)。"""
|
||||
msg: dict[str, Any] = {"role": "assistant"}
|
||||
|
||||
# 始终写入 content 键:
|
||||
# 部分 provider 在 key 缺失时会拒绝请求(即使值是 None 也要有该键)。
|
||||
msg["content"] = content
|
||||
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
|
||||
# reasoning_content 是“思考模型”专用字段,仅在有值时附加。
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
|
||||
messages.append(msg)
|
||||
return messages
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,813 +0,0 @@
|
||||
"""Agent 主循环:Boardware Genius 的核心处理引擎。
|
||||
|
||||
职责概览:
|
||||
1. 从消息总线读取入站消息;
|
||||
2. 结合会话历史、记忆与工作区上下文构建提示词;
|
||||
3. 调用 LLM 并迭代执行工具调用;
|
||||
4. 将结果写回会话并发布出站消息;
|
||||
5. 在后台处理记忆归档与 MCP 工具连接生命周期。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.agent_registry import AgentRegistry
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.delegation import DelegationManager
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.plugins import PluginLoader
|
||||
from nanobot.agent.process_events import process_event_sink
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import DelegationTool, SpawnAgentTeamTool, SpawnSubagentTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import A2AConfig, ChannelsConfig, ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
AgentLoop 是 Boardware Genius 运行时的“对话编排器”。
|
||||
|
||||
一次标准处理链路:
|
||||
1. 接收入站消息(来自 CLI 或外部渠道);
|
||||
2. 恢复对应会话并构建当前轮上下文;
|
||||
3. 调用模型,解析工具调用并执行;
|
||||
4. 将本轮新增消息写入会话;
|
||||
5. 输出最终回复(或由消息工具自行发送)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 100,
|
||||
brave_api_key: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
a2a_config: "A2AConfig | None" = None,
|
||||
cron_service: CronService | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
session_manager: SessionManager | None = None,
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
authz_config: Any | None = None,
|
||||
backend_identity: Any | None = None,
|
||||
allow_spawn: bool = True,
|
||||
allow_message: bool = True,
|
||||
allow_cron: bool = True,
|
||||
include_local_fallback: bool = True,
|
||||
allow_local_delegation: bool = True,
|
||||
allow_plugin_delegation: bool = True,
|
||||
include_plugin_agents: bool = True,
|
||||
gateway_port: int = 18790,
|
||||
):
|
||||
from nanobot.config.schema import A2AConfig, ExecToolConfig
|
||||
# 基础依赖与运行参数。
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.memory_window = memory_window
|
||||
self.brave_api_key = brave_api_key
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.a2a_config = a2a_config or A2AConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.authz_config = authz_config
|
||||
self.backend_identity = backend_identity
|
||||
self.allow_spawn = allow_spawn
|
||||
self.allow_message = allow_message
|
||||
self.allow_cron = allow_cron
|
||||
self.include_local_fallback = include_local_fallback
|
||||
self.allow_local_delegation = allow_local_delegation
|
||||
self.allow_plugin_delegation = allow_plugin_delegation
|
||||
self.include_plugin_agents = include_plugin_agents
|
||||
|
||||
# 核心组件:上下文构建、会话管理、工具注册、子代理管理。
|
||||
self.plugins = PluginLoader(workspace)
|
||||
# SkillsLoader 需要感知 plugin 附带的 skill 目录,因此单独抽到 helper 构建。
|
||||
self.skills = self._build_skills_loader()
|
||||
self.agent_registry = AgentRegistry(
|
||||
workspace,
|
||||
plugins=self.plugins,
|
||||
skills=self.skills,
|
||||
allow_skill_cards=self.a2a_config.allow_skill_cards,
|
||||
allow_workspace_agents=self.a2a_config.allow_workspace_agents,
|
||||
include_local_fallback=self.include_local_fallback,
|
||||
include_plugin_agents=self.include_plugin_agents,
|
||||
)
|
||||
self.context = ContextBuilder(
|
||||
workspace,
|
||||
skills_loader=self.skills,
|
||||
agent_registry=self.agent_registry,
|
||||
)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
brave_api_key=brave_api_key,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
self.delegation = DelegationManager(
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
registry=self.agent_registry,
|
||||
skills_loader=self.skills,
|
||||
local_executor=self.subagents,
|
||||
timeout_seconds=self.a2a_config.timeout_seconds,
|
||||
poll_interval_seconds=self.a2a_config.poll_interval_seconds,
|
||||
card_cache_ttl_seconds=self.a2a_config.card_cache_ttl_seconds,
|
||||
max_parallel_agents=self.a2a_config.max_parallel_agents,
|
||||
allowed_hosts=self.a2a_config.allowed_hosts,
|
||||
authz_config=self.authz_config,
|
||||
backend_identity=self.backend_identity,
|
||||
allow_local_delegation=self.allow_local_delegation,
|
||||
allow_plugin_delegation=self.allow_plugin_delegation,
|
||||
allow_local_fallback=self.include_local_fallback,
|
||||
gateway_port=gateway_port,
|
||||
)
|
||||
self.subagents.set_nested_delegate(self.delegation)
|
||||
|
||||
# 运行时状态位。
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
# `_mcp_report` 保存最近一次连接结果,供 Web API 展示状态和错误信息。
|
||||
self._mcp_report: dict[str, dict[str, Any]] = {}
|
||||
# 会话级记忆归档控制:避免同一会话并发归档。
|
||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||
self._register_default_tools()
|
||||
|
||||
def apply_runtime_config(self, *, authz_config: Any | None, backend_identity: Any | None) -> None:
|
||||
"""同步运行中 loop 的鉴权上下文,避免变更后必须重启。"""
|
||||
self.authz_config = authz_config
|
||||
self.backend_identity = backend_identity
|
||||
self.delegation.a2a_client.authz_config = authz_config
|
||||
self.delegation.a2a_client.backend_identity = backend_identity
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""注册默认工具集合。"""
|
||||
# 启用工作区限制时,文件读写工具仅允许访问 workspace 目录树。
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
protected_skill_paths = [self.workspace / "skills"]
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(
|
||||
WriteFileTool(
|
||||
workspace=self.workspace,
|
||||
allowed_dir=allowed_dir,
|
||||
protected_paths=protected_skill_paths,
|
||||
)
|
||||
)
|
||||
self.tools.register(
|
||||
EditFileTool(
|
||||
workspace=self.workspace,
|
||||
allowed_dir=allowed_dir,
|
||||
protected_paths=protected_skill_paths,
|
||||
)
|
||||
)
|
||||
|
||||
# Shell 工具独立配置超时与目录约束。
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
protected_paths=protected_skill_paths,
|
||||
))
|
||||
|
||||
# 网络、消息、委派工具按职责注册。
|
||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||
self.tools.register(WebFetchTool())
|
||||
if self.allow_message:
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
if self.allow_spawn:
|
||||
self.tools.register(SpawnSubagentTool(manager=self.delegation))
|
||||
self.tools.register(SpawnAgentTeamTool(manager=self.delegation))
|
||||
|
||||
# 只有注入 cron_service 时才暴露 cron 工具,避免空引用。
|
||||
if self.cron_service and self.allow_cron:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
"""懒加载连接 MCP 服务器(单次连接,失败可重试)。"""
|
||||
# 已连接 / 正在连接 / 未配置时直接返回。
|
||||
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
try:
|
||||
# 用 AsyncExitStack 统一托管各 MCP 连接的退出清理。
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
self._mcp_report = await connect_mcp_servers(
|
||||
self._mcp_servers,
|
||||
self.tools,
|
||||
self._mcp_stack,
|
||||
authz_config=self.authz_config,
|
||||
backend_identity=self.backend_identity,
|
||||
)
|
||||
self._mcp_connected = any(item.get("status") == "connected" for item in self._mcp_report.values())
|
||||
except Exception as e:
|
||||
# 失败后保留可重试能力:释放已建立资源,下一条消息再尝试连接。
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_report = {
|
||||
name: {
|
||||
"status": "error",
|
||||
"last_error": str(e),
|
||||
"tool_names": [],
|
||||
"tool_count": 0,
|
||||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||||
}
|
||||
for name, cfg in self._mcp_servers.items()
|
||||
}
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _clear_mcp_tools(self) -> None:
|
||||
"""移除当前 registry 里所有 MCP 工具包装器。"""
|
||||
for tool_name in list(self.tools.tool_names):
|
||||
if tool_name.startswith("mcp_"):
|
||||
self.tools.unregister(tool_name)
|
||||
|
||||
async def reload_mcp_servers(self, mcp_servers: dict | None) -> None:
|
||||
"""替换 MCP 配置并按新配置重新连接。"""
|
||||
# 先彻底关闭旧连接并移除旧工具,避免新旧配置混杂。
|
||||
await self.close_mcp()
|
||||
self._clear_mcp_tools()
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._mcp_report = {}
|
||||
if self._mcp_servers:
|
||||
await self._connect_mcp()
|
||||
|
||||
def get_mcp_servers_view(self) -> list[dict[str, Any]]:
|
||||
"""返回 MCP 静态配置与运行态状态合并后的视图。"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for name in sorted(self._mcp_servers):
|
||||
cfg = self._mcp_servers[name]
|
||||
report = self._mcp_report.get(name, {})
|
||||
sensitive = bool(getattr(cfg, "sensitive", False))
|
||||
tool_names = report.get("tool_names")
|
||||
if not isinstance(tool_names, list):
|
||||
# 若当前 report 不完整,则退化为扫描已注册工具名进行推断。
|
||||
tool_names = [
|
||||
item
|
||||
for item in self.tools.tool_names
|
||||
if item.startswith(f"mcp_{name}_")
|
||||
]
|
||||
result.append({
|
||||
"id": name,
|
||||
"name": name,
|
||||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||||
"url": getattr(cfg, "url", "") or None,
|
||||
"command": getattr(cfg, "command", "") or None,
|
||||
"args": list(getattr(cfg, "args", []) or []),
|
||||
"auth_mode": getattr(cfg, "auth_mode", "none") or "none",
|
||||
"auth_audience": getattr(cfg, "auth_audience", "") or None,
|
||||
"auth_scopes": [str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
|
||||
"headers": (
|
||||
{key: "***" for key in dict(getattr(cfg, "headers", {}) or {})}
|
||||
if sensitive
|
||||
else dict(getattr(cfg, "headers", {}) or {})
|
||||
),
|
||||
"env": (
|
||||
{key: "***" for key in dict(getattr(cfg, "env", {}) or {})}
|
||||
if sensitive
|
||||
else dict(getattr(cfg, "env", {}) or {})
|
||||
),
|
||||
"tool_timeout": int(getattr(cfg, "tool_timeout", 30)),
|
||||
"sensitive": sensitive,
|
||||
"enabled": True,
|
||||
"status": report.get("status", "disconnected"),
|
||||
"tool_count": int(report.get("tool_count", len(tool_names))),
|
||||
"tool_names": tool_names,
|
||||
"last_error": report.get("last_error"),
|
||||
})
|
||||
return result
|
||||
|
||||
def _set_tool_context(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
message_id: str | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""把当前请求的路由上下文写入各工具的默认目标。
|
||||
|
||||
设计目的:
|
||||
1. 工具调用参数里不一定每次都显式传 `channel/chat_id`;
|
||||
2. 通过这里预注入默认值,工具可自动回落到“当前会话”;
|
||||
3. 每条消息处理前都调用一次,避免沿用上一轮残留上下文。
|
||||
"""
|
||||
# message 工具:需要 channel/chat_id 才能发消息;
|
||||
# message_id 在支持线程回复/引用回复的渠道里可用于“回这条消息”。
|
||||
if message_tool := self.tools.get("message"):
|
||||
# ToolRegistry.get() 返回通用 Tool | None,
|
||||
# 用 isinstance 确认具体类型后再调用专有 set_context()。
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(channel, chat_id, message_id)
|
||||
|
||||
# 委派工具:后台任务完成后需要把结果回投到原会话,
|
||||
# 因此只需记住来源 channel/chat_id。
|
||||
for tool_name in ("spawn_subagent", "spawn_agent_team"):
|
||||
if delegation_tool := self.tools.get(tool_name):
|
||||
if isinstance(delegation_tool, DelegationTool):
|
||||
delegation_tool.set_context(channel, chat_id, announce_via_bus=self._running)
|
||||
|
||||
# cron 工具:创建任务时会把 deliver 目标写入任务 payload,
|
||||
# 后续定时触发时才能把结果送回同一会话。
|
||||
if cron_tool := self.tools.get("cron"):
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(channel, chat_id, session_key=session_key)
|
||||
|
||||
def _build_skills_loader(self):
|
||||
"""构造可感知 plugin skill 目录的 SkillsLoader。"""
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
return SkillsLoader(self.workspace, extra_dirs=self.plugins.get_skill_dirs())
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""去除模型输出中的 `<think>...</think>` 推理块。"""
|
||||
# 某些模型会把思考内容混入最终文本,这里统一做显示层清洗。
|
||||
if not text:
|
||||
return None
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""把工具调用格式化为简短提示,如 `web_search("query")`。"""
|
||||
def _fmt(tc):
|
||||
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""执行 agent 迭代循环。
|
||||
|
||||
返回:
|
||||
- final_content: 最终可回复文本(无则为 None)
|
||||
- tools_used: 本轮调用过的工具名列表
|
||||
- messages: 迭代结束后的完整消息数组(含 tool 结果)
|
||||
"""
|
||||
messages = initial_messages
|
||||
tools = tool_registry or self.tools
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
# 循环直到拿到最终回复,或达到最大迭代次数。
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# 每一轮都带上当前消息状态与工具定义,让模型决定是否继续调工具。
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# 进度回调用于 CLI/渠道侧实时展示:先输出正文片段,再输出工具提示。
|
||||
if on_progress:
|
||||
clean = self._strip_think(response.content)
|
||||
if clean:
|
||||
await on_progress(clean)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
# 把 assistant 的“工具调用意图”写入对话,再逐个执行工具。
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
# 无工具调用即视为本轮收敛,输出最终内容。
|
||||
final_content = self._strip_think(response.content)
|
||||
# 将最终 assistant 回复写入消息链,确保会话可持久化回放。
|
||||
# 对于空/None 内容,回退到原始 content(或空串)避免丢失一轮回复。
|
||||
persist_content = final_content if final_content is not None else (response.content or "")
|
||||
messages = self.context.add_assistant_message(
|
||||
messages,
|
||||
persist_content,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
# 兜底提示:防止模型反复调工具导致“无终止回复”。
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
# 将兜底回复也写入会话,避免刷新后看不到最终结论。
|
||||
messages = self.context.add_assistant_message(messages, final_content)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""启动常驻循环:持续消费入站消息并发布出站消息。"""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 用短超时轮询,便于 stop() 后快速退出循环。
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_inbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
# CLI 下若消息工具已代发,仍回一个空结束包通知“本轮结束”。
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {},
|
||||
))
|
||||
except Exception as e:
|
||||
# 单条消息失败不影响主循环存活。
|
||||
logger.error("Error processing message: {}", e)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=f"Sorry, I encountered an error: {str(e)}"
|
||||
))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""关闭 MCP 连接并释放退出栈。"""
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
# MCP SDK 在取消清理阶段可能抛出噪声异常,这里忽略即可。
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
|
||||
def stop(self) -> None:
|
||||
"""请求停止主循环。"""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""获取会话级归档锁;不存在则创建。"""
|
||||
lock = self._consolidation_locks.get(session_key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._consolidation_locks[session_key] = lock
|
||||
return lock
|
||||
|
||||
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
||||
"""在锁空闲时清理缓存,避免锁字典无限增长。"""
|
||||
if not lock.locked():
|
||||
self._consolidation_locks.pop(session_key, None)
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
execution_context: str | None = None,
|
||||
extra_tools: list[Tool] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""处理单条入站消息并返回出站消息(或 None)。"""
|
||||
# system 通道用于内部任务(如 cron/heartbeat),来源路由编码在 chat_id。
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"), session_key=key)
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
execution_context=execution_context,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# 内建斜杠命令:在进入模型前优先处理。
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
# `/new` 的语义是“开启新会话”,但在真正清空前要先做一次强制归档:
|
||||
# - 把尚未沉淀的消息写入 MEMORY/HISTORY;
|
||||
# - 若归档失败则直接返回,不执行清空,避免用户上下文丢失。
|
||||
|
||||
# 取会话级锁并标记 consolidating,防止与后台自动归档并发执行。
|
||||
# (同一会话同时归档可能导致重复写入或状态错乱)
|
||||
lock = self._get_consolidation_lock(session.key)
|
||||
self._consolidating.add(session.key)
|
||||
try:
|
||||
async with lock:
|
||||
# 只处理“未归档尾部”消息:
|
||||
# [0:last_consolidated] 视为已经落入长期记忆,
|
||||
# [last_consolidated:] 才是本次需要补归档的增量。
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if snapshot:
|
||||
# 用临时 Session 包装快照,再传给 consolidate:
|
||||
# 1) 不污染当前 live session 对象;
|
||||
# 2) 即便归档失败,也不会提前改动原会话结构。
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
# archive_all=True:对这个临时快照做“全量归档”,
|
||||
# 确保 /new 前的上下文尽可能完整地写入记忆文件。
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
# 归档过程任何异常都视为失败,保持原会话不动并给出明确提示。
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
finally:
|
||||
# 无论成功/失败都要撤销 in-progress 标记并清理空闲锁缓存,
|
||||
# 避免会话长期卡在 consolidating 状态。
|
||||
self._consolidating.discard(session.key)
|
||||
self._prune_consolidation_lock(session.key, lock)
|
||||
|
||||
# 走到这里说明归档已成功(或本就无增量可归档),才执行真正清空。
|
||||
session.clear()
|
||||
# clear 后立即落盘,保证重启后状态一致。
|
||||
self.sessions.save(session)
|
||||
# 使内存缓存失效,后续读取将基于磁盘中的“新空会话”重新构建。
|
||||
self.sessions.invalidate(session.key)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Boardware Genius commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||
|
||||
# 异步触发记忆归档:达到窗口阈值时在后台执行,不阻塞当前回复。
|
||||
unconsolidated = len(session.messages) - session.last_consolidated
|
||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||
self._consolidating.add(session.key)
|
||||
lock = self._get_consolidation_lock(session.key)
|
||||
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
finally:
|
||||
# 无论成功失败都要解注册状态,避免会话长期卡在 consolidating。
|
||||
self._consolidating.discard(session.key)
|
||||
self._prune_consolidation_lock(session.key, lock)
|
||||
_task = asyncio.current_task()
|
||||
if _task is not None:
|
||||
self._consolidation_tasks.discard(_task)
|
||||
|
||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||
self._consolidation_tasks.add(_task)
|
||||
|
||||
# 每轮处理前刷新工具上下文,并重置 message 工具的“本轮已发送”状态。
|
||||
self._set_tool_context(
|
||||
msg.channel,
|
||||
msg.chat_id,
|
||||
msg.metadata.get("message_id"),
|
||||
session_key=key,
|
||||
)
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
active_tools = self.tools
|
||||
if extra_tools:
|
||||
active_tools = self.tools.clone()
|
||||
for tool in extra_tools:
|
||||
active_tools.register(tool)
|
||||
|
||||
# 从会话中截取有限历史,避免上下文无限膨胀。
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
# 组装本轮发给模型的初始消息:
|
||||
# - history: 会话历史(已按窗口裁剪)
|
||||
# - current_message: 用户本轮输入
|
||||
# - media: 可选多模态附件(如图片)
|
||||
# - channel/chat_id: 当前会话路由信息(写入 system prompt 供工具决策)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
execution_context=execution_context,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
# `_bus_progress` 是“默认进度回调”:
|
||||
# - 当 _run_agent_loop 里出现中间文本/工具提示时被调用;
|
||||
# - 不走最终回复通道,而是作为“中间态事件”发到 outbound。
|
||||
#
|
||||
# 这样做的好处:
|
||||
# 1) CLI/渠道可以实时显示“正在做什么”,而不是一直静默等待;
|
||||
# 2) 进度消息与最终答复共用同一队列,但可通过 metadata 区分。
|
||||
meta = dict(msg.metadata or {})
|
||||
# `_progress=True`:标记这是进度事件,消费端可选择轻量渲染。
|
||||
meta["_progress"] = True
|
||||
# `_tool_hint=True`:标记这是工具调用提示(例如 web_search(...))。
|
||||
# 消费端可按配置独立开关(send_tool_hints)来显示/隐藏。
|
||||
meta["_tool_hint"] = tool_hint
|
||||
# 进度消息仍沿用原始 channel/chat_id,保证路由到当前会话。
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
# 执行核心 agent 迭代:
|
||||
# - 可能多轮“模型 -> 工具 -> 模型”
|
||||
# - on_progress 若外部未传,则默认走 `_bus_progress` 输出中间态
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
tool_registry=active_tools,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
# 极少数情况下模型未给出最终文本(例如异常边界),这里兜底避免空回复。
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
# 日志只打印预览,避免超长内容污染日志输出。
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
# 把本轮新增消息(assistant/tool/final)写回会话并持久化到磁盘。
|
||||
# `1 + len(history)` 用于跳过本轮前已存在的 system+history 部分。
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
# 去重保护:
|
||||
# 若本轮 agent 已通过 message 工具主动发过消息,
|
||||
# 再返回 OutboundMessage 会导致渠道侧“同内容重复发送”。
|
||||
# 因此返回 None,交给上层按“已发过”路径结束本轮。
|
||||
return None
|
||||
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 500
|
||||
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""保存本轮新增消息到会话,并截断过长工具输出。"""
|
||||
from datetime import datetime
|
||||
for m in messages[skip:]:
|
||||
# 不持久化 reasoning_content,避免会话文件冗长且混入思考文本。
|
||||
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
||||
if entry.get("role") == "tool" and isinstance(entry.get("content"), str):
|
||||
content = entry["content"]
|
||||
if len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
# 大工具结果只保留前缀,兼顾可读性与存储体积。
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||
"""调用 MemoryStore 做记忆归档;成功返回 True。"""
|
||||
return await MemoryStore(self.workspace).consolidate(
|
||||
session, self.provider, self.model,
|
||||
archive_all=archive_all, memory_window=self.memory_window,
|
||||
)
|
||||
|
||||
async def process_system_announcement(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
origin_channel: str,
|
||||
origin_chat_id: str,
|
||||
sender_id: str = "delegation",
|
||||
) -> str:
|
||||
"""在无常驻 run() 的场景下,本地处理一条 system 公告。"""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(
|
||||
channel="system",
|
||||
sender_id=sender_id,
|
||||
chat_id=f"{origin_channel}:{origin_chat_id}",
|
||||
content=content,
|
||||
)
|
||||
response = await self._process_message(msg)
|
||||
return response.content if response else ""
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
process_event_callback: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||
execution_context: str | None = None,
|
||||
extra_tools: list[Tool] | None = None,
|
||||
) -> str:
|
||||
"""直接处理一条消息(用于 CLI 单轮或 cron 触发)。"""
|
||||
# 直连模式不依赖 run() 主循环,但仍需确保 MCP 可用。
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
# process_event_sink 只在当前调用链内生效,因此不会污染其他并发请求。
|
||||
with process_event_sink(process_event_callback):
|
||||
response = await self._process_message(
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
# execution_context / extra_tools 主要服务于 cron 和其他系统触发场景。
|
||||
execution_context=execution_context,
|
||||
extra_tools=extra_tools,
|
||||
)
|
||||
return response.content if response else ""
|
||||
@ -1,582 +0,0 @@
|
||||
"""Marketplace manager for Boardware Genius — discover, install, and manage plugin marketplaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketplaceEntry:
|
||||
"""A registered marketplace source."""
|
||||
|
||||
name: str
|
||||
source: str
|
||||
type: str # "local" or "git"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketplacePluginInfo:
|
||||
"""A plugin available in a marketplace."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
source_path: str # Relative path inside the marketplace (e.g. "./claude-plugins/data-toolkit")
|
||||
marketplace_name: str
|
||||
installed: bool
|
||||
|
||||
|
||||
class MarketplaceManager:
|
||||
"""
|
||||
Manages plugin marketplaces: register/remove marketplace sources, discover
|
||||
available plugins, and install/uninstall them into ``~/.nanobot/plugins/``.
|
||||
|
||||
Marketplace sources can be local directories or git repositories. Each
|
||||
marketplace root must contain ``.claude-plugin/marketplace.json`` with the
|
||||
manifest listing available plugins.
|
||||
|
||||
Config is persisted in ``~/.nanobot/marketplaces.json``.
|
||||
Git repos are cached in ``~/.nanobot/marketplace-cache/<name>/``.
|
||||
Installed plugins land in ``~/.nanobot/plugins/<plugin-name>/``.
|
||||
"""
|
||||
|
||||
CONFIG_PATH = Path.home() / ".nanobot" / "marketplaces.json"
|
||||
CACHE_DIR = Path.home() / ".nanobot" / "marketplace-cache"
|
||||
PLUGINS_DIR = Path.home() / ".nanobot" / "plugins"
|
||||
|
||||
GIT_TIMEOUT = 60 # seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: Path | None = None,
|
||||
cache_dir: Path | None = None,
|
||||
plugins_dir: Path | None = None,
|
||||
):
|
||||
self.config_path = config_path or self.CONFIG_PATH
|
||||
self.cache_dir = cache_dir or self.CACHE_DIR
|
||||
self.plugins_dir = plugins_dir or self.PLUGINS_DIR
|
||||
|
||||
# ------------------------------------------------------------------ public
|
||||
|
||||
def list_marketplaces(self) -> list[MarketplaceEntry]:
|
||||
"""Return all registered marketplaces."""
|
||||
return self._load_config()
|
||||
|
||||
def add_marketplace(self, source: str) -> MarketplaceEntry:
|
||||
"""
|
||||
Register a new marketplace from a local path or git URL.
|
||||
|
||||
For git sources the repo is cloned (``--depth=1``) into the cache
|
||||
directory and the manifest is read to determine the marketplace name.
|
||||
For local sources the path must exist and contain a valid manifest.
|
||||
|
||||
Returns the created ``MarketplaceEntry``.
|
||||
|
||||
Raises ``ValueError`` on invalid source or duplicate name.
|
||||
"""
|
||||
source_type = self._detect_type(source)
|
||||
|
||||
if source_type == "git":
|
||||
entry = self._add_git_marketplace(source)
|
||||
else:
|
||||
entry = self._add_local_marketplace(source)
|
||||
|
||||
# Persist — update existing entry if one with the same name exists
|
||||
entries = self._load_config()
|
||||
replaced = False
|
||||
for i, existing in enumerate(entries):
|
||||
if existing.name == entry.name:
|
||||
logger.info(
|
||||
"Updating existing marketplace '{}' (old source: {} → new source: {})",
|
||||
entry.name,
|
||||
existing.source,
|
||||
entry.source,
|
||||
)
|
||||
entries[i] = entry
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
entries.append(entry)
|
||||
self._save_config(entries)
|
||||
logger.info("Registered marketplace '{}' from {}", entry.name, entry.source)
|
||||
return entry
|
||||
|
||||
def remove_marketplace(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a marketplace by name.
|
||||
|
||||
If the marketplace was cloned from git, the cached clone is also deleted.
|
||||
|
||||
Raises ``ValueError`` if the marketplace is not found.
|
||||
"""
|
||||
entries = self._load_config()
|
||||
entry = self._find_entry(entries, name)
|
||||
|
||||
# Clean up git cache if applicable
|
||||
cache_path = self.cache_dir / name
|
||||
if cache_path.exists():
|
||||
shutil.rmtree(cache_path)
|
||||
logger.debug("Removed cached clone at {}", cache_path)
|
||||
|
||||
entries = [e for e in entries if e.name != name]
|
||||
self._save_config(entries)
|
||||
logger.info("Removed marketplace '{}'", name)
|
||||
|
||||
def list_available_plugins(
|
||||
self, marketplace_name: str
|
||||
) -> list[MarketplacePluginInfo]:
|
||||
"""
|
||||
List all plugins offered by a registered marketplace.
|
||||
|
||||
For git marketplaces the cached clone is updated (``git pull --ff-only``)
|
||||
before reading the manifest.
|
||||
|
||||
Raises ``ValueError`` if the marketplace is not found or the manifest
|
||||
is missing/invalid.
|
||||
"""
|
||||
entries = self._load_config()
|
||||
entry = self._find_entry(entries, marketplace_name)
|
||||
root = self._resolve_root(entry)
|
||||
manifest = self._read_manifest(root, entry.name)
|
||||
|
||||
installed_names = self._installed_plugin_names()
|
||||
|
||||
plugins: list[MarketplacePluginInfo] = []
|
||||
for p in manifest.get("plugins", []):
|
||||
pname = p.get("name", "")
|
||||
if not pname:
|
||||
continue
|
||||
# Skip plugins whose names would be unsafe as directory names
|
||||
try:
|
||||
self._validate_name(pname, "plugin name")
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Skipping plugin with unsafe name '{}' in marketplace '{}'",
|
||||
pname,
|
||||
marketplace_name,
|
||||
)
|
||||
continue
|
||||
plugins.append(
|
||||
MarketplacePluginInfo(
|
||||
name=pname,
|
||||
description=p.get("description", ""),
|
||||
source_path=p.get("source", ""),
|
||||
marketplace_name=entry.name,
|
||||
installed=pname in installed_names,
|
||||
)
|
||||
)
|
||||
return plugins
|
||||
|
||||
def install_plugin(self, marketplace_name: str, plugin_name: str) -> Path:
|
||||
"""
|
||||
Install a plugin from a marketplace into ``~/.nanobot/plugins/``.
|
||||
|
||||
The plugin directory is copied (not symlinked) so it works even if the
|
||||
marketplace source is later removed.
|
||||
|
||||
Returns the ``Path`` to the installed plugin directory.
|
||||
|
||||
Raises ``ValueError`` if the marketplace or plugin is not found, or if
|
||||
the plugin source directory does not exist.
|
||||
"""
|
||||
self._validate_name(plugin_name, "plugin name")
|
||||
|
||||
entries = self._load_config()
|
||||
entry = self._find_entry(entries, marketplace_name)
|
||||
root = self._resolve_root(entry)
|
||||
manifest = self._read_manifest(root, entry.name)
|
||||
|
||||
plugin_meta = self._find_plugin_in_manifest(manifest, plugin_name, entry.name)
|
||||
source_rel = plugin_meta.get("source", "")
|
||||
source_dir = (root / source_rel).resolve()
|
||||
root_resolved = root.resolve()
|
||||
|
||||
# Guard against path traversal — source_dir must be inside the marketplace root
|
||||
if not str(source_dir).startswith(str(root_resolved)):
|
||||
raise ValueError(
|
||||
f"Plugin source '{source_rel}' resolves outside the marketplace "
|
||||
f"root ({root_resolved}). This looks like a path traversal attempt."
|
||||
)
|
||||
|
||||
if not source_dir.is_dir():
|
||||
raise ValueError(
|
||||
f"Plugin source directory does not exist: {source_dir}"
|
||||
)
|
||||
|
||||
dest = self.plugins_dir / plugin_name
|
||||
if dest.exists():
|
||||
logger.debug("Removing existing plugin dir at {}", dest)
|
||||
shutil.rmtree(dest)
|
||||
|
||||
self.plugins_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copytree(source_dir, dest)
|
||||
logger.info(
|
||||
"Installed plugin '{}' from marketplace '{}' → {}",
|
||||
plugin_name,
|
||||
entry.name,
|
||||
dest,
|
||||
)
|
||||
return dest
|
||||
|
||||
def update_marketplace(self, name: str) -> MarketplaceEntry:
|
||||
"""
|
||||
Update a marketplace's cached data.
|
||||
|
||||
For git marketplaces: clones if cache is missing, pulls if it exists.
|
||||
For local marketplaces: validates the path still exists.
|
||||
|
||||
Returns the ``MarketplaceEntry``.
|
||||
|
||||
Raises ``ValueError`` if the marketplace is not registered or the
|
||||
update fails.
|
||||
"""
|
||||
entries = self._load_config()
|
||||
entry = self._find_entry(entries, name)
|
||||
|
||||
if entry.type == "git":
|
||||
cache_path = self.cache_dir / name
|
||||
if not cache_path.exists():
|
||||
# Cache missing (e.g. fresh Docker container) — clone
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "clone", "--depth=1", entry.source, str(cache_path)],
|
||||
capture_output=True,
|
||||
timeout=self.GIT_TIMEOUT,
|
||||
check=True,
|
||||
)
|
||||
logger.info(
|
||||
"Cloned marketplace '{}' from {}", name, entry.source
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
stderr = (
|
||||
e.stderr.decode(errors="replace").strip()
|
||||
if e.stderr
|
||||
else ""
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to clone marketplace '{name}': {stderr}"
|
||||
) from e
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise ValueError(
|
||||
f"Git clone timed out after {self.GIT_TIMEOUT}s "
|
||||
f"for marketplace '{name}'"
|
||||
) from e
|
||||
else:
|
||||
# Cache exists — pull latest
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "pull", "--ff-only"],
|
||||
cwd=cache_path,
|
||||
capture_output=True,
|
||||
timeout=self.GIT_TIMEOUT,
|
||||
check=True,
|
||||
)
|
||||
logger.info(
|
||||
"Updated marketplace '{}' from {}", name, entry.source
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
stderr = (
|
||||
e.stderr.decode(errors="replace").strip()
|
||||
if e.stderr
|
||||
else ""
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to update marketplace '{name}': {stderr}"
|
||||
) from e
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise ValueError(
|
||||
f"Git pull timed out after {self.GIT_TIMEOUT}s "
|
||||
f"for marketplace '{name}'"
|
||||
) from e
|
||||
else:
|
||||
# Local marketplace — just verify path still exists
|
||||
path = Path(entry.source).expanduser().resolve()
|
||||
if not path.is_dir():
|
||||
raise ValueError(
|
||||
f"Local marketplace directory no longer exists: {path}"
|
||||
)
|
||||
logger.debug("Local marketplace '{}' verified at {}", name, path)
|
||||
|
||||
return entry
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str) -> None:
|
||||
"""
|
||||
Remove an installed plugin from ``~/.nanobot/plugins/``.
|
||||
|
||||
Raises ``ValueError`` if the plugin directory does not exist.
|
||||
"""
|
||||
dest = self.plugins_dir / plugin_name
|
||||
if not dest.exists():
|
||||
raise ValueError(
|
||||
f"Plugin '{plugin_name}' is not installed (expected at {dest})"
|
||||
)
|
||||
shutil.rmtree(dest)
|
||||
logger.info("Uninstalled plugin '{}'", plugin_name)
|
||||
|
||||
# ------------------------------------------------------------------ config
|
||||
|
||||
def _load_config(self) -> list[MarketplaceEntry]:
|
||||
"""Load the marketplaces config file. Returns empty list on missing/corrupt file."""
|
||||
if not self.config_path.exists():
|
||||
return []
|
||||
try:
|
||||
raw = json.loads(self.config_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw, list):
|
||||
logger.warning(
|
||||
"marketplaces.json is not a list, resetting to empty"
|
||||
)
|
||||
return []
|
||||
return [
|
||||
MarketplaceEntry(
|
||||
name=item["name"],
|
||||
source=item["source"],
|
||||
type=item["type"],
|
||||
)
|
||||
for item in raw
|
||||
if isinstance(item, dict) and "name" in item and "source" in item and "type" in item
|
||||
]
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to read marketplaces.json: {}", e)
|
||||
return []
|
||||
|
||||
def _save_config(self, entries: list[MarketplaceEntry]) -> None:
|
||||
"""Persist the marketplaces list to disk."""
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = [asdict(e) for e in entries]
|
||||
self.config_path.write_text(
|
||||
json.dumps(data, indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ helpers
|
||||
|
||||
@staticmethod
|
||||
def _validate_name(name: str, label: str = "name") -> None:
|
||||
"""Reject names that could cause path traversal when used in filesystem paths.
|
||||
|
||||
Raises ``ValueError`` if *name* contains ``/``, ``\\``, or is ``.`` / `..``.
|
||||
"""
|
||||
if "/" in name or "\\" in name or name in (".", ".."):
|
||||
raise ValueError(
|
||||
f"Invalid {label} '{name}': must not contain path separators "
|
||||
f"or be '.' / '..'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _detect_type(source: str) -> str:
|
||||
"""Determine whether a source string is a git URL or a local path."""
|
||||
if (
|
||||
source.startswith("http://")
|
||||
or source.startswith("https://")
|
||||
or source.startswith("ssh://")
|
||||
or source.startswith("git://")
|
||||
or source.startswith("git@")
|
||||
or source.endswith(".git")
|
||||
):
|
||||
return "git"
|
||||
return "local"
|
||||
|
||||
def _find_entry(
|
||||
self, entries: list[MarketplaceEntry], name: str
|
||||
) -> MarketplaceEntry:
|
||||
"""Lookup a marketplace entry by name or raise ValueError."""
|
||||
for entry in entries:
|
||||
if entry.name == name:
|
||||
return entry
|
||||
raise ValueError(
|
||||
f"Marketplace '{name}' is not registered. "
|
||||
f"Use add_marketplace() first."
|
||||
)
|
||||
|
||||
def _resolve_root(self, entry: MarketplaceEntry) -> Path:
|
||||
"""
|
||||
Return the filesystem root of a marketplace.
|
||||
|
||||
For local marketplaces this is the source path directly.
|
||||
For git marketplaces this is the cached clone, updated with
|
||||
``git pull --ff-only`` before returning.
|
||||
"""
|
||||
if entry.type == "git":
|
||||
cache_path = self.cache_dir / entry.name
|
||||
if not cache_path.exists():
|
||||
raise ValueError(
|
||||
f"Git cache for marketplace '{entry.name}' not found at "
|
||||
f"{cache_path}. Try removing and re-adding the marketplace."
|
||||
)
|
||||
# Update the cached clone
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "pull", "--ff-only"],
|
||||
cwd=cache_path,
|
||||
capture_output=True,
|
||||
timeout=self.GIT_TIMEOUT,
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Updated git cache for '{}'", entry.name)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(
|
||||
"git pull failed for '{}': {}",
|
||||
entry.name,
|
||||
e.stderr.decode(errors="replace").strip() if e.stderr else str(e),
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("git pull timed out for '{}'", entry.name)
|
||||
return cache_path
|
||||
else:
|
||||
path = Path(entry.source).expanduser().resolve()
|
||||
if not path.is_dir():
|
||||
raise ValueError(
|
||||
f"Local marketplace directory does not exist: {path}"
|
||||
)
|
||||
return path
|
||||
|
||||
def _read_manifest(self, root: Path, marketplace_name: str) -> dict:
|
||||
"""Read marketplace manifest, or auto-discover plugins if no manifest exists.
|
||||
|
||||
Looks for ``.claude-plugin/marketplace.json`` first. If that file is
|
||||
missing, falls back to scanning ``claude-plugins/`` for subdirectories
|
||||
that contain a ``plugin.json`` or ``.claude-plugin/plugin.json``.
|
||||
"""
|
||||
manifest_path = root / ".claude-plugin" / "marketplace.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
data = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse marketplace manifest at {manifest_path}: {e}"
|
||||
) from e
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(
|
||||
f"Marketplace manifest at {manifest_path} must be a JSON object"
|
||||
)
|
||||
if "plugins" not in data or not isinstance(data["plugins"], list):
|
||||
raise ValueError(
|
||||
f"Marketplace manifest at {manifest_path} missing 'plugins' array"
|
||||
)
|
||||
return data
|
||||
|
||||
# Fallback: auto-discover plugins under claude-plugins/
|
||||
return self._auto_discover_plugins(root, marketplace_name)
|
||||
|
||||
def _auto_discover_plugins(self, root: Path, marketplace_name: str) -> dict:
|
||||
"""Scan ``claude-plugins/`` for plugin directories and build a manifest."""
|
||||
plugins_dir = root / "claude-plugins"
|
||||
if not plugins_dir.is_dir():
|
||||
raise ValueError(
|
||||
f"Marketplace at {root} has no .claude-plugin/marketplace.json "
|
||||
f"and no claude-plugins/ directory to scan."
|
||||
)
|
||||
|
||||
plugins: list[dict] = []
|
||||
for plugin_dir in sorted(plugins_dir.iterdir()):
|
||||
if not plugin_dir.is_dir():
|
||||
continue
|
||||
# Read plugin metadata
|
||||
name = plugin_dir.name
|
||||
description = ""
|
||||
for candidate in (plugin_dir / "plugin.json", plugin_dir / ".claude-plugin" / "plugin.json"):
|
||||
if candidate.exists():
|
||||
try:
|
||||
meta = json.loads(candidate.read_text(encoding="utf-8"))
|
||||
name = meta.get("name", name)
|
||||
description = meta.get("description", "")
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
break
|
||||
plugins.append({
|
||||
"name": name,
|
||||
"source": f"./claude-plugins/{plugin_dir.name}",
|
||||
"description": description,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Auto-discovered {} plugins in marketplace '{}' (no manifest file)",
|
||||
len(plugins), marketplace_name,
|
||||
)
|
||||
return {"name": marketplace_name, "plugins": plugins}
|
||||
|
||||
@staticmethod
|
||||
def _find_plugin_in_manifest(
|
||||
manifest: dict, plugin_name: str, marketplace_name: str
|
||||
) -> dict:
|
||||
"""Find a plugin entry by name in a marketplace manifest."""
|
||||
for p in manifest.get("plugins", []):
|
||||
if p.get("name") == plugin_name:
|
||||
return p
|
||||
raise ValueError(
|
||||
f"Plugin '{plugin_name}' not found in marketplace '{marketplace_name}'. "
|
||||
f"Available: {[p.get('name') for p in manifest.get('plugins', [])]}"
|
||||
)
|
||||
|
||||
def _installed_plugin_names(self) -> set[str]:
|
||||
"""Return the set of currently installed plugin directory names."""
|
||||
if not self.plugins_dir.exists():
|
||||
return set()
|
||||
return {d.name for d in self.plugins_dir.iterdir() if d.is_dir()}
|
||||
|
||||
# ------------------------------------------------------------------ git
|
||||
|
||||
def _add_git_marketplace(self, source: str) -> MarketplaceEntry:
|
||||
"""Clone a git URL, read the manifest to get the name, move to cache."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp) / "repo"
|
||||
logger.debug("Cloning {} into temp dir", source)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "clone", "--depth=1", source, str(tmp_path)],
|
||||
capture_output=True,
|
||||
timeout=self.GIT_TIMEOUT,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
stderr = e.stderr.decode(errors="replace").strip() if e.stderr else ""
|
||||
raise ValueError(
|
||||
f"Failed to clone git repository '{source}': {stderr}"
|
||||
) from e
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise ValueError(
|
||||
f"Git clone timed out after {self.GIT_TIMEOUT}s for '{source}'"
|
||||
) from e
|
||||
|
||||
# Derive a fallback name from the git URL (e.g. "my-marketplace" from ".../my-marketplace.git")
|
||||
fallback_name = source.rstrip("/").rsplit("/", 1)[-1].removesuffix(".git") or "unknown"
|
||||
manifest = self._read_manifest(tmp_path, fallback_name)
|
||||
name = manifest.get("name")
|
||||
if not name or not isinstance(name, str):
|
||||
name = fallback_name
|
||||
self._validate_name(name, "marketplace name")
|
||||
|
||||
# Move to permanent cache location
|
||||
cache_path = self.cache_dir / name
|
||||
if cache_path.exists():
|
||||
shutil.rmtree(cache_path)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(tmp_path), str(cache_path))
|
||||
logger.debug("Cached git marketplace '{}' at {}", name, cache_path)
|
||||
|
||||
return MarketplaceEntry(name=name, source=source, type="git")
|
||||
|
||||
def _add_local_marketplace(self, source: str) -> MarketplaceEntry:
|
||||
"""Register a local directory as a marketplace source."""
|
||||
path = Path(source).expanduser().resolve()
|
||||
if not path.is_dir():
|
||||
raise ValueError(
|
||||
f"Local marketplace path does not exist or is not a directory: {path}"
|
||||
)
|
||||
|
||||
fallback_name = path.name
|
||||
manifest = self._read_manifest(path, fallback_name)
|
||||
name = manifest.get("name")
|
||||
if not name or not isinstance(name, str):
|
||||
name = fallback_name
|
||||
self._validate_name(name, "marketplace name")
|
||||
|
||||
return MarketplaceEntry(name=name, source=str(path), type="local")
|
||||
@ -1,143 +0,0 @@
|
||||
"""Memory system for persistent agent memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
_SAVE_MEMORY_TOOL = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_memory",
|
||||
"description": "Save the memory consolidation result to persistent storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"history_entry": {
|
||||
"type": "string",
|
||||
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||
},
|
||||
"memory_update": {
|
||||
"type": "string",
|
||||
"description": "Full updated long-term memory as markdown. Include all existing "
|
||||
"facts plus new ones. Return unchanged if nothing new.",
|
||||
},
|
||||
},
|
||||
"required": ["history_entry", "memory_update"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.memory_dir = ensure_dir(workspace / "memory")
|
||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def read_long_term(self) -> str:
|
||||
if self.memory_file.exists():
|
||||
return self.memory_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def write_long_term(self, content: str) -> None:
|
||||
self.memory_file.write_text(content, encoding="utf-8")
|
||||
|
||||
def append_history(self, entry: str) -> None:
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry.rstrip() + "\n\n")
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
async def consolidate(
|
||||
self,
|
||||
session: Session,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
*,
|
||||
archive_all: bool = False,
|
||||
memory_window: int = 50,
|
||||
) -> bool:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||
|
||||
Returns True on success (including no-op), False on failure.
|
||||
"""
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
||||
else:
|
||||
keep_count = memory_window // 2
|
||||
if len(session.messages) <= keep_count:
|
||||
return True
|
||||
if len(session.messages) - session.last_consolidated <= 0:
|
||||
return True
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return True
|
||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||
|
||||
lines = []
|
||||
for m in old_messages:
|
||||
if not m.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||
|
||||
current_memory = self.read_long_term()
|
||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||
|
||||
## Current Long-term Memory
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{chr(10).join(lines)}"""
|
||||
|
||||
try:
|
||||
response = await provider.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||
return False
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
if entry := args.get("history_entry"):
|
||||
if not isinstance(entry, str):
|
||||
entry = json.dumps(entry, ensure_ascii=False)
|
||||
self.append_history(entry)
|
||||
if update := args.get("memory_update"):
|
||||
if not isinstance(update, str):
|
||||
update = json.dumps(update, ensure_ascii=False)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Memory consolidation failed")
|
||||
return False
|
||||
@ -1,291 +0,0 @@
|
||||
"""Plugin system for Boardware Genius - load agents, commands, and skills from plugin directories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginAgent:
|
||||
name: str
|
||||
description: str
|
||||
model: str | None
|
||||
system_prompt: str
|
||||
plugin_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginCommand:
|
||||
name: str
|
||||
description: str
|
||||
argument_hint: str | None
|
||||
content: str # Raw body with $ARGUMENTS placeholder
|
||||
plugin_name: str
|
||||
|
||||
def expand(self, arguments: str) -> str:
|
||||
return self.content.replace("$ARGUMENTS", arguments.strip())
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
name: str
|
||||
description: str
|
||||
source: str # "global" or "workspace"
|
||||
agents: dict[str, PluginAgent] = field(default_factory=dict)
|
||||
commands: dict[str, PluginCommand] = field(default_factory=dict)
|
||||
skill_dirs: list[Path] = field(default_factory=list)
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""
|
||||
Loads plugins from global and workspace plugin directories.
|
||||
|
||||
Search paths (workspace takes priority over global):
|
||||
- Global: ~/.nanobot/plugins/<plugin-name>/
|
||||
- Workspace: <workspace>/plugins/<plugin-name>/
|
||||
|
||||
Each plugin directory may contain:
|
||||
- plugin.json — manifest with name/description
|
||||
- agents/<name>.md — agent definitions (frontmatter + system prompt)
|
||||
- commands/<name>.md — slash command definitions (frontmatter + content)
|
||||
- skills/<name>/SKILL.md — skill files exposed to SkillsLoader
|
||||
"""
|
||||
|
||||
GLOBAL_DIR = Path.home() / ".nanobot" / "plugins"
|
||||
|
||||
def __init__(self, workspace: Path, global_dir: Path | None = None):
|
||||
self.workspace = workspace
|
||||
self.global_dir = global_dir or self.GLOBAL_DIR
|
||||
self.workspace_dir = workspace / "plugins"
|
||||
self._plugins: dict[str, Plugin] | None = None
|
||||
|
||||
@property
|
||||
def plugins(self) -> dict[str, Plugin]:
|
||||
if self._plugins is None:
|
||||
self._plugins = self._load_all()
|
||||
return self._plugins
|
||||
|
||||
def find_command(self, cmd_name: str) -> PluginCommand | None:
|
||||
"""Find a command by name. Workspace plugins take priority over global."""
|
||||
for plugin in self.plugins.values():
|
||||
if plugin.source == "workspace" and cmd_name in plugin.commands:
|
||||
return plugin.commands[cmd_name]
|
||||
for plugin in self.plugins.values():
|
||||
if plugin.source == "global" and cmd_name in plugin.commands:
|
||||
return plugin.commands[cmd_name]
|
||||
return None
|
||||
|
||||
def find_agent(self, agent_name: str) -> PluginAgent | None:
|
||||
"""Find an agent by name. Workspace plugins take priority over global."""
|
||||
for plugin in self.plugins.values():
|
||||
if plugin.source == "workspace" and agent_name in plugin.agents:
|
||||
return plugin.agents[agent_name]
|
||||
for plugin in self.plugins.values():
|
||||
if plugin.source == "global" and agent_name in plugin.agents:
|
||||
return plugin.agents[agent_name]
|
||||
return None
|
||||
|
||||
def get_skill_dirs(self) -> list[Path]:
|
||||
"""Return all skill root directories contributed by plugins."""
|
||||
dirs = []
|
||||
for plugin in self.plugins.values():
|
||||
dirs.extend(plugin.skill_dirs)
|
||||
return dirs
|
||||
|
||||
def build_agents_summary(self) -> str:
|
||||
"""Build an XML summary of all plugin agents for the system prompt."""
|
||||
agents = []
|
||||
for plugin in self.plugins.values():
|
||||
agents.extend(plugin.agents.values())
|
||||
if not agents:
|
||||
return ""
|
||||
|
||||
def esc(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<agents>"]
|
||||
for agent in agents:
|
||||
lines.append(" <agent>")
|
||||
lines.append(f" <name>{esc(agent.name)}</name>")
|
||||
lines.append(f" <plugin>{esc(agent.plugin_name)}</plugin>")
|
||||
lines.append(f" <description>{esc(agent.description)}</description>")
|
||||
if agent.model:
|
||||
lines.append(f" <model>{esc(agent.model)}</model>")
|
||||
lines.append(" </agent>")
|
||||
lines.append("</agents>")
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_commands_summary(self) -> str:
|
||||
"""Build an XML summary of all plugin commands for the system prompt."""
|
||||
commands = []
|
||||
for plugin in self.plugins.values():
|
||||
commands.extend(plugin.commands.values())
|
||||
if not commands:
|
||||
return ""
|
||||
|
||||
def esc(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<commands>"]
|
||||
for cmd in commands:
|
||||
lines.append(" <command>")
|
||||
lines.append(f" <name>/{esc(cmd.name)}</name>")
|
||||
lines.append(f" <plugin>{esc(cmd.plugin_name)}</plugin>")
|
||||
lines.append(f" <description>{esc(cmd.description)}</description>")
|
||||
if cmd.argument_hint:
|
||||
lines.append(f" <argument-hint>{esc(cmd.argument_hint)}</argument-hint>")
|
||||
lines.append(" </command>")
|
||||
lines.append("</commands>")
|
||||
return "\n".join(lines)
|
||||
|
||||
# ------------------------------------------------------------------ private
|
||||
|
||||
def _load_all(self) -> dict[str, Plugin]:
|
||||
"""Load all plugins from global then workspace (workspace wins)."""
|
||||
plugins: dict[str, Plugin] = {}
|
||||
|
||||
if self.global_dir.exists():
|
||||
for plugin_dir in sorted(self.global_dir.iterdir()):
|
||||
if plugin_dir.is_dir():
|
||||
plugin = self._load_plugin(plugin_dir, "global")
|
||||
if plugin:
|
||||
plugins[plugin.name] = plugin
|
||||
logger.debug("Loaded global plugin: {}", plugin.name)
|
||||
|
||||
if self.workspace_dir.exists():
|
||||
for plugin_dir in sorted(self.workspace_dir.iterdir()):
|
||||
if plugin_dir.is_dir():
|
||||
plugin = self._load_plugin(plugin_dir, "workspace")
|
||||
if plugin:
|
||||
plugins[plugin.name] = plugin # override global
|
||||
logger.debug("Loaded workspace plugin: {}", plugin.name)
|
||||
|
||||
return plugins
|
||||
|
||||
def _load_plugin(self, plugin_dir: Path, source: str) -> Plugin | None:
|
||||
"""Load a single plugin from a directory."""
|
||||
try:
|
||||
name = plugin_dir.name
|
||||
description = ""
|
||||
|
||||
# Look for plugin.json at root, then fall back to .claude-plugin/plugin.json
|
||||
# so that Claude Code plugin repos work without copying files.
|
||||
manifest_file = plugin_dir / "plugin.json"
|
||||
if not manifest_file.exists():
|
||||
manifest_file = plugin_dir / ".claude-plugin" / "plugin.json"
|
||||
if manifest_file.exists():
|
||||
try:
|
||||
manifest = json.loads(manifest_file.read_text(encoding="utf-8"))
|
||||
name = manifest.get("name", name)
|
||||
description = manifest.get("description", "")
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to parse plugin.json in {}: {}", plugin_dir, e)
|
||||
|
||||
agents_dir = plugin_dir / "agents"
|
||||
agents = self._load_agents(agents_dir, name) if agents_dir.exists() else {}
|
||||
|
||||
commands_dir = plugin_dir / "commands"
|
||||
commands = self._load_commands(commands_dir, name) if commands_dir.exists() else {}
|
||||
|
||||
skills_dir = plugin_dir / "skills"
|
||||
skill_dirs = [skills_dir] if skills_dir.exists() else []
|
||||
|
||||
return Plugin(
|
||||
name=name,
|
||||
description=description,
|
||||
source=source,
|
||||
agents=agents,
|
||||
commands=commands,
|
||||
skill_dirs=skill_dirs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load plugin from {}: {}", plugin_dir, e)
|
||||
return None
|
||||
|
||||
def _load_agents(self, agents_dir: Path, plugin_name: str) -> dict[str, PluginAgent]:
|
||||
"""Load agent .md files from a directory."""
|
||||
agents: dict[str, PluginAgent] = {}
|
||||
for md_file in sorted(agents_dir.glob("*.md")):
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
meta, body = self._parse_frontmatter(content)
|
||||
name = meta.get("name", md_file.stem)
|
||||
description = meta.get("description", "")
|
||||
model = meta.get("model") or None
|
||||
agents[name] = PluginAgent(
|
||||
name=name,
|
||||
description=description,
|
||||
model=model,
|
||||
system_prompt=body,
|
||||
plugin_name=plugin_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load agent {}: {}", md_file, e)
|
||||
return agents
|
||||
|
||||
def _load_commands(self, commands_dir: Path, plugin_name: str) -> dict[str, PluginCommand]:
|
||||
"""Load command .md files from a directory."""
|
||||
commands: dict[str, PluginCommand] = {}
|
||||
for md_file in sorted(commands_dir.glob("*.md")):
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
meta, body = self._parse_frontmatter(content)
|
||||
name = md_file.stem
|
||||
description = meta.get("description", "")
|
||||
argument_hint = meta.get("argument-hint") or None
|
||||
commands[name] = PluginCommand(
|
||||
name=name,
|
||||
description=description,
|
||||
argument_hint=argument_hint,
|
||||
content=body,
|
||||
plugin_name=plugin_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load command {}: {}", md_file, e)
|
||||
return commands
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> tuple[dict[str, str], str]:
|
||||
"""
|
||||
Parse YAML frontmatter delimited by ``---`` lines.
|
||||
|
||||
Returns (meta_dict, body). Supports simple ``key: value`` pairs and
|
||||
block scalars (``key: |``). Does not require PyYAML.
|
||||
"""
|
||||
if not content.startswith("---"):
|
||||
return {}, content
|
||||
|
||||
match = re.match(r"^---\n(.*?)\n---\n?", content, re.DOTALL)
|
||||
if not match:
|
||||
return {}, content
|
||||
|
||||
raw = match.group(1)
|
||||
body = content[match.end():].strip()
|
||||
|
||||
meta: dict[str, str] = {}
|
||||
lines = raw.split("\n")
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
if ":" in line and not line.startswith((" ", "\t")):
|
||||
key, _, value = line.partition(":")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if value == "|":
|
||||
# Block scalar: collect following indented lines
|
||||
block_lines: list[str] = []
|
||||
i += 1
|
||||
while i < len(lines) and (lines[i].startswith(" ") or lines[i] == ""):
|
||||
block_lines.append(lines[i][2:] if lines[i].startswith(" ") else "")
|
||||
i += 1
|
||||
meta[key] = "\n".join(block_lines).strip()
|
||||
continue
|
||||
else:
|
||||
meta[key] = value.strip("\"'")
|
||||
i += 1
|
||||
|
||||
return meta, body
|
||||
@ -1,84 +0,0 @@
|
||||
"""结构化过程事件辅助工具。
|
||||
|
||||
这个模块的作用是把“运行中的中间状态”从底层执行逻辑安全地带到上层 UI:
|
||||
1. 用 `ContextVar` 记录当前异步上下文是否挂了事件 sink;
|
||||
2. 用单独的 run_id 上下文把父子流程串起来;
|
||||
3. 让委派、MCP、A2A 等模块只管发事件,不需要知道 WebSocket/SSE 细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
ProcessEvent = dict[str, Any]
|
||||
ProcessEventSink = Callable[[ProcessEvent], Awaitable[None]]
|
||||
|
||||
# `_sink_var` 保存“当前异步上下文的事件接收器”。
|
||||
# 这样可以避免把回调一层层显式往下传,同时又不会污染并发请求之间的上下文。
|
||||
_sink_var: ContextVar[ProcessEventSink | None] = ContextVar("process_event_sink", default=None)
|
||||
# `_run_id_var` 保存“当前流程的父 run_id”。
|
||||
# 子流程发事件时可以把它挂到 `parent_run_id`,供前端拼接树状执行视图。
|
||||
_run_id_var: ContextVar[str | None] = ContextVar("process_current_run_id", default=None)
|
||||
|
||||
|
||||
def new_run_id(prefix: str = "run") -> str:
|
||||
"""生成一个短且可读的运行 ID。"""
|
||||
# 只截取 8 位十六进制是为了兼顾:
|
||||
# 1. 日志 / WebSocket 里更短、更容易肉眼追踪;
|
||||
# 2. 同一进程内短期冲突概率仍足够低。
|
||||
return f"{prefix}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
"""返回带 `Z` 后缀的 UTC ISO8601 时间戳。"""
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def process_event_sink(sink: ProcessEventSink | None):
|
||||
"""为当前异步上下文临时绑定一个事件 sink。"""
|
||||
# `ContextVar.set()` 会返回 token,退出时要 reset,避免泄漏到后续请求。
|
||||
token = _sink_var.set(sink)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_sink_var.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def process_run_context(run_id: str | None):
|
||||
"""为当前异步上下文绑定一个逻辑父 run_id。"""
|
||||
token = _run_id_var.set(run_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_run_id_var.reset(token)
|
||||
|
||||
|
||||
def current_process_run_id() -> str | None:
|
||||
"""读取当前上下文里绑定的 run_id。"""
|
||||
return _run_id_var.get()
|
||||
|
||||
|
||||
def has_process_event_sink() -> bool:
|
||||
"""判断当前上下文是否具备过程事件接收能力。"""
|
||||
return _sink_var.get() is not None
|
||||
|
||||
|
||||
async def emit_process_event(event_type: str, **payload: Any) -> None:
|
||||
"""在存在 sink 时发出一个结构化过程事件。"""
|
||||
sink = _sink_var.get()
|
||||
# 没有 sink 说明当前调用链不关心中间态,例如纯 CLI 单轮场景,直接静默跳过。
|
||||
if sink is None:
|
||||
return
|
||||
# `created_at` 允许调用方覆盖;未传时统一补 UTC 时间,方便前端排序。
|
||||
event: ProcessEvent = {
|
||||
"type": event_type,
|
||||
"created_at": payload.pop("created_at", utc_now_iso()),
|
||||
**payload,
|
||||
}
|
||||
await sink(event)
|
||||
@ -1,58 +0,0 @@
|
||||
"""委派执行结果的共享类型定义。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
_PLACEHOLDER_SUMMARY_MARKERS = (
|
||||
"task completed but no final response was generated",
|
||||
"no final response was generated",
|
||||
"已启动代理团队",
|
||||
"代理团队正在后台工作",
|
||||
"agent team [",
|
||||
"spawn_agent_team",
|
||||
"error calling llm",
|
||||
"litellm.timeout",
|
||||
"dashscopeexception",
|
||||
"service temporarily unavailable",
|
||||
"planner调用失败",
|
||||
"本任务当前不可执行",
|
||||
"无法由单一非sop工具完成",
|
||||
)
|
||||
|
||||
|
||||
def normalize_summary_text(text: str | None) -> str:
|
||||
"""把摘要文本压成便于判定的稳定形式。"""
|
||||
return " ".join(str(text or "").strip().split())
|
||||
|
||||
|
||||
def contains_placeholder_summary(text: str | None) -> bool:
|
||||
"""判断摘要是否只是占位兜底文本。"""
|
||||
normalized = normalize_summary_text(text).lower()
|
||||
if not normalized:
|
||||
return True
|
||||
return any(marker in normalized for marker in _PLACEHOLDER_SUMMARY_MARKERS)
|
||||
|
||||
|
||||
def has_meaningful_summary(text: str | None) -> bool:
|
||||
"""判断摘要是否包含可复用的真实结果。"""
|
||||
normalized = normalize_summary_text(text)
|
||||
return bool(normalized) and not contains_placeholder_summary(normalized)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRunResult:
|
||||
"""统一描述一次 agent 执行结果。"""
|
||||
|
||||
# 执行方的稳定 ID,适合程序判断和日志检索。
|
||||
agent_id: str
|
||||
# 展示给用户或前端时使用的人类可读名称。
|
||||
agent_name: str
|
||||
# 归一化状态:通常是 `ok` / `error` / `cancelled` 等。
|
||||
status: str
|
||||
# 面向上层的简要总结,是最终展示和二次总结的主要输入。
|
||||
summary: str
|
||||
# 可选原始载荷,保留底层协议返回值,便于调试或后续扩展。
|
||||
raw: dict[str, Any] | None = None
|
||||
@ -1,238 +0,0 @@
|
||||
"""Review-first skill installation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir, get_workspace_state_path, safe_filename, timestamp
|
||||
|
||||
|
||||
def _is_relative_to(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_frontmatter(content: str) -> dict[str, str]:
|
||||
if not content.startswith("---"):
|
||||
return {}
|
||||
|
||||
end = content.find("\n---", 3)
|
||||
if end == -1:
|
||||
return {}
|
||||
|
||||
metadata: dict[str, str] = {}
|
||||
for line in content[3:end].splitlines():
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip("\"'")
|
||||
return metadata
|
||||
|
||||
|
||||
def _parse_skill_metadata(raw: str) -> dict[str, Any]:
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
nested = data.get("nanobot", data.get("openclaw", {}))
|
||||
return nested if isinstance(nested, dict) else {}
|
||||
|
||||
|
||||
class SkillReviewManager:
|
||||
"""Stage workspace skill installs until the user explicitly approves them."""
|
||||
|
||||
REVIEW_META_FILE = "review.json"
|
||||
ARCHIVE_FILE = "upload.zip"
|
||||
STAGED_DIR = "staged"
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace.expanduser().resolve()
|
||||
self.workspace_skills = ensure_dir(self.workspace / "skills")
|
||||
self.reviews_dir = ensure_dir(get_workspace_state_path(self.workspace) / "skill-reviews")
|
||||
|
||||
def list_reviews(self) -> list[dict[str, Any]]:
|
||||
reviews: list[dict[str, Any]] = []
|
||||
for review_dir in sorted(self.reviews_dir.iterdir(), reverse=True):
|
||||
if not review_dir.is_dir():
|
||||
continue
|
||||
try:
|
||||
reviews.append(self._read_review(review_dir))
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
return reviews
|
||||
|
||||
def get_review(self, review_id: str) -> dict[str, Any]:
|
||||
return self._read_review(self._review_dir(review_id))
|
||||
|
||||
def create_review_from_zip(self, filename: str, content: bytes) -> dict[str, Any]:
|
||||
review_id = secrets.token_hex(8)
|
||||
review_dir = ensure_dir(self._review_dir(review_id))
|
||||
archive_path = review_dir / self.ARCHIVE_FILE
|
||||
archive_path.write_bytes(content)
|
||||
|
||||
staged_root = ensure_dir(review_dir / self.STAGED_DIR)
|
||||
preview = self._extract_archive(archive_path, staged_root, filename)
|
||||
review = {
|
||||
"id": review_id,
|
||||
"status": "pending_review",
|
||||
"created_at": timestamp(),
|
||||
"archive_name": filename,
|
||||
**preview,
|
||||
}
|
||||
self._write_review(review_dir, review)
|
||||
return review
|
||||
|
||||
def approve_review(self, review_id: str, overwrite: bool = False) -> dict[str, Any]:
|
||||
review_dir = self._review_dir(review_id)
|
||||
review = self._read_review(review_dir)
|
||||
|
||||
if review.get("status") == "approved":
|
||||
return review
|
||||
|
||||
skill_name = str(review.get("skill_name") or "").strip()
|
||||
if not skill_name:
|
||||
raise ValueError("Review is missing a skill_name")
|
||||
|
||||
source_dir = review_dir / self.STAGED_DIR / skill_name
|
||||
if not source_dir.is_dir():
|
||||
raise FileNotFoundError(f"Staged skill not found for review {review_id}")
|
||||
|
||||
target_dir = self.workspace_skills / skill_name
|
||||
if target_dir.exists():
|
||||
if not overwrite:
|
||||
raise FileExistsError(
|
||||
f"Skill '{skill_name}' already exists. Re-submit approval with overwrite=true."
|
||||
)
|
||||
shutil.rmtree(target_dir)
|
||||
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
review["status"] = "approved"
|
||||
review["approved_at"] = timestamp()
|
||||
review["overwrite"] = overwrite
|
||||
review["installed_path"] = str(target_dir / "SKILL.md")
|
||||
self._write_review(review_dir, review)
|
||||
return review
|
||||
|
||||
def discard_review(self, review_id: str) -> None:
|
||||
review_dir = self._review_dir(review_id)
|
||||
if not review_dir.exists():
|
||||
raise FileNotFoundError(f"Skill review '{review_id}' not found")
|
||||
shutil.rmtree(review_dir)
|
||||
|
||||
def _review_dir(self, review_id: str) -> Path:
|
||||
return self.reviews_dir / review_id
|
||||
|
||||
def _read_review(self, review_dir: Path) -> dict[str, Any]:
|
||||
review_file = review_dir / self.REVIEW_META_FILE
|
||||
if not review_file.exists():
|
||||
raise FileNotFoundError(f"Skill review metadata not found: {review_dir.name}")
|
||||
return json.loads(review_file.read_text(encoding="utf-8"))
|
||||
|
||||
def _write_review(self, review_dir: Path, review: dict[str, Any]) -> None:
|
||||
review_file = review_dir / self.REVIEW_META_FILE
|
||||
review_file.write_text(
|
||||
json.dumps(review, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _extract_archive(
|
||||
self,
|
||||
archive_path: Path,
|
||||
staged_root: Path,
|
||||
upload_name: str,
|
||||
) -> dict[str, Any]:
|
||||
with zipfile.ZipFile(archive_path, "r") as zf:
|
||||
file_infos = [info for info in zf.infolist() if not info.is_dir()]
|
||||
if not file_infos:
|
||||
raise ValueError("Zip archive is empty")
|
||||
|
||||
skill_md_entries: list[str] = []
|
||||
for info in file_infos:
|
||||
rel = PurePosixPath(info.filename)
|
||||
if rel.name != "SKILL.md":
|
||||
continue
|
||||
if len(rel.parts) not in (1, 2):
|
||||
raise ValueError(
|
||||
"SKILL.md must be at the archive root or inside a single top-level directory"
|
||||
)
|
||||
skill_md_entries.append(info.filename)
|
||||
|
||||
if not skill_md_entries:
|
||||
raise ValueError("Zip must contain a top-level SKILL.md file")
|
||||
|
||||
skill_md_entry = skill_md_entries[0]
|
||||
skill_md_parts = PurePosixPath(skill_md_entry).parts
|
||||
top_level_dir = skill_md_parts[0] if len(skill_md_parts) == 2 else ""
|
||||
frontmatter = _parse_frontmatter(
|
||||
zf.read(skill_md_entry).decode("utf-8", errors="replace")
|
||||
)
|
||||
|
||||
if top_level_dir:
|
||||
skill_name = top_level_dir
|
||||
else:
|
||||
skill_name = frontmatter.get("name") or Path(upload_name).stem
|
||||
|
||||
skill_name = safe_filename(skill_name).replace(" ", "-")
|
||||
if not skill_name:
|
||||
raise ValueError("Could not determine a safe skill name")
|
||||
|
||||
staged_skill_dir = staged_root / skill_name
|
||||
staged_skill_dir.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
extracted_files: list[str] = []
|
||||
for info in file_infos:
|
||||
raw_rel = PurePosixPath(info.filename)
|
||||
if "__MACOSX" in raw_rel.parts or raw_rel.name == ".DS_Store":
|
||||
continue
|
||||
|
||||
if top_level_dir:
|
||||
if not raw_rel.parts or raw_rel.parts[0] != top_level_dir:
|
||||
continue
|
||||
rel_parts = raw_rel.parts[1:]
|
||||
else:
|
||||
rel_parts = raw_rel.parts
|
||||
|
||||
if not rel_parts:
|
||||
continue
|
||||
if any(part in {"", ".", ".."} for part in rel_parts):
|
||||
raise ValueError(f"Unsafe archive entry: {info.filename}")
|
||||
|
||||
dest = staged_skill_dir.joinpath(*rel_parts)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
resolved_dest = dest.resolve()
|
||||
if not _is_relative_to(resolved_dest, staged_skill_dir.resolve()):
|
||||
raise ValueError(f"Unsafe archive entry: {info.filename}")
|
||||
|
||||
with zf.open(info) as src, open(dest, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
extracted_files.append(PurePosixPath(*rel_parts).as_posix())
|
||||
|
||||
if not (staged_skill_dir / "SKILL.md").exists():
|
||||
raise ValueError("Staged skill is missing SKILL.md after extraction")
|
||||
|
||||
skill_meta = _parse_skill_metadata(frontmatter.get("metadata", ""))
|
||||
target_dir = self.workspace_skills / skill_name
|
||||
return {
|
||||
"skill_name": skill_name,
|
||||
"declared_name": frontmatter.get("name", skill_name),
|
||||
"description": frontmatter.get("description", ""),
|
||||
"metadata": frontmatter,
|
||||
"requires": skill_meta.get("requires", {}),
|
||||
"file_count": len(extracted_files),
|
||||
"files": sorted(extracted_files),
|
||||
"target_exists": target_dir.exists(),
|
||||
"target_path": str(target_dir / "SKILL.md"),
|
||||
"staged_path": str(staged_skill_dir / "SKILL.md"),
|
||||
}
|
||||
@ -1,284 +0,0 @@
|
||||
"""Skills loader for agent capabilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Default builtin skills directory (relative to this file)
|
||||
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""
|
||||
Loader for agent skills.
|
||||
|
||||
Skills are markdown files (SKILL.md) that teach the agent how to use
|
||||
specific tools or perform certain tasks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
builtin_skills_dir: Path | None = None,
|
||||
extra_dirs: list[Path] | None = None,
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.workspace_skills = workspace / "skills"
|
||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||
if extra_dirs is None:
|
||||
from nanobot.agent.plugins import PluginLoader
|
||||
|
||||
extra_dirs = PluginLoader(workspace).get_skill_dirs()
|
||||
self.extra_dirs: list[Path] = extra_dirs
|
||||
|
||||
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available skills.
|
||||
|
||||
Args:
|
||||
filter_unavailable: If True, filter out skills with unmet requirements.
|
||||
|
||||
Returns:
|
||||
List of skill info dicts with 'name', 'path', 'source'.
|
||||
"""
|
||||
skills = []
|
||||
|
||||
# Workspace skills (highest priority)
|
||||
if self.workspace_skills.exists():
|
||||
for skill_dir in self.workspace_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists():
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
||||
|
||||
# Extra skill roots (e.g. plugin-provided skills)
|
||||
for extra_dir in self.extra_dirs:
|
||||
if extra_dir.exists():
|
||||
for skill_dir in extra_dir.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "plugin"})
|
||||
|
||||
# Built-in skills
|
||||
if self.builtin_skills and self.builtin_skills.exists():
|
||||
for skill_dir in self.builtin_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
||||
|
||||
# Filter by requirements
|
||||
if filter_unavailable:
|
||||
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
||||
return skills
|
||||
|
||||
def load_skill(self, name: str) -> str | None:
|
||||
"""
|
||||
Load a skill by name.
|
||||
|
||||
Args:
|
||||
name: Skill name (directory name).
|
||||
|
||||
Returns:
|
||||
Skill content or None if not found.
|
||||
"""
|
||||
# Check workspace first
|
||||
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
||||
if workspace_skill.exists():
|
||||
return workspace_skill.read_text(encoding="utf-8")
|
||||
|
||||
# Check plugin-provided roots
|
||||
for extra_dir in self.extra_dirs:
|
||||
extra_skill = extra_dir / name / "SKILL.md"
|
||||
if extra_skill.exists():
|
||||
return extra_skill.read_text(encoding="utf-8")
|
||||
|
||||
# Check built-in
|
||||
if self.builtin_skills:
|
||||
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
||||
if builtin_skill.exists():
|
||||
return builtin_skill.read_text(encoding="utf-8")
|
||||
|
||||
return None
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
"""
|
||||
Load specific skills for inclusion in agent context.
|
||||
|
||||
Args:
|
||||
skill_names: List of skill names to load.
|
||||
|
||||
Returns:
|
||||
Formatted skills content.
|
||||
"""
|
||||
parts = []
|
||||
for name in skill_names:
|
||||
content = self.load_skill(name)
|
||||
if content:
|
||||
content = self._strip_frontmatter(content)
|
||||
parts.append(f"### Skill: {name}\n\n{content}")
|
||||
|
||||
return "\n\n---\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
"""
|
||||
Build a summary of all skills (name, description, path, availability).
|
||||
|
||||
This is used for progressive loading - the agent can read the full
|
||||
skill content using read_file when needed.
|
||||
|
||||
Returns:
|
||||
XML-formatted skills summary.
|
||||
"""
|
||||
all_skills = self.list_skills(filter_unavailable=False)
|
||||
if not all_skills:
|
||||
return ""
|
||||
|
||||
def escape_xml(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<skills>"]
|
||||
for s in all_skills:
|
||||
name = escape_xml(s["name"])
|
||||
path = s["path"]
|
||||
desc = escape_xml(self._get_skill_description(s["name"]))
|
||||
skill_meta = self._get_skill_meta(s["name"])
|
||||
available = self._check_requirements(skill_meta)
|
||||
|
||||
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
||||
lines.append(f" <name>{name}</name>")
|
||||
lines.append(f" <description>{desc}</description>")
|
||||
lines.append(f" <location>{path}</location>")
|
||||
|
||||
# Show missing requirements for unavailable skills
|
||||
if not available:
|
||||
missing = self._get_missing_requirements(skill_meta)
|
||||
if missing:
|
||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||
"""Get a description of missing requirements."""
|
||||
missing = []
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
missing.append(f"CLI: {b}")
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
missing.append(f"ENV: {env}")
|
||||
return ", ".join(missing)
|
||||
|
||||
def _get_skill_description(self, name: str) -> str:
|
||||
"""Get the description of a skill from its frontmatter."""
|
||||
meta = self.get_skill_metadata(name)
|
||||
if meta and meta.get("description"):
|
||||
return meta["description"]
|
||||
return name # Fallback to skill name
|
||||
|
||||
def _strip_frontmatter(self, content: str) -> str:
|
||||
"""Remove YAML frontmatter from markdown content."""
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL)
|
||||
if match:
|
||||
return content[match.end():].strip()
|
||||
return content
|
||||
|
||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
def _check_requirements(self, skill_meta: dict) -> bool:
|
||||
"""Check if skill requirements are met (bins, env vars)."""
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
return False
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_skill_meta(self, name: str) -> dict:
|
||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||
meta = self.get_skill_metadata(name) or {}
|
||||
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""Get skills marked as always=true that meet requirements."""
|
||||
result = []
|
||||
for s in self.list_skills(filter_unavailable=True):
|
||||
meta = self.get_skill_metadata(s["name"]) or {}
|
||||
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
if skill_meta.get("always") or meta.get("always"):
|
||||
result.append(s["name"])
|
||||
return result
|
||||
|
||||
def get_skill_metadata(self, name: str) -> dict | None:
|
||||
"""
|
||||
Get metadata from a skill's frontmatter.
|
||||
|
||||
Args:
|
||||
name: Skill name.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None.
|
||||
"""
|
||||
content = self.load_skill(name)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
# Simple YAML parsing
|
||||
metadata = {}
|
||||
for line in match.group(1).split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
return metadata
|
||||
|
||||
return None
|
||||
|
||||
def get_skill_agent_cards(self, name: str) -> list[dict]:
|
||||
"""从 skill 元数据里提取 A2A agent card 声明。"""
|
||||
# 技能 frontmatter 里的 metadata 是字符串形式,先复用现有解析逻辑拿到 nanobot 扩展字段。
|
||||
meta = self.get_skill_metadata(name) or {}
|
||||
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
cards = skill_meta.get("agent_cards", [])
|
||||
if not isinstance(cards, list):
|
||||
return []
|
||||
|
||||
result = []
|
||||
for idx, card in enumerate(cards):
|
||||
if not isinstance(card, dict):
|
||||
continue
|
||||
# 复制一份,避免直接修改原 metadata 结构。
|
||||
item = dict(card)
|
||||
# 对缺失字段做兜底补全,保证后续 AgentRegistry 可以稳定消费。
|
||||
item.setdefault("id", item.get("name") or f"{name}-agent-{idx + 1}")
|
||||
item.setdefault("name", item["id"])
|
||||
item.setdefault("description", meta.get("description", item["name"]))
|
||||
# 额外挂回 skill_name,方便前端展示来源,也便于后续定位声明位置。
|
||||
item["skill_name"] = name
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
def list_skill_agent_cards(self) -> list[dict]:
|
||||
"""聚合所有可见 skill 中声明的 agent card。"""
|
||||
cards = []
|
||||
for skill in self.list_skills(filter_unavailable=False):
|
||||
cards.extend(self.get_skill_agent_cards(skill["name"]))
|
||||
return cards
|
||||
@ -1,311 +0,0 @@
|
||||
"""本地委派执行器。
|
||||
|
||||
这个类不再负责“后台任务管理”和“结果回流”,只保留一件事:
|
||||
在统一委派层要求执行本地任务时,提供一个受限工具集的本地 agent 执行环境。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time as _time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.run_result import AgentRunResult, has_meaningful_summary
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import NestedDelegateTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.delegation import DelegationManager
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""用受限工具集在本地执行委派任务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
brave_api_key: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
# 这里保存的都是本地执行所需的静态配置,不再维护后台任务表。
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.brave_api_key = brave_api_key
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._nested_delegate: DelegationManager | None = None
|
||||
|
||||
def set_nested_delegate(self, manager: "DelegationManager | None") -> None:
|
||||
"""注入 delegated worker 可用的受控下游委派器。"""
|
||||
self._nested_delegate = manager
|
||||
|
||||
async def run_local_task(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
agent_id: str = "local-subagent",
|
||||
agent_name: str = "Local Subagent",
|
||||
system_prompt: str | None = None,
|
||||
model: str | None = None,
|
||||
progress_callback: Callable[..., Awaitable[None]] | None = None,
|
||||
allow_nested_delegation: bool = True,
|
||||
skill_context: str = "",
|
||||
skill_names: list[str] | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""执行一次本地委派任务,并返回结构化结果。"""
|
||||
# 每次任务都新建一套局部工具注册表,避免不同任务之间共享临时状态。
|
||||
tools = self._build_local_tools(
|
||||
allow_nested_delegation=allow_nested_delegation,
|
||||
skill_names=skill_names,
|
||||
)
|
||||
prompt = self._build_subagent_prompt(
|
||||
task,
|
||||
agent_name=agent_name,
|
||||
custom_system_prompt=system_prompt,
|
||||
allow_nested_delegation=allow_nested_delegation,
|
||||
skill_context=skill_context,
|
||||
)
|
||||
# 本地委派不共享主会话历史,只带“专用 system prompt + 当前任务”。
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
# 本地子 agent 也走“模型 -> 工具 -> 模型”的短循环,但轮数更保守。
|
||||
max_iterations = 15
|
||||
iteration = 0
|
||||
final_result: str | None = None
|
||||
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=model or self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if progress_callback:
|
||||
# 进度回调只发对用户有价值的文本,不把 `<think>` 之类内部推理暴露出去。
|
||||
clean = self._strip_think(response.content)
|
||||
if clean:
|
||||
await progress_callback(clean, tool_hint=False)
|
||||
# 额外补一条短工具提示,让上层 UI 知道当前在做什么。
|
||||
await progress_callback(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": tool_call_dicts,
|
||||
})
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Agent [{}] executing: {} with arguments: {}", agent_id, tool_call.name, args_str)
|
||||
# 真正执行工具后,把结果回填到 messages,让下一轮模型能看到执行结果。
|
||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
else:
|
||||
# 没有继续调用工具时,视为任务已收敛,直接采纳当前回复。
|
||||
final_result = response.content
|
||||
break
|
||||
|
||||
status = "ok"
|
||||
raw: dict[str, Any] | None = None
|
||||
if not has_meaningful_summary(final_result):
|
||||
# 兜底避免出现“任务做完了但完全没文本”的空结果,并显式标记为失败,
|
||||
# 防止上层把这类占位结果学习成 procedure。
|
||||
final_result = "Task completed but no final response was generated."
|
||||
status = "error"
|
||||
raw = {
|
||||
"reason": "no_final_response_generated",
|
||||
"iterations": iteration,
|
||||
}
|
||||
|
||||
return AgentRunResult(
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_name,
|
||||
status=status,
|
||||
summary=final_result,
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
def _build_local_tools(
|
||||
self,
|
||||
*,
|
||||
allow_nested_delegation: bool,
|
||||
skill_names: list[str] | None = None,
|
||||
) -> ToolRegistry:
|
||||
"""构建本地委派可用的受限工具集。"""
|
||||
tools = ToolRegistry()
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
protected_skill_paths = [self.workspace / "skills"]
|
||||
# 文件工具统一按相同的 workspace / allowed_dir 约束注册。
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(
|
||||
WriteFileTool(
|
||||
workspace=self.workspace,
|
||||
allowed_dir=allowed_dir,
|
||||
protected_paths=protected_skill_paths,
|
||||
)
|
||||
)
|
||||
tools.register(
|
||||
EditFileTool(
|
||||
workspace=self.workspace,
|
||||
allowed_dir=allowed_dir,
|
||||
protected_paths=protected_skill_paths,
|
||||
)
|
||||
)
|
||||
# 本地命令执行沿用主配置里的超时和 workspace 限制。
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
protected_paths=protected_skill_paths,
|
||||
))
|
||||
# 网络能力保持只读:搜索和抓取,不提供消息发送/再次委派等工具。
|
||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||
tools.register(WebFetchTool())
|
||||
if allow_nested_delegation and self._nested_delegate is not None:
|
||||
tools.register(NestedDelegateTool(manager=self._nested_delegate, default_skills=skill_names))
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Remove provider-specific think blocks from visible progress text."""
|
||||
if not text:
|
||||
return None
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""把工具调用列表格式化成简短进度提示。"""
|
||||
|
||||
def _fmt(tc):
|
||||
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
def _build_subagent_prompt(
|
||||
self,
|
||||
task: str,
|
||||
agent_name: str = "Local Subagent",
|
||||
custom_system_prompt: str | None = None,
|
||||
allow_nested_delegation: bool = True,
|
||||
skill_context: str = "",
|
||||
) -> str:
|
||||
"""构建子代理专用 system prompt。"""
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = _time.strftime("%Z") or "UTC"
|
||||
# plugin agent 的自定义系统提示拼到末尾,保留通用约束,再叠加个性化指令。
|
||||
extra = f"\n\n## Agent Instructions\n{custom_system_prompt.strip()}" if custom_system_prompt else ""
|
||||
can_do_lines = [
|
||||
"- Read and write files in the workspace",
|
||||
"- Execute shell commands",
|
||||
"- Search the web and fetch web pages",
|
||||
"- Complete the task thoroughly",
|
||||
]
|
||||
cannot_do_lines = [
|
||||
"- Send messages directly to users (no message tool available)",
|
||||
"- Access the main agent's conversation history",
|
||||
]
|
||||
delegation_section = (
|
||||
"\n## Downstream Delegation\n"
|
||||
"- Do not delegate further. Complete the task yourself with the tools you have."
|
||||
)
|
||||
if allow_nested_delegation and self._nested_delegate is not None:
|
||||
can_do_lines.append(
|
||||
"- Use `delegate_task` for controlled downstream delegation when specialized help is required"
|
||||
)
|
||||
cannot_do_lines.append("- Do not start agent teams or use background delegation tools")
|
||||
nested_summary = self._nested_delegate.build_nested_agents_summary()
|
||||
summary_block = f"\n\n{nested_summary}" if nested_summary else ""
|
||||
delegation_section = (
|
||||
"\n## Downstream Delegation\n"
|
||||
"- Use `delegate_task` only when a specialized downstream worker is actually needed.\n"
|
||||
"- `strategy=\"a2a\"` delegates directly to an available A2A agent.\n"
|
||||
"- `strategy=\"ephemeral_subagent\"` runs a temporary local worker for this task only.\n"
|
||||
"- Never create, register, or persist a new local sub-agent through `subagentctl.py`, `/api/subagents`, or registry edits."
|
||||
f"{summary_block}"
|
||||
)
|
||||
else:
|
||||
cannot_do_lines.append("- Spawn other subagents or downstream workers")
|
||||
skill_section = f"\n## Required Skills\n{skill_context.strip()}" if skill_context.strip() else ""
|
||||
|
||||
return f"""# {agent_name}
|
||||
|
||||
## Current Time
|
||||
{now} ({tz})
|
||||
|
||||
You are a delegated agent spawned by the main agent to complete a specific task.
|
||||
|
||||
## Rules
|
||||
1. Stay focused - complete only the assigned task, nothing else
|
||||
2. Your final response will be reported back to the main agent
|
||||
3. Do not initiate conversations or take on side tasks
|
||||
4. Be concise but informative in your findings
|
||||
5. Do not create or modify persistent local sub-agents unless the task explicitly requires that workflow
|
||||
|
||||
## What You Can Do
|
||||
{chr(10).join(can_do_lines)}
|
||||
|
||||
## What You Cannot Do
|
||||
{chr(10).join(cannot_do_lines)}
|
||||
|
||||
{delegation_section}
|
||||
|
||||
{skill_section}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {self.workspace}
|
||||
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
||||
|
||||
## Special Workflow
|
||||
- If the task is about creating, updating, repairing, or deleting a persistent local sub-agent, read `skills/subagent-manager/SKILL.md` before making changes.
|
||||
- For persistent local sub-agents, follow only the canonical workflow from that skill.
|
||||
- Do not manually create `workspace/agents/<id>/agent.json` as a substitute for a persistent sub-agent.
|
||||
- Do not manually edit `workspace/agents/registry.json` to register a persistent sub-agent.
|
||||
- A valid persistent sub-agent must be created through `subagentctl.py` or `/api/subagents` and must end up at `workspace/agents/<id>_agent/AGENTS.json`.
|
||||
|
||||
When you have completed the task, provide a clear summary of your findings or actions.{extra}"""
|
||||
@ -1,258 +0,0 @@
|
||||
"""Persistent local sub-agent storage helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from importlib.resources import files as pkg_files
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.config.schema import Config, MCPServerConfig
|
||||
|
||||
_INVALID_ID_RE = re.compile(r"[^a-z0-9-]+")
|
||||
|
||||
|
||||
def normalize_subagent_id(value: str) -> str:
|
||||
normalized = _INVALID_ID_RE.sub("-", str(value or "").strip().lower()).strip("-")
|
||||
normalized = re.sub(r"-{2,}", "-", normalized)
|
||||
if not normalized:
|
||||
raise ValueError("Sub-agent id is required")
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentSpec:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
enabled: bool = True
|
||||
workspace: str = ""
|
||||
system_prompt: str = ""
|
||||
model: str | None = None
|
||||
delegation_mode: str = "remote_a2a_only"
|
||||
allow_mcp: bool = True
|
||||
tags: list[str] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
mcp_servers: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any], *, workspace_path: Path | None = None) -> "SubagentSpec":
|
||||
agent_id = normalize_subagent_id(payload.get("id", ""))
|
||||
name = str(payload.get("name") or agent_id).strip() or agent_id
|
||||
description = str(payload.get("description") or name).strip() or name
|
||||
workspace = str(payload.get("workspace") or "").strip()
|
||||
if not workspace and workspace_path is not None:
|
||||
workspace = str(workspace_path)
|
||||
tags = [str(item).strip() for item in payload.get("tags", []) if str(item).strip()]
|
||||
aliases = [str(item).strip() for item in payload.get("aliases", []) if str(item).strip()]
|
||||
mcp_servers = payload.get("mcp_servers", {})
|
||||
if not isinstance(mcp_servers, dict):
|
||||
mcp_servers = {}
|
||||
metadata = payload.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
return cls(
|
||||
id=agent_id,
|
||||
name=name,
|
||||
description=description,
|
||||
enabled=bool(payload.get("enabled", True)),
|
||||
workspace=workspace,
|
||||
system_prompt=str(payload.get("system_prompt") or "").strip(),
|
||||
model=(str(payload.get("model") or "").strip() or None),
|
||||
delegation_mode=(str(payload.get("delegation_mode") or "remote_a2a_only").strip() or "remote_a2a_only"),
|
||||
allow_mcp=bool(payload.get("allow_mcp", True)),
|
||||
tags=tags,
|
||||
aliases=aliases,
|
||||
mcp_servers=mcp_servers,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = asdict(self)
|
||||
if not self.model:
|
||||
payload["model"] = None
|
||||
return payload
|
||||
|
||||
|
||||
class LocalSubagentStore:
|
||||
"""Persist sub-agent definitions under `<workspace>/agents/<id>_agent/`."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace.expanduser().resolve()
|
||||
self.directory = self.workspace / "agents"
|
||||
|
||||
def list_subagents(self) -> list[SubagentSpec]:
|
||||
if not self.directory.exists():
|
||||
return []
|
||||
result: list[SubagentSpec] = []
|
||||
for child in sorted(self.directory.iterdir()):
|
||||
agents_json = child / "AGENTS.json"
|
||||
if not child.is_dir() or not agents_json.exists():
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(agents_json.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
result.append(SubagentSpec.from_dict(payload, workspace_path=child))
|
||||
return result
|
||||
|
||||
def get_subagent(self, agent_id: str) -> SubagentSpec | None:
|
||||
path = self.agents_json_path(agent_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return SubagentSpec.from_dict(payload, workspace_path=self.subagent_dir(agent_id))
|
||||
|
||||
def upsert_subagent(self, payload: dict[str, Any], config: Config) -> SubagentSpec:
|
||||
agent_id = normalize_subagent_id(payload.get("id", ""))
|
||||
workspace_path = self.subagent_dir(agent_id)
|
||||
spec = SubagentSpec.from_dict(payload, workspace_path=workspace_path)
|
||||
|
||||
self._ensure_workspace(workspace_path)
|
||||
spec.workspace = str(workspace_path)
|
||||
self._sync_agents_md(workspace_path, spec)
|
||||
self.agents_json_path(agent_id).write_text(
|
||||
json.dumps(spec.to_dict(), indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from nanobot.agent.agent_registry import WorkspaceAgentStore
|
||||
|
||||
WorkspaceAgentStore(self.workspace).upsert_agent(self.build_registry_record(spec, config))
|
||||
return spec
|
||||
|
||||
def delete_subagent(self, agent_id: str) -> bool:
|
||||
agent_id = normalize_subagent_id(agent_id)
|
||||
target = self.subagent_dir(agent_id)
|
||||
if not target.exists():
|
||||
return False
|
||||
|
||||
from nanobot.agent.agent_registry import WorkspaceAgentStore
|
||||
|
||||
WorkspaceAgentStore(self.workspace).delete_agent(agent_id)
|
||||
shutil.rmtree(target)
|
||||
return True
|
||||
|
||||
def subagent_dir(self, agent_id: str) -> Path:
|
||||
return self.directory / f"{normalize_subagent_id(agent_id)}_agent"
|
||||
|
||||
def agents_json_path(self, agent_id: str) -> Path:
|
||||
return self.subagent_dir(agent_id) / "AGENTS.json"
|
||||
|
||||
def local_base_url(self, config: Config, agent_id: str) -> str:
|
||||
return f"http://127.0.0.1:{int(config.gateway.port)}/subagents/{normalize_subagent_id(agent_id)}"
|
||||
|
||||
def build_registry_record(self, spec: SubagentSpec, config: Config) -> dict[str, Any]:
|
||||
base_url = self.local_base_url(config, spec.id)
|
||||
card_url = f"{base_url}/.well-known/agent-card"
|
||||
return {
|
||||
"id": spec.id,
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"protocol": "a2a",
|
||||
"base_url": base_url,
|
||||
"endpoint": f"{base_url}/rpc",
|
||||
"card_url": card_url,
|
||||
"enabled": spec.enabled,
|
||||
"tags": sorted(set(["local-subagent", *spec.tags])),
|
||||
"aliases": sorted(set([spec.name, *spec.aliases])),
|
||||
"metadata": {
|
||||
**spec.metadata,
|
||||
"workspace": spec.workspace,
|
||||
"managed_by": "subagent-manager",
|
||||
"local_subagent": True,
|
||||
},
|
||||
"capabilities": {"streaming": False},
|
||||
"support_streaming": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_agent_card(spec: SubagentSpec, config: Config) -> dict[str, Any]:
|
||||
base_url = f"http://127.0.0.1:{int(config.gateway.port)}/subagents/{spec.id}"
|
||||
rpc_url = f"{base_url}/rpc"
|
||||
return {
|
||||
"id": spec.id,
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"url": rpc_url,
|
||||
"preferred_transport": "jsonrpc",
|
||||
"interfaces": [{"transport": "jsonrpc", "url": rpc_url}],
|
||||
"capabilities": {"streaming": False},
|
||||
"tags": sorted(set(["local-subagent", *spec.tags])),
|
||||
"metadata": {
|
||||
"workspace": spec.workspace,
|
||||
"managed_by": "subagent-manager",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def coerce_mcp_servers(spec: SubagentSpec) -> dict[str, MCPServerConfig]:
|
||||
if not spec.allow_mcp:
|
||||
return {}
|
||||
result: dict[str, MCPServerConfig] = {}
|
||||
for name, payload in spec.mcp_servers.items():
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
try:
|
||||
result[name] = MCPServerConfig.model_validate(payload)
|
||||
except Exception:
|
||||
continue
|
||||
return result
|
||||
|
||||
def _ensure_workspace(self, workspace_path: Path) -> None:
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
templates_dir = pkg_files("nanobot") / "templates"
|
||||
for item in templates_dir.iterdir():
|
||||
if not item.name.endswith(".md") or item.name == "AGENTS.md":
|
||||
continue
|
||||
dest = workspace_path / item.name
|
||||
if not dest.exists():
|
||||
dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
|
||||
memory_dir = workspace_path / "memory"
|
||||
memory_dir.mkdir(exist_ok=True)
|
||||
memory_template = templates_dir / "memory" / "MEMORY.md"
|
||||
memory_file = memory_dir / "MEMORY.md"
|
||||
if not memory_file.exists():
|
||||
memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
history_file = memory_dir / "HISTORY.md"
|
||||
if not history_file.exists():
|
||||
history_file.write_text("", encoding="utf-8")
|
||||
(workspace_path / "skills").mkdir(exist_ok=True)
|
||||
|
||||
def _sync_agents_md(self, workspace_path: Path, spec: SubagentSpec) -> None:
|
||||
content = self._render_agents_md(spec)
|
||||
(workspace_path / "AGENTS.md").write_text(content, encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _render_agents_md(spec: SubagentSpec) -> str:
|
||||
prompt = spec.system_prompt.strip() or "Complete delegated tasks accurately and concisely."
|
||||
return f"""# {spec.name}
|
||||
|
||||
You are {spec.name}, a persistent local sub-agent managed by Boardware Genius.
|
||||
|
||||
## Role
|
||||
{spec.description}
|
||||
|
||||
## System Prompt
|
||||
{prompt}
|
||||
|
||||
## Constraints
|
||||
- Work only inside this workspace.
|
||||
- Respond only to delegated tasks.
|
||||
- Delegate only to remote A2A agents when delegation is enabled.
|
||||
- Do not create or manage local sub-agents.
|
||||
- Do not message end users directly.
|
||||
"""
|
||||
@ -1,6 +0,0 @@
|
||||
"""Agent tools module."""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
__all__ = ["Tool", "ToolRegistry"]
|
||||
@ -1,102 +0,0 @@
|
||||
"""Base class for agent tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for agent tools.
|
||||
|
||||
Tools are capabilities that the agent can use to interact with
|
||||
the environment, such as reading files, executing commands, etc.
|
||||
"""
|
||||
|
||||
_TYPE_MAP = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""
|
||||
Execute the tool with given parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
||||
return self._validate(params, {**schema, "type": "object"}, "")
|
||||
|
||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||
t, label = schema.get("type"), path or "parameter"
|
||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
||||
return [f"{label} should be {t}"]
|
||||
|
||||
errors = []
|
||||
if "enum" in schema and val not in schema["enum"]:
|
||||
errors.append(f"{label} must be one of {schema['enum']}")
|
||||
if t in ("integer", "number"):
|
||||
if "minimum" in schema and val < schema["minimum"]:
|
||||
errors.append(f"{label} must be >= {schema['minimum']}")
|
||||
if "maximum" in schema and val > schema["maximum"]:
|
||||
errors.append(f"{label} must be <= {schema['maximum']}")
|
||||
if t == "string":
|
||||
if "minLength" in schema and len(val) < schema["minLength"]:
|
||||
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
||||
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
||||
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
||||
if t == "object":
|
||||
props = schema.get("properties", {})
|
||||
for k in schema.get("required", []):
|
||||
if k not in val:
|
||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||
for k, v in val.items():
|
||||
if k in props:
|
||||
errors.extend(self._validate(v, props[k], path + '.' + k if path else k))
|
||||
if t == "array" and "items" in schema:
|
||||
for i, item in enumerate(val):
|
||||
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]"))
|
||||
return errors
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to OpenAI function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
}
|
||||
@ -1,246 +0,0 @@
|
||||
"""cron 工具:给 Agent 提供“定时任务管理”能力。
|
||||
|
||||
这个工具是 LLM 在对话中可调用的 function tool,主要负责三件事:
|
||||
1. `add`:创建一个定时任务(周期/cron/一次性);
|
||||
2. `list`:列出现有任务;
|
||||
3. `remove`:删除指定任务。
|
||||
|
||||
设计定位说明:
|
||||
- 本工具只做“任务管理面”,不直接负责“定时器循环”;
|
||||
- 真正的调度与执行由 `CronService` 统一负责(start/stop/on_job);
|
||||
- 工具层通过 `set_context(channel, chat_id)` 注入当前会话路由,
|
||||
从而让定时任务在触发后把结果回投到正确会话。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
class CronTool(Tool):
|
||||
"""对话可调用的 cron 管理工具。
|
||||
|
||||
调用来源:
|
||||
- 主 agent 在工具调用回合中发起 `cron(...)`。
|
||||
|
||||
关键约束:
|
||||
- action 仅支持 `add/list/remove` 三种;
|
||||
- `add` 必须带 message,并且必须先注入 session 上下文(channel/chat_id);
|
||||
- 时间相关参数三选一:`every_seconds` / `cron_expr` / `at`。
|
||||
"""
|
||||
|
||||
def __init__(self, cron_service: CronService):
|
||||
# 持有同一个 CronService 实例,保证:
|
||||
# 1) CLI 命令与 agent 工具看到同一份 jobs.json;
|
||||
# 2) 任务状态(next_run、enabled)在进程内一致。
|
||||
self._cron = cron_service
|
||||
# 路由上下文由 AgentLoop 每轮注入。
|
||||
# 任务触发时将按该路由把结果投递回原会话。
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
self._session_key = ""
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, session_key: str | None = None) -> None:
|
||||
"""设置当前会话路由上下文。
|
||||
|
||||
为什么需要它:
|
||||
- 用户在 A 会话里让 agent“每天提醒我”,
|
||||
任务未来触发时应回到 A,而不是误发到其他会话。
|
||||
- 因此 channel/chat_id 不依赖模型每次显式传参,
|
||||
而是由运行时在调用前预注入默认目标。
|
||||
"""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._session_key = session_key or f"{channel}:{chat_id}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# 暴露给模型的工具名。模型会以 `cron(...)` 发起 function call。
|
||||
return "cron"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# 给模型看的简要能力描述,尽量短而明确。
|
||||
return "Schedule reminders and recurring tasks. Actions: add, list, remove. Use mode=reminder or task."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# OpenAI function schema:
|
||||
# - 定义参数结构与类型;
|
||||
# - 由 ToolRegistry 在调用前做基础参数校验。
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
# add 时的任务文本:
|
||||
# - 既可做“纯提醒文案”,也可做“交给 agent 执行的提示”。
|
||||
"description": "Reminder message (for add)"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["reminder", "task"],
|
||||
"description": "Execution mode: reminder sends message directly; task re-enters agent"
|
||||
},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
# 固定间隔调度(单位秒),内部会转换为毫秒。
|
||||
"description": "Interval in seconds (for recurring tasks)"
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
# 标准 cron 表达式(5 段),例如每天 9 点:0 9 * * *
|
||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
# 仅与 cron_expr 搭配使用的 IANA 时区。
|
||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')"
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
# 一次性触发时间,ISO 格式(本地/带偏移都可由 fromisoformat 解析)。
|
||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
||||
},
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
message: str = "",
|
||||
mode: str | None = None,
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
"""工具主入口:按 action 分发到具体处理函数。
|
||||
|
||||
注意:
|
||||
- 这里不直接抛异常给上层;尽量返回可读错误字符串。
|
||||
- 真正未捕获异常(如非法日期解析)会被 ToolRegistry 包装成 Error 文本。
|
||||
"""
|
||||
# add:创建任务(并立即持久化),返回任务 ID。
|
||||
if action == "add":
|
||||
return self._add_job(message, mode, every_seconds, cron_expr, tz, at)
|
||||
# list:只读取并格式化输出,不改状态。
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
# remove:按 ID 删除任务并重置调度器。
|
||||
elif action == "remove":
|
||||
return self._remove_job(job_id)
|
||||
# schema 已限制枚举,这里是兜底防御。
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
def _add_job(
|
||||
self,
|
||||
message: str,
|
||||
mode: str | None,
|
||||
every_seconds: int | None,
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
) -> str:
|
||||
"""创建任务并写入 CronService。
|
||||
|
||||
参数优先级(互斥选择):
|
||||
1. `every_seconds` -> 固定间隔任务
|
||||
2. `cron_expr` -> cron 表达式任务
|
||||
3. `at` -> 一次性任务(执行后自动删除)
|
||||
"""
|
||||
# message 是 add 的必填语义字段:没有内容就无法定义“要做什么”。
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
# channel/chat_id 由 AgentLoop 注入;
|
||||
# 若缺失,说明当前调用上下文不完整,无法保证结果回投目标正确。
|
||||
if not self._channel or not self._chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
# 时区仅对 cron 表达式有意义;避免用户误把 tz 用在 every/at 上。
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
# 尽早校验时区,提前给出明确错误,避免把非法数据写入存储。
|
||||
if tz:
|
||||
from zoneinfo import ZoneInfo
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except (KeyError, Exception):
|
||||
return f"Error: unknown timezone '{tz}'"
|
||||
|
||||
# mode 缺省时默认按“提醒”处理:
|
||||
# - 与 cron skill 的说明一致;
|
||||
# - 避免把原始建任务指令再次送回 agent,造成任务自复制。
|
||||
normalized_mode = (mode or "reminder").strip().lower()
|
||||
if normalized_mode not in {"reminder", "task"}:
|
||||
return "Error: mode must be 'reminder' or 'task'"
|
||||
payload_kind = "system_event" if normalized_mode == "reminder" else "agent_turn"
|
||||
|
||||
# 构建调度对象:
|
||||
# - CronService 内部统一使用毫秒时间戳;
|
||||
# - `at` 任务默认 delete_after_run=True,执行一次后自动移除。
|
||||
delete_after = False
|
||||
if every_seconds:
|
||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||
elif cron_expr:
|
||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||
elif at:
|
||||
from datetime import datetime
|
||||
# fromisoformat 解析失败会抛 ValueError,
|
||||
# 该异常会由 ToolRegistry 统一转换为错误字符串返回给模型。
|
||||
dt = datetime.fromisoformat(at)
|
||||
at_ms = int(dt.timestamp() * 1000)
|
||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||
delete_after = True
|
||||
else:
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
# 创建任务并持久化:
|
||||
# - name 使用 message 前 30 字符做简短标题,便于列表展示;
|
||||
# - deliver=True:任务触发后默认向当前会话投递结果;
|
||||
# - channel/to 使用注入上下文,确保消息路由一致。
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
payload_kind=payload_kind,
|
||||
session_key=self._session_key or None,
|
||||
deliver=True,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
# 返回简明确认文本,便于模型后续引用 job_id 做删除或说明。
|
||||
return f"Created {normalized_mode} job '{job.name}' (id: {job.id})"
|
||||
|
||||
def _list_jobs(self) -> str:
|
||||
"""列出当前可见任务(默认仅启用任务)。"""
|
||||
jobs = self._cron.list_jobs()
|
||||
if not jobs:
|
||||
return "No scheduled jobs."
|
||||
# 输出格式保持轻量,避免把过多状态塞给模型。
|
||||
# 详细状态(next_run/last_error)可在 CLI 的 `nanobot cron list` 查看。
|
||||
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
||||
return "Scheduled jobs:\n" + "\n".join(lines)
|
||||
|
||||
def _remove_job(self, job_id: str | None) -> str:
|
||||
"""按 ID 删除任务。"""
|
||||
if not job_id:
|
||||
return "Error: job_id is required for remove"
|
||||
# remove_job 返回 bool,工具层负责转换成对话友好的文案。
|
||||
if self._cron.remove_job(job_id):
|
||||
return f"Removed job {job_id}"
|
||||
return f"Job {job_id} not found"
|
||||
@ -1,116 +0,0 @@
|
||||
"""结构化 cron 生命周期控制工具。
|
||||
|
||||
cron 任务不是普通用户对话,它经常需要在运行完成后主动告诉调度器:
|
||||
- 这个任务已经可以删掉;
|
||||
- 今天这一轮先结束,下一天再继续;
|
||||
- 下次应该改成新的时间表。
|
||||
|
||||
这个工具就是让模型把这些决策显式写成结构化数据,而不是只留在自然语言里。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.types import CronAction
|
||||
|
||||
|
||||
class CronActionTool(Tool):
|
||||
"""捕获模型输出的机器可读 cron 控制决策。"""
|
||||
|
||||
def __init__(self, job_id: str):
|
||||
# `job_id` 仅用于回显和审计,不参与决策本身。
|
||||
self.job_id = job_id
|
||||
# `_decision` 在本轮 agent 执行期间最多被写一次,外部在结束后读取。
|
||||
self._decision: CronAction | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cron_action"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Record a structured lifecycle action for the currently running cron job."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["none", "remove", "disable", "complete_today", "reschedule"],
|
||||
"description": "Lifecycle action for the current cron job",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Short reason for audit logs",
|
||||
},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Required when action=reschedule and using fixed interval",
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
"description": "Required when action=reschedule and using cron expression",
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone for cron_expr reschedules",
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
"description": "Required when action=reschedule and using one-time ISO datetime",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@property
|
||||
def decision(self) -> CronAction | None:
|
||||
# 暴露最终结构化决策给 cron runtime,便于后处理调度状态。
|
||||
return self._decision
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
reason: str | None = None,
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> str:
|
||||
# 统一做小写规范化,避免模型传入 `Remove` / `REMOVE` 之类大小写变体。
|
||||
normalized = (action or "").strip().lower()
|
||||
allowed_actions = {"none", "remove", "disable", "complete_today", "reschedule"}
|
||||
if normalized not in allowed_actions:
|
||||
return f"Error: unsupported cron action '{action}'"
|
||||
# 非重排任务不允许额外携带调度字段,避免出现“说 remove 但又传 cron_expr”的脏数据。
|
||||
if normalized != "reschedule" and any(value is not None for value in (every_seconds, cron_expr, tz, at)):
|
||||
return "Error: schedule fields can only be used when action='reschedule'"
|
||||
|
||||
if normalized == "reschedule":
|
||||
# 重新排期必须在三种时间表达方式里三选一,不能都不传,也不能混传。
|
||||
options = int(every_seconds is not None) + int(bool(cron_expr)) + int(bool(at))
|
||||
if options != 1:
|
||||
return "Error: reschedule requires exactly one of every_seconds, cron_expr, or at"
|
||||
# 时区只有 cron 表达式才有意义。
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
|
||||
# 校验通过后,把本轮决策固化为 dataclass,交给 runtime 在执行后统一消费。
|
||||
self._decision = CronAction(
|
||||
action=normalized or "none",
|
||||
reason=(reason or "").strip() or None,
|
||||
every_seconds=every_seconds,
|
||||
cron_expr=cron_expr,
|
||||
tz=tz,
|
||||
at=at,
|
||||
)
|
||||
# 返回给模型/日志的是一条可读确认文本,方便工具调用结果出现在上下文里。
|
||||
detail = f" for job {self.job_id}"
|
||||
if self._decision.reason:
|
||||
detail += f" ({self._decision.reason})"
|
||||
return f"Recorded cron_action={self._decision.action}{detail}"
|
||||
@ -1,275 +0,0 @@
|
||||
"""File system tools: read, write, edit."""
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path:
|
||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||
p = Path(path).expanduser()
|
||||
if not p.is_absolute() and workspace:
|
||||
p = workspace / p
|
||||
resolved = p.resolve()
|
||||
if allowed_dir:
|
||||
try:
|
||||
resolved.relative_to(allowed_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_relative_to(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root.resolve())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _protected_write_error() -> str:
|
||||
return (
|
||||
"Error: Direct writes to workspace skills are blocked. "
|
||||
"Stage the skill for review and require explicit user approval before installation."
|
||||
)
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Tool to read file contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file at the given path."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to read"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
return content
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Tool to write content to a file."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
self._protected_paths = [p.expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to write to"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if any(_is_relative_to(file_path, protected) for protected in self._protected_paths):
|
||||
return _protected_write_error()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
|
||||
|
||||
class EditFileTool(Tool):
|
||||
"""Tool to edit a file by replacing text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
self._protected_paths = [p.expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to edit"
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace"
|
||||
},
|
||||
"new_text": {
|
||||
"type": "string",
|
||||
"description": "The text to replace with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if any(_is_relative_to(file_path, protected) for protected in self._protected_paths):
|
||||
return _protected_write_error()
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
if old_text not in content:
|
||||
return self._not_found_message(old_text, content, path)
|
||||
|
||||
# Count occurrences
|
||||
count = content.count(old_text)
|
||||
if count > 1:
|
||||
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
||||
|
||||
new_content = content.replace(old_text, new_text, 1)
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
return f"Successfully edited {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _not_found_message(old_text: str, content: str, path: str) -> str:
|
||||
"""Build a helpful error when old_text is not found."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
class ListDirTool(Tool):
|
||||
"""Tool to list directory contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_dir"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List the contents of a directory."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not dir_path.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
items = []
|
||||
for item in sorted(dir_path.iterdir()):
|
||||
prefix = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{prefix}{item.name}")
|
||||
|
||||
if not items:
|
||||
return f"Directory {path} is empty"
|
||||
|
||||
return "\n".join(items)
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
@ -1,382 +0,0 @@
|
||||
"""MCP 客户端封装。
|
||||
|
||||
职责分两层:
|
||||
1. `connect_mcp_servers()` 负责建立与 MCP server 的连接,并把远端工具注册成 nanobot 本地工具;
|
||||
2. `MCPToolWrapper` 负责把单个远端 MCP tool 包装成可供 LLM 调用的 `Tool`,同时发出结构化过程事件。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.process_events import current_process_run_id, emit_process_event, new_run_id
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
def _iter_leaf_exceptions(exc: BaseException) -> list[BaseException]:
|
||||
if isinstance(exc, BaseExceptionGroup):
|
||||
leaves: list[BaseException] = []
|
||||
for sub_exc in exc.exceptions:
|
||||
leaves.extend(_iter_leaf_exceptions(sub_exc))
|
||||
return leaves
|
||||
return [exc]
|
||||
|
||||
|
||||
def _describe_mcp_exception(exc: BaseException, *, server_name: str, url: str | None = None) -> str:
|
||||
leaves = _iter_leaf_exceptions(exc)
|
||||
target = f" ({url})" if url else ""
|
||||
|
||||
for leaf in leaves:
|
||||
if isinstance(leaf, httpx.TimeoutException):
|
||||
return f"MCP server '{server_name}' timed out while waiting for a response{target}"
|
||||
if isinstance(leaf, httpx.ConnectError):
|
||||
return f"MCP server '{server_name}' is unreachable{target}"
|
||||
if isinstance(leaf, httpx.HTTPStatusError):
|
||||
return f"MCP server '{server_name}' returned HTTP {leaf.response.status_code}{target}"
|
||||
if isinstance(leaf, httpx.HTTPError):
|
||||
detail = str(leaf).strip() or leaf.__class__.__name__
|
||||
return f"MCP server '{server_name}' HTTP error{target}: {detail}"
|
||||
|
||||
detail_source = leaves[0] if leaves else exc
|
||||
detail = str(detail_source).strip() or detail_source.__class__.__name__
|
||||
if isinstance(exc, BaseExceptionGroup):
|
||||
return f"MCP server '{server_name}' failed: {detail_source.__class__.__name__}: {detail}"
|
||||
return detail
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""把单个 MCP server tool 包装成 nanobot Tool。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
server_name: str,
|
||||
tool_def,
|
||||
*,
|
||||
call_tool: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
tool_timeout: int = 30,
|
||||
sensitive: bool = False,
|
||||
):
|
||||
self._session = session
|
||||
self._call_tool = call_tool or self._default_call_tool
|
||||
# 记录来源服务名,便于日志、事件流和最终导出的工具名保持可追踪。
|
||||
self._server_name = server_name
|
||||
self._original_name = tool_def.name
|
||||
# 在 nanobot 内部为 MCP 工具统一加 `mcp_<server>_` 前缀,避免同名冲突。
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._tool_timeout = tool_timeout
|
||||
self._sensitive = sensitive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
# 每次 MCP 调用都分配独立 run_id,前端可以把它显示成树状子步骤。
|
||||
run_id = new_run_id("mcp")
|
||||
args_json = json.dumps(kwargs, ensure_ascii=False) if kwargs else "{}"
|
||||
await emit_process_event(
|
||||
"process_run_started",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name}",
|
||||
status="running",
|
||||
metadata={
|
||||
"tool_name": self._original_name,
|
||||
"tool_args": None if self._sensitive else kwargs,
|
||||
"tool_timeout": self._tool_timeout,
|
||||
"sensitive": self._sensitive,
|
||||
},
|
||||
)
|
||||
# 在真正请求远端前先发一条 progress,方便 UI 及时显示“正在调用哪个工具”。
|
||||
await emit_process_event(
|
||||
"process_run_progress",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
text=(
|
||||
f"Calling {self._original_name}"
|
||||
if self._sensitive
|
||||
else f"Calling {self._original_name} with {args_json}"
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._call_tool(self._original_name, kwargs),
|
||||
timeout=self._tool_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时被视为业务失败,但不抛异常给上层 agent 循环,而是返回可读错误文本。
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
summary = f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
await emit_process_event(
|
||||
"process_run_status",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
text=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
summary=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return summary
|
||||
|
||||
# MCP SDK 返回的是结构化 content block 列表,这里统一摊平成文本。
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
output = "\n".join(parts) or "(no output)"
|
||||
artifact_type = "text"
|
||||
artifact_data: Any | None = None
|
||||
stripped = output.strip()
|
||||
# 如果看起来像 JSON,则额外解析成结构化 artifact,方便前端做更丰富展示。
|
||||
if stripped.startswith("{") or stripped.startswith("["):
|
||||
try:
|
||||
artifact_data = json.loads(stripped)
|
||||
artifact_type = "json"
|
||||
except json.JSONDecodeError:
|
||||
artifact_data = None
|
||||
await emit_process_event(
|
||||
"process_run_artifact",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name} result",
|
||||
artifact_type="redacted" if self._sensitive else artifact_type,
|
||||
content=None if self._sensitive or artifact_data is not None else output,
|
||||
data=None if self._sensitive else artifact_data,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="done",
|
||||
summary=(
|
||||
f"{self._original_name} completed"
|
||||
if self._sensitive
|
||||
else output[:1000]
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return output
|
||||
|
||||
async def _default_call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
return await self._session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict,
|
||||
registry: ToolRegistry,
|
||||
stack: AsyncExitStack,
|
||||
*,
|
||||
authz_config: Any | None = None,
|
||||
backend_identity: Any | None = None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""连接所有配置中的 MCP server,并把工具注册到 registry。"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from nanobot.authz.client import AuthzClient
|
||||
|
||||
async def _build_http_headers(server_name: str, cfg: Any) -> dict[str, str]:
|
||||
headers = dict(getattr(cfg, "headers", {}) or {})
|
||||
if getattr(cfg, "auth_mode", "none") != "oauth_backend_token":
|
||||
return headers
|
||||
|
||||
if not (
|
||||
authz_config
|
||||
and getattr(authz_config, "base_url", "").strip()
|
||||
and backend_identity
|
||||
and getattr(backend_identity, "client_id", "").strip()
|
||||
and getattr(backend_identity, "client_secret", "").strip()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"MCP server '{server_name}' requires AuthZ backend token, but authz/backend identity is incomplete"
|
||||
)
|
||||
|
||||
authz_client = AuthzClient(
|
||||
getattr(authz_config, "base_url"),
|
||||
timeout_seconds=int(getattr(authz_config, "request_timeout_seconds", 10)),
|
||||
)
|
||||
raw_audience = str(getattr(cfg, "auth_audience", "") or "").strip()
|
||||
# Older managed Outlook configs stored `auth_audience="mcp"`, but AuthZ
|
||||
# permissions are issued against `mcp:<server_id>`.
|
||||
if not raw_audience or raw_audience == "mcp":
|
||||
audience = f"mcp:{server_name}"
|
||||
elif raw_audience.startswith("mcp:"):
|
||||
audience = raw_audience
|
||||
else:
|
||||
audience = f"mcp:{raw_audience}"
|
||||
token_response = await authz_client.issue_token(
|
||||
client_id=getattr(backend_identity, "client_id"),
|
||||
client_secret=getattr(backend_identity, "client_secret"),
|
||||
audience=audience,
|
||||
scopes=[str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
|
||||
)
|
||||
access_token = str(token_response.get("access_token") or "").strip()
|
||||
if not access_token:
|
||||
raise RuntimeError(f"MCP server '{server_name}' did not receive an access token from AuthZ")
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
async def _open_http_session(
|
||||
session_stack: AsyncExitStack,
|
||||
cfg: Any,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
):
|
||||
http_client = await session_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=headers or None,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
)
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
return session
|
||||
|
||||
async def _list_http_tools(server_name: str, cfg: Any):
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
return tools.tools
|
||||
|
||||
def _make_http_call_tool(server_name: str, cfg: Any) -> Callable[[str, dict[str, Any]], Awaitable[Any]]:
|
||||
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
return await session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
return _call_tool
|
||||
|
||||
# `report` 会返回给调用方,用于 Web UI 展示连接状态和已发现工具。
|
||||
report: dict[str, dict[str, Any]] = {}
|
||||
for name, cfg in mcp_servers.items():
|
||||
report[name] = {
|
||||
"status": "disconnected",
|
||||
"last_error": None,
|
||||
"tool_names": [],
|
||||
"tool_count": 0,
|
||||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||||
}
|
||||
try:
|
||||
if cfg.command:
|
||||
# stdio 模式:本地拉起一个子进程,通过 stdin/stdout 与 MCP server 通信。
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
elif cfg.url:
|
||||
if getattr(cfg, "auth_mode", "none") == "oauth_backend_token":
|
||||
tools_defs = await _list_http_tools(name, cfg)
|
||||
call_tool = _make_http_call_tool(name, cfg)
|
||||
for tool_def in tools_defs:
|
||||
wrapper = MCPToolWrapper(
|
||||
None,
|
||||
name,
|
||||
tool_def,
|
||||
call_tool=call_tool,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
headers = await _build_http_headers(name, cfg)
|
||||
session = await _open_http_session(stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
# 没有 command 也没有 url 的条目视为无效配置,跳过但不抛异常。
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
|
||||
report[name]["tool_count"] = len(report[name]["tool_names"])
|
||||
report[name]["status"] = "connected"
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} tools registered",
|
||||
name,
|
||||
len(report[name]["tool_names"]),
|
||||
)
|
||||
except Exception as e:
|
||||
# 单个 server 失败不影响其他 server 继续连;错误写进 report 供 UI 展示。
|
||||
error_detail = _describe_mcp_exception(
|
||||
e,
|
||||
server_name=name,
|
||||
url=str(getattr(cfg, "url", "") or "").strip() or None,
|
||||
)
|
||||
report[name]["status"] = "error"
|
||||
report[name]["last_error"] = error_detail
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, error_detail)
|
||||
return report
|
||||
@ -1,108 +0,0 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
class MessageTool(Tool):
|
||||
"""Tool to send messages to users on chat channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||
default_channel: str = "",
|
||||
default_chat_id: str = "",
|
||||
default_message_id: str | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._default_channel = default_channel
|
||||
self._default_chat_id = default_chat_id
|
||||
self._default_message_id = default_message_id
|
||||
self._sent_in_turn: bool = False
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel = channel
|
||||
self._default_chat_id = chat_id
|
||||
self._default_message_id = message_id
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
self._send_callback = callback
|
||||
|
||||
def start_turn(self) -> None:
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Send a message to the user. Use this when you want to communicate something."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, discord, etc.)"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID"
|
||||
},
|
||||
"media": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional: list of file paths to attach (images, audio, documents)"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
content: str,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
media: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
self._sent_in_turn = True
|
||||
media_info = f" with {len(media)} attachments" if media else ""
|
||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||
except Exception as e:
|
||||
return f"Error sending message: {str(e)}"
|
||||
@ -1,96 +0,0 @@
|
||||
"""工具注册中心。
|
||||
|
||||
职责很单一:
|
||||
1. 保存当前可用工具实例;
|
||||
2. 向 LLM 暴露 function schema;
|
||||
3. 在执行前做基础参数校验,并把异常统一转成文本结果。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
Registry for agent tools.
|
||||
|
||||
Allows dynamic registration and execution of tools.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 工具名到实例的映射表;工具名在整个 registry 内必须唯一。
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""注册一个工具实例。"""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def clone(self) -> "ToolRegistry":
|
||||
"""创建一个浅拷贝,复用同一批工具实例。"""
|
||||
# 这里不深拷贝工具对象,因为很多工具本身持有运行时状态或外部连接。
|
||||
# 当前需求只是“在一个请求里临时附加额外工具”,复用实例即可。
|
||||
other = ToolRegistry()
|
||||
other._tools = dict(self._tools)
|
||||
return other
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool by name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""
|
||||
Execute a tool by name with given parameters.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
params: Tool parameters.
|
||||
|
||||
Returns:
|
||||
Tool execution result as string.
|
||||
|
||||
Raises:
|
||||
KeyError: If tool not found.
|
||||
"""
|
||||
_hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# schema 级参数校验放在真正调用前做,尽量把错误反馈成模型能自修复的文本。
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _hint
|
||||
result = await tool.execute(**params)
|
||||
# 约定:工具若返回以 Error 开头的文本,说明是业务失败而非程序崩溃。
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _hint
|
||||
return result
|
||||
except Exception as e:
|
||||
# 保持“不抛异常到模型层”的接口语义,统一回成可读文本。
|
||||
return f"Error executing {name}: {str(e)}" + _hint
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
@ -1,284 +0,0 @@
|
||||
"""Shell execution tool."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ExecTool(Tool):
|
||||
"""Tool to execute shell commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
working_dir: str | None = None,
|
||||
deny_patterns: list[str] | None = None,
|
||||
allow_patterns: list[str] | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
self.deny_patterns = deny_patterns or [
|
||||
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||
r"\brmdir\s+/s\b", # rmdir /s
|
||||
r"(?:^|[;&|]\s*)format\b", # format (as standalone command only)
|
||||
r"\b(mkfs|diskpart)\b", # disk operations
|
||||
r"\bdd\s+if=", # dd
|
||||
r">\s*/dev/sd", # write to disk
|
||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||
]
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.protected_paths = [Path(p).expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "exec"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
# Wait for the process to fully terminate so pipes are
|
||||
# drained and file descriptors are released.
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
return f"Error: Command timed out after {self.timeout} seconds"
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
# Truncate very long output
|
||||
max_len = 10000
|
||||
if len(result) > max_len:
|
||||
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
cmd = command.strip()
|
||||
lower = cmd.lower()
|
||||
|
||||
for pattern in self.deny_patterns:
|
||||
if re.search(pattern, lower):
|
||||
return "Error: Command blocked by safety guard (dangerous pattern detected)"
|
||||
|
||||
if self.allow_patterns:
|
||||
if not any(re.search(p, lower) for p in self.allow_patterns):
|
||||
return "Error: Command blocked by safety guard (not in allowlist)"
|
||||
|
||||
if self.restrict_to_workspace:
|
||||
if "..\\" in cmd or "../" in cmd:
|
||||
return "Error: Command blocked by safety guard (path traversal detected)"
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
||||
# Only match absolute paths — avoid false positives on relative
|
||||
# paths like ".venv/bin/python" where "/bin/python" would be
|
||||
# incorrectly extracted by the old pattern.
|
||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
|
||||
|
||||
for raw in win_paths + posix_paths:
|
||||
try:
|
||||
p = Path(raw.strip()).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
protected_error = self._guard_protected_paths(command, cwd)
|
||||
if protected_error:
|
||||
return protected_error
|
||||
|
||||
return None
|
||||
|
||||
def _guard_protected_paths(self, command: str, cwd: str) -> str | None:
|
||||
if not self.protected_paths:
|
||||
return None
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
if self._is_blocked_clawhub_install(command, cwd_path):
|
||||
return self._protected_write_error()
|
||||
|
||||
if not self._looks_like_write(command):
|
||||
return None
|
||||
|
||||
for raw in self._extract_path_tokens(command):
|
||||
resolved = self._resolve_command_path(raw, cwd_path)
|
||||
if resolved and any(self._is_relative_to(resolved, root) for root in self.protected_paths):
|
||||
return self._protected_write_error()
|
||||
|
||||
return None
|
||||
|
||||
def _is_blocked_clawhub_install(self, command: str, cwd_path: Path) -> bool:
|
||||
lower = command.lower()
|
||||
if "clawhub" not in lower or not re.search(r"\b(install|update)\b", lower):
|
||||
return False
|
||||
|
||||
workdir = self._extract_flag_value(command, "--workdir")
|
||||
if workdir:
|
||||
resolved = self._resolve_command_path(workdir, cwd_path)
|
||||
return any(
|
||||
resolved == root.parent or self._is_relative_to(root, resolved)
|
||||
for root in self.protected_paths
|
||||
)
|
||||
|
||||
return any(cwd_path == root.parent for root in self.protected_paths)
|
||||
|
||||
@staticmethod
|
||||
def _protected_write_error() -> str:
|
||||
return (
|
||||
"Error: Direct writes to workspace skills are blocked. "
|
||||
"Stage the skill for review and require explicit user approval before installation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_relative_to(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_flag_value(command: str, flag: str) -> str | None:
|
||||
tokens = ExecTool._tokenize(command)
|
||||
for i, token in enumerate(tokens):
|
||||
if token == flag and i + 1 < len(tokens):
|
||||
return tokens[i + 1]
|
||||
if token.startswith(flag + "="):
|
||||
return token.split("=", 1)[1]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_write(command: str) -> bool:
|
||||
lower = command.lower()
|
||||
if re.search(r"(^|[^<])>>?\s*\S+", command):
|
||||
return True
|
||||
if re.search(r"\bsed\s+-i(?:\s|$)", lower):
|
||||
return True
|
||||
return bool(re.search(
|
||||
r"\b(cp|mv|rm|mkdir|touch|install|tee|tar|unzip|zip|chmod|chown|git|python|python3|node|npx|bash|sh|zsh|pwsh|powershell)\b",
|
||||
lower,
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def _extract_path_tokens(command: str) -> list[str]:
|
||||
tokens = ExecTool._tokenize(command)
|
||||
path_tokens: list[str] = []
|
||||
skip_next = False
|
||||
for i, token in enumerate(tokens):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if token in {"--workdir", "-C"}:
|
||||
if i + 1 < len(tokens):
|
||||
path_tokens.append(tokens[i + 1])
|
||||
skip_next = True
|
||||
continue
|
||||
if "=" in token:
|
||||
key, value = token.split("=", 1)
|
||||
if key in {"--workdir"}:
|
||||
path_tokens.append(value)
|
||||
continue
|
||||
cleaned = token.strip("\"'")
|
||||
if ExecTool._looks_like_path_token(cleaned):
|
||||
path_tokens.append(cleaned)
|
||||
return path_tokens
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_path_token(token: str) -> bool:
|
||||
if not token or token in {".", ".."}:
|
||||
return True
|
||||
if token.startswith(("~", "/", "./", "../")):
|
||||
return True
|
||||
if re.match(r"^[A-Za-z]:\\", token):
|
||||
return True
|
||||
return "/" in token or "\\" in token
|
||||
|
||||
@staticmethod
|
||||
def _resolve_command_path(raw: str, cwd_path: Path) -> Path | None:
|
||||
token = raw.strip().strip("\"'")
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
path = Path(token).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (cwd_path / path).resolve()
|
||||
else:
|
||||
path = path.resolve()
|
||||
return path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(command: str) -> list[str]:
|
||||
try:
|
||||
return shlex.split(command, posix=os.name != "nt")
|
||||
except ValueError:
|
||||
return command.split()
|
||||
@ -1,204 +0,0 @@
|
||||
"""委派工具:分别暴露 subagent 与 agent team 两种调用接口。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.delegation import DelegationManager
|
||||
|
||||
|
||||
class DelegationTool(Tool):
|
||||
"""委派类工具的公共上下文注入逻辑。"""
|
||||
|
||||
def __init__(self, manager: "DelegationManager"):
|
||||
self._manager = manager
|
||||
self._origin_channel = "cli"
|
||||
self._origin_chat_id = "direct"
|
||||
self._announce_via_bus = True
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, announce_via_bus: bool = True) -> None:
|
||||
"""设置后台委派结果回传的目标会话。"""
|
||||
self._origin_channel = channel
|
||||
self._origin_chat_id = chat_id
|
||||
self._announce_via_bus = announce_via_bus
|
||||
|
||||
|
||||
class SpawnSubagentTool(DelegationTool):
|
||||
"""把任务委派给单个 subagent。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn_subagent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delegate a focused task to one background subagent. "
|
||||
"Use this for complex or time-consuming work that can run independently. "
|
||||
"You only provide the task and optional required skills; downstream routing decides the concrete agent. "
|
||||
"The subagent will report back when done."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task for the delegated subagent to complete",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of skill names the delegated worker must follow",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
skills: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""创建并启动一个 subagent 后台任务。"""
|
||||
return await self._manager.dispatch_subagent(
|
||||
task=task,
|
||||
label=label,
|
||||
skills=skills,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
announce_via_bus=self._announce_via_bus,
|
||||
)
|
||||
|
||||
|
||||
class SpawnAgentTeamTool(DelegationTool):
|
||||
"""启动一个 agent team 任务。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn_agent_team"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Start an agent team for parallel exploration. "
|
||||
"Use this when multiple agents should investigate the task in parallel and return a combined result. "
|
||||
"You only provide the task and optional required skills; downstream routing selects the concrete members."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The shared task for the agent team",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the team task (for display)",
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of skill names the team must follow",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
skills: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""创建并启动一个 agent team 后台任务。"""
|
||||
return await self._manager.dispatch_agent_team(
|
||||
task=task,
|
||||
label=label,
|
||||
skills=skills,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
announce_via_bus=self._announce_via_bus,
|
||||
)
|
||||
|
||||
|
||||
class NestedDelegateTool(Tool):
|
||||
"""供 delegated worker 使用的受控下游委派工具。"""
|
||||
|
||||
def __init__(self, manager: "DelegationManager", default_skills: list[str] | None = None):
|
||||
self._manager = manager
|
||||
self._default_skills = [str(item).strip() for item in (default_skills or []) if str(item).strip()]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delegate_task"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Synchronously delegate a downstream task from a delegated worker. "
|
||||
"Use this only when specialized help is needed. "
|
||||
"It can route to an A2A agent or an ephemeral local subagent, but never creates a persistent subagent."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The downstream task to delegate",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the downstream task",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Optional agent ID or name for the downstream worker",
|
||||
},
|
||||
"strategy": {
|
||||
"type": "string",
|
||||
"enum": ["auto", "a2a", "ephemeral_subagent"],
|
||||
"description": "Routing strategy for downstream delegation. Default is auto.",
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional required skills for the downstream delegate. Defaults to the current worker's required skills.",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
target: str | None = None,
|
||||
strategy: str = "auto",
|
||||
skills: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""同步执行一次受控下游委派,并把结果返回给当前 worker。"""
|
||||
return await self._manager.delegate_for_subagent(
|
||||
task=task,
|
||||
label=label,
|
||||
target=target,
|
||||
strategy=strategy,
|
||||
skills=skills if skills is not None else list(self._default_skills),
|
||||
)
|
||||
@ -1,163 +0,0 @@
|
||||
"""Web tools: web_search and web_fetch."""
|
||||
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Shared constants
|
||||
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_7_2) AppleWebKit/537.36"
|
||||
MAX_REDIRECTS = 5 # Limit redirects to prevent DoS attacks
|
||||
|
||||
|
||||
def _strip_tags(text: str) -> str:
|
||||
"""Remove HTML tags and decode entities."""
|
||||
text = re.sub(r'<script[\s\S]*?</script>', '', text, flags=re.I)
|
||||
text = re.sub(r'<style[\s\S]*?</style>', '', text, flags=re.I)
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
return html.unescape(text).strip()
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize whitespace."""
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
||||
|
||||
|
||||
def _validate_url(url: str) -> tuple[bool, str]:
|
||||
"""Validate URL: must be http(s) with valid domain."""
|
||||
try:
|
||||
p = urlparse(url)
|
||||
if p.scheme not in ('http', 'https'):
|
||||
return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
|
||||
if not p.netloc:
|
||||
return False, "Missing domain"
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web using Brave Search API."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Results (1-10)", "minimum": 1, "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
self.max_results = max_results
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
if not self.api_key:
|
||||
return "Error: BRAVE_API_KEY not configured"
|
||||
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": n},
|
||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||
timeout=10.0
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
results = r.json().get("web", {}).get("results", [])
|
||||
if not results:
|
||||
return f"No results for: {query}"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results[:n], 1):
|
||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||
if desc := item.get("description"):
|
||||
lines.append(f" {desc}")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL using Readability."""
|
||||
|
||||
name = "web_fetch"
|
||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to fetch"},
|
||||
"extractMode": {"type": "string", "enum": ["markdown", "text"], "default": "markdown"},
|
||||
"maxChars": {"type": "integer", "minimum": 100}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, max_chars: int = 50000):
|
||||
self.max_chars = max_chars
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||
from readability import Document
|
||||
|
||||
max_chars = maxChars or self.max_chars
|
||||
|
||||
# Validate URL before fetching
|
||||
is_valid, error_msg = _validate_url(url)
|
||||
if not is_valid:
|
||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
max_redirects=MAX_REDIRECTS,
|
||||
timeout=30.0
|
||||
) as client:
|
||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||
r.raise_for_status()
|
||||
|
||||
ctype = r.headers.get("content-type", "")
|
||||
|
||||
# JSON
|
||||
if "application/json" in ctype:
|
||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||
# HTML
|
||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||
doc = Document(r.text)
|
||||
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
||||
text = f"# {doc.title()}\n\n{content}" if doc.title() else content
|
||||
extractor = "readability"
|
||||
else:
|
||||
text, extractor = r.text, "raw"
|
||||
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||
|
||||
def _to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to markdown."""
|
||||
# Convert links, headings, lists before stripping tags
|
||||
text = re.sub(r'<a\s+[^>]*href=["\']([^"\']+)["\'][^>]*>([\s\S]*?)</a>',
|
||||
lambda m: f'[{_strip_tags(m[2])}]({m[1]})', html, flags=re.I)
|
||||
text = re.sub(r'<h([1-6])[^>]*>([\s\S]*?)</h\1>',
|
||||
lambda m: f'\n{"#" * int(m[1])} {_strip_tags(m[2])}\n', text, flags=re.I)
|
||||
text = re.sub(r'<li[^>]*>([\s\S]*?)</li>', lambda m: f'\n- {_strip_tags(m[1])}', text, flags=re.I)
|
||||
text = re.sub(r'</(p|div|section|article)>', '\n\n', text, flags=re.I)
|
||||
text = re.sub(r'<(br|hr)\s*/?>', '\n', text, flags=re.I)
|
||||
return _normalize(_strip_tags(text))
|
||||
@ -1,63 +0,0 @@
|
||||
"""Agent Team swarms adapter package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"AgentTeamOrchestrator",
|
||||
"BridgeAttempt",
|
||||
"BridgeResult",
|
||||
"ExecutionMode",
|
||||
"NanobotAgentAdapter",
|
||||
"ProcedureMemory",
|
||||
"ProcedureRecord",
|
||||
"ResolvedTeamPlan",
|
||||
"RunMemory",
|
||||
"RunRecord",
|
||||
"SwarmsBridge",
|
||||
"SwarmsPolicy",
|
||||
"SwarmsRunPlanner",
|
||||
"SwarmsRunResult",
|
||||
"SwarmsRunSpec",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "AgentTeamOrchestrator":
|
||||
from nanobot.agent_team.orchestrator import AgentTeamOrchestrator
|
||||
|
||||
return AgentTeamOrchestrator
|
||||
if name == "NanobotAgentAdapter":
|
||||
from nanobot.agent_team.swarms_adapter import NanobotAgentAdapter
|
||||
|
||||
return NanobotAgentAdapter
|
||||
if name == "SwarmsBridge":
|
||||
from nanobot.agent_team.swarms_bridge import SwarmsBridge
|
||||
|
||||
return SwarmsBridge
|
||||
if name == "SwarmsPolicy":
|
||||
from nanobot.agent_team.swarms_policy import SwarmsPolicy
|
||||
|
||||
return SwarmsPolicy
|
||||
if name == "SwarmsRunPlanner":
|
||||
from nanobot.agent_team.swarms_planner import SwarmsRunPlanner
|
||||
|
||||
return SwarmsRunPlanner
|
||||
if name in {"ProcedureMemory", "RunMemory"}:
|
||||
memory = import_module("nanobot.agent_team.memory")
|
||||
return getattr(memory, name)
|
||||
if name in {
|
||||
"BridgeAttempt",
|
||||
"BridgeResult",
|
||||
"ExecutionMode",
|
||||
"ProcedureRecord",
|
||||
"ResolvedTeamPlan",
|
||||
"RunRecord",
|
||||
"SwarmsRunResult",
|
||||
"SwarmsRunSpec",
|
||||
}:
|
||||
types = import_module("nanobot.agent_team.types")
|
||||
return getattr(types, name)
|
||||
raise AttributeError(name)
|
||||
@ -1,361 +0,0 @@
|
||||
"""Agent Team 的轻量持久化层。
|
||||
|
||||
这里没有引入数据库,
|
||||
而是参考轻量 file store 设计:
|
||||
1. 数据结构尽量稳定;
|
||||
2. 使用原子写覆盖,避免半写状态;
|
||||
3. 单文件规模保持小而可读,便于排查与测试。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.run_result import contains_placeholder_summary, has_meaningful_summary
|
||||
from nanobot.agent_team.types import (
|
||||
BridgeResult,
|
||||
ExecutionMode,
|
||||
ProcedureRecord,
|
||||
RunRecord,
|
||||
now_iso,
|
||||
)
|
||||
|
||||
# ASCII token 用于英文/agent id/命令片段匹配。
|
||||
_ASCII_TOKEN_RE = re.compile(r"[a-z0-9_:-]+")
|
||||
# 中文任务没有自然空格,这里退而求其次按单字切分,保证最小可匹配能力。
|
||||
_CJK_CHAR_RE = re.compile(r"[\u4e00-\u9fff]")
|
||||
|
||||
|
||||
def _memory_root(workspace: Path) -> Path:
|
||||
"""返回 agent team memory 根目录。
|
||||
|
||||
Demo 输出:
|
||||
`/workspace/agent_team`
|
||||
"""
|
||||
# 独立目录便于用户直接查看 procedure/runs 文件,不和其他 runtime 状态混在一起。
|
||||
root = workspace / "agent_team"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _load_json(path: Path, default: Any) -> Any:
|
||||
"""从磁盘加载 JSON;损坏或不存在时回退到默认值。
|
||||
|
||||
Demo 输出:
|
||||
`[]`
|
||||
"""
|
||||
# agent team memory 不应因为单个文件损坏就拖垮主链路,所以统一做软失败。
|
||||
if not path.exists():
|
||||
return default
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, ValueError, json.JSONDecodeError):
|
||||
return default
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, payload: Any) -> None:
|
||||
"""把 JSON 原子写入目标路径。
|
||||
|
||||
Demo 输出:
|
||||
`None`
|
||||
"""
|
||||
# 先写临时文件再 `os.replace`,这样即使进程中断也不会留下半截 JSON。
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp_path.write_text(
|
||||
json.dumps(payload, indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(str(tmp_path), str(path))
|
||||
|
||||
|
||||
def task_tokens(text: str) -> list[str]:
|
||||
"""把任务文本压成可匹配的轻量 token 列表。
|
||||
|
||||
Demo 输出:
|
||||
`["生成", "周报", "writer-agent", "publish"]`
|
||||
"""
|
||||
# 统一小写,保证 agent id、英文命令和 task keywords 比较时大小写无关。
|
||||
lowered = (text or "").strip().lower()
|
||||
if not lowered:
|
||||
return []
|
||||
|
||||
# 英文 token 适合匹配 agent id、命令词和常见英文任务描述。
|
||||
ascii_tokens = [token for token in _ASCII_TOKEN_RE.findall(lowered) if len(token) > 1]
|
||||
# 中文这里按单字匹配,虽然粗糙,但比整句更利于无分词依赖的第一版实现。
|
||||
cjk_tokens = _CJK_CHAR_RE.findall(lowered)
|
||||
|
||||
# 用 `dict.fromkeys` 去重并保持原始顺序,便于后续测试断言更稳定。
|
||||
return list(dict.fromkeys([*ascii_tokens, *cjk_tokens]))
|
||||
|
||||
|
||||
def similarity_score(query_tokens: list[str], candidate_tokens: list[str]) -> float:
|
||||
"""按 token 重叠度计算相似度。
|
||||
|
||||
Demo 输出:
|
||||
`0.67`
|
||||
"""
|
||||
# 任一侧为空都说明没有稳定的匹配依据,直接给 0。
|
||||
if not query_tokens or not candidate_tokens:
|
||||
return 0.0
|
||||
|
||||
# 这里故意不做复杂权重,保持算法透明、可预测、可测试。
|
||||
query_set = set(query_tokens)
|
||||
candidate_set = set(candidate_tokens)
|
||||
overlap = len(query_set & candidate_set)
|
||||
if overlap <= 0:
|
||||
return 0.0
|
||||
|
||||
# 使用 `max(len(query), len(candidate))` 作为分母,让长任务模板不会被短查询轻易误命中。
|
||||
return overlap / max(len(query_set), len(candidate_set))
|
||||
|
||||
|
||||
def clip_confidence(value: float) -> float:
|
||||
"""把置信度裁剪到 `[0.0, 1.0]`。
|
||||
|
||||
Demo 输出:
|
||||
`0.8`
|
||||
"""
|
||||
# 所有 confidence 更新都收口到这里,避免散落的边界处理不一致。
|
||||
return max(0.0, min(1.0, round(value, 4)))
|
||||
|
||||
|
||||
class ProcedureMemory:
|
||||
"""管理 learned procedure 的持久化和匹配。
|
||||
|
||||
公开方法都带了 Demo 输出说明,便于用户直接对照磁盘结果和测试脚本理解行为。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
*,
|
||||
min_confidence: float = 0.55,
|
||||
match_threshold: float = 0.2,
|
||||
) -> None:
|
||||
"""初始化 procedure memory。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureMemory(workspace=/tmp/demo-workspace, procedures.json ready)`
|
||||
"""
|
||||
# `procedures.json` 用数组存储,人工排查时最直观。
|
||||
self.workspace = workspace
|
||||
self.path = _memory_root(workspace) / "procedures.json"
|
||||
# 低于该值的 procedure 即使匹配到关键词,也不建议作为复用提示。
|
||||
self.min_confidence = min_confidence
|
||||
# 匹配阈值保持较低,只作为 AutoSwarmBuilder / planner 的参考提示。
|
||||
self.match_threshold = match_threshold
|
||||
|
||||
def list_procedures(self) -> list[ProcedureRecord]:
|
||||
"""读取全部 procedure 记录并按置信度排序。
|
||||
|
||||
Demo 输出:
|
||||
`[ProcedureRecord(...), ProcedureRecord(...)]`
|
||||
"""
|
||||
# 文件损坏或不存在时直接回空列表,主流程会自动退回探索模式。
|
||||
raw = _load_json(self.path, [])
|
||||
records = [
|
||||
ProcedureRecord.from_dict(item)
|
||||
for item in raw
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
# 高置信度、最近更新的记录更靠前,方便测试和人工查看。
|
||||
records.sort(key=lambda item: (item.confidence, item.updated_at), reverse=True)
|
||||
return records
|
||||
|
||||
def match_procedure(self, task: str) -> ProcedureRecord | None:
|
||||
"""为当前任务匹配最合适的 procedure。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureRecord(id='procedure-a1b2c3d4', task_template='生成周报', ...)`
|
||||
"""
|
||||
# 没有 token 说明任务文本几乎为空,此时不应命中任何 procedure。
|
||||
query_tokens = task_tokens(task)
|
||||
if not query_tokens:
|
||||
return None
|
||||
|
||||
best_record: ProcedureRecord | None = None
|
||||
best_score = 0.0
|
||||
for record in self.list_procedures():
|
||||
# 明显是占位/空结果的历史 procedure 直接忽略,避免污染后续路由。
|
||||
if contains_placeholder_summary(record.summary):
|
||||
continue
|
||||
# 优先用关键词匹配;任务模板是人工兜底线索。
|
||||
candidate_tokens = record.task_keywords or task_tokens(record.task_template)
|
||||
score = similarity_score(query_tokens, candidate_tokens)
|
||||
# task_template 全量包含时,给一个小额加分,提高近似重跑命中率。
|
||||
if record.task_template and record.task_template.lower() in task.lower():
|
||||
score += 0.1
|
||||
# 最终排序同时考虑相似度、置信度和失败率,避免高失败 procedure 反复被选中。
|
||||
weighted = score + record.confidence * 0.2 - record.failure_rate() * 0.2
|
||||
if weighted > best_score:
|
||||
best_record = record
|
||||
best_score = weighted
|
||||
|
||||
# 分数不足则视为没有可靠命中,让上层走探索式执行。
|
||||
if best_record is None or best_score < self.match_threshold:
|
||||
return None
|
||||
return best_record
|
||||
|
||||
async def record_candidate(self, task: str, result: BridgeResult) -> ProcedureRecord | None:
|
||||
"""把探索阶段产出的候选 procedure 写入 memory。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureRecord(id='procedure-a1b2c3d4', confidence=0.6, success_count=1, ...)`
|
||||
"""
|
||||
# 只有 bridge 显式产出候选 procedure 时才会落盘。
|
||||
candidate = result.candidate_procedure
|
||||
if candidate is None:
|
||||
return None
|
||||
if not has_meaningful_summary(candidate.summary):
|
||||
return None
|
||||
|
||||
# 记录写入时间统一在这里刷新,保证磁盘上的排序行为可预测。
|
||||
timestamp = now_iso()
|
||||
# 任务 token 统一在持久化层补齐,保证不依赖具体 bridge 的实现细节。
|
||||
merged_keywords = list(dict.fromkeys([*candidate.task_keywords, *task_tokens(task)]))
|
||||
candidate.task_keywords = merged_keywords
|
||||
candidate.task_template = candidate.task_template or task
|
||||
candidate.summary = candidate.summary or result.summary
|
||||
candidate.confidence = clip_confidence(candidate.confidence or 0.55)
|
||||
candidate.created_at = candidate.created_at or timestamp
|
||||
candidate.updated_at = timestamp
|
||||
|
||||
records = self.list_procedures()
|
||||
best_index: int | None = None
|
||||
best_score = 0.0
|
||||
for index, record in enumerate(records):
|
||||
# 完全相同 agent 组合视为强相关;否则退回关键词重叠比对。
|
||||
same_agents = (
|
||||
record.strategy == candidate.strategy
|
||||
and record.agent_ids == candidate.agent_ids
|
||||
)
|
||||
score = 1.0 if same_agents else similarity_score(candidate.task_keywords, record.task_keywords)
|
||||
if score > best_score:
|
||||
best_index = index
|
||||
best_score = score
|
||||
|
||||
if best_index is not None and best_score >= 0.5:
|
||||
# 合并已有记录,避免每次探索都生成一条几乎重复的 procedure。
|
||||
current = records[best_index]
|
||||
current.task_template = candidate.task_template or current.task_template
|
||||
current.summary = candidate.summary or current.summary
|
||||
current.agent_ids = list(candidate.agent_ids) or current.agent_ids
|
||||
current.strategy = candidate.strategy or current.strategy
|
||||
current.task_keywords = list(dict.fromkeys([*current.task_keywords, *candidate.task_keywords]))
|
||||
current.confidence = clip_confidence(max(current.confidence, candidate.confidence))
|
||||
current.success_count += 1
|
||||
current.updated_at = timestamp
|
||||
current.metadata.update(candidate.metadata)
|
||||
current.source_run_id = candidate.source_run_id or current.source_run_id
|
||||
stored = current
|
||||
else:
|
||||
# 新候选第一次入库时直接记为一次成功学习。
|
||||
candidate.success_count = max(candidate.success_count, 1)
|
||||
candidate.failure_count = max(candidate.failure_count, 0)
|
||||
candidate.created_at = candidate.created_at or timestamp
|
||||
candidate.updated_at = timestamp
|
||||
records.append(candidate)
|
||||
stored = candidate
|
||||
|
||||
_atomic_write_json(self.path, [item.to_dict() for item in records])
|
||||
return stored
|
||||
|
||||
async def update_confidence(self, procedure_id: str, delta: float) -> ProcedureRecord | None:
|
||||
"""更新某条 procedure 的置信度与成败计数。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureRecord(id='procedure-a1b2c3d4', confidence=0.75, success_count=2, failure_count=0, ...)`
|
||||
"""
|
||||
# 没有主键时直接回空,避免误更新所有记录。
|
||||
if not procedure_id:
|
||||
return None
|
||||
|
||||
records = self.list_procedures()
|
||||
updated: ProcedureRecord | None = None
|
||||
for record in records:
|
||||
if record.id != procedure_id:
|
||||
continue
|
||||
# 所有状态变更都集中在这里,保证计数和 confidence 始终同步。
|
||||
record.confidence = clip_confidence(record.confidence + delta)
|
||||
# 统一刷新“最近一次使用”和“最近一次更新时间”,这两个字段都服务于路由与排障。
|
||||
timestamp = now_iso()
|
||||
record.updated_at = timestamp
|
||||
record.last_used_at = timestamp
|
||||
if delta >= 0:
|
||||
record.success_count += 1
|
||||
else:
|
||||
record.failure_count += 1
|
||||
updated = record
|
||||
break
|
||||
|
||||
if updated is None:
|
||||
return None
|
||||
|
||||
_atomic_write_json(self.path, [item.to_dict() for item in records])
|
||||
return updated
|
||||
|
||||
|
||||
class RunMemory:
|
||||
"""管理 run 级别的历史记录。"""
|
||||
|
||||
def __init__(self, workspace: Path, *, max_records: int = 200) -> None:
|
||||
"""初始化 run memory。
|
||||
|
||||
Demo 输出:
|
||||
`RunMemory(workspace=/tmp/demo-workspace, runs.json ready)`
|
||||
"""
|
||||
# `runs.json` 保持轻量滚动窗口,避免长期运行后无限膨胀。
|
||||
self.workspace = workspace
|
||||
self.path = _memory_root(workspace) / "runs.json"
|
||||
self.max_records = max(1, max_records)
|
||||
|
||||
def list_runs(self) -> list[RunRecord]:
|
||||
"""读取全部 run 记录。
|
||||
|
||||
Demo 输出:
|
||||
`[RunRecord(...), RunRecord(...)]`
|
||||
"""
|
||||
raw = _load_json(self.path, [])
|
||||
return [
|
||||
RunRecord.from_dict(item)
|
||||
for item in raw
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
|
||||
async def record_run(
|
||||
self,
|
||||
task: str,
|
||||
mode: ExecutionMode,
|
||||
result: BridgeResult,
|
||||
procedure_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""把一次 agent team 运行结果落盘。
|
||||
|
||||
Demo 输出:
|
||||
`RunRecord(id='run-1a2b3c4d', mode=<ExecutionMode.SWARMS: 'swarms'>, success=True, ...)`
|
||||
"""
|
||||
# 把 attempt/原始 bridge 结果也带进 metadata,后面排查 swarms 执行很有用。
|
||||
record = RunRecord(
|
||||
task=task,
|
||||
mode=mode,
|
||||
success=result.success,
|
||||
summary=result.summary,
|
||||
error=result.error,
|
||||
procedure_id=procedure_id or (result.matched_procedure.id if result.matched_procedure else None),
|
||||
metadata={
|
||||
"attempts": [attempt.to_dict() for attempt in result.attempts],
|
||||
"bridge_result": result.to_dict(),
|
||||
},
|
||||
)
|
||||
runs = self.list_runs()
|
||||
runs.append(record)
|
||||
# 只保留最近 N 条,保证 JSON 文件体积可控。
|
||||
if len(runs) > self.max_records:
|
||||
runs = runs[-self.max_records:]
|
||||
_atomic_write_json(self.path, [item.to_dict() for item in runs])
|
||||
return record
|
||||
@ -1,241 +0,0 @@
|
||||
"""Thin swarms orchestrator for `spawn_agent_team`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.agent_registry import AgentRegistry
|
||||
from nanobot.agent.process_events import emit_process_event
|
||||
from nanobot.agent_team.memory import ProcedureMemory, RunMemory
|
||||
from nanobot.agent_team.swarms_adapter import MemberRunner
|
||||
from nanobot.agent_team.swarms_bridge import SwarmsBridge
|
||||
from nanobot.agent_team.swarms_planner import SwarmsRunPlanner
|
||||
from nanobot.agent_team.swarms_policy import SwarmsPolicy
|
||||
from nanobot.agent_team.target_resolver import TargetResolver
|
||||
from nanobot.agent_team.types import BridgeResult, ExecutionMode
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class AgentTeamOrchestrator:
|
||||
"""Plan a swarms run, execute it, and persist the normalized result."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
provider: LLMProvider,
|
||||
model: str | None,
|
||||
registry: AgentRegistry,
|
||||
bus: Any,
|
||||
local_executor: Any,
|
||||
member_runner: MemberRunner,
|
||||
max_parallel_agents: int = 4,
|
||||
gateway_port: int = 18790,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.registry = registry
|
||||
self.bus = bus
|
||||
self.local_executor = local_executor
|
||||
self.procedure_memory = ProcedureMemory(workspace)
|
||||
self.run_memory = RunMemory(workspace)
|
||||
self.policy = SwarmsPolicy(max_agents=max_parallel_agents)
|
||||
self.target_resolver = TargetResolver(
|
||||
workspace=workspace,
|
||||
registry=registry,
|
||||
provider=provider,
|
||||
model=model,
|
||||
max_parallel_agents=max_parallel_agents,
|
||||
gateway_port=gateway_port,
|
||||
)
|
||||
self.planner = SwarmsRunPlanner(
|
||||
model=model,
|
||||
registry=registry,
|
||||
target_resolver=self.target_resolver,
|
||||
procedure_memory=self.procedure_memory,
|
||||
policy=self.policy,
|
||||
)
|
||||
self.swarms = SwarmsBridge(
|
||||
workspace=workspace,
|
||||
registry=registry,
|
||||
member_runner=member_runner,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _clean_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
if value is not None
|
||||
and not (isinstance(value, str) and not value.strip())
|
||||
and not (isinstance(value, (list, tuple, set, dict)) and not value)
|
||||
}
|
||||
|
||||
async def _emit_trace(
|
||||
self,
|
||||
run_id: str,
|
||||
text: str,
|
||||
*,
|
||||
stage_label: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await emit_process_event(
|
||||
"process_run_progress",
|
||||
run_id=run_id,
|
||||
actor_type="system",
|
||||
actor_id="agent-team",
|
||||
actor_name="Agent Team",
|
||||
text=text,
|
||||
metadata=self._clean_metadata({
|
||||
"source": "agent_team_orchestrator",
|
||||
"stage_label": stage_label,
|
||||
**(metadata or {}),
|
||||
}),
|
||||
)
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
label: str,
|
||||
skills: list[str],
|
||||
origin: dict[str, str],
|
||||
announce_via_bus: bool,
|
||||
run_id: str,
|
||||
) -> BridgeResult:
|
||||
"""Run the team task through swarms only."""
|
||||
await self._emit_trace(
|
||||
run_id,
|
||||
"Preparing a swarms run specification for the agent team.",
|
||||
stage_label="准备 swarms 运行规格",
|
||||
metadata={
|
||||
"phase": "planning",
|
||||
"skills": list(skills),
|
||||
"origin": dict(origin),
|
||||
"announce_via_bus": announce_via_bus,
|
||||
},
|
||||
)
|
||||
spec = await self.planner.plan(task=task, label=label, skills=list(skills))
|
||||
await self._emit_trace(
|
||||
run_id,
|
||||
f"Swarms run spec is ready: {spec.swarm_type} with {len(spec.agent_ids)} agent(s).",
|
||||
stage_label="swarms 运行规格已就绪",
|
||||
metadata={
|
||||
"phase": "planning",
|
||||
"spec": spec.to_dict(),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Agent team [{}] running swarms type={} agents={}",
|
||||
run_id,
|
||||
spec.swarm_type,
|
||||
spec.agent_ids,
|
||||
)
|
||||
|
||||
cleanup: dict[str, Any] = {}
|
||||
try:
|
||||
result = await self.swarms.run_spec(spec=spec, run_id=run_id)
|
||||
finally:
|
||||
cleanup = await self._cleanup_created_specialists(spec, run_id)
|
||||
if cleanup:
|
||||
result.raw.setdefault("provisioning_cleanup", cleanup)
|
||||
if cleanup.get("created_targets"):
|
||||
# The run used temporary specialists that have now been removed; do not
|
||||
# persist a reusable procedure pointing at deleted agent ids.
|
||||
result.candidate_procedure = None
|
||||
result.raw.setdefault("origin", dict(origin))
|
||||
result.raw.setdefault("announce_via_bus", announce_via_bus)
|
||||
|
||||
stored_procedure = None
|
||||
if result.success:
|
||||
stored_procedure = await self.procedure_memory.record_candidate(task, result)
|
||||
await self.run_memory.record_run(
|
||||
task,
|
||||
ExecutionMode.SWARMS,
|
||||
result,
|
||||
procedure_id=(
|
||||
stored_procedure.id
|
||||
if stored_procedure is not None
|
||||
else (
|
||||
result.matched_procedure.id
|
||||
if result.matched_procedure is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
await self._emit_trace(
|
||||
run_id,
|
||||
"Swarms agent team run completed.",
|
||||
stage_label="swarms 团队执行完成",
|
||||
metadata={
|
||||
"phase": "completed",
|
||||
"success": result.success,
|
||||
"mode": result.mode.value,
|
||||
"stored_procedure_id": stored_procedure.id if stored_procedure else None,
|
||||
"attempt_count": len(result.attempts),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
async def _cleanup_created_specialists(
|
||||
self,
|
||||
spec: Any,
|
||||
run_id: str,
|
||||
) -> dict[str, Any]:
|
||||
created_targets = self._created_provisioned_targets(spec)
|
||||
if not created_targets:
|
||||
return {}
|
||||
error = None
|
||||
try:
|
||||
deleted_targets = self.target_resolver.provisioning.cleanup_local_specialists(created_targets)
|
||||
except Exception as exc:
|
||||
deleted_targets = []
|
||||
error = str(exc)
|
||||
logger.warning("Failed to clean up auto-provisioned agent-team specialists: {}", exc)
|
||||
deleted_set = set(deleted_targets)
|
||||
cleanup = {
|
||||
"created_targets": created_targets,
|
||||
"deleted_targets": deleted_targets,
|
||||
"skipped_targets": [
|
||||
target
|
||||
for target in created_targets
|
||||
if target not in deleted_set
|
||||
],
|
||||
}
|
||||
if error is not None:
|
||||
cleanup["error"] = error
|
||||
try:
|
||||
await self._emit_trace(
|
||||
run_id,
|
||||
"Cleaned up auto-provisioned agent-team specialists.",
|
||||
stage_label="清理自动创建的团队成员",
|
||||
metadata={
|
||||
"phase": "cleanup",
|
||||
**cleanup,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to emit agent-team cleanup trace: {}", exc)
|
||||
return cleanup
|
||||
|
||||
@staticmethod
|
||||
def _created_provisioned_targets(spec: Any) -> list[str]:
|
||||
metadata = getattr(spec, "metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
return []
|
||||
target_plan = metadata.get("target_plan")
|
||||
if not isinstance(target_plan, dict):
|
||||
return []
|
||||
created_targets = target_plan.get("created_provisioned_targets")
|
||||
if not created_targets:
|
||||
plan_metadata = target_plan.get("metadata")
|
||||
if isinstance(plan_metadata, dict):
|
||||
created_targets = plan_metadata.get("created_provisioned_targets")
|
||||
return [
|
||||
target
|
||||
for target in dict.fromkeys(str(item).strip() for item in (created_targets or []))
|
||||
if target
|
||||
]
|
||||
@ -1,185 +0,0 @@
|
||||
"""Provision managed local A2A specialists for agent teams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.subagents import LocalSubagentStore, normalize_subagent_id
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpecialistProvisionResult:
|
||||
"""Result of ensuring a managed specialist exists."""
|
||||
|
||||
agent_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class ProvisioningManager:
|
||||
"""Manage local specialists through LocalSubagentStore."""
|
||||
|
||||
def __init__(self, workspace: Path, *, gateway_port: int = 18790) -> None:
|
||||
self.workspace = workspace
|
||||
self.gateway_port = int(os.getenv("APP_BACKEND_PORT") or gateway_port)
|
||||
self.store = LocalSubagentStore(workspace)
|
||||
|
||||
async def ensure_local_specialist_with_result(
|
||||
self,
|
||||
*,
|
||||
role: str,
|
||||
task: str,
|
||||
skills: list[str] | None = None,
|
||||
) -> SpecialistProvisionResult:
|
||||
"""创建或刷新一个本地 specialist,并返回它是否是首次创建。"""
|
||||
# role 可能来自上游 planner、用户输入或其他动态流程,这里先做兜底和规范化:
|
||||
# 1. 空值时退回到通用角色 "general specialist"
|
||||
# 2. 去掉首尾空白,避免生成不稳定的 agent 标识
|
||||
# 这样可以保证后续 id、显示名、标签等字段都基于同一个干净的角色名生成。
|
||||
role_name = str(role or "general specialist").strip() or "general specialist"
|
||||
|
||||
# agent_id 由“角色名 + 任务指纹”组成:
|
||||
# - 同一角色处理同一任务时会命中同一个 id,从而实现刷新/复用
|
||||
# - 同一角色处理不同任务时会得到不同 id,避免不同任务上下文互相污染
|
||||
agent_id = self._specialist_id(role_name, task)
|
||||
|
||||
# display_name 主要用于人类可读展示;它不影响真正的唯一性,
|
||||
# 唯一性仍由 agent_id 保证。
|
||||
display_name = self._display_name(role_name)
|
||||
|
||||
# 为即将 upsert 的 subagent 构造运行时配置。
|
||||
# 这里显式覆盖两个关键字段:
|
||||
# - workspace:确保 specialist 和当前 agent team 运行在同一个工作目录
|
||||
# - gateway.port:确保它连接到当前后端实例暴露的网关端口
|
||||
# 这样新建/刷新出来的本地 specialist 才能在正确的环境里工作。
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(self.workspace)
|
||||
config.gateway.port = self.gateway_port
|
||||
|
||||
# payload 是写入 LocalSubagentStore 的完整声明式规格。
|
||||
# store.upsert_subagent(...) 会根据这份规格创建或刷新 subagent。
|
||||
payload = {
|
||||
# 稳定唯一 id,用于判断“是否已存在”以及后续更新同一个 specialist。
|
||||
"id": agent_id,
|
||||
|
||||
# 人类可读名称,便于在 UI、日志或调试信息中识别角色。
|
||||
"name": display_name,
|
||||
|
||||
# 简短描述说明该 agent 的来源和用途:它是 agent team 自动托管的本地 A2A specialist。
|
||||
"description": f"Managed local A2A specialist for {role_name}.",
|
||||
|
||||
# system_prompt 注入角色视角、原始任务以及本次要求携带的技能上下文,
|
||||
# 是 specialist 实际行为边界和任务目标的核心输入。
|
||||
"system_prompt": self._system_prompt(role_name, task, skills or []),
|
||||
|
||||
# 允许它进行完整委派;也就是说该 specialist 自己可以继续向下分派任务,
|
||||
# 而不是被限制为只能本地直接回答。
|
||||
"delegation_mode": "full",
|
||||
|
||||
# 允许访问 MCP,表示这个 specialist 在受外层权限控制的前提下可以使用 MCP 能力。
|
||||
"allow_mcp": True,
|
||||
|
||||
# tags 用于分类、筛选和后续清理:
|
||||
# - auto-provisioned / agent-team:标明它是系统自动创建的团队成员
|
||||
# - role_name.replace(" ", "-"):保留一个角色维度标签,便于检索
|
||||
# - skills:把本次技能要求也落到标签中,方便观测和调试
|
||||
# 使用 set 去重、sorted 排序,保证结果稳定。
|
||||
"tags": sorted(set(["auto-provisioned", "agent-team", role_name.replace(" ", "-")] + list(skills or []))),
|
||||
|
||||
# aliases 提供额外可匹配名称,既支持原始角色名,也支持格式化后的展示名。
|
||||
"aliases": [role_name, display_name],
|
||||
|
||||
# metadata 存放程序消费的结构化信息:
|
||||
# - managed_by:标记由哪个模块托管,后续 cleanup 时会用来判定是否允许删除
|
||||
# - role:记录规范化后的角色名
|
||||
# - task_fingerprint:记录任务指纹,便于追踪这个 specialist 绑定的是哪类任务上下文
|
||||
"metadata": {
|
||||
"managed_by": "agent_team_provisioning",
|
||||
"role": role_name,
|
||||
"task_fingerprint": self._fingerprint(task),
|
||||
},
|
||||
}
|
||||
|
||||
# 先读取一次已有记录,用于区分“首次创建”还是“刷新已有 specialist”。
|
||||
# 注意:真正的写入动作由后面的 upsert 完成。
|
||||
existing = self.store.get_subagent(agent_id)
|
||||
|
||||
# upsert 语义是:
|
||||
# - 不存在则创建
|
||||
# - 已存在则按新的 payload/config 刷新
|
||||
# 这样调用方不需要区分 create / update 两条路径。
|
||||
spec = self.store.upsert_subagent(payload, config)
|
||||
|
||||
# 日志区分 provisioned 和 refreshed,便于排查:
|
||||
# - 为什么这次新建了一个 specialist
|
||||
# - 或者为什么只是把旧的配置重新覆盖了一次
|
||||
if existing is None:
|
||||
logger.info("Provisioned local A2A specialist {} for role '{}'", spec.id, role_name)
|
||||
else:
|
||||
logger.info("Refreshed local A2A specialist {} for role '{}'", spec.id, role_name)
|
||||
|
||||
# 返回两类关键信息:
|
||||
# - agent_id:供上游继续引用这个 specialist
|
||||
# - created:明确告知这次是首次创建,还是命中了已有对象并完成刷新
|
||||
return SpecialistProvisionResult(agent_id=spec.id, created=existing is None)
|
||||
|
||||
def cleanup_local_specialists(self, agent_ids: list[str]) -> list[str]:
|
||||
"""Delete managed specialists and return the ids actually removed."""
|
||||
deleted: list[str] = []
|
||||
for agent_id in dict.fromkeys(str(item).strip() for item in agent_ids if str(item).strip()):
|
||||
spec = self.store.get_subagent(agent_id)
|
||||
if spec is None:
|
||||
continue
|
||||
if not self._is_managed_specialist(spec.metadata, spec.tags):
|
||||
logger.warning("Skipping cleanup for unmanaged local specialist candidate {}", agent_id)
|
||||
continue
|
||||
if self.store.delete_subagent(agent_id):
|
||||
deleted.append(agent_id)
|
||||
logger.info("Cleaned up local A2A specialist {}", agent_id)
|
||||
return deleted
|
||||
|
||||
@staticmethod
|
||||
def _is_managed_specialist(metadata: dict[str, Any], tags: list[str]) -> bool:
|
||||
return (
|
||||
metadata.get("managed_by") == "agent_team_provisioning"
|
||||
or "auto-provisioned" in tags
|
||||
)
|
||||
|
||||
def _specialist_id(self, role: str, task: str) -> str:
|
||||
base = normalize_subagent_id(role)
|
||||
return normalize_subagent_id(f"{base}-{self._fingerprint(task)}")
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint(task: str) -> str:
|
||||
return hashlib.sha1(str(task or "").encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
@staticmethod
|
||||
def _display_name(role: str) -> str:
|
||||
return " ".join(part.capitalize() for part in re.split(r"[\s_-]+", role.strip()) if part)
|
||||
|
||||
def _system_prompt(self, role: str, task: str, skills: list[str]) -> str:
|
||||
# skills 是本次 team run 要求携带的技能上下文;这里仅写入提示词,
|
||||
# 真正的工具可用性和权限仍由外层 AgentLoop / tool registry 控制。
|
||||
skills_text = ", ".join(skills) if skills else "none"
|
||||
role_text = re.sub(r"\s+", " ", str(role or "").strip()) or "general specialist"
|
||||
|
||||
# 这里保持一套完全通用的提示模板:
|
||||
# - 不对具体角色做领域特化
|
||||
# - 不规定固定输出格式
|
||||
# - 只强调“按该角色名称隐含的职责边界来贡献结果”
|
||||
return (
|
||||
f"你是 nanobot agent team 中的 {role_text}。\n\n"
|
||||
"请围绕这个角色名称所隐含的职责边界处理原始团队任务。根据任务本身选择"
|
||||
"合适的方法、工具、下游委派方式和输出格式,不要强行套用固定报告模板。"
|
||||
"你的结果应该便于团队合并成最终答案;如果关键假设、阻塞点或风险会影响"
|
||||
"结论,请明确指出。\n\n"
|
||||
f"原始团队任务:\n{task}\n\n"
|
||||
f"本次要求的技能:\n{skills_text}"
|
||||
)
|
||||
@ -1,261 +0,0 @@
|
||||
# Agent Team 真实运行调用链
|
||||
|
||||
更新时间:2026-04-08
|
||||
|
||||
这份文档用于代码 review。它不再写伪代码流程图,而是按当前实现列出从 `spawn_agent_team` 被调用,到 swarms 多 agent 执行,再到结果公告和持久化的真实函数链路。
|
||||
|
||||
核心原则:
|
||||
|
||||
```text
|
||||
nanobot 负责入口、registry、权限、skills、事件、memory、BridgeResult。
|
||||
swarms 负责团队架构运行、agent 间讨论/编排、调用 adapter。
|
||||
```
|
||||
|
||||
## 主调用链
|
||||
|
||||
```text
|
||||
SpawnAgentTeamTool.execute()
|
||||
作用:LLM/tool 层入口,接收 task / label / skills。
|
||||
-》 DelegationManager.dispatch_agent_team()
|
||||
作用:把工具调用转换成 agent_team 委派请求,固定 mode="agent_team"、strategy="group"。
|
||||
-》 DelegationManager._dispatch()
|
||||
作用:生成 run_id、display_label、origin,创建后台 asyncio task,立即返回“Agent team started”。
|
||||
-》 DelegationManager._run_dispatch()
|
||||
作用:后台真正执行 agent_team 分支;发出团队开始事件,并把任务交给 orchestrator。
|
||||
-》 AgentTeamOrchestrator.run_task()
|
||||
作用:agent team 薄编排入口;只做 plan -> swarms -> memory,不自建 team runtime。
|
||||
-》 SwarmsRunPlanner.plan()
|
||||
作用:生成 SwarmsRunSpec,决定 swarm_type、agent_ids、skills、rules、max_loops。
|
||||
-》 SwarmsBridge.run_spec()
|
||||
作用:发出“启动 swarms runtime”事件,执行 swarms,并把 swarms 输出转成 BridgeResult。
|
||||
-》 SwarmsBridge._run_swarms()
|
||||
作用:把 SwarmsRunSpec.agent_ids 转成 AgentDescriptor,再包成 NanobotAgentAdapter。
|
||||
-》 load_swarms_runtime()
|
||||
作用:懒加载 vendored third_party/swarms,取 AutoSwarmBuilder / SwarmRouter / GroupChat。
|
||||
-》 swarms.SwarmRouter(...)
|
||||
作用:创建 swarms 统一路由器,传入 nanobot adapters、swarm_type、rules、max_loops。
|
||||
-》 SwarmRouter.run(task=...)
|
||||
作用:交给 swarms 运行对应架构,例如 GroupChat / SequentialWorkflow / ConcurrentWorkflow。
|
||||
-》 NanobotAgentAdapter.run()
|
||||
作用:swarms 调用每个 agent adapter;adapter 把 swarms conversation context 转回 nanobot 成员任务。
|
||||
-》 DelegationManager._run_team_member_for_swarms()
|
||||
作用:为该成员创建 child run,做权限检查,发 agent started/finished 事件。
|
||||
-》 DelegationManager._execute_descriptor()
|
||||
作用:真正执行成员 agent;local_prompt/local_fallback 走 local_executor,A2A agent 走 A2AClient。
|
||||
-》 local_executor.run_local_task() 或 A2AClient.run_task()
|
||||
作用:成员 agent 产出 AgentRunResult。
|
||||
-》 NanobotAgentAdapter.run()
|
||||
作用:收集 AgentRunResult 到 adapter.results,并把 summary 返回给 swarms。
|
||||
-》 SwarmRouter.run(task=...)
|
||||
作用:swarms 收集所有 adapter 响应,返回 raw_output/transcript。
|
||||
-》 SwarmsBridge._normalize_swarms_output()
|
||||
作用:优先用 adapter.results 生成可读 SwarmsRunResult.summary,并保留 raw_output。
|
||||
-》 SwarmsBridge.run_spec()
|
||||
作用:构造 BridgeAttempt、candidate ProcedureRecord、BridgeResult。
|
||||
-》 AgentTeamOrchestrator.run_task()
|
||||
作用:成功时 ProcedureMemory.record_candidate(),随后 RunMemory.record_run(),再返回 BridgeResult。
|
||||
-》 DelegationManager._run_dispatch()
|
||||
作用:发团队 finished 事件,并调用 _announce_orchestrator_result()。
|
||||
-》 DelegationManager._announce_orchestrator_result()
|
||||
作用:把 BridgeResult 组装成给主 agent 的总结消息。
|
||||
-》 DelegationManager._publish_announcement() 或 _notify_direct_announcement()
|
||||
作用:通过 bus 回流主 agent,或直连回调到本地会话。
|
||||
-》 DelegationManager._emit_direct_user_message()
|
||||
作用:如果有 process event sink,给 UI 发即时可见完成消息。
|
||||
```
|
||||
|
||||
## Plan 分支
|
||||
|
||||
`SwarmsRunPlanner.plan()` 内部有两个分支。
|
||||
|
||||
简单/常规任务:
|
||||
|
||||
```text
|
||||
SwarmsRunPlanner.plan()
|
||||
作用:读取 ProcedureMemory.match_procedure(task),判断不需要 AutoSwarmBuilder。
|
||||
-》 SwarmsRunPlanner._simple_required_roles()
|
||||
作用:从 skills 生成角色,例如 implementation specialist / test specialist;没有 skills 则用 general specialist / synthesis analyst。
|
||||
-》 TargetResolver.resolve_team_targets()
|
||||
作用:根据 task、skills、required_specialists 选择已有 registry agents;缺人时调用 provisioning。
|
||||
-》 AgentRegistry.suggest_agents() / AgentRegistry.get_agent()
|
||||
作用:从 workspace/plugin/skill/local registry 中查找可执行 agent。
|
||||
-》 ProvisioningManager.ensure_local_specialist()
|
||||
作用:缺少合适 agent 时创建 managed local A2A specialist,并写入 workspace agent registry。
|
||||
-》 SwarmsRunSpec(...)
|
||||
作用:返回默认 GroupChat 运行规格,带 agent_ids、skills、rules、target_plan metadata。
|
||||
```
|
||||
|
||||
复杂/开放任务:
|
||||
|
||||
```text
|
||||
SwarmsRunPlanner.plan()
|
||||
作用:如果任务较长、命中复杂关键词,或有 ProcedureMemory hint,则进入自动建队。
|
||||
-》 SwarmsRunPlanner._run_auto_swarm_builder()
|
||||
作用:调用 swarms.AutoSwarmBuilder 生成 router config 建议。
|
||||
-》 SwarmsRunPlanner._auto_builder_prompt()
|
||||
作用:把 task、skills、memory_hint 和硬约束写入 AutoSwarmBuilder prompt。
|
||||
-》 SwarmsPolicy.validate_auto_config()
|
||||
作用:只允许安全的 swarm_type,限制 max_agents/max_loops,剥掉 tools、MCP、API key 等越权字段。
|
||||
-》 SwarmsRunPlanner._roles_from_auto_config()
|
||||
作用:从 AutoSwarmBuilder 输出提取需要的角色描述。
|
||||
-》 TargetResolver.resolve_team_targets()
|
||||
作用:把角色描述映射成 nanobot registry 中真实可执行的 agent_ids。
|
||||
-》 SwarmsRunPlanner._rearrange_flow()
|
||||
作用:如果 swarm_type 是 AgentRearrange,则用 safe_swarms_name(agent_id) 生成 flow。
|
||||
-》 SwarmsRunSpec(...)
|
||||
作用:返回经过 policy 清洗后的 swarms 运行规格。
|
||||
```
|
||||
|
||||
## Swarms 执行链
|
||||
|
||||
```text
|
||||
SwarmsBridge.run_spec()
|
||||
作用:接收 SwarmsRunSpec,发 process_run_progress(stage_label="启动 swarms runtime")。
|
||||
-》 SwarmsBridge._run_swarms()
|
||||
作用:解析 spec.agent_ids,构造 adapters,并实例化 SwarmRouter。
|
||||
-》 NanobotAgentAdapter.__post_init__()
|
||||
作用:设置 swarms 可识别的 agent_name/name/__name__/system_prompt。
|
||||
-》 SwarmsBridge._rules_with_skills()
|
||||
作用:生成 swarms rules,加入“不要新增工具/凭证/外部 endpoint”和 skills 约束。
|
||||
-》 SwarmsBridge._task_with_skills()
|
||||
作用:把 spec.task 和 spec.skills 合并成传给 SwarmRouter.run(task=...) 的任务文本。
|
||||
-》 SwarmRouter.run(task=...)
|
||||
作用:swarms 按 spec.swarm_type 创建并运行实际 swarm。
|
||||
-》 GroupChat / SequentialWorkflow / ConcurrentWorkflow / AgentRearrange / MixtureOfAgents / HierarchicalSwarm
|
||||
作用:由 swarms 负责具体多 agent 架构的讨论、顺序、并行、动态流程或层级协作。
|
||||
-》 NanobotAgentAdapter.run()
|
||||
作用:当 swarms 需要某个 agent 响应时,调用 nanobot adapter。
|
||||
-》 SwarmsBridge._normalize_swarms_output()
|
||||
作用:把 swarms raw_output 和 adapter.results 合并成 SwarmsRunResult。
|
||||
-》 SwarmsBridge._candidate_procedure()
|
||||
作用:成功时构造可选 ProcedureRecord,供 ProcedureMemory 学习复用。
|
||||
-》 BridgeResult(...)
|
||||
作用:统一返回 success、summary、member_results、candidate_procedure、attempts、raw。
|
||||
```
|
||||
|
||||
## 成员执行链
|
||||
|
||||
```text
|
||||
NanobotAgentAdapter.run(task)
|
||||
作用:接收 swarms 传入的 conversation/task。
|
||||
-》 NanobotAgentAdapter._task_with_skills()
|
||||
作用:把 skills 注入成员任务文本,形成 delegated_task。
|
||||
-》 asyncio.run_coroutine_threadsafe(member_runner(...))
|
||||
作用:从 swarms 的同步调用线程切回 nanobot 当前事件循环。
|
||||
-》 DelegationManager._run_team_member_for_swarms(descriptor, task, parent_run_id, skills)
|
||||
作用:创建 child_run_id,保持父子 process tree。
|
||||
-》 DelegationManager._ensure_descriptor_allowed()
|
||||
作用:检查 local/plugin/A2A agent 是否允许被委派。
|
||||
-》 DelegationManager._emit_agent_started()
|
||||
作用:发出成员开始事件。
|
||||
-》 DelegationManager._execute_descriptor()
|
||||
作用:根据 AgentDescriptor.kind / protocol 选择执行方式。
|
||||
-》 local_executor.run_local_task()
|
||||
作用:执行 local_prompt / local_fallback agent,并传入 skill_context、skill_names、progress_callback。
|
||||
-》 A2AClient.run_task()
|
||||
作用:执行远端或本地 gateway 暴露的 A2A agent。
|
||||
-》 DelegationManager._emit_agent_finished()
|
||||
作用:发出成员完成事件。
|
||||
-》 NanobotAgentAdapter.run()
|
||||
作用:把 AgentRunResult 存入 adapter.results;成功时返回 result.summary,失败时返回 error 文本给 swarms。
|
||||
```
|
||||
|
||||
## skills 注入链
|
||||
|
||||
```text
|
||||
SpawnAgentTeamTool.execute(skills)
|
||||
作用:接收工具参数里的 skills。
|
||||
-》 DelegationManager.dispatch_agent_team(skills=skills)
|
||||
作用:把 skills 放进后台 dispatch 参数。
|
||||
-》 DelegationManager._dispatch(skills=skills)
|
||||
作用:把 skills 保存到后台 task 调用参数。
|
||||
-》 DelegationManager._run_dispatch(skills=skills)
|
||||
作用:把 skills 传给 AgentTeamOrchestrator.run_task()。
|
||||
-》 AgentTeamOrchestrator.run_task(skills=skills)
|
||||
作用:把 skills 传给 planner 和 swarms bridge。
|
||||
-》 SwarmsRunPlanner.plan(skills=skills)
|
||||
作用:skills 参与角色选择和 AutoSwarmBuilder prompt。
|
||||
-》 SwarmsRunSpec.skills
|
||||
作用:skills 固化到运行规格,供 events、rules、task、adapter 使用。
|
||||
-》 SwarmsBridge._rules_with_skills()
|
||||
作用:把 skills 写入 SwarmRouter rules。
|
||||
-》 SwarmsBridge._task_with_skills()
|
||||
作用:把 skills 写入 SwarmRouter.run(task=...) 的任务文本。
|
||||
-》 NanobotAgentAdapter._task_with_skills()
|
||||
作用:把 skills 写入每个成员看到的 delegated task。
|
||||
-》 DelegationManager._execute_descriptor(skill_names=skills)
|
||||
作用:本地 agent 获得 skill_context / skill_names;A2A agent 获得 augment 后的任务文本。
|
||||
```
|
||||
|
||||
## 结果返回链
|
||||
|
||||
```text
|
||||
SwarmsBridge._normalize_swarms_output()
|
||||
作用:生成 SwarmsRunResult(summary, raw_output, member_results)。
|
||||
-》 SwarmsBridge.run_spec()
|
||||
作用:生成 BridgeAttempt 和 BridgeResult。
|
||||
-》 AgentTeamOrchestrator.run_task()
|
||||
作用:写 ProcedureMemory 和 RunMemory。
|
||||
-》 DelegationManager._emit_group_finished()
|
||||
作用:把团队 run 标记为 done/error,metadata 带 attempts 和成员状态。
|
||||
-》 DelegationManager._announce_orchestrator_result()
|
||||
作用:把 BridgeResult 整理成主 agent 可读的系统消息。
|
||||
-》 DelegationManager._publish_announcement()
|
||||
作用:announce_via_bus=True 时,把消息 publish 到 inbound bus,让主 agent 继续总结。
|
||||
-》 DelegationManager._notify_direct_announcement()
|
||||
作用:announce_via_bus=False 时,直接调用本地回调回流会话。
|
||||
-》 DelegationManager._emit_direct_user_message()
|
||||
作用:有 process event sink 时,给前端/UI 发一条即时完成消息。
|
||||
```
|
||||
|
||||
## 当前放行的 swarms 架构
|
||||
|
||||
`SwarmsPolicy.allowed_swarm_types` 当前只放行能消费 nanobot adapters 的架构:
|
||||
|
||||
```text
|
||||
GroupChat
|
||||
SequentialWorkflow
|
||||
ConcurrentWorkflow
|
||||
AgentRearrange
|
||||
MixtureOfAgents
|
||||
HierarchicalSwarm
|
||||
```
|
||||
|
||||
`GraphWorkflow` / `HeavySwarm` 暂不直接放行,因为当前 vendored `SwarmRouter` 的相关 factory 还不能稳定消费 nanobot 提供的 `NanobotAgentAdapter`、registry、skills 和权限边界。
|
||||
|
||||
## 文件职责速查
|
||||
|
||||
```text
|
||||
agent/tools/spawn.py
|
||||
作用:定义 spawn_agent_team 工具入口。
|
||||
|
||||
agent/delegation.py
|
||||
作用:后台调度、process events、成员执行、结果公告。
|
||||
|
||||
agent_team/orchestrator.py
|
||||
作用:agent team 主 glue,负责 plan -> swarms -> memory。
|
||||
|
||||
agent_team/swarms_planner.py
|
||||
作用:生成 SwarmsRunSpec;需要时调用 AutoSwarmBuilder。
|
||||
|
||||
agent_team/swarms_policy.py
|
||||
作用:清洗 AutoSwarmBuilder 输出,限制 swarm_type、agents、loops 和越权字段。
|
||||
|
||||
agent_team/target_resolver.py
|
||||
作用:把角色需求解析成真实 agent_ids。
|
||||
|
||||
agent_team/provisioning.py
|
||||
作用:缺少合适成员时创建 managed local A2A specialist。
|
||||
|
||||
agent_team/swarms_adapter.py
|
||||
作用:懒加载 vendored swarms,并把 nanobot agent 包成 swarms 可调用 adapter。
|
||||
|
||||
agent_team/swarms_bridge.py
|
||||
作用:构造 SwarmRouter、运行 swarms、归一化 BridgeResult。
|
||||
|
||||
agent_team/memory.py
|
||||
作用:记录 RunMemory / ProcedureMemory。
|
||||
|
||||
agent_team/types.py
|
||||
作用:定义 SwarmsRunSpec、SwarmsRunResult、BridgeAttempt、BridgeResult 等共享类型。
|
||||
```
|
||||
@ -1,114 +0,0 @@
|
||||
"""Thin adapters between nanobot agents and the vendored swarms runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.agent_registry import AgentDescriptor
|
||||
from nanobot.agent.run_result import AgentRunResult
|
||||
|
||||
MemberRunner = Callable[[AgentDescriptor, str, str, list[str]], Awaitable[AgentRunResult]]
|
||||
|
||||
|
||||
def _candidate_swarms_roots() -> list[Path]:
|
||||
"""Return likely vendored swarms paths across source and packaged layouts."""
|
||||
module_path = Path(__file__).resolve()
|
||||
candidates = [
|
||||
module_path.parents[2] / "third_party" / "swarms",
|
||||
Path("/opt/app/backend/third_party/swarms"),
|
||||
Path("/app/third_party/swarms"),
|
||||
Path.cwd() / "third_party" / "swarms",
|
||||
Path.cwd() / "backend" / "third_party" / "swarms",
|
||||
]
|
||||
unique: list[Path] = []
|
||||
seen: set[str] = set()
|
||||
for candidate in candidates:
|
||||
key = str(candidate)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique.append(candidate)
|
||||
return unique
|
||||
|
||||
|
||||
def ensure_swarms_importable() -> None:
|
||||
"""Put the vendored swarms checkout on `sys.path` if needed."""
|
||||
for swarms_root in _candidate_swarms_roots():
|
||||
if swarms_root.exists() and str(swarms_root) not in sys.path:
|
||||
sys.path.insert(0, str(swarms_root))
|
||||
return
|
||||
|
||||
|
||||
def load_swarms_runtime() -> dict[str, Any]:
|
||||
"""Lazy-load swarms classes without making package import fragile."""
|
||||
ensure_swarms_importable()
|
||||
from swarms import AutoSwarmBuilder # type: ignore
|
||||
from swarms.structs.groupchat import GroupChat # type: ignore
|
||||
from swarms.structs.swarm_router import SwarmRouter # type: ignore
|
||||
|
||||
return {
|
||||
"AutoSwarmBuilder": AutoSwarmBuilder,
|
||||
"GroupChat": GroupChat,
|
||||
"SwarmRouter": SwarmRouter,
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in {"AutoSwarmBuilder", "GroupChat", "SwarmRouter"}:
|
||||
return load_swarms_runtime()[name]
|
||||
raise AttributeError(name)
|
||||
|
||||
|
||||
def safe_swarms_name(agent_id: str) -> str:
|
||||
"""Return a GroupChat-friendly ASCII-ish name for @mentions."""
|
||||
normalized = "".join(ch if ch.isalnum() else "_" for ch in str(agent_id or "agent"))
|
||||
normalized = normalized.strip("_") or "agent"
|
||||
return f"agent_{normalized}"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class NanobotAgentAdapter:
|
||||
"""Callable wrapper that lets swarms invoke a nanobot agent descriptor."""
|
||||
|
||||
descriptor: AgentDescriptor
|
||||
run_id: str
|
||||
loop: asyncio.AbstractEventLoop
|
||||
member_runner: MemberRunner
|
||||
skills: list[str]
|
||||
results: list[AgentRunResult] = field(default_factory=list, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.agent_name = safe_swarms_name(self.descriptor.id)
|
||||
self.name = self.agent_name
|
||||
self.system_prompt = self.descriptor.system_prompt or self.descriptor.description
|
||||
self.__name__ = self.agent_name
|
||||
|
||||
def __call__(self, conversation_context: str) -> str:
|
||||
return self.run(conversation_context)
|
||||
|
||||
def run(self, task: str, *args: Any, **kwargs: Any) -> str:
|
||||
delegated_task = self._task_with_skills(task)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.member_runner(self.descriptor, delegated_task, self.run_id, list(self.skills)),
|
||||
self.loop,
|
||||
)
|
||||
result = future.result(timeout=300)
|
||||
self.results.append(result)
|
||||
if result.status != "ok":
|
||||
return f"Error from {self.agent_name}: {result.summary}"
|
||||
return result.summary
|
||||
|
||||
def _task_with_skills(self, conversation_context: str) -> str:
|
||||
if not self.skills:
|
||||
return conversation_context
|
||||
return (
|
||||
"Required skills for this delegated team member:\n"
|
||||
f"{', '.join(self.skills)}\n\n"
|
||||
"Swarms conversation context:\n"
|
||||
f"{conversation_context}"
|
||||
).strip()
|
||||
@ -1,302 +0,0 @@
|
||||
"""Bridge from nanobot agent-team tasks into the vendored swarms runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.agent_registry import AgentRegistry
|
||||
from nanobot.agent.process_events import emit_process_event
|
||||
from nanobot.agent.run_result import has_meaningful_summary
|
||||
from nanobot.agent_team.swarms_adapter import MemberRunner, NanobotAgentAdapter, load_swarms_runtime
|
||||
from nanobot.agent_team.types import (
|
||||
BridgeAttempt,
|
||||
BridgeResult,
|
||||
ExecutionMode,
|
||||
ProcedureRecord,
|
||||
SwarmsRunResult,
|
||||
SwarmsRunSpec,
|
||||
)
|
||||
|
||||
|
||||
class SwarmsBridge:
|
||||
"""Execute a `SwarmsRunSpec` with `SwarmRouter` and normalize the output."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
registry: AgentRegistry,
|
||||
member_runner: MemberRunner,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.registry = registry
|
||||
self.member_runner = member_runner
|
||||
|
||||
async def run_spec(self, *, spec: SwarmsRunSpec, run_id: str) -> BridgeResult:
|
||||
# 先发一条过程事件,告诉上层“swarms 执行阶段已经开始”。
|
||||
# metadata 里带完整 spec,便于前端或日志侧排查本次实际执行参数。
|
||||
await self._emit_progress(
|
||||
run_id,
|
||||
f"Starting swarms run: {spec.swarm_type}.",
|
||||
stage_label="启动 swarms runtime",
|
||||
metadata={"spec": spec.to_dict()},
|
||||
)
|
||||
|
||||
# 真正调用 swarms runtime,返回的是“桥接层内部使用”的 SwarmsRunResult。
|
||||
swarms_result = await self._run_swarms(spec=spec, run_id=run_id)
|
||||
|
||||
# success 不只看 swarms_result.success,还要求 summary 有实际内容。
|
||||
# 这样可以避免 runtime technically 跑完了,但最终没有任何可消费结论时,
|
||||
# 上层误把它当成一次成功执行。
|
||||
success = swarms_result.success and has_meaningful_summary(swarms_result.summary)
|
||||
error = None if success else (swarms_result.error or swarms_result.summary)
|
||||
|
||||
# BridgeAttempt 表示“这次 swarms 模式尝试”的完整快照;
|
||||
# 后续 BridgeResult.attempts 可以累计不同执行策略/回退路径的尝试记录。
|
||||
attempt = BridgeAttempt(
|
||||
mode=ExecutionMode.SWARMS,
|
||||
success=success,
|
||||
summary=swarms_result.summary,
|
||||
error=error,
|
||||
member_results=list(swarms_result.member_results),
|
||||
targets=list(spec.agent_ids),
|
||||
raw={
|
||||
"spec": spec.to_dict(),
|
||||
"swarms_result": swarms_result.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
# 只有成功时才生成 candidate procedure,避免把失败或空结果学习成可复用流程。
|
||||
candidate = self._candidate_procedure(spec, swarms_result, run_id) if success else None
|
||||
|
||||
# 再发一条归一化完成事件,让编排层知道 bridge 已经把 swarms 原始输出
|
||||
# 压成了 nanobot 可消费的标准结果结构。
|
||||
await self._emit_progress(
|
||||
run_id,
|
||||
"Swarms run returned a normalized bridge result.",
|
||||
stage_label="swarms 输出已归一",
|
||||
metadata={
|
||||
"success": success,
|
||||
"swarm_type": spec.swarm_type,
|
||||
"candidate_procedure_id": candidate.id if candidate else None,
|
||||
},
|
||||
)
|
||||
|
||||
# BridgeResult 是 swarms bridge 对外暴露的稳定边界:
|
||||
# - summary/member_results 给上层公告和持久化使用
|
||||
# - attempts/raw 保留足够多细节,便于后续解释和调试
|
||||
return BridgeResult(
|
||||
mode=ExecutionMode.SWARMS,
|
||||
success=success,
|
||||
summary=swarms_result.summary,
|
||||
error=error,
|
||||
member_results=list(swarms_result.member_results),
|
||||
candidate_procedure=candidate,
|
||||
attempts=[attempt],
|
||||
raw={
|
||||
"spec": spec.to_dict(),
|
||||
"swarms_result": swarms_result.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
async def _run_swarms(self, *, spec: SwarmsRunSpec, run_id: str) -> SwarmsRunResult:
|
||||
try:
|
||||
# 先把 spec.agent_ids 解析成当前 registry 中的 AgentDescriptor。
|
||||
# 这里显式校验 agent 必须存在,避免 swarms runtime 在更深处才报模糊错误。
|
||||
descriptors = []
|
||||
for agent_id in spec.agent_ids:
|
||||
descriptor = self.registry.get_agent(agent_id)
|
||||
if descriptor is None:
|
||||
raise ValueError(f"Agent not found for swarms run: {agent_id}")
|
||||
descriptors.append(descriptor)
|
||||
|
||||
# swarms runtime 运行在线程池里,但每个 NanobotAgentAdapter 最终仍要把执行
|
||||
# 切回当前事件循环中的 member_runner,因此这里提前拿到 running loop。
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# 把 nanobot 的 AgentDescriptor 包装成 swarms 可以直接调用的 adapter。
|
||||
# swarms 视角下它们只是“可调用 agent”;nanobot 视角下它们会回流到
|
||||
# member_runner,再由本地执行器或 A2A client 真正完成任务。
|
||||
adapters = [
|
||||
NanobotAgentAdapter(
|
||||
descriptor=descriptor,
|
||||
run_id=run_id,
|
||||
loop=loop,
|
||||
member_runner=self.member_runner,
|
||||
skills=list(spec.skills),
|
||||
)
|
||||
for descriptor in descriptors
|
||||
]
|
||||
|
||||
# SwarmRouter 是 vendored swarms runtime 的核心入口。
|
||||
# 这里把 planner 产出的 swarm_type / loops / flow / rules 全部映射进去。
|
||||
runtime = load_swarms_runtime()
|
||||
router = runtime["SwarmRouter"](
|
||||
name=spec.label or "nanobot-agent-team",
|
||||
description="Nanobot agent-team swarms router",
|
||||
agents=adapters,
|
||||
swarm_type=spec.swarm_type,
|
||||
max_loops=max(1, spec.max_loops),
|
||||
rearrange_flow=spec.rearrange_flow,
|
||||
rules=self._rules_with_skills(spec),
|
||||
autosave=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# swarms 的 router.run 是同步阻塞调用,因此放到线程池中执行,
|
||||
# 避免阻塞当前 asyncio 事件循环。
|
||||
raw_output = await asyncio.to_thread(router.run, task=self._task_with_skills(spec))
|
||||
|
||||
# swarms 原始输出结构并不稳定,统一在这里归一成 SwarmsRunResult。
|
||||
return self._normalize_swarms_output(raw_output, adapters)
|
||||
except Exception as exc:
|
||||
# 桥接层把异常收口成失败结果,而不是继续向上抛,
|
||||
# 这样 orchestrator 可以用统一的 BridgeResult 流程处理失败。
|
||||
return SwarmsRunResult(
|
||||
success=False,
|
||||
summary=f"Swarms execution failed: {exc}",
|
||||
raw_output=None,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
def _rules_with_skills(self, spec: SwarmsRunSpec) -> str:
|
||||
# 把上层规则和桥接层的硬约束拼到一起:
|
||||
# 1. 保留 planner 指定的 rules
|
||||
# 2. 明确禁止 swarms 擅自引入额外 agent、工具或凭证
|
||||
# 3. 把 skills 也写入规则,确保团队行为不偏离 nanobot 约束
|
||||
parts = [
|
||||
spec.rules or "Run the nanobot agent team through swarms and produce a concise synthesis.",
|
||||
"Do not add tools, credentials, network endpoints, or agents outside the provided nanobot adapters.",
|
||||
]
|
||||
if spec.skills:
|
||||
parts.append("Required nanobot skills: " + ", ".join(spec.skills))
|
||||
return "\n".join(parts)
|
||||
|
||||
def _task_with_skills(self, spec: SwarmsRunSpec) -> str:
|
||||
# skills 既体现在 rules 中,也直接拼到任务文本里,
|
||||
# 这样无论 swarms runtime 更依赖哪部分上下文,都能看到技能约束。
|
||||
if not spec.skills:
|
||||
return spec.task
|
||||
return (
|
||||
f"{spec.task}\n\n"
|
||||
"Required skills for this swarms run:\n"
|
||||
f"{', '.join(spec.skills)}"
|
||||
).strip()
|
||||
|
||||
def _normalize_swarms_output(
|
||||
self,
|
||||
raw_output: Any,
|
||||
adapters: list[NanobotAgentAdapter],
|
||||
) -> SwarmsRunResult:
|
||||
# 优先从 adapters 收集每个成员真实执行后的 AgentRunResult。
|
||||
# 这些结果比 swarms runtime 的自由格式输出更稳定、也更适合后续持久化。
|
||||
member_results = [
|
||||
result
|
||||
for adapter in adapters
|
||||
for result in adapter.results
|
||||
]
|
||||
|
||||
# summary 优先从成员结果推导;如果成员结果拿不到,再从 swarms 原始输出中兜底提取。
|
||||
summary = self._summary_from_swarms_output(raw_output, member_results)
|
||||
return SwarmsRunResult(
|
||||
success=bool(summary.strip()),
|
||||
summary=summary.strip(),
|
||||
raw_output=self._jsonable(raw_output),
|
||||
member_results=member_results,
|
||||
)
|
||||
|
||||
def _summary_from_swarms_output(self, raw_output: Any, member_results: list[Any]) -> str:
|
||||
# 如果已经拿到了结构化 member_results,就优先用它们生成总结,
|
||||
# 因为这比直接依赖 swarms 的原始输出更稳定、更贴近 nanobot 的结果模型。
|
||||
if member_results:
|
||||
return "\n\n".join(
|
||||
f"{result.agent_name} ({result.status}):\n{result.summary}"
|
||||
for result in member_results
|
||||
if str(result.summary or "").strip()
|
||||
)
|
||||
|
||||
# swarms 有时直接返回字符串,那就把它当作最终 summary。
|
||||
if isinstance(raw_output, str):
|
||||
return raw_output.strip()
|
||||
|
||||
# swarms 也可能返回 transcript/list 结构;这里尝试提取非 user/system 的发言,
|
||||
# 拼成一个可读摘要。
|
||||
if isinstance(raw_output, list):
|
||||
lines: list[str] = []
|
||||
for item in raw_output:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = str(item.get("role") or item.get("speaker") or "").strip()
|
||||
content = str(item.get("content") or item.get("message") or "").strip()
|
||||
if not content or role.lower() in {"user", "system"}:
|
||||
continue
|
||||
lines.append(f"{role}: {content}" if role else content)
|
||||
if lines:
|
||||
return "\n\n".join(lines)
|
||||
|
||||
# 最后兜底把原始输出尽量序列化成 JSON 文本;再不行就直接 str(...)。
|
||||
try:
|
||||
return json.dumps(raw_output, ensure_ascii=False, indent=2)
|
||||
except TypeError:
|
||||
return str(raw_output)
|
||||
|
||||
def _jsonable(self, value: Any) -> Any:
|
||||
# raw_output 最终要落到 BridgeResult / RunMemory 里,因此这里尽量保证它可序列化。
|
||||
# 若原值无法直接 JSON 化,则退回字符串表示,避免整个持久化流程失败。
|
||||
try:
|
||||
json.dumps(value, ensure_ascii=False)
|
||||
return value
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
def _candidate_procedure(
|
||||
self,
|
||||
spec: SwarmsRunSpec,
|
||||
result: SwarmsRunResult,
|
||||
run_id: str,
|
||||
) -> ProcedureRecord:
|
||||
# bridge 只负责产出一个“可候选复用”的 procedure 草稿:
|
||||
# - task_template/agent_ids/strategy 记录执行骨架
|
||||
# - summary 提供人类可读概览
|
||||
# - metadata 记录它来自 swarms bridge
|
||||
# 真正是否持久化、如何更新统计,由更上层的 procedure memory 决定。
|
||||
return ProcedureRecord(
|
||||
task_template=spec.task,
|
||||
summary=result.summary,
|
||||
agent_ids=list(spec.agent_ids),
|
||||
strategy=spec.swarm_type,
|
||||
confidence=0.6,
|
||||
source_run_id=run_id,
|
||||
metadata={
|
||||
"source": "swarms_bridge",
|
||||
"swarm_type": spec.swarm_type,
|
||||
"auto_generated": spec.auto_generated,
|
||||
"skills": list(spec.skills),
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
text: str,
|
||||
*,
|
||||
stage_label: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
# 统一发 process_run_progress,让前端/日志看到 swarms bridge 当前阶段。
|
||||
await emit_process_event(
|
||||
"process_run_progress",
|
||||
run_id=run_id,
|
||||
actor_type="system",
|
||||
actor_id="swarms-bridge",
|
||||
actor_name="Swarms Bridge",
|
||||
text=text,
|
||||
metadata={
|
||||
"source": "swarms_bridge",
|
||||
"stage_label": stage_label,
|
||||
**(metadata or {}),
|
||||
},
|
||||
)
|
||||
@ -1,184 +0,0 @@
|
||||
"""Planner that prepares a minimal swarms run spec for agent-team tasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.agent_registry import AgentRegistry
|
||||
from nanobot.agent_team.memory import ProcedureMemory
|
||||
from nanobot.agent_team.swarms_adapter import load_swarms_runtime, safe_swarms_name
|
||||
from nanobot.agent_team.swarms_policy import SwarmsPolicy
|
||||
from nanobot.agent_team.target_resolver import TargetResolver
|
||||
from nanobot.agent_team.types import SwarmsRunSpec
|
||||
|
||||
|
||||
class SwarmsRunPlanner:
|
||||
"""Generate `SwarmsRunSpec` without rebuilding swarms' own planner/runtime."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None,
|
||||
registry: AgentRegistry,
|
||||
target_resolver: TargetResolver,
|
||||
procedure_memory: ProcedureMemory,
|
||||
policy: SwarmsPolicy,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.registry = registry
|
||||
self.target_resolver = target_resolver
|
||||
self.procedure_memory = procedure_memory
|
||||
self.policy = policy
|
||||
|
||||
async def plan(self, *, task: str, label: str, skills: list[str]) -> SwarmsRunSpec:
|
||||
memory_hint = self.procedure_memory.match_procedure(task)
|
||||
if self._should_auto_build(task, skills, memory_hint):
|
||||
raw_config = await self._run_auto_swarm_builder(task, skills, memory_hint)
|
||||
return await self._spec_from_auto_config(task, label, skills, raw_config)
|
||||
|
||||
target_plan = await self.target_resolver.resolve_team_targets(
|
||||
task=task,
|
||||
skills=skills,
|
||||
required_specialists=self._simple_required_roles(task, skills),
|
||||
)
|
||||
return SwarmsRunSpec(
|
||||
task=task,
|
||||
label=label,
|
||||
skills=list(skills),
|
||||
swarm_type="GroupChat",
|
||||
agent_ids=list(target_plan.final_targets),
|
||||
auto_generated=False,
|
||||
max_loops=2,
|
||||
rules=self._default_rules(),
|
||||
metadata={
|
||||
"memory_hint": memory_hint.id if memory_hint else None,
|
||||
"target_plan": target_plan.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
def _should_auto_build(self, task: str, skills: list[str], memory_hint: Any) -> bool:
|
||||
source = task or ""
|
||||
text = source.lower()
|
||||
markers = ("架构", "调研", "复杂", "多阶段", "strategy", "architecture", "research")
|
||||
return len(source) > 80 or memory_hint is not None or any(
|
||||
marker in source or marker in text for marker in markers
|
||||
)
|
||||
|
||||
async def _run_auto_swarm_builder(self, task: str, skills: list[str], memory_hint: Any) -> dict[str, Any]:
|
||||
try:
|
||||
runtime = load_swarms_runtime()
|
||||
builder = runtime["AutoSwarmBuilder"](
|
||||
name="nanobot-auto-swarm-builder",
|
||||
description="Generate a safe swarms router config for nanobot",
|
||||
max_loops=1,
|
||||
model_name=self._auto_builder_model_name(),
|
||||
generate_router_config=True,
|
||||
execution_type="return-swarm-router-config",
|
||||
interactive=False,
|
||||
verbose=False,
|
||||
)
|
||||
raw = await asyncio.to_thread(
|
||||
builder.run,
|
||||
self._auto_builder_prompt(task, skills, memory_hint),
|
||||
)
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
model_dump = getattr(raw, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
payload = model_dump()
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
except Exception as exc:
|
||||
logger.warning("AutoSwarmBuilder failed; falling back to deterministic run spec: {}", exc)
|
||||
return {}
|
||||
|
||||
def _auto_builder_model_name(self) -> str:
|
||||
model_name = str(self.model or "").strip()
|
||||
if not model_name:
|
||||
return "gpt-4.1"
|
||||
if "/" in model_name:
|
||||
return model_name
|
||||
return f"openai/{model_name}"
|
||||
|
||||
def _auto_builder_prompt(self, task: str, skills: list[str], memory_hint: Any) -> str:
|
||||
return (
|
||||
"Build a multi-agent swarm router config for nanobot.\n\n"
|
||||
f"User task:\n{task}\n\n"
|
||||
f"Required nanobot skills:\n{skills}\n\n"
|
||||
f"Procedure memory hint:\n{memory_hint}\n\n"
|
||||
"Return a valid JSON object that matches the swarm router config schema.\n\n"
|
||||
"Hard constraints:\n"
|
||||
"- Every generated role must follow the listed skills.\n"
|
||||
"- Do not replace, ignore, or reinterpret the listed skills.\n"
|
||||
"- Do not add external tools, credentials, MCP URLs, or hidden side effects.\n"
|
||||
"- Prefer existing nanobot registry agents; only describe missing roles."
|
||||
)
|
||||
|
||||
async def _spec_from_auto_config(
|
||||
self,
|
||||
task: str,
|
||||
label: str,
|
||||
skills: list[str],
|
||||
raw_config: dict[str, Any],
|
||||
) -> SwarmsRunSpec:
|
||||
safe_config = self.policy.validate_auto_config(raw_config)
|
||||
target_plan = await self.target_resolver.resolve_team_targets(
|
||||
task=task,
|
||||
skills=skills,
|
||||
required_specialists=self._roles_from_auto_config(safe_config),
|
||||
)
|
||||
return SwarmsRunSpec(
|
||||
task=task,
|
||||
label=label,
|
||||
skills=list(skills),
|
||||
swarm_type=str(safe_config.get("swarm_type") or "GroupChat"),
|
||||
agent_ids=list(target_plan.final_targets),
|
||||
auto_generated=bool(raw_config),
|
||||
max_loops=min(int(safe_config.get("max_loops") or 2), self.policy.max_loops),
|
||||
rearrange_flow=self._rearrange_flow(safe_config, target_plan.final_targets),
|
||||
rules=str(safe_config.get("rules") or self._default_rules()),
|
||||
raw_auto_config=safe_config,
|
||||
metadata={
|
||||
"target_plan": target_plan.to_dict(),
|
||||
"auto_builder_returned_config": bool(raw_config),
|
||||
},
|
||||
)
|
||||
|
||||
def _rearrange_flow(self, config: dict[str, Any], agent_ids: list[str]) -> str | None:
|
||||
if str(config.get("swarm_type") or "") == "AgentRearrange" and agent_ids:
|
||||
return " -> ".join(safe_swarms_name(agent_id) for agent_id in agent_ids)
|
||||
flow = config.get("rearrange_flow") or config.get("flow")
|
||||
if flow:
|
||||
return str(flow)
|
||||
return None
|
||||
|
||||
def _roles_from_auto_config(self, config: dict[str, Any]) -> list[str]:
|
||||
roles: list[str] = []
|
||||
for item in config.get("agents", []) or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = str(
|
||||
item.get("description")
|
||||
or item.get("system_prompt")
|
||||
or item.get("agent_name")
|
||||
or ""
|
||||
).strip()
|
||||
if role:
|
||||
roles.append(role)
|
||||
return roles or ["general specialist", "synthesis analyst"]
|
||||
|
||||
def _simple_required_roles(self, task: str, skills: list[str]) -> list[str]:
|
||||
if skills:
|
||||
return [f"{skill} specialist" for skill in skills]
|
||||
return ["general specialist", "synthesis analyst"]
|
||||
|
||||
def _default_rules(self) -> str:
|
||||
return (
|
||||
"You are running inside a nanobot agent team. Follow the provided skills, "
|
||||
"stay within your assigned role, and produce a concise final synthesis."
|
||||
)
|
||||
@ -1,70 +0,0 @@
|
||||
"""Policy guardrails for swarms-generated agent team plans."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SwarmsPolicy:
|
||||
"""Clamp AutoSwarmBuilder output before nanobot executes it."""
|
||||
|
||||
allowed_swarm_types = {
|
||||
# Keep this list to swarms that consume the provided nanobot agent adapters.
|
||||
"GroupChat",
|
||||
"SequentialWorkflow",
|
||||
"ConcurrentWorkflow",
|
||||
"AgentRearrange",
|
||||
"MixtureOfAgents",
|
||||
"HierarchicalSwarm",
|
||||
}
|
||||
|
||||
def __init__(self, *, max_agents: int = 4, max_loops: int = 3) -> None:
|
||||
self.max_agents = max(1, max_agents)
|
||||
self.max_loops = max(1, max_loops)
|
||||
|
||||
def validate_auto_config(self, raw_config: dict[str, Any]) -> dict[str, Any]:
|
||||
config = self._plain_dict(raw_config)
|
||||
|
||||
swarm_type = str(
|
||||
config.get("swarm_type")
|
||||
or config.get("type")
|
||||
or config.get("architecture")
|
||||
or "GroupChat"
|
||||
)
|
||||
if swarm_type not in self.allowed_swarm_types:
|
||||
swarm_type = "GroupChat"
|
||||
config["swarm_type"] = swarm_type
|
||||
|
||||
agents = list(config.get("agents") or [])[: self.max_agents]
|
||||
config["agents"] = [self._sanitize_agent_spec(item) for item in agents]
|
||||
config["max_loops"] = min(max(1, int(config.get("max_loops") or 2)), self.max_loops)
|
||||
|
||||
# AutoSwarmBuilder may suggest structure, not grant capabilities.
|
||||
config.pop("tools", None)
|
||||
config.pop("mcp_url", None)
|
||||
config.pop("mcp_urls", None)
|
||||
config.pop("llm_api_key", None)
|
||||
config.pop("api_key", None)
|
||||
return config
|
||||
|
||||
def _plain_dict(self, raw_config: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_config, dict):
|
||||
return dict(raw_config)
|
||||
model_dump = getattr(raw_config, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
payload = model_dump()
|
||||
return dict(payload) if isinstance(payload, dict) else {}
|
||||
dict_method = getattr(raw_config, "dict", None)
|
||||
if callable(dict_method):
|
||||
payload = dict_method()
|
||||
return dict(payload) if isinstance(payload, dict) else {}
|
||||
return {}
|
||||
|
||||
def _sanitize_agent_spec(self, item: Any) -> dict[str, Any]:
|
||||
spec = self._plain_dict(item)
|
||||
return {
|
||||
"agent_name": str(spec.get("agent_name") or spec.get("name") or "specialist"),
|
||||
"description": str(spec.get("description") or spec.get("agent_description") or ""),
|
||||
"system_prompt": str(spec.get("system_prompt") or "")[:4000],
|
||||
"role": str(spec.get("role") or "worker"),
|
||||
}
|
||||
@ -1,267 +0,0 @@
|
||||
"""Resolve and provision team targets before execution.
|
||||
|
||||
该模块负责在真正启动 agent-team / swarms 执行前,把“任务需要哪些角色”
|
||||
转换成一组可执行的 agent id。它优先复用 registry 里已有的 agent;当没有合适
|
||||
agent 覆盖某个角色时,再通过 ProvisioningManager 在本地创建 A2A specialist。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.agent_registry import AgentDescriptor, AgentRegistry
|
||||
from nanobot.agent_team.provisioning import ProvisioningManager
|
||||
from nanobot.agent_team.types import ResolvedTeamPlan
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class TargetResolver:
|
||||
"""把任务级的 specialist 需求解析成最终可执行的 agent id 列表。
|
||||
|
||||
解析策略分两层:
|
||||
1. 先读取当前 registry 里所有可见 agent,并过滤掉 router/planner 等
|
||||
不适合作为群聊工作成员的 agent。
|
||||
2. 如果调用方明确给出 required_specialists,则把 role 和候选 agent 交给
|
||||
LLM 直接选择最合适的已有 agent;LLM 选不出来时才 provision 本地
|
||||
specialist。没有明确角色时,则直接使用过滤后的已有 agent;若为空再
|
||||
兜底创建 general specialist。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
registry: AgentRegistry,
|
||||
provider: LLMProvider,
|
||||
model: str | None = None,
|
||||
max_parallel_agents: int = 16,
|
||||
gateway_port: int = 18790,
|
||||
provisioning: ProvisioningManager | None = None,
|
||||
) -> None:
|
||||
# max_parallel_agents 同时限制“最多尝试的角色数”和“最终返回的 agent 数”,
|
||||
# 避免一次 team run 生成过多并行成员。
|
||||
self.workspace = workspace
|
||||
self.registry = registry
|
||||
self.provider = provider
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_parallel_agents = max(1, max_parallel_agents)
|
||||
self.provisioning = provisioning or ProvisioningManager(workspace, gateway_port=gateway_port)
|
||||
|
||||
async def resolve_team_targets(
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
skills: list[str] | None = None,
|
||||
required_specialists: list[str] | None = None,
|
||||
) -> ResolvedTeamPlan:
|
||||
"""解析一次 team run 的目标 agent。
|
||||
|
||||
Args:
|
||||
task: 用户原始任务,用于 LLM 选 agent 和 specialist provision prompt。
|
||||
skills: 本次任务要求携带的技能列表,会传给新 provision 的 specialist。
|
||||
required_specialists: 上游 planner 推导出的角色需求。例如来自
|
||||
AutoSwarmBuilder config 的 agent description,或 skills 的简单映射。
|
||||
|
||||
Returns:
|
||||
ResolvedTeamPlan: 包含已复用 agent、已 provision agent、最终执行目标、
|
||||
选择理由和审计 metadata。
|
||||
"""
|
||||
# 清理空字符串/空白角色,避免后续创建出没有意义的 specialist。
|
||||
required = [item for item in (required_specialists or []) if str(item).strip()]
|
||||
|
||||
# 直接读取 registry 当前所有可见 agent,再过滤掉 router、planner、
|
||||
# local-subagent 这类不适合作为 swarms/group worker 的 agent。
|
||||
suggestions = [
|
||||
agent
|
||||
for agent in self.registry.list_agents(include_local_fallback=False)
|
||||
if self._is_group_worker_candidate(agent)
|
||||
]
|
||||
|
||||
# selected: 从 registry 复用的已有 agent id。
|
||||
# covered_roles: 哪些 required role 已经被已有 agent 覆盖,用于 metadata。
|
||||
# provisioned: 为缺失角色新建/确保存在的本地 specialist id。
|
||||
# created_provisioned: 本次 run 真正新建出来的 specialist id;后续自动清理只看它,
|
||||
# 避免把之前已经存在、只是被刷新/复用的 specialist 误删。
|
||||
# actions: provision 审计记录,方便上层解释“为什么创建了某个 agent”。
|
||||
selected: list[str] = []
|
||||
covered_roles: list[str] = []
|
||||
provisioned: list[str] = []
|
||||
created_provisioned: list[str] = []
|
||||
actions: list[dict[str, str]] = []
|
||||
|
||||
if required:
|
||||
# 调用方给出了明确角色时,不再做本地词法规则匹配,而是直接把
|
||||
# role + task + 候选 agent 交给 LLM 判断最适合复用哪个已有 agent。
|
||||
# 这里切片是为了遵守 max_parallel_agents 上限。
|
||||
for role in required[: self.max_parallel_agents]:
|
||||
existing = await self._select_existing_for_role_with_llm(
|
||||
task=task,
|
||||
role=role,
|
||||
suggestions=suggestions,
|
||||
selected=selected,
|
||||
)
|
||||
if existing is not None:
|
||||
selected.append(existing.id)
|
||||
covered_roles.append(role)
|
||||
continue
|
||||
provision_result = await self.provisioning.ensure_local_specialist_with_result(
|
||||
role=role,
|
||||
task=task,
|
||||
skills=skills or [],
|
||||
)
|
||||
agent_id = provision_result.agent_id
|
||||
provisioned.append(agent_id)
|
||||
if provision_result.created:
|
||||
created_provisioned.append(agent_id)
|
||||
actions.append({
|
||||
"action": "ensure_local_specialist",
|
||||
"role": role,
|
||||
"agent_id": agent_id,
|
||||
"created": str(provision_result.created).lower(),
|
||||
})
|
||||
else:
|
||||
# 没有明确角色需求时,直接使用当前可见的已有 agent,最多取并行上限。
|
||||
selected = [agent.id for agent in suggestions[: self.max_parallel_agents]]
|
||||
if not selected:
|
||||
# 当前 registry 没有可用 worker 时,创建一个通用 specialist 作为最低可执行兜底。
|
||||
provision_result = await self.provisioning.ensure_local_specialist_with_result(
|
||||
role="general specialist",
|
||||
task=task,
|
||||
skills=skills or [],
|
||||
)
|
||||
agent_id = provision_result.agent_id
|
||||
provisioned.append(agent_id)
|
||||
if provision_result.created:
|
||||
created_provisioned.append(agent_id)
|
||||
actions.append({
|
||||
"action": "ensure_local_specialist",
|
||||
"role": "general specialist",
|
||||
"agent_id": agent_id,
|
||||
"created": str(provision_result.created).lower(),
|
||||
})
|
||||
|
||||
# 合并已有 agent 和新 provision 的 agent:
|
||||
# - dict.fromkeys 保留顺序并去重,避免同一个 agent 被重复加入;
|
||||
# - 最后再次截断,防止 selected + provisioned 总数超过并行上限。
|
||||
final_targets = list(dict.fromkeys([*selected, *provisioned]))[: self.max_parallel_agents]
|
||||
|
||||
# selection_reason 是给上层/日志展示的粗粒度解释,metadata 里会保留更细的明细。
|
||||
reason = (
|
||||
"已选择现有 registry agent。"
|
||||
if selected and not provisioned
|
||||
else "已选择现有 registry agent,并为缺失角色补充了 specialist。"
|
||||
if selected and provisioned
|
||||
else "没有匹配到合适的现有 agent,已补充本地 A2A specialist。"
|
||||
if provisioned
|
||||
else "没有匹配到合适的现有 agent,且未补充任何 specialist。"
|
||||
)
|
||||
logger.info(
|
||||
"Resolved agent-team targets selected={} provisioned={} final={}",
|
||||
selected,
|
||||
provisioned,
|
||||
final_targets,
|
||||
)
|
||||
|
||||
# ResolvedTeamPlan 是后续 orchestrator/swarms planner 使用的稳定边界:
|
||||
# final_targets 用于实际执行,selected/provisioned/actions/metadata 用于解释和调试。
|
||||
return ResolvedTeamPlan(
|
||||
selected_existing_targets=selected,
|
||||
provisioned_targets=provisioned,
|
||||
created_provisioned_targets=created_provisioned,
|
||||
final_targets=final_targets,
|
||||
selection_reason=reason,
|
||||
provision_actions=actions,
|
||||
metadata={
|
||||
"required_specialists": required,
|
||||
"available_agent_count": len(suggestions),
|
||||
"covered_roles": covered_roles,
|
||||
"created_provisioned_targets": created_provisioned,
|
||||
"max_parallel_agents": self.max_parallel_agents,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_group_worker_candidate(agent: AgentDescriptor) -> bool:
|
||||
"""判断一个 registry agent 是否适合作为 team/group worker。
|
||||
|
||||
router/planner 类 agent 通常负责调度,不应被当作普通成员加入 GroupChat 或
|
||||
swarms worker 列表;local-subagent 是通用本地代理入口,也避免在这里重复选中。
|
||||
"""
|
||||
probe = " ".join([
|
||||
agent.id,
|
||||
agent.name,
|
||||
agent.description,
|
||||
" ".join(agent.tags),
|
||||
" ".join(agent.aliases),
|
||||
]).lower()
|
||||
if agent.id == "local-subagent":
|
||||
return False
|
||||
return not any(marker in probe for marker in ("chat-router", "router", "planner"))
|
||||
|
||||
async def _select_existing_for_role_with_llm(
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
role: str,
|
||||
suggestions: list[AgentDescriptor],
|
||||
selected: list[str],
|
||||
) -> AgentDescriptor | None:
|
||||
"""让 LLM 从已有候选 agent 中为 role 选择最合适的一个。"""
|
||||
candidates = [agent for agent in suggestions if agent.id not in selected]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
|
||||
lines = []
|
||||
for agent in candidates:
|
||||
tags = ", ".join(agent.tags) if agent.tags else "none"
|
||||
aliases = ", ".join(agent.aliases) if agent.aliases else "none"
|
||||
lines.append(
|
||||
f"- id: {agent.id}\n"
|
||||
f" name: {agent.name}\n"
|
||||
f" description: {agent.description}\n"
|
||||
f" tags: {tags}\n"
|
||||
f" aliases: {aliases}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.provider.chat(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You select one existing agent for a required team role.\n"
|
||||
"Return exactly one agent id from the candidate list, or NONE.\n"
|
||||
"Do not explain your reasoning."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Task:\n{task}\n\n"
|
||||
f"Required role:\n{role}\n\n"
|
||||
"Candidates:\n"
|
||||
f"{chr(10).join(lines)}\n\n"
|
||||
"Return exactly one candidate id, or NONE if none of them clearly fits."
|
||||
),
|
||||
},
|
||||
],
|
||||
model=self.model,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("LLM role selection failed for role '{}': {}", role, exc)
|
||||
return None
|
||||
|
||||
raw = str(response.content or "").strip()
|
||||
choice = raw.splitlines()[0].strip().strip("`'\"") if raw else ""
|
||||
candidate_map = {agent.id: agent for agent in candidates}
|
||||
if choice in candidate_map:
|
||||
return candidate_map[choice]
|
||||
if choice.upper() not in {"", "NONE"}:
|
||||
logger.info("LLM role selection returned unknown agent id '{}' for role '{}'", choice, role)
|
||||
return None
|
||||
@ -1,546 +0,0 @@
|
||||
"""Agent Team swarms 适配层的共享类型定义。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.run_result import AgentRunResult
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
"""返回统一格式的 UTC 时间戳字符串。
|
||||
|
||||
Demo 输出:
|
||||
`2026-03-31T12:00:00.000000+00:00`
|
||||
"""
|
||||
# 统一使用 UTC,避免跨机器或跨时区比较 run/procedure 时间时出现歧义。
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def new_record_id(prefix: str) -> str:
|
||||
"""为 memory 记录生成短 ID。
|
||||
|
||||
Demo 输出:
|
||||
`procedure-3fa2c7b1`
|
||||
"""
|
||||
# 这里保留可读前缀,方便磁盘文件、日志和测试断言定位数据来源。
|
||||
return f"{prefix}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def agent_result_to_dict(result: AgentRunResult) -> dict[str, Any]:
|
||||
"""把 `AgentRunResult` 转成可 JSON 序列化的字典。
|
||||
|
||||
Demo 输出:
|
||||
`{"agent_id": "writer", "agent_name": "Writer", "status": "ok", "summary": "...", "raw": {}}`
|
||||
"""
|
||||
# `raw` 允许为空,这里统一转成字典或 None,避免后续序列化分支散落各处。
|
||||
return {
|
||||
"agent_id": result.agent_id,
|
||||
"agent_name": result.agent_name,
|
||||
"status": result.status,
|
||||
"summary": result.summary,
|
||||
"raw": result.raw,
|
||||
}
|
||||
|
||||
|
||||
def agent_result_from_dict(payload: dict[str, Any]) -> AgentRunResult:
|
||||
"""从字典重建 `AgentRunResult`。
|
||||
|
||||
Demo 输出:
|
||||
`AgentRunResult(agent_id="writer", agent_name="Writer", status="ok", summary="...", raw=None)`
|
||||
"""
|
||||
# 所有字段都做最小兜底,防止历史磁盘记录缺字段时直接炸掉整个读取流程。
|
||||
return AgentRunResult(
|
||||
agent_id=str(payload.get("agent_id") or "unknown-agent"),
|
||||
agent_name=str(payload.get("agent_name") or payload.get("agent_id") or "Unknown Agent"),
|
||||
status=str(payload.get("status") or "error"),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
raw=payload.get("raw") if isinstance(payload.get("raw"), dict) else None,
|
||||
)
|
||||
|
||||
|
||||
class ExecutionMode(str, Enum):
|
||||
"""编排器支持的执行模式。"""
|
||||
|
||||
SWARMS = "swarms"
|
||||
|
||||
|
||||
def parse_execution_mode(value: Any, default: ExecutionMode = ExecutionMode.SWARMS) -> ExecutionMode:
|
||||
"""把持久化里的 mode 字符串解析成 ExecutionMode。"""
|
||||
raw = str(value or default.value)
|
||||
try:
|
||||
return ExecutionMode(raw)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ResolvedTeamPlan:
|
||||
"""最终执行前解析出的成员计划。"""
|
||||
|
||||
selected_existing_targets: list[str] = field(default_factory=list)
|
||||
provisioned_targets: list[str] = field(default_factory=list)
|
||||
created_provisioned_targets: list[str] = field(default_factory=list)
|
||||
final_targets: list[str] = field(default_factory=list)
|
||||
selection_reason: str = ""
|
||||
provision_actions: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"selected_existing_targets": list(self.selected_existing_targets),
|
||||
"provisioned_targets": list(self.provisioned_targets),
|
||||
"created_provisioned_targets": list(self.created_provisioned_targets),
|
||||
"final_targets": list(self.final_targets),
|
||||
"selection_reason": self.selection_reason,
|
||||
"provision_actions": [dict(item) for item in self.provision_actions],
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "ResolvedTeamPlan":
|
||||
return cls(
|
||||
selected_existing_targets=[
|
||||
str(item)
|
||||
for item in payload.get("selected_existing_targets", [])
|
||||
if str(item).strip()
|
||||
],
|
||||
provisioned_targets=[
|
||||
str(item)
|
||||
for item in payload.get("provisioned_targets", [])
|
||||
if str(item).strip()
|
||||
],
|
||||
created_provisioned_targets=[
|
||||
str(item)
|
||||
for item in payload.get("created_provisioned_targets", [])
|
||||
if str(item).strip()
|
||||
],
|
||||
final_targets=[
|
||||
str(item)
|
||||
for item in payload.get("final_targets", [])
|
||||
if str(item).strip()
|
||||
],
|
||||
selection_reason=str(payload.get("selection_reason") or ""),
|
||||
provision_actions=[
|
||||
dict(item)
|
||||
for item in payload.get("provision_actions", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SwarmsRunSpec:
|
||||
"""nanobot 交给 swarms runtime 的最小运行规格。"""
|
||||
|
||||
task: str
|
||||
label: str
|
||||
skills: list[str]
|
||||
swarm_type: str
|
||||
agent_ids: list[str]
|
||||
auto_generated: bool = False
|
||||
max_loops: int = 2
|
||||
rearrange_flow: str | None = None
|
||||
rules: str | None = None
|
||||
raw_auto_config: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"task": self.task,
|
||||
"label": self.label,
|
||||
"skills": list(self.skills),
|
||||
"swarm_type": self.swarm_type,
|
||||
"agent_ids": list(self.agent_ids),
|
||||
"auto_generated": self.auto_generated,
|
||||
"max_loops": self.max_loops,
|
||||
"rearrange_flow": self.rearrange_flow,
|
||||
"rules": self.rules,
|
||||
"raw_auto_config": dict(self.raw_auto_config),
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "SwarmsRunSpec":
|
||||
return cls(
|
||||
task=str(payload.get("task") or ""),
|
||||
label=str(payload.get("label") or ""),
|
||||
skills=[str(item) for item in payload.get("skills", []) if str(item).strip()],
|
||||
swarm_type=str(payload.get("swarm_type") or "GroupChat"),
|
||||
agent_ids=[str(item) for item in payload.get("agent_ids", []) if str(item).strip()],
|
||||
auto_generated=bool(payload.get("auto_generated", False)),
|
||||
max_loops=max(1, int(payload.get("max_loops") or 2)),
|
||||
rearrange_flow=str(payload["rearrange_flow"]) if payload.get("rearrange_flow") else None,
|
||||
rules=str(payload["rules"]) if payload.get("rules") else None,
|
||||
raw_auto_config=payload.get("raw_auto_config") if isinstance(payload.get("raw_auto_config"), dict) else {},
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SwarmsRunResult:
|
||||
"""swarms runtime 的原始输出归一化前结果。"""
|
||||
|
||||
success: bool
|
||||
summary: str
|
||||
raw_output: Any
|
||||
error: str | None = None
|
||||
member_results: list[AgentRunResult] = field(default_factory=list)
|
||||
transcript: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"summary": self.summary,
|
||||
"raw_output": self.raw_output,
|
||||
"error": self.error,
|
||||
"member_results": [agent_result_to_dict(item) for item in self.member_results],
|
||||
"transcript": [dict(item) for item in self.transcript],
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "SwarmsRunResult":
|
||||
return cls(
|
||||
success=bool(payload.get("success", False)),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
raw_output=payload.get("raw_output"),
|
||||
error=str(payload["error"]) if payload.get("error") else None,
|
||||
member_results=[
|
||||
agent_result_from_dict(item)
|
||||
for item in payload.get("member_results", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
transcript=[
|
||||
dict(item)
|
||||
for item in payload.get("transcript", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProcedureRecord:
|
||||
"""一条可复用的 procedure 记录。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureRecord(id='procedure-a1b2c3d4', task_template='生成周报', agent_ids=['writer-agent'], strategy='single', confidence=0.65, ...)`
|
||||
"""
|
||||
|
||||
# 稳定主键会被 `RunMemory` 和公告信息引用。
|
||||
id: str = field(default_factory=lambda: new_record_id("procedure"))
|
||||
# 原始任务模板用于向后续执行注入“之前学到的做法”。
|
||||
task_template: str = ""
|
||||
# 一句话总结这个 procedure 适用的场景和执行方式。
|
||||
summary: str = ""
|
||||
# swarms bridge 会按这里列出的 agent 顺序/组合执行。
|
||||
agent_ids: list[str] = field(default_factory=list)
|
||||
# 第一版只实现 `single | parallel` 两种策略。
|
||||
strategy: str = "parallel"
|
||||
# 用简单关键词做粗粒度匹配,避免引入重型向量索引。
|
||||
task_keywords: list[str] = field(default_factory=list)
|
||||
# 置信度用于后续复用和人工排查。
|
||||
confidence: float = 0.5
|
||||
# 成功/失败计数用来估算 failure rate。
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
# 便于追踪该 procedure 从哪次探索 run 学来。
|
||||
source_run_id: str | None = None
|
||||
# 标准时间字段全部保留,方便 UI 或后续排序扩展。
|
||||
created_at: str = field(default_factory=now_iso)
|
||||
updated_at: str = field(default_factory=now_iso)
|
||||
last_used_at: str | None = None
|
||||
# 额外扩展字段集中收口到 metadata,避免频繁改 schema。
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def failure_rate(self) -> float:
|
||||
"""计算该 procedure 的累计失败率。
|
||||
|
||||
Demo 输出:
|
||||
`0.25`
|
||||
"""
|
||||
# 没有历史执行时直接返回 0,避免“新 procedure 天生失败率 100%”的误判。
|
||||
total = self.success_count + self.failure_count
|
||||
if total <= 0:
|
||||
return 0.0
|
||||
return self.failure_count / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""把 procedure 记录转成字典。
|
||||
|
||||
Demo 输出:
|
||||
`{"id": "procedure-a1b2c3d4", "strategy": "parallel", "agent_ids": ["agent-a", "agent-b"], ...}`
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"task_template": self.task_template,
|
||||
"summary": self.summary,
|
||||
"agent_ids": list(self.agent_ids),
|
||||
"strategy": self.strategy,
|
||||
"task_keywords": list(self.task_keywords),
|
||||
"confidence": self.confidence,
|
||||
"success_count": self.success_count,
|
||||
"failure_count": self.failure_count,
|
||||
"source_run_id": self.source_run_id,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"last_used_at": self.last_used_at,
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "ProcedureRecord":
|
||||
"""从字典重建 procedure 记录。
|
||||
|
||||
Demo 输出:
|
||||
`ProcedureRecord(id='procedure-a1b2c3d4', task_template='生成周报', ...)`
|
||||
"""
|
||||
return cls(
|
||||
id=str(payload.get("id") or new_record_id("procedure")),
|
||||
task_template=str(payload.get("task_template") or ""),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
agent_ids=[str(item) for item in payload.get("agent_ids", []) if str(item).strip()],
|
||||
strategy=str(payload.get("strategy") or "parallel"),
|
||||
task_keywords=[
|
||||
str(item)
|
||||
for item in payload.get("task_keywords", [])
|
||||
if str(item).strip()
|
||||
],
|
||||
confidence=float(payload.get("confidence") or 0.5),
|
||||
success_count=int(payload.get("success_count") or 0),
|
||||
failure_count=int(payload.get("failure_count") or 0),
|
||||
source_run_id=str(payload["source_run_id"]) if payload.get("source_run_id") else None,
|
||||
created_at=str(payload.get("created_at") or now_iso()),
|
||||
updated_at=str(payload.get("updated_at") or now_iso()),
|
||||
last_used_at=str(payload["last_used_at"]) if payload.get("last_used_at") else None,
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RunRecord:
|
||||
"""一次 agent team 运行的持久化记录。
|
||||
|
||||
Demo 输出:
|
||||
`RunRecord(id='run-1a2b3c4d', task='生成周报', mode=<ExecutionMode.SWARMS: 'swarms'>, success=True, ...)`
|
||||
"""
|
||||
|
||||
# run 记录也使用短 ID,便于文件和日志双向检索。
|
||||
id: str = field(default_factory=lambda: new_record_id("run"))
|
||||
# 原始任务文本是最重要的回溯信息,必须完整保留。
|
||||
task: str = ""
|
||||
# 执行模式会用于后续做简单统计和问题排查。
|
||||
mode: ExecutionMode = ExecutionMode.SWARMS
|
||||
# 归一化成功标记。
|
||||
success: bool = False
|
||||
# 最终摘要可直接展示在运维面板或调试脚本里。
|
||||
summary: str = ""
|
||||
# 失败时保留错误信息;成功时为 None。
|
||||
error: str | None = None
|
||||
# 命中的 procedure 主键,没有命中则为空。
|
||||
procedure_id: str | None = None
|
||||
# 记录创建时间。
|
||||
created_at: str = field(default_factory=now_iso)
|
||||
# metadata 会保存 attempts、raw 等调试信息。
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""把 run 记录转成字典。
|
||||
|
||||
Demo 输出:
|
||||
`{"id": "run-1a2b3c4d", "mode": "swarms", "success": true, ...}`
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"task": self.task,
|
||||
"mode": self.mode.value,
|
||||
"success": self.success,
|
||||
"summary": self.summary,
|
||||
"error": self.error,
|
||||
"procedure_id": self.procedure_id,
|
||||
"created_at": self.created_at,
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "RunRecord":
|
||||
"""从字典重建 run 记录。
|
||||
|
||||
Demo 输出:
|
||||
`RunRecord(id='run-1a2b3c4d', task='生成周报', mode=<ExecutionMode.SWARMS: 'swarms'>, ...)`
|
||||
"""
|
||||
return cls(
|
||||
id=str(payload.get("id") or new_record_id("run")),
|
||||
task=str(payload.get("task") or ""),
|
||||
mode=parse_execution_mode(payload.get("mode")),
|
||||
success=bool(payload.get("success", False)),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
error=str(payload["error"]) if payload.get("error") else None,
|
||||
procedure_id=str(payload["procedure_id"]) if payload.get("procedure_id") else None,
|
||||
created_at=str(payload.get("created_at") or now_iso()),
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BridgeAttempt:
|
||||
"""单次 bridge 执行尝试的归一化结果。
|
||||
|
||||
Demo 输出:
|
||||
`BridgeAttempt(mode=<ExecutionMode.SWARMS: 'swarms'>, success=False, summary='执行失败', error='timeout', targets=['writer-agent'])`
|
||||
"""
|
||||
|
||||
# 记录尝试来自哪个 bridge,便于 swarms 链路审计。
|
||||
mode: ExecutionMode
|
||||
# 是否成功决定最终团队结果状态。
|
||||
success: bool
|
||||
# 本次尝试的聚合摘要。
|
||||
summary: str
|
||||
# 若失败,则记录错误原因。
|
||||
error: str | None = None
|
||||
# 保留成员级结果,供公告和测试直接读取。
|
||||
member_results: list[AgentRunResult] = field(default_factory=list)
|
||||
# 记录本次尝试的目标 agent。
|
||||
targets: list[str] = field(default_factory=list)
|
||||
# 透传底层调试字段。
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""把单次尝试转成字典。
|
||||
|
||||
Demo 输出:
|
||||
`{"mode": "swarms", "success": false, "targets": ["writer-agent"], ...}`
|
||||
"""
|
||||
return {
|
||||
"mode": self.mode.value,
|
||||
"success": self.success,
|
||||
"summary": self.summary,
|
||||
"error": self.error,
|
||||
"member_results": [agent_result_to_dict(item) for item in self.member_results],
|
||||
"targets": list(self.targets),
|
||||
"raw": dict(self.raw),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "BridgeAttempt":
|
||||
"""从字典重建单次尝试。
|
||||
|
||||
Demo 输出:
|
||||
`BridgeAttempt(mode=<ExecutionMode.SWARMS: 'swarms'>, success=True, summary='swarms 完成', ...)`
|
||||
"""
|
||||
return cls(
|
||||
mode=parse_execution_mode(payload.get("mode")),
|
||||
success=bool(payload.get("success", False)),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
error=str(payload["error"]) if payload.get("error") else None,
|
||||
member_results=[
|
||||
agent_result_from_dict(item)
|
||||
for item in payload.get("member_results", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
targets=[str(item) for item in payload.get("targets", []) if str(item).strip()],
|
||||
raw=payload.get("raw") if isinstance(payload.get("raw"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BridgeResult:
|
||||
"""统一封装 `SwarmsBridge` 的最终输出。
|
||||
|
||||
Demo 输出:
|
||||
`BridgeResult(mode=<ExecutionMode.SWARMS: 'swarms'>, success=True, summary='swarms 已完成', ...)`
|
||||
"""
|
||||
|
||||
# 最终采用的执行模式。
|
||||
mode: ExecutionMode
|
||||
# 编排结果是否成功。
|
||||
success: bool
|
||||
# 最终可展示摘要。
|
||||
summary: str
|
||||
# 失败时的归一化错误说明。
|
||||
error: str | None = None
|
||||
# 当前结果对应的成员结果,一般取最终一次 attempt。
|
||||
member_results: list[AgentRunResult] = field(default_factory=list)
|
||||
# 探索阶段提炼出的候选 procedure。
|
||||
candidate_procedure: ProcedureRecord | None = None
|
||||
# 命中的历史 procedure,便于公告和 run 记录追踪。
|
||||
matched_procedure: ProcedureRecord | None = None
|
||||
# 支持记录多次尝试,便于后续扩展到 swarms 内部多阶段路由。
|
||||
attempts: list[BridgeAttempt] = field(default_factory=list)
|
||||
# 原始调试字段统一放在这里。
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def last_member_results(self) -> list[AgentRunResult]:
|
||||
"""返回最后一次有成员结果的 attempt。
|
||||
|
||||
Demo 输出:
|
||||
`[AgentRunResult(agent_id='writer-agent', agent_name='Writer Agent', status='ok', summary='...', raw=None)]`
|
||||
"""
|
||||
# 优先使用显式写入的最终成员结果,避免每次都从 attempts 倒推。
|
||||
if self.member_results:
|
||||
return list(self.member_results)
|
||||
# 若最终结果没显式写入,则从最后一个有成员结果的 attempt 回退。
|
||||
for attempt in reversed(self.attempts):
|
||||
if attempt.member_results:
|
||||
return list(attempt.member_results)
|
||||
return []
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""把 bridge 结果转成字典。
|
||||
|
||||
Demo 输出:
|
||||
`{"mode": "exploration", "success": true, "attempts": [...], "candidate_procedure": {...}}`
|
||||
"""
|
||||
return {
|
||||
"mode": self.mode.value,
|
||||
"success": self.success,
|
||||
"summary": self.summary,
|
||||
"error": self.error,
|
||||
"member_results": [agent_result_to_dict(item) for item in self.member_results],
|
||||
"candidate_procedure": self.candidate_procedure.to_dict() if self.candidate_procedure else None,
|
||||
"matched_procedure": self.matched_procedure.to_dict() if self.matched_procedure else None,
|
||||
"attempts": [attempt.to_dict() for attempt in self.attempts],
|
||||
"raw": dict(self.raw),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> "BridgeResult":
|
||||
"""从字典重建 bridge 结果。
|
||||
|
||||
Demo 输出:
|
||||
`BridgeResult(mode=<ExecutionMode.SWARMS: 'swarms'>, success=False, summary='执行失败', ...)`
|
||||
"""
|
||||
return cls(
|
||||
mode=parse_execution_mode(payload.get("mode")),
|
||||
success=bool(payload.get("success", False)),
|
||||
summary=str(payload.get("summary") or ""),
|
||||
error=str(payload["error"]) if payload.get("error") else None,
|
||||
member_results=[
|
||||
agent_result_from_dict(item)
|
||||
for item in payload.get("member_results", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
candidate_procedure=(
|
||||
ProcedureRecord.from_dict(payload["candidate_procedure"])
|
||||
if isinstance(payload.get("candidate_procedure"), dict)
|
||||
else None
|
||||
),
|
||||
matched_procedure=(
|
||||
ProcedureRecord.from_dict(payload["matched_procedure"])
|
||||
if isinstance(payload.get("matched_procedure"), dict)
|
||||
else None
|
||||
),
|
||||
attempts=[
|
||||
BridgeAttempt.from_dict(item)
|
||||
for item in payload.get("attempts", [])
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
raw=payload.get("raw") if isinstance(payload.get("raw"), dict) else {},
|
||||
)
|
||||
@ -1,5 +0,0 @@
|
||||
"""AuthZ service helpers."""
|
||||
|
||||
from nanobot.authz.client import AuthzClient
|
||||
|
||||
__all__ = ["AuthzClient"]
|
||||
@ -1,212 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BackendRegistrationResult:
|
||||
backend_id: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
created_at: str
|
||||
frontend_base_url: str | None = None
|
||||
|
||||
|
||||
class AuthzClient:
|
||||
def __init__(self, base_url: str, timeout_seconds: int = 10):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json_body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
# Internal AuthZ calls should not inherit shell proxy env vars.
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.timeout_seconds,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
method,
|
||||
f"{self.base_url}{path}",
|
||||
json=json_body,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.content:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
async def register_backend(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
base_url: str,
|
||||
frontend_base_url: str | None = None,
|
||||
backend_id: str | None = None,
|
||||
) -> BackendRegistrationResult:
|
||||
payload = {"name": name, "base_url": base_url}
|
||||
if backend_id:
|
||||
payload["backend_id"] = backend_id
|
||||
if frontend_base_url:
|
||||
payload["frontend_base_url"] = frontend_base_url
|
||||
data = await self._request("POST", "/backends/register", json_body=payload)
|
||||
return BackendRegistrationResult(
|
||||
backend_id=str(data["backend_id"]),
|
||||
client_id=str(data["client_id"]),
|
||||
client_secret=str(data["client_secret"]),
|
||||
created_at=str(data["created_at"]),
|
||||
frontend_base_url=str(data.get("frontend_base_url") or "").strip() or None,
|
||||
)
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
password: str,
|
||||
email: str | None = None,
|
||||
backend_name: str | None = None,
|
||||
backend_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
frontend_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
if email:
|
||||
payload["email"] = email
|
||||
|
||||
backend_payload: dict[str, Any] = {}
|
||||
if backend_name:
|
||||
payload["name"] = backend_name
|
||||
payload["backend_name"] = backend_name
|
||||
backend_payload["name"] = backend_name
|
||||
if backend_id:
|
||||
payload["backend_id"] = backend_id
|
||||
backend_payload["backend_id"] = backend_id
|
||||
if base_url:
|
||||
payload["base_url"] = base_url
|
||||
payload["public_base_url"] = base_url
|
||||
backend_payload["base_url"] = base_url
|
||||
if frontend_base_url:
|
||||
payload["frontend_base_url"] = frontend_base_url
|
||||
backend_payload["frontend_base_url"] = frontend_base_url
|
||||
|
||||
if backend_payload:
|
||||
payload["backend"] = backend_payload
|
||||
|
||||
data = await self._request("POST", "/oauth/register", json_body=payload)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def list_backends(self) -> list[dict[str, Any]]:
|
||||
data = await self._request("GET", "/backends")
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
async def get_backend(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def update_backend(
|
||||
self,
|
||||
backend_id: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
base_url: str | None = None,
|
||||
frontend_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
if name:
|
||||
payload["name"] = name
|
||||
if base_url:
|
||||
payload["base_url"] = base_url
|
||||
if frontend_base_url:
|
||||
payload["frontend_base_url"] = frontend_base_url
|
||||
data = await self._request("PUT", f"/backends/{backend_id}", json_body=payload)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def disable_backend(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/disable")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def enable_backend(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/enable")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def rotate_secret(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/rotate-secret")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def get_permissions(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}/permissions")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def set_permissions(self, backend_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/permissions", json_body=payload)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def get_outlook_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}/settings/outlook")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def set_outlook_settings(self, backend_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/settings/outlook", json_body=payload)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def delete_outlook_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("DELETE", f"/backends/{backend_id}/settings/outlook")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def list_channel_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}/settings/channels")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def get_channel_settings(self, backend_id: str, channel_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}/settings/channels/{channel_id}")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def set_channel_settings(
|
||||
self,
|
||||
backend_id: str,
|
||||
channel_id: str,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
data = await self._request(
|
||||
"POST",
|
||||
f"/backends/{backend_id}/settings/channels/{channel_id}",
|
||||
json_body=payload,
|
||||
)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def delete_channel_settings(self, backend_id: str, channel_id: str) -> dict[str, Any]:
|
||||
data = await self._request("DELETE", f"/backends/{backend_id}/settings/channels/{channel_id}")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def issue_token(
|
||||
self,
|
||||
*,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
audience: str,
|
||||
scopes: list[str],
|
||||
) -> dict[str, Any]:
|
||||
data = await self._request(
|
||||
"POST",
|
||||
"/oauth/token",
|
||||
json_body={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"aud": audience,
|
||||
"scopes": scopes,
|
||||
},
|
||||
)
|
||||
return data if isinstance(data, dict) else {}
|
||||
@ -1,6 +0,0 @@
|
||||
"""Message bus module for decoupled channel-agent communication."""
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
__all__ = ["MessageBus", "InboundMessage", "OutboundMessage"]
|
||||
@ -1,38 +0,0 @@
|
||||
"""Event types for the message bus."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class InboundMessage:
|
||||
"""Message received from a chat channel."""
|
||||
|
||||
channel: str # telegram, discord, slack, whatsapp
|
||||
sender_id: str # User identifier
|
||||
chat_id: str # Chat/channel identifier
|
||||
content: str # Message text
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
media: list[str] = field(default_factory=list) # Media URLs
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
||||
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
||||
|
||||
@property
|
||||
def session_key(self) -> str:
|
||||
"""Unique key for session identification."""
|
||||
return self.session_key_override or f"{self.channel}:{self.chat_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutboundMessage:
|
||||
"""Message to send to a chat channel."""
|
||||
|
||||
channel: str
|
||||
chat_id: str
|
||||
content: str
|
||||
reply_to: str | None = None
|
||||
media: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
"""消息总线(MessageBus):用异步队列解耦“渠道层”和“Agent 核心层”。
|
||||
|
||||
核心思想:
|
||||
1. 渠道(Telegram/Discord/CLI 等)只负责收发消息,不直接调用 Agent 内部逻辑
|
||||
2. Agent 只关心“从入站队列取消息、处理后写回出站队列”
|
||||
3. 通过队列实现生产者/消费者解耦,提升并发稳定性与可维护性
|
||||
|
||||
为什么需要两个队列:
|
||||
- inbound:渠道 -> Agent
|
||||
- outbound:Agent -> 渠道
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""
|
||||
异步消息总线。
|
||||
|
||||
典型流转:
|
||||
- 渠道监听到用户消息后调用 `publish_inbound`
|
||||
- Agent 主循环调用 `consume_inbound` 拿到消息并处理
|
||||
- Agent 产出回复后调用 `publish_outbound`
|
||||
- 渠道管理器调用 `consume_outbound` 并把回复发送到对应平台
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 入站队列:存放所有“用户 -> Agent”的消息事件。
|
||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
# 出站队列:存放所有“Agent -> 用户”的回复事件。
|
||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||
|
||||
async def publish_inbound(self, msg: InboundMessage) -> None:
|
||||
"""发布入站消息(由渠道层调用)。
|
||||
|
||||
参数:
|
||||
- msg: 一个 InboundMessage,包含 channel/sender/chat_id/content 等信息
|
||||
"""
|
||||
# put 是异步的:当队列受限时可自然背压;当前默认无长度上限。
|
||||
await self.inbound.put(msg)
|
||||
|
||||
async def consume_inbound(self) -> InboundMessage:
|
||||
"""消费下一条入站消息(由 Agent 主循环调用)。
|
||||
|
||||
行为:
|
||||
- 若队列为空会等待(阻塞当前协程,不阻塞事件循环)
|
||||
"""
|
||||
return await self.inbound.get()
|
||||
|
||||
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""发布出站消息(由 Agent 调用)。
|
||||
|
||||
参数:
|
||||
- msg: 一个 OutboundMessage,包含目标 channel/chat_id 与内容
|
||||
"""
|
||||
await self.outbound.put(msg)
|
||||
|
||||
async def consume_outbound(self) -> OutboundMessage:
|
||||
"""消费下一条出站消息(由渠道分发器调用)。
|
||||
|
||||
行为:
|
||||
- 若队列为空会等待,直到 Agent 写入新的回复
|
||||
"""
|
||||
return await self.outbound.get()
|
||||
|
||||
@property
|
||||
def inbound_size(self) -> int:
|
||||
"""当前入站队列长度(待处理消息数)。"""
|
||||
# 常用于监控/调试:判断是否出现消息堆积。
|
||||
return self.inbound.qsize()
|
||||
|
||||
@property
|
||||
def outbound_size(self) -> int:
|
||||
"""当前出站队列长度(待发送回复数)。"""
|
||||
return self.outbound.qsize()
|
||||
@ -1,6 +0,0 @@
|
||||
"""Chat channels module with plugin architecture."""
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
__all__ = ["BaseChannel", "ChannelManager"]
|
||||
@ -1,131 +0,0 @@
|
||||
"""Base channel interface for chat platforms."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
"""
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""
|
||||
Check if a sender is allowed to use this bot.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise.
|
||||
"""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
|
||||
# If no allow list, allow everyone
|
||||
if not allow_list:
|
||||
return True
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str in allow_list:
|
||||
return True
|
||||
if "|" in sender_str:
|
||||
for part in sender_str.split("|"):
|
||||
if part and part in allow_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
return self._running
|
||||
@ -1,247 +0,0 @@
|
||||
"""DingTalk/DingDing channel implementation using Stream Mode."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
import httpx
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
DingTalkStreamClient,
|
||||
Credential,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
AckMessage,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
DINGTALK_AVAILABLE = True
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
# Fallback so class definitions don't crash at module level
|
||||
CallbackHandler = object # type: ignore[assignment,misc]
|
||||
CallbackMessage = None # type: ignore[assignment,misc]
|
||||
AckMessage = None # type: ignore[assignment,misc]
|
||||
ChatbotMessage = None # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
class NanobotDingTalkHandler(CallbackHandler):
|
||||
"""
|
||||
Standard DingTalk Stream SDK Callback Handler.
|
||||
Parses incoming messages and forwards them to the Nanobot channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: "DingTalkChannel"):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
|
||||
async def process(self, message: CallbackMessage):
|
||||
"""Process incoming stream message."""
|
||||
try:
|
||||
# Parse using SDK's ChatbotMessage for robust handling
|
||||
chatbot_msg = ChatbotMessage.from_dict(message.data)
|
||||
|
||||
# Extract text content; fall back to raw dict if SDK object is empty
|
||||
content = ""
|
||||
if chatbot_msg.text:
|
||||
content = chatbot_msg.text.content.strip()
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Received empty or unsupported message type: {}",
|
||||
chatbot_msg.message_type,
|
||||
)
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(content, sender_id, sender_name)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing DingTalk message: {}", e)
|
||||
# Return OK to avoid retry loop from DingTalk server
|
||||
return AckMessage.STATUS_OK, "Error"
|
||||
|
||||
|
||||
class DingTalkChannel(BaseChannel):
|
||||
"""
|
||||
DingTalk channel using Stream Mode.
|
||||
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Note: Currently only supports private (1:1) chat. Group messages are
|
||||
received but replies are sent back as private messages to the sender.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
|
||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DingTalkConfig = config
|
||||
self._client: Any = None
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
# Access Token management for sending messages
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0
|
||||
|
||||
# Hold references to background tasks to prevent GC
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the DingTalk bot with Stream Mode."""
|
||||
try:
|
||||
if not DINGTALK_AVAILABLE:
|
||||
logger.error(
|
||||
"DingTalk Stream SDK not installed. Run: pip install dingtalk-stream"
|
||||
)
|
||||
return
|
||||
|
||||
if not self.config.client_id or not self.config.client_secret:
|
||||
logger.error("DingTalk client_id and client_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient()
|
||||
|
||||
logger.info(
|
||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||
self.config.client_id,
|
||||
)
|
||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||
self._client = DingTalkStreamClient(credential)
|
||||
|
||||
# Register standard handler
|
||||
handler = NanobotDingTalkHandler(self)
|
||||
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
||||
|
||||
logger.info("DingTalk bot started with Stream Mode")
|
||||
|
||||
# Reconnect loop: restart stream if SDK exits or crashes
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning("DingTalk stream error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DingTalk bot."""
|
||||
self._running = False
|
||||
# Close the shared HTTP client
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
# Cancel outstanding background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""Get or refresh Access Token."""
|
||||
if self._access_token and time.time() < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
data = {
|
||||
"appKey": self.config.client_id,
|
||||
"appSecret": self.config.client_secret,
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot refresh token")
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data)
|
||||
resp.raise_for_status()
|
||||
res_data = resp.json()
|
||||
self._access_token = res_data.get("accessToken")
|
||||
# Expire 60s early to be safe
|
||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||
return self._access_token
|
||||
except Exception as e:
|
||||
logger.error("Failed to get DingTalk access token: {}", e)
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through DingTalk."""
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return
|
||||
|
||||
# oToMessages/batchSend: sends to individual users (private chat)
|
||||
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
|
||||
data = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [msg.chat_id], # chat_id is the user's staffId
|
||||
"msgKey": "sampleMarkdown",
|
||||
"msgParam": json.dumps({
|
||||
"text": msg.content,
|
||||
"title": "Nanobot Reply",
|
||||
}, ensure_ascii=False),
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed: {}", resp.text)
|
||||
else:
|
||||
logger.debug("DingTalk message sent to {}", msg.chat_id)
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message: {}", e)
|
||||
|
||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
permission checks before publishing to the bus.
|
||||
"""
|
||||
try:
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
@ -1,301 +0,0 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if not content:
|
||||
return []
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos <= 0:
|
||||
pos = cut.rfind(' ')
|
||||
if pos <= 0:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
chunks = _split_message(msg.content or "")
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Only set reply reference on the first chunk
|
||||
if i == 0 and msg.reply_to:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": payload.get("guild_id"),
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
@ -1,404 +0,0 @@
|
||||
"""Email channel implementation using IMAP polling + SMTP replies."""
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import imaplib
|
||||
import re
|
||||
import smtplib
|
||||
import ssl
|
||||
from datetime import date
|
||||
from email import policy
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import EmailMessage
|
||||
from email.parser import BytesParser
|
||||
from email.utils import parseaddr
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import EmailConfig
|
||||
|
||||
|
||||
class EmailChannel(BaseChannel):
|
||||
"""
|
||||
Email channel.
|
||||
|
||||
Inbound:
|
||||
- Poll IMAP mailbox for unread messages.
|
||||
- Convert each message into an inbound event.
|
||||
|
||||
Outbound:
|
||||
- Send responses via SMTP back to the sender address.
|
||||
"""
|
||||
|
||||
name = "email"
|
||||
_IMAP_MONTHS = (
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
)
|
||||
|
||||
def __init__(self, config: EmailConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: EmailConfig = config
|
||||
self._last_subject_by_chat: dict[str, str] = {}
|
||||
self._last_message_id_by_chat: dict[str, str] = {}
|
||||
self._processed_uids: set[str] = set() # Capped to prevent unbounded growth
|
||||
self._MAX_PROCESSED_UIDS = 100000
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start polling IMAP for inbound emails."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning(
|
||||
"Email channel disabled: consent_granted is false. "
|
||||
"Set channels.email.consentGranted=true after explicit user permission."
|
||||
)
|
||||
return
|
||||
|
||||
if not self._validate_config():
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("Starting Email channel (IMAP polling mode)...")
|
||||
|
||||
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
||||
while self._running:
|
||||
try:
|
||||
inbound_items = await asyncio.to_thread(self._fetch_new_messages)
|
||||
for item in inbound_items:
|
||||
sender = item["sender"]
|
||||
subject = item.get("subject", "")
|
||||
message_id = item.get("message_id", "")
|
||||
|
||||
if subject:
|
||||
self._last_subject_by_chat[sender] = subject
|
||||
if message_id:
|
||||
self._last_message_id_by_chat[sender] = message_id
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender,
|
||||
chat_id=sender,
|
||||
content=item["content"],
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Email polling error: {}", e)
|
||||
|
||||
await asyncio.sleep(poll_seconds)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop polling loop."""
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send email via SMTP."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning("Skip email send: consent_granted is false")
|
||||
return
|
||||
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
if not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply: auto_reply_enabled is false")
|
||||
return
|
||||
|
||||
if not self.config.smtp_host:
|
||||
logger.warning("Email channel SMTP host not configured")
|
||||
return
|
||||
|
||||
to_addr = msg.chat_id.strip()
|
||||
if not to_addr:
|
||||
logger.warning("Email channel missing recipient address")
|
||||
return
|
||||
|
||||
base_subject = self._last_subject_by_chat.get(to_addr, "Boardware Genius reply")
|
||||
subject = self._reply_subject(base_subject)
|
||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||
override = msg.metadata["subject"].strip()
|
||||
if override:
|
||||
subject = override
|
||||
|
||||
email_msg = EmailMessage()
|
||||
email_msg["From"] = self.config.from_address or self.config.smtp_username or self.config.imap_username
|
||||
email_msg["To"] = to_addr
|
||||
email_msg["Subject"] = subject
|
||||
email_msg.set_content(msg.content or "")
|
||||
|
||||
in_reply_to = self._last_message_id_by_chat.get(to_addr)
|
||||
if in_reply_to:
|
||||
email_msg["In-Reply-To"] = in_reply_to
|
||||
email_msg["References"] = in_reply_to
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||
raise
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
missing = []
|
||||
if not self.config.imap_host:
|
||||
missing.append("imap_host")
|
||||
if not self.config.imap_username:
|
||||
missing.append("imap_username")
|
||||
if not self.config.imap_password:
|
||||
missing.append("imap_password")
|
||||
if not self.config.smtp_host:
|
||||
missing.append("smtp_host")
|
||||
if not self.config.smtp_username:
|
||||
missing.append("smtp_username")
|
||||
if not self.config.smtp_password:
|
||||
missing.append("smtp_password")
|
||||
|
||||
if missing:
|
||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||
return False
|
||||
return True
|
||||
|
||||
def _smtp_send(self, msg: EmailMessage) -> None:
|
||||
timeout = 30
|
||||
if self.config.smtp_use_ssl:
|
||||
with smtplib.SMTP_SSL(
|
||||
self.config.smtp_host,
|
||||
self.config.smtp_port,
|
||||
timeout=timeout,
|
||||
) as smtp:
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
return
|
||||
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port, timeout=timeout) as smtp:
|
||||
if self.config.smtp_use_tls:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
|
||||
def _fetch_new_messages(self) -> list[dict[str, Any]]:
|
||||
"""Poll IMAP and return parsed unread messages."""
|
||||
return self._fetch_messages(
|
||||
search_criteria=("UNSEEN",),
|
||||
mark_seen=self.config.mark_seen,
|
||||
dedupe=True,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
def fetch_messages_between_dates(
|
||||
self,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch messages in [start_date, end_date) by IMAP date search.
|
||||
|
||||
This is used for historical summarization tasks (e.g. "yesterday").
|
||||
"""
|
||||
if end_date <= start_date:
|
||||
return []
|
||||
|
||||
return self._fetch_messages(
|
||||
search_criteria=(
|
||||
"SINCE",
|
||||
self._format_imap_date(start_date),
|
||||
"BEFORE",
|
||||
self._format_imap_date(end_date),
|
||||
),
|
||||
mark_seen=False,
|
||||
dedupe=False,
|
||||
limit=max(1, int(limit)),
|
||||
)
|
||||
|
||||
def _fetch_messages(
|
||||
self,
|
||||
search_criteria: tuple[str, ...],
|
||||
mark_seen: bool,
|
||||
dedupe: bool,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch messages by arbitrary IMAP search criteria."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
mailbox = self.config.imap_mailbox or "INBOX"
|
||||
|
||||
if self.config.imap_use_ssl:
|
||||
client = imaplib.IMAP4_SSL(self.config.imap_host, self.config.imap_port)
|
||||
else:
|
||||
client = imaplib.IMAP4(self.config.imap_host, self.config.imap_port)
|
||||
|
||||
try:
|
||||
client.login(self.config.imap_username, self.config.imap_password)
|
||||
status, _ = client.select(mailbox)
|
||||
if status != "OK":
|
||||
return messages
|
||||
|
||||
status, data = client.search(None, *search_criteria)
|
||||
if status != "OK" or not data:
|
||||
return messages
|
||||
|
||||
ids = data[0].split()
|
||||
if limit > 0 and len(ids) > limit:
|
||||
ids = ids[-limit:]
|
||||
for imap_id in ids:
|
||||
status, fetched = client.fetch(imap_id, "(BODY.PEEK[] UID)")
|
||||
if status != "OK" or not fetched:
|
||||
continue
|
||||
|
||||
raw_bytes = self._extract_message_bytes(fetched)
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
uid = self._extract_uid(fetched)
|
||||
if dedupe and uid and uid in self._processed_uids:
|
||||
continue
|
||||
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(raw_bytes)
|
||||
sender = parseaddr(parsed.get("From", ""))[1].strip().lower()
|
||||
if not sender:
|
||||
continue
|
||||
|
||||
subject = self._decode_header_value(parsed.get("Subject", ""))
|
||||
date_value = parsed.get("Date", "")
|
||||
message_id = parsed.get("Message-ID", "").strip()
|
||||
body = self._extract_text_body(parsed)
|
||||
|
||||
if not body:
|
||||
body = "(empty email body)"
|
||||
|
||||
body = body[: self.config.max_body_chars]
|
||||
content = (
|
||||
f"Email received.\n"
|
||||
f"From: {sender}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Date: {date_value}\n\n"
|
||||
f"{body}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"date": date_value,
|
||||
"sender_email": sender,
|
||||
"uid": uid,
|
||||
}
|
||||
messages.append(
|
||||
{
|
||||
"sender": sender,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
if dedupe and uid:
|
||||
self._processed_uids.add(uid)
|
||||
# mark_seen is the primary dedup; this set is a safety net
|
||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||
|
||||
if mark_seen:
|
||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||
finally:
|
||||
try:
|
||||
client.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def _format_imap_date(cls, value: date) -> str:
|
||||
"""Format date for IMAP search (always English month abbreviations)."""
|
||||
month = cls._IMAP_MONTHS[value.month - 1]
|
||||
return f"{value.day:02d}-{month}-{value.year}"
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_bytes(fetched: list[Any]) -> bytes | None:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and len(item) >= 2 and isinstance(item[1], (bytes, bytearray)):
|
||||
return bytes(item[1])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_uid(fetched: list[Any]) -> str:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and item and isinstance(item[0], (bytes, bytearray)):
|
||||
head = bytes(item[0]).decode("utf-8", errors="ignore")
|
||||
m = re.search(r"UID\s+(\d+)", head)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _decode_header_value(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
return str(make_header(decode_header(value)))
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls, msg: Any) -> str:
|
||||
"""Best-effort extraction of readable body text."""
|
||||
if msg.is_multipart():
|
||||
plain_parts: list[str] = []
|
||||
html_parts: list[str] = []
|
||||
for part in msg.walk():
|
||||
if part.get_content_disposition() == "attachment":
|
||||
continue
|
||||
content_type = part.get_content_type()
|
||||
try:
|
||||
payload = part.get_content()
|
||||
except Exception:
|
||||
payload_bytes = part.get_payload(decode=True) or b""
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
continue
|
||||
if content_type == "text/plain":
|
||||
plain_parts.append(payload)
|
||||
elif content_type == "text/html":
|
||||
html_parts.append(payload)
|
||||
if plain_parts:
|
||||
return "\n\n".join(plain_parts).strip()
|
||||
if html_parts:
|
||||
return cls._html_to_text("\n\n".join(html_parts)).strip()
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = msg.get_content()
|
||||
except Exception:
|
||||
payload_bytes = msg.get_payload(decode=True) or b""
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
return ""
|
||||
if msg.get_content_type() == "text/html":
|
||||
return cls._html_to_text(payload).strip()
|
||||
return payload.strip()
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(raw_html: str) -> str:
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<\s*/\s*p\s*>", "\n", text, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
return html.unescape(text)
|
||||
|
||||
def _reply_subject(self, base_subject: str) -> str:
|
||||
subject = (base_subject or "").strip() or "Boardware Genius reply"
|
||||
prefix = self.config.subject_prefix or "Re: "
|
||||
if subject.lower().startswith("re:"):
|
||||
return subject
|
||||
return f"{prefix}{subject}"
|
||||
@ -1,733 +0,0 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import FeishuConfig
|
||||
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
GetFileRequest,
|
||||
GetMessageResourceRequest,
|
||||
P2ImMessageReceiveV1,
|
||||
)
|
||||
FEISHU_AVAILABLE = True
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
lark = None
|
||||
Emoji = None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"audio": "[audio]",
|
||||
"file": "[file]",
|
||||
"sticker": "[sticker]",
|
||||
}
|
||||
|
||||
|
||||
def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||
"""Extract text representation from share cards and interactive messages."""
|
||||
parts = []
|
||||
|
||||
if msg_type == "share_chat":
|
||||
parts.append(f"[shared chat: {content_json.get('chat_id', '')}]")
|
||||
elif msg_type == "share_user":
|
||||
parts.append(f"[shared user: {content_json.get('user_id', '')}]")
|
||||
elif msg_type == "interactive":
|
||||
parts.extend(_extract_interactive_content(content_json))
|
||||
elif msg_type == "share_calendar_event":
|
||||
parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]")
|
||||
elif msg_type == "system":
|
||||
parts.append("[system message]")
|
||||
elif msg_type == "merge_forward":
|
||||
parts.append("[merged forward messages]")
|
||||
|
||||
return "\n".join(parts) if parts else f"[{msg_type}]"
|
||||
|
||||
|
||||
def _extract_interactive_content(content: dict) -> list[str]:
|
||||
"""Recursively extract text and links from interactive card content."""
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return [content] if content.strip() else []
|
||||
|
||||
if not isinstance(content, dict):
|
||||
return parts
|
||||
|
||||
if "title" in content:
|
||||
title = content["title"]
|
||||
if isinstance(title, dict):
|
||||
title_content = title.get("content", "") or title.get("text", "")
|
||||
if title_content:
|
||||
parts.append(f"title: {title_content}")
|
||||
elif isinstance(title, str):
|
||||
parts.append(f"title: {title}")
|
||||
|
||||
for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||
parts.extend(_extract_element_content(element))
|
||||
|
||||
card = content.get("card", {})
|
||||
if card:
|
||||
parts.extend(_extract_interactive_content(card))
|
||||
|
||||
header = content.get("header", {})
|
||||
if header:
|
||||
header_title = header.get("title", {})
|
||||
if isinstance(header_title, dict):
|
||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||
if header_text:
|
||||
parts.append(f"title: {header_text}")
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_element_content(element: dict) -> list[str]:
|
||||
"""Extract content from a single card element."""
|
||||
parts = []
|
||||
|
||||
if not isinstance(element, dict):
|
||||
return parts
|
||||
|
||||
tag = element.get("tag", "")
|
||||
|
||||
if tag in ("markdown", "lark_md"):
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
elif tag == "div":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
text_content = text.get("content", "") or text.get("text", "")
|
||||
if text_content:
|
||||
parts.append(text_content)
|
||||
elif isinstance(text, str):
|
||||
parts.append(text)
|
||||
for field in element.get("fields", []):
|
||||
if isinstance(field, dict):
|
||||
field_text = field.get("text", {})
|
||||
if isinstance(field_text, dict):
|
||||
c = field_text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
|
||||
elif tag == "a":
|
||||
href = element.get("href", "")
|
||||
text = element.get("text", "")
|
||||
if href:
|
||||
parts.append(f"link: {href}")
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
elif tag == "button":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
c = text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
url = element.get("url", "") or element.get("multi_url", {}).get("url", "")
|
||||
if url:
|
||||
parts.append(f"link: {url}")
|
||||
|
||||
elif tag == "img":
|
||||
alt = element.get("alt", {})
|
||||
parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]")
|
||||
|
||||
elif tag == "note":
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
elif tag == "column_set":
|
||||
for col in element.get("columns", []):
|
||||
for ce in col.get("elements", []):
|
||||
parts.extend(_extract_element_content(ce))
|
||||
|
||||
elif tag == "plain_text":
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
else:
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
Supports two formats:
|
||||
1. Direct format: {"title": "...", "content": [...]}
|
||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
"""
|
||||
def extract_from_lang(lang_content: dict) -> str | None:
|
||||
if not isinstance(lang_content, dict):
|
||||
return None
|
||||
title = lang_content.get("title", "")
|
||||
content_blocks = lang_content.get("content", [])
|
||||
if not isinstance(content_blocks, list):
|
||||
return None
|
||||
text_parts = []
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
for element in block:
|
||||
if isinstance(element, dict):
|
||||
tag = element.get("tag")
|
||||
if tag == "text":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "a":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "at":
|
||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||
return " ".join(text_parts).strip() if text_parts else None
|
||||
|
||||
# Try direct format first
|
||||
if "content" in content_json:
|
||||
result = extract_from_lang(content_json)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Try localized format
|
||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||
lang_content = content_json.get(lang_key)
|
||||
result = extract_from_lang(lang_content)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- App ID and App Secret from Feishu Open Platform
|
||||
- Bot capability enabled
|
||||
- Event subscription enabled (im.message.receive_v1)
|
||||
"""
|
||||
|
||||
name = "feishu"
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
self._client: Any = None
|
||||
self._ws_client: Any = None
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.app_secret:
|
||||
logger.error("Feishu app_id and app_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
|
||||
# Create event handler (only register message receive, ignore other events)
|
||||
event_handler = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
).build()
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop
|
||||
def run_ws():
|
||||
while self._running:
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.warning("Feishu WebSocket error: {}", e)
|
||||
if self._running:
|
||||
import time; time.sleep(5)
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Feishu bot."""
|
||||
self._running = False
|
||||
if self._ws_client:
|
||||
try:
|
||||
self._ws_client.stop()
|
||||
except Exception as e:
|
||||
logger.warning("Error stopping WebSocket client: {}", e)
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
else:
|
||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error adding reaction: {}", e)
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
if not self._client or not Emoji:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
|
||||
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
split = lambda l: [c.strip() for c in l.strip("|").split("|")]
|
||||
headers = split(lines[0])
|
||||
rows = [split(l) for l in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
return {
|
||||
"tag": "table",
|
||||
"page_size": len(rows) + 1,
|
||||
"columns": columns,
|
||||
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
elements.extend(self._split_headings(remaining))
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
def _split_headings(self, content: str) -> list[dict]:
|
||||
"""Split content by headings, converting headings to div elements."""
|
||||
protected = content
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
text = m.group(2).strip()
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": f"**{text}**",
|
||||
},
|
||||
})
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
elements.append({"tag": "markdown", "content": remaining})
|
||||
|
||||
for i, cb in enumerate(code_blocks):
|
||||
for el in elements:
|
||||
if el.get("tag") == "markdown":
|
||||
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||
_AUDIO_EXTS = {".opus"}
|
||||
_FILE_TYPE_MAP = {
|
||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||
}
|
||||
|
||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.image.create(request)
|
||||
if response.success():
|
||||
image_key = response.data.image_key
|
||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||
return image_key
|
||||
else:
|
||||
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading image {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||
file_name = os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateFileRequest.builder() \
|
||||
.request_body(
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.file.create(request)
|
||||
if response.success():
|
||||
file_key = response.data.file_key
|
||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||
return file_key
|
||||
else:
|
||||
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading file {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||
"""Download an image from Feishu message by message_id and image_key."""
|
||||
try:
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||
if hasattr(file_data, 'read'):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Error downloading image {}: {}", image_key, e)
|
||||
return None, None
|
||||
|
||||
def _download_file_sync(
|
||||
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||
try:
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message_id)
|
||||
.file_key(file_key)
|
||||
.type(resource_type)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
if hasattr(file_data, "read"):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||
return None, None
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
msg_type: str,
|
||||
content_json: dict,
|
||||
message_id: str | None = None
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Download media from Feishu and save to local disk.
|
||||
|
||||
Returns:
|
||||
(file_path, content_text) - file_path is None if download failed
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data, filename = None, None
|
||||
|
||||
if msg_type == "image":
|
||||
image_key = content_json.get("image_key")
|
||||
if image_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_image_sync, message_id, image_key
|
||||
)
|
||||
if not filename:
|
||||
filename = f"{image_key[:16]}.jpg"
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||
filename = f"{file_key[:16]}{ext}"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", msg_type, file_path)
|
||||
return str(file_path), f"[{msg_type}: {filename}]"
|
||||
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||
try:
|
||||
request = CreateMessageRequest.builder() \
|
||||
.receive_id_type(receive_id_type) \
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.message.create(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||
msg_type, response.code, response.msg, response.get_log_id()
|
||||
)
|
||||
return False
|
||||
logger.debug("Feishu {} message sent to {}", msg_type, receive_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Feishu, including media (images/files) if present."""
|
||||
if not self._client:
|
||||
logger.warning("Feishu client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for file_path in msg.media:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Media file not found: {}", file_path)
|
||||
continue
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self._IMAGE_EXTS:
|
||||
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
|
||||
if key:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||
if key:
|
||||
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)}
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
Schedules async handling in the main event loop.
|
||||
"""
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
# Deduplication check
|
||||
message_id = message.message_id
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Skip bot messages
|
||||
if sender.sender_type == "bot":
|
||||
return
|
||||
|
||||
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
||||
chat_id = message.chat_id
|
||||
chat_type = message.chat_type
|
||||
msg_type = message.message_type
|
||||
|
||||
# Add reaction
|
||||
await self._add_reaction(message_id, "THUMBSUP")
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
try:
|
||||
content_json = json.loads(message.content) if message.content else {}
|
||||
except json.JSONDecodeError:
|
||||
content_json = {}
|
||||
|
||||
if msg_type == "text":
|
||||
text = content_json.get("text", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "post":
|
||||
text = _extract_post_text(content_json)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type in ("image", "audio", "file", "media"):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
# Handle share cards and interactive messages
|
||||
text = _extract_share_card_content(content_json, msg_type)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=reply_to,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing Feishu message: {}", e)
|
||||
@ -1,326 +0,0 @@
|
||||
"""渠道管理器:统一管理多聊天渠道的生命周期与消息路由。
|
||||
|
||||
本模块处在“Agent 核心逻辑”和“外部 IM 平台”之间,承担两类关键职责:
|
||||
1. 渠道生命周期管理:
|
||||
- 按配置初始化可用渠道(Telegram/Slack/Discord/WhatsApp/...);
|
||||
- 统一启动与停止,避免各渠道在 CLI 层分散管理。
|
||||
2. 出站消息分发:
|
||||
- 从 MessageBus 的 outbound 队列读取消息;
|
||||
- 根据 `msg.channel` 路由到目标渠道对象并执行 `send(...)`;
|
||||
- 对进度消息(_progress/_tool_hint)按全局开关过滤。
|
||||
|
||||
设计原则:
|
||||
- 渠道失败隔离:单个渠道启动/发送失败不应拖垮其它渠道;
|
||||
- 配置驱动:是否启用由 `config.channels.*.enabled` 决定;
|
||||
- 统一入口:上层只需与 MessageBus 交互,不关心各渠道细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
渠道协调器。
|
||||
|
||||
你可以把它看成一个“渠道运行时容器”:
|
||||
- `self.channels` 保存已启用渠道实例;
|
||||
- `_dispatch_outbound()` 作为中央分发协程持续消费 outbound 消息;
|
||||
- `start_all()/stop_all()` 负责渠道与分发协程的统一启停。
|
||||
|
||||
与 AgentLoop 的关系:
|
||||
- AgentLoop 只负责“生成 OutboundMessage”;
|
||||
- ChannelManager 负责“把 OutboundMessage 真的发出去”。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
# 全局配置(含渠道开关、进度消息开关等)
|
||||
self.config = config
|
||||
# 与 AgentLoop 共享同一 MessageBus,负责消费 outbound。
|
||||
self.bus = bus
|
||||
# name -> channel instance(只存启用且成功初始化的渠道)
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
# 出站分发后台任务句柄(由 start_all 创建,stop_all 取消)
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
# 构造时即按配置初始化渠道实例(不启动网络连接,仅实例化)。
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""按配置初始化渠道实例。
|
||||
|
||||
注意:
|
||||
- 这里只做“实例化”,不会进入各渠道的 start() 主循环;
|
||||
- ImportError 会被捕获并记录 warning,允许缺依赖时降级运行;
|
||||
- 未启用渠道不会创建实例,也不会出现在 enabled_channels 列表里。
|
||||
"""
|
||||
|
||||
# Telegram 渠道:
|
||||
# - 需要 telegram 配置开启;
|
||||
# - 额外透传 groq_api_key(用于语音/转写等能力时按渠道内部策略使用)。
|
||||
if self.config.channels.telegram.enabled:
|
||||
try:
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
self.channels["telegram"] = TelegramChannel(
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
# WhatsApp 渠道(通过 bridge 连接)
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
self.channels["whatsapp"] = WhatsAppChannel(
|
||||
self.config.channels.whatsapp, self.bus
|
||||
)
|
||||
logger.info("WhatsApp channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("WhatsApp channel not available: {}", e)
|
||||
|
||||
# Discord 渠道
|
||||
if self.config.channels.discord.enabled:
|
||||
try:
|
||||
from nanobot.channels.discord import DiscordChannel
|
||||
self.channels["discord"] = DiscordChannel(
|
||||
self.config.channels.discord, self.bus
|
||||
)
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
# 飞书 / Lark 渠道
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
self.channels["feishu"] = FeishuChannel(
|
||||
self.config.channels.feishu, self.bus
|
||||
)
|
||||
logger.info("Feishu channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Feishu channel not available: {}", e)
|
||||
|
||||
# Mochat 渠道
|
||||
if self.config.channels.mochat.enabled:
|
||||
try:
|
||||
from nanobot.channels.mochat import MochatChannel
|
||||
|
||||
self.channels["mochat"] = MochatChannel(
|
||||
self.config.channels.mochat, self.bus
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Mochat channel not available: {}", e)
|
||||
|
||||
# 钉钉渠道
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
try:
|
||||
from nanobot.channels.dingtalk import DingTalkChannel
|
||||
self.channels["dingtalk"] = DingTalkChannel(
|
||||
self.config.channels.dingtalk, self.bus
|
||||
)
|
||||
logger.info("DingTalk channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("DingTalk channel not available: {}", e)
|
||||
|
||||
# Email 渠道(IMAP 收件 + SMTP 发件)
|
||||
if self.config.channels.email.enabled:
|
||||
try:
|
||||
from nanobot.channels.email import EmailChannel
|
||||
self.channels["email"] = EmailChannel(
|
||||
self.config.channels.email, self.bus
|
||||
)
|
||||
logger.info("Email channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Email channel not available: {}", e)
|
||||
|
||||
# Slack 渠道
|
||||
if self.config.channels.slack.enabled:
|
||||
try:
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
self.channels["slack"] = SlackChannel(
|
||||
self.config.channels.slack, self.bus
|
||||
)
|
||||
logger.info("Slack channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Slack channel not available: {}", e)
|
||||
|
||||
# QQ 渠道
|
||||
if self.config.channels.qq.enabled:
|
||||
try:
|
||||
from nanobot.channels.qq import QQChannel
|
||||
self.channels["qq"] = QQChannel(
|
||||
self.config.channels.qq,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
# Matrix 渠道
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
from nanobot.channels.matrix import MatrixChannel
|
||||
self.channels["matrix"] = MatrixChannel(
|
||||
self.config.channels.matrix,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""启动单个渠道并隔离异常。
|
||||
|
||||
设计意图:
|
||||
- 不让一个渠道的启动失败影响其它渠道启动;
|
||||
- 错误统一记录日志,方便后续定位具体渠道问题。
|
||||
"""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""启动所有渠道与出站分发协程。
|
||||
|
||||
启动顺序:
|
||||
1. 启动 outbound 分发任务(先就绪,避免启动早期消息丢失);
|
||||
2. 并发启动所有渠道 start() 协程;
|
||||
3. `gather` 挂住,直到渠道协程返回(正常应长期运行)。
|
||||
"""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
# 启动出站分发协程:负责消费 bus.outbound 并调用 channel.send()。
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
# 启动渠道主循环。
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
# 等待所有渠道任务(理论上它们应常驻直到 stop_all 被调用)。
|
||||
# return_exceptions=True 可避免一个任务异常导致 gather 整体中断。
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""停止所有渠道并关闭出站分发任务。
|
||||
|
||||
停止顺序:
|
||||
1. 先取消分发协程,避免继续从队列取消息;
|
||||
2. 再逐个 stop 渠道,释放各自连接/资源;
|
||||
3. 各渠道停止异常仅记录,不影响其它渠道收尾。
|
||||
"""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
# 停止分发协程。
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有渠道实例。
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""消费 outbound 队列并路由发送到对应渠道。
|
||||
|
||||
分发规则:
|
||||
- `msg.channel` 决定目标渠道实例;
|
||||
- 若渠道不存在,记录 warning(通常表示渠道未启用或名称不匹配);
|
||||
- 进度消息可被全局开关过滤(send_progress / send_tool_hints)。
|
||||
|
||||
循环模型:
|
||||
- 使用 `wait_for(..., timeout=1.0)` 做短超时轮询,
|
||||
便于 stop_all 取消后快速退出;
|
||||
- Timeout 属于正常空闲态,不视为错误。
|
||||
"""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 从总线获取一条待发送消息;短超时保证可取消性。
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# 进度消息过滤:
|
||||
# - _progress=True 且 _tool_hint=True 受 send_tool_hints 控制
|
||||
# - _progress=True 且非工具提示受 send_progress 控制
|
||||
# 这样可以在渠道侧按需静默“中间态”,只保留最终回复。
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
# 按 channel 名路由发送。
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
# 实际发送由各渠道实现(统一接口:BaseChannel.send)。
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
# 单条发送失败不终止分发循环,避免“全局停摆”。
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 队列暂时无消息:继续下一轮轮询。
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# stop_all 取消任务时走这里退出循环。
|
||||
break
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""按名称获取渠道实例(未启用/不存在返回 None)。"""
|
||||
return self.channels.get(name)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""返回所有已启用渠道的运行状态快照。
|
||||
|
||||
返回结构示例:
|
||||
{
|
||||
"telegram": {"enabled": True, "running": True},
|
||||
"slack": {"enabled": True, "running": False},
|
||||
}
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
# 出现在 self.channels 里即表示“配置层已启用且实例化成功”。
|
||||
"enabled": True,
|
||||
# running 由渠道实例自身维护,反映连接/主循环当前状态。
|
||||
"running": channel.is_running
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""返回当前已启用并成功初始化的渠道名称列表。"""
|
||||
return list(self.channels.keys())
|
||||
@ -1,733 +0,0 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncResponse,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||
_ATTACH_MARKER = "[attachment: {}]"
|
||||
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||
_DEFAULT_ATTACH_NAME = "attachment"
|
||||
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
MATRIX_MARKDOWN = create_markdown(
|
||||
escape=True,
|
||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||
)
|
||||
|
||||
MATRIX_ALLOWED_HTML_TAGS = {
|
||||
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||
"caption", "sup", "sub", "img",
|
||||
}
|
||||
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||
"img": {"src", "alt", "title", "width", "height"},
|
||||
}
|
||||
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||
|
||||
|
||||
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||
if tag == "a" and attr == "href":
|
||||
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||
if tag == "img" and attr == "src":
|
||||
return value if value.lower().startswith("mxc://") else None
|
||||
if tag == "code" and attr == "class":
|
||||
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||
return " ".join(classes) if classes else None
|
||||
return value
|
||||
|
||||
|
||||
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||
attribute_filter=_filter_matrix_html_attribute,
|
||||
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||
strip_comments=True,
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
try:
|
||||
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||
except Exception:
|
||||
return None
|
||||
if not formatted:
|
||||
return None
|
||||
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||
inner = formatted[3:-4]
|
||||
if "<" not in inner and ">" not in inner:
|
||||
return None
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
return content
|
||||
|
||||
|
||||
class _NioLoguruHandler(logging.Handler):
|
||||
"""Route matrix-nio stdlib logs into Loguru."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame, depth = frame.f_back, depth + 1
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def _configure_nio_logging_bridge() -> None:
|
||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||
nio_logger = logging.getLogger("nio")
|
||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||
nio_logger.handlers = [_NioLoguruHandler()]
|
||||
nio_logger.propagate = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus, groq_api_key: str = ""):
|
||||
super().__init__(config, bus)
|
||||
self.groq_api_key = groq_api_key
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = False
|
||||
self._workspace: Path | None = None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
self._sync_ready_logged = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
try:
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||
self._running = False
|
||||
for room_id in list(self._typing_tasks):
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
if self.client:
|
||||
self.client.stop_sync_forever()
|
||||
if self._sync_task:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||
timeout=self.config.sync_stop_grace_seconds)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
return True
|
||||
try:
|
||||
path.resolve(strict=False).relative_to(self._workspace)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||
"""Deduplicate and resolve outbound attachment paths."""
|
||||
seen: set[str] = set()
|
||||
candidates: list[Path] = []
|
||||
for raw in media:
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
continue
|
||||
path = Path(raw.strip()).expanduser()
|
||||
try:
|
||||
key = str(path.resolve(strict=False))
|
||||
except OSError:
|
||||
key = str(path)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(path)
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_outbound_attachment_content(
|
||||
*, filename: str, mime: str, size_bytes: int,
|
||||
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||
prefix = mime.split("/")[0]
|
||||
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||
content: dict[str, Any] = {
|
||||
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||
}
|
||||
if encryption_info:
|
||||
content["file"] = {**encryption_info, "url": mxc_url}
|
||||
else:
|
||||
content["url"] = mxc_url
|
||||
return content
|
||||
|
||||
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||
if not self.client:
|
||||
return False
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
if self._server_upload_limit_checked:
|
||||
return self._server_upload_limit_bytes
|
||||
self._server_upload_limit_checked = True
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
response = await self.client.content_repository_config()
|
||||
except Exception:
|
||||
return None
|
||||
upload_size = getattr(response, "upload_size", None)
|
||||
if isinstance(upload_size, int) and upload_size > 0:
|
||||
self._server_upload_limit_bytes = upload_size
|
||||
return upload_size
|
||||
return None
|
||||
|
||||
async def _effective_media_limit_bytes(self) -> int:
|
||||
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||
if server_limit is None:
|
||||
return local_limit
|
||||
return min(local_limit, server_limit) if local_limit else 0
|
||||
|
||||
async def _upload_and_send_attachment(
|
||||
self, room_id: str, path: Path, limit_bytes: int,
|
||||
relates_to: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||
if not self.client:
|
||||
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||
|
||||
resolved = path.expanduser().resolve(strict=False)
|
||||
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||
|
||||
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||
return fail
|
||||
try:
|
||||
size_bytes = resolved.stat().st_size
|
||||
except OSError:
|
||||
return fail
|
||||
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||
return _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||
try:
|
||||
with resolved.open("rb") as f:
|
||||
upload_result = await self.client.upload(
|
||||
f, content_type=mime, filename=filename,
|
||||
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||
filesize=size_bytes,
|
||||
)
|
||||
except Exception:
|
||||
return fail
|
||||
|
||||
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||
if isinstance(upload_response, UploadError):
|
||||
return fail
|
||||
mxc_url = getattr(upload_response, "content_uri", None)
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return fail
|
||||
|
||||
content = self._build_outbound_attachment_content(
|
||||
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||
)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
try:
|
||||
await self._send_room_content(room_id, content)
|
||||
except Exception:
|
||||
return fail
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound content; clear typing for non-progress messages."""
|
||||
if not self.client:
|
||||
return
|
||||
text = msg.content or ""
|
||||
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
failures: list[str] = []
|
||||
if candidates:
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
for path in candidates:
|
||||
if fail := await self._upload_and_send_attachment(
|
||||
room_id=msg.chat_id,
|
||||
path=path,
|
||||
limit_bytes=limit_bytes,
|
||||
relates_to=relates_to,
|
||||
):
|
||||
failures.append(fail)
|
||||
if failures:
|
||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||
if text or not candidates:
|
||||
content = _build_matrix_text_content(text)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(msg.chat_id, content)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_success, SyncResponse)
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _log_response_error(self, label: str, response: Any) -> None:
|
||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||
|
||||
async def _on_sync_success(self, response: SyncResponse) -> None:
|
||||
if self._sync_ready_logged:
|
||||
return
|
||||
rooms = getattr(response, "rooms", None)
|
||||
joined = len(getattr(rooms, "join", {}) or {})
|
||||
invited = len(getattr(rooms, "invite", {}) or {})
|
||||
logger.info(
|
||||
"Matrix sync ready: user={} device={} joined_rooms={} invited_rooms={}",
|
||||
self.config.user_id,
|
||||
self.config.device_id or "-",
|
||||
joined,
|
||||
invited,
|
||||
)
|
||||
self._sync_ready_logged = True
|
||||
|
||||
async def _on_sync_error(self, response: SyncError) -> None:
|
||||
self._log_response_error("sync", response)
|
||||
|
||||
async def _on_join_error(self, response: JoinError) -> None:
|
||||
self._log_response_error("join", response)
|
||||
|
||||
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||
self._log_response_error("send", response)
|
||||
|
||||
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||
"""Best-effort typing indicator update."""
|
||||
if not self.client:
|
||||
return
|
||||
try:
|
||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||
if isinstance(response, RoomTypingError):
|
||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
await self._set_typing(room_id, True)
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async def loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||
await self._set_typing(room_id, True)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||
|
||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||
if task := self._typing_tasks.pop(room_id, None):
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if clear_typing:
|
||||
await self._set_typing(room_id, False)
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||
if self.is_allowed(event.sender):
|
||||
await self.client.join(room.room_id)
|
||||
|
||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||
count = getattr(room, "member_count", None)
|
||||
return isinstance(count, int) and count <= 2
|
||||
|
||||
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||
"""Check m.mentions payload for bot mention."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return False
|
||||
mentions = (source.get("content") or {}).get("m.mentions")
|
||||
if not isinstance(mentions, dict):
|
||||
return False
|
||||
user_ids = mentions.get("user_ids")
|
||||
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||
return True
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
if self._is_direct_room(room):
|
||||
return True
|
||||
policy = self.config.group_policy
|
||||
if policy == "open":
|
||||
return True
|
||||
if policy == "allowlist":
|
||||
return room.room_id in (self.config.group_allow_from or [])
|
||||
if policy == "mention":
|
||||
return self._is_bot_mentioned(event)
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
return get_media_dir("matrix")
|
||||
|
||||
async def transcribe_audio(self, file_path: str) -> str:
|
||||
"""Best-effort audio transcription for inbound Matrix voice/audio messages."""
|
||||
try:
|
||||
return await GroqTranscriptionProvider(api_key=self.groq_api_key).transcribe(file_path)
|
||||
except Exception:
|
||||
logger.exception("Matrix audio transcription failed")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
if not (root_id := self._event_thread_root_id(event)):
|
||||
return None
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {"rel_type": "m.thread", "event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
return (isinstance(getattr(event, "key", None), dict)
|
||||
and isinstance(getattr(event, "hashes", None), dict)
|
||||
and isinstance(getattr(event, "iv", None), str))
|
||||
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
size = info.get("size") if isinstance(info, dict) else None
|
||||
return size if isinstance(size, int) and size >= 0 else None
|
||||
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||
return m
|
||||
m = getattr(event, "mimetype", None)
|
||||
return m if isinstance(m, str) and m else None
|
||||
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
if candidate := safe_filename(Path(body).name):
|
||||
return candidate
|
||||
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||
|
||||
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||
filename: str, mime: str | None) -> Path:
|
||||
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||
suffix = Path(safe_name).suffix
|
||||
if not suffix and mime:
|
||||
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||
suffix = suffix[:16]
|
||||
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||
|
||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||
if not self.client:
|
||||
return None
|
||||
response = await self.client.download(mxc=mxc_url)
|
||||
if isinstance(response, DownloadError):
|
||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||
return None
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return bytes(body)
|
||||
if isinstance(response, MemoryDownloadResponse):
|
||||
return bytes(response.body)
|
||||
if isinstance(body, (str, Path)):
|
||||
path = Path(body)
|
||||
if path.is_file():
|
||||
try:
|
||||
return path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||
return None
|
||||
try:
|
||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||
except (EncryptionError, ValueError, TypeError):
|
||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||
return None
|
||||
|
||||
async def _fetch_media_attachment(
|
||||
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||
atype = self._event_attachment_type(event)
|
||||
mime = self._event_mime(event)
|
||||
filename = self._event_filename(event, atype)
|
||||
mxc_url = getattr(event, "url", None)
|
||||
fail = _ATTACH_FAILED.format(filename)
|
||||
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return None, fail
|
||||
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
declared = self._event_declared_size_bytes(event)
|
||||
if declared is not None and declared > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
downloaded = await self._download_media_bytes(mxc_url)
|
||||
if downloaded is None:
|
||||
return None, fail
|
||||
|
||||
encrypted = self._is_encrypted_media_event(event)
|
||||
data = downloaded
|
||||
if encrypted:
|
||||
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||
return None, fail
|
||||
|
||||
if len(data) > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
path = self._build_attachment_path(event, atype, filename, mime)
|
||||
try:
|
||||
path.write_bytes(data)
|
||||
except OSError:
|
||||
return None, fail
|
||||
|
||||
attachment = {
|
||||
"type": atype, "mime": mime, "filename": filename,
|
||||
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||
"encrypted": encrypted, "size_bytes": len(data),
|
||||
"path": str(path), "mxc_url": mxc_url,
|
||||
}
|
||||
return attachment, _ATTACH_MARKER.format(path)
|
||||
|
||||
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||
"""Build common metadata for text and media handlers."""
|
||||
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||
meta["event_id"] = eid
|
||||
if thread := self._thread_metadata(event):
|
||||
meta.update(thread)
|
||||
return meta
|
||||
|
||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
|
||||
if attachment and attachment.get("type") == "audio":
|
||||
transcription = await self.transcribe_audio(attachment["path"])
|
||||
if transcription:
|
||||
parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
parts.append(marker)
|
||||
elif marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
meta = self._base_metadata(room, event)
|
||||
meta["attachments"] = []
|
||||
if attachment:
|
||||
meta["attachments"] = [attachment]
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
@ -1,895 +0,0 @@
|
||||
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import MochatConfig
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
|
||||
try:
|
||||
import socketio
|
||||
SOCKETIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
socketio = None
|
||||
SOCKETIO_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import msgpack # noqa: F401
|
||||
MSGPACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSGPACK_AVAILABLE = False
|
||||
|
||||
MAX_SEEN_MESSAGE_IDS = 2000
|
||||
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MochatBufferedEntry:
|
||||
"""Buffered inbound entry for delayed dispatch."""
|
||||
raw_body: str
|
||||
author: str
|
||||
sender_name: str = ""
|
||||
sender_username: str = ""
|
||||
timestamp: int | None = None
|
||||
message_id: str = ""
|
||||
group_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelayState:
|
||||
"""Per-target delayed message state."""
|
||||
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
timer: asyncio.Task | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MochatTarget:
|
||||
"""Outbound target resolution result."""
|
||||
id: str
|
||||
is_panel: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_dict(value: Any) -> dict:
|
||||
"""Return *value* if it's a dict, else empty dict."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _str_field(src: dict, *keys: str) -> str:
|
||||
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||
for k in keys:
|
||||
v = src.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
return v.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _make_synthetic_event(
|
||||
message_id: str, author: str, content: Any,
|
||||
meta: Any, group_id: str, converse_id: str,
|
||||
timestamp: Any = None, *, author_info: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a synthetic ``message.add`` event dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"messageId": message_id, "author": author,
|
||||
"content": content, "meta": _safe_dict(meta),
|
||||
"groupId": group_id, "converseId": converse_id,
|
||||
}
|
||||
if author_info is not None:
|
||||
payload["authorInfo"] = _safe_dict(author_info)
|
||||
return {
|
||||
"type": "message.add",
|
||||
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def normalize_mochat_content(content: Any) -> str:
|
||||
"""Normalize content payload to text."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||
"""Resolve id and target kind from user-provided target string."""
|
||||
trimmed = (raw or "").strip()
|
||||
if not trimmed:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
|
||||
lowered = trimmed.lower()
|
||||
cleaned, forced_panel = trimmed, False
|
||||
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||
if lowered.startswith(prefix):
|
||||
cleaned = trimmed[len(prefix):].strip()
|
||||
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||
break
|
||||
|
||||
if not cleaned:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||
|
||||
|
||||
def extract_mention_ids(value: Any) -> list[str]:
|
||||
"""Extract mention ids from heterogeneous mention payload."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
ids: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
ids.append(item.strip())
|
||||
elif isinstance(item, dict):
|
||||
for key in ("id", "userId", "_id"):
|
||||
candidate = item.get(key)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
ids.append(candidate.strip())
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||
"""Resolve mention state from payload metadata and text fallback."""
|
||||
meta = payload.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||
return True
|
||||
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||
return True
|
||||
if not agent_user_id:
|
||||
return False
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||
|
||||
|
||||
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||
"""Resolve mention requirement for group/panel conversations."""
|
||||
groups = config.groups or {}
|
||||
for key in (group_id, session_id, "*"):
|
||||
if key and key in groups:
|
||||
return bool(groups[key].require_mention)
|
||||
return bool(config.mention.require_in_groups)
|
||||
|
||||
|
||||
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||
"""Build text body from one or more buffered entries."""
|
||||
if not entries:
|
||||
return ""
|
||||
if len(entries) == 1:
|
||||
return entries[0].raw_body
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
if not entry.raw_body:
|
||||
continue
|
||||
if is_group:
|
||||
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||
if label:
|
||||
lines.append(f"{label}: {entry.raw_body}")
|
||||
continue
|
||||
lines.append(entry.raw_body)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def parse_timestamp(value: Any) -> int | None:
|
||||
"""Parse event timestamp to epoch milliseconds."""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig = config
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_data_path() / "mochat"
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
self._session_set: set[str] = set()
|
||||
self._panel_set: set[str] = set()
|
||||
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||
|
||||
self._cold_sessions: set[str] = set()
|
||||
self._session_by_converse: dict[str, str] = {}
|
||||
|
||||
self._seen_set: dict[str, set[str]] = {}
|
||||
self._seen_queue: dict[str, deque[str]] = {}
|
||||
self._delay_states: dict[str, DelayState] = {}
|
||||
|
||||
self._fallback_mode = False
|
||||
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# ---- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Mochat channel workers and websocket connection."""
|
||||
if not self.config.claw_token:
|
||||
logger.error("Mochat claw_token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
await self._load_session_cursors()
|
||||
self._seed_targets_from_config()
|
||||
await self._refresh_targets(subscribe_new=False)
|
||||
|
||||
if not await self._start_socket_client():
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all workers and clean up resources."""
|
||||
self._running = False
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
await self._stop_fallback_workers()
|
||||
await self._cancel_delay_timers()
|
||||
|
||||
if self._socket:
|
||||
try:
|
||||
await self._socket.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
|
||||
if self._cursor_save_task:
|
||||
self._cursor_save_task.cancel()
|
||||
self._cursor_save_task = None
|
||||
await self._save_session_cursors()
|
||||
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound message to session or panel."""
|
||||
if not self.config.claw_token:
|
||||
logger.warning("Mochat claw_token missing, skip send")
|
||||
return
|
||||
|
||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||
if msg.media:
|
||||
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||
content = "\n".join(parts).strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
target = resolve_mochat_target(msg.chat_id)
|
||||
if not target.id:
|
||||
logger.warning("Mochat outbound target is empty")
|
||||
return
|
||||
|
||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||
try:
|
||||
if is_panel:
|
||||
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||
else:
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send Mochat message: {}", e)
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
def _seed_targets_from_config(self) -> None:
|
||||
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||
self._session_set.update(sessions)
|
||||
self._panel_set.update(panels)
|
||||
for sid in sessions:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||
|
||||
# ---- websocket ---------------------------------------------------------
|
||||
|
||||
async def _start_socket_client(self) -> bool:
|
||||
if not SOCKETIO_AVAILABLE:
|
||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||
return False
|
||||
|
||||
serializer = "default"
|
||||
if not self.config.socket_disable_msgpack:
|
||||
if MSGPACK_AVAILABLE:
|
||||
serializer = "msgpack"
|
||||
else:
|
||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||
|
||||
client = socketio.AsyncClient(
|
||||
reconnection=True,
|
||||
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||
logger=False, engineio_logger=False, serializer=serializer,
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def connect() -> None:
|
||||
self._ws_connected, self._ws_ready = True, False
|
||||
logger.info("Mochat websocket connected")
|
||||
subscribed = await self._subscribe_all()
|
||||
self._ws_ready = subscribed
|
||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||
|
||||
@client.event
|
||||
async def disconnect() -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._ws_connected = self._ws_ready = False
|
||||
logger.warning("Mochat websocket disconnected")
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error("Mochat websocket connect error: {}", data)
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
|
||||
@client.on("claw.panel.events")
|
||||
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "panel")
|
||||
|
||||
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||
"notify:chat.message.update", "notify:chat.message.recall",
|
||||
"notify:chat.message.delete"):
|
||||
client.on(ev, self._build_notify_handler(ev))
|
||||
|
||||
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||
|
||||
try:
|
||||
self._socket = client
|
||||
await client.connect(
|
||||
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||
auth={"token": self.config.claw_token},
|
||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
return False
|
||||
|
||||
def _build_notify_handler(self, event_name: str):
|
||||
async def handler(payload: Any) -> None:
|
||||
if event_name == "notify:chat.inbox.append":
|
||||
await self._handle_notify_inbox_append(payload)
|
||||
elif event_name.startswith("notify:chat.message."):
|
||||
await self._handle_notify_chat_message(payload)
|
||||
return handler
|
||||
|
||||
# ---- subscribe ---------------------------------------------------------
|
||||
|
||||
async def _subscribe_all(self) -> bool:
|
||||
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||
await self._refresh_targets(subscribe_new=True)
|
||||
return ok
|
||||
|
||||
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||
if not session_ids:
|
||||
return True
|
||||
for sid in session_ids:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(data, list):
|
||||
items = [i for i in data if isinstance(i, dict)]
|
||||
elif isinstance(data, dict):
|
||||
sessions = data.get("sessions")
|
||||
if isinstance(sessions, list):
|
||||
items = [i for i in sessions if isinstance(i, dict)]
|
||||
elif "sessionId" in data:
|
||||
items = [data]
|
||||
for p in items:
|
||||
await self._handle_watch_payload(p, "session")
|
||||
return True
|
||||
|
||||
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||
if not self._auto_discover_panels and not panel_ids:
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._socket:
|
||||
return {"result": False, "message": "socket not connected"}
|
||||
try:
|
||||
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||
except Exception as e:
|
||||
return {"result": False, "message": str(e)}
|
||||
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||
|
||||
# ---- refresh / discovery -----------------------------------------------
|
||||
|
||||
async def _refresh_loop(self) -> None:
|
||||
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running:
|
||||
await asyncio.sleep(interval_s)
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning("Mochat refresh failed: {}", e)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||
if self._auto_discover_sessions:
|
||||
await self._refresh_sessions_directory(subscribe_new)
|
||||
if self._auto_discover_panels:
|
||||
await self._refresh_panels(subscribe_new)
|
||||
|
||||
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat listSessions failed: {}", e)
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
if not isinstance(sessions, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for s in sessions:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
sid = _str_field(s, "sessionId")
|
||||
if not sid:
|
||||
continue
|
||||
if sid not in self._session_set:
|
||||
self._session_set.add(sid)
|
||||
new_ids.append(sid)
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
cid = _str_field(s, "converseId")
|
||||
if cid:
|
||||
self._session_by_converse[cid] = sid
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_sessions(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
if not isinstance(raw_panels, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for p in raw_panels:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
pt = p.get("type")
|
||||
if isinstance(pt, int) and pt != 0:
|
||||
continue
|
||||
pid = _str_field(p, "id", "_id")
|
||||
if pid and pid not in self._panel_set:
|
||||
self._panel_set.add(pid)
|
||||
new_ids.append(pid)
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_panels(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
# ---- fallback workers --------------------------------------------------
|
||||
|
||||
async def _ensure_fallback_workers(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._fallback_mode = True
|
||||
for sid in sorted(self._session_set):
|
||||
t = self._session_fallback_tasks.get(sid)
|
||||
if not t or t.done():
|
||||
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||
for pid in sorted(self._panel_set):
|
||||
t = self._panel_fallback_tasks.get(pid)
|
||||
if not t or t.done():
|
||||
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||
|
||||
async def _stop_fallback_workers(self) -> None:
|
||||
self._fallback_mode = False
|
||||
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._session_fallback_tasks.clear()
|
||||
self._panel_fallback_tasks.clear()
|
||||
|
||||
async def _session_watch_worker(self, session_id: str) -> None:
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||
})
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||
})
|
||||
msgs = resp.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
for m in reversed(msgs):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(m.get("messageId") or ""),
|
||||
author=str(m.get("author") or ""),
|
||||
content=m.get("content"),
|
||||
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||
author_info=m.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
|
||||
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
target_id = _str_field(payload, "sessionId")
|
||||
if not target_id:
|
||||
return
|
||||
|
||||
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||
async with lock:
|
||||
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||
pc = payload.get("cursor")
|
||||
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||
self._mark_session_cursor(target_id, pc)
|
||||
|
||||
raw_events = payload.get("events")
|
||||
if not isinstance(raw_events, list):
|
||||
return
|
||||
if target_kind == "session" and target_id in self._cold_sessions:
|
||||
self._cold_sessions.discard(target_id)
|
||||
return
|
||||
|
||||
for event in raw_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
seq = event.get("seq")
|
||||
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||
self._mark_session_cursor(target_id, seq)
|
||||
if event.get("type") == "message.add":
|
||||
await self._process_inbound_event(target_id, event, target_kind)
|
||||
|
||||
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||
payload = event.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
author = _str_field(payload, "author")
|
||||
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||
return
|
||||
if not self.is_allowed(author):
|
||||
return
|
||||
|
||||
message_id = _str_field(payload, "messageId")
|
||||
seen_key = f"{target_kind}:{target_id}"
|
||||
if message_id and self._remember_message_id(seen_key, message_id):
|
||||
return
|
||||
|
||||
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||
ai = _safe_dict(payload.get("authorInfo"))
|
||||
sender_name = _str_field(ai, "nickname", "email")
|
||||
sender_username = _str_field(ai, "agentId")
|
||||
|
||||
group_id = _str_field(payload, "groupId")
|
||||
is_group = bool(group_id)
|
||||
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||
|
||||
if require_mention and not was_mentioned and not use_delay:
|
||||
return
|
||||
|
||||
entry = MochatBufferedEntry(
|
||||
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||
message_id=message_id, group_id=group_id,
|
||||
)
|
||||
|
||||
if use_delay:
|
||||
delay_key = seen_key
|
||||
if was_mentioned:
|
||||
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||
else:
|
||||
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||
return
|
||||
|
||||
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||
|
||||
# ---- dedup / buffering -------------------------------------------------
|
||||
|
||||
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||
seen_set = self._seen_set.setdefault(key, set())
|
||||
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||
if message_id in seen_set:
|
||||
return True
|
||||
seen_set.add(message_id)
|
||||
seen_queue.append(message_id)
|
||||
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||
seen_set.discard(seen_queue.popleft())
|
||||
return False
|
||||
|
||||
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
state.entries.append(entry)
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||
|
||||
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||
|
||||
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
if entry:
|
||||
state.entries.append(entry)
|
||||
current = asyncio.current_task()
|
||||
if state.timer and state.timer is not current:
|
||||
state.timer.cancel()
|
||||
state.timer = None
|
||||
entries = state.entries[:]
|
||||
state.entries.clear()
|
||||
if entries:
|
||||
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||
|
||||
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||
if not entries:
|
||||
return
|
||||
last = entries[-1]
|
||||
is_group = bool(last.group_id)
|
||||
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=last.author, chat_id=target_id, content=body,
|
||||
metadata={
|
||||
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||
"is_group": is_group, "group_id": last.group_id,
|
||||
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||
"buffered_count": len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
async def _cancel_delay_timers(self) -> None:
|
||||
for state in self._delay_states.values():
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
self._delay_states.clear()
|
||||
|
||||
# ---- notify handlers ---------------------------------------------------
|
||||
|
||||
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
group_id = _str_field(payload, "groupId")
|
||||
panel_id = _str_field(payload, "converseId", "panelId")
|
||||
if not group_id or not panel_id:
|
||||
return
|
||||
if self._panel_set and panel_id not in self._panel_set:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||
author=str(payload.get("author") or ""),
|
||||
content=payload.get("content"), meta=payload.get("meta"),
|
||||
group_id=group_id, converse_id=panel_id,
|
||||
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
|
||||
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||
return
|
||||
detail = payload.get("payload")
|
||||
if not isinstance(detail, dict):
|
||||
return
|
||||
if _str_field(detail, "groupId"):
|
||||
return
|
||||
converse_id = _str_field(detail, "converseId")
|
||||
if not converse_id:
|
||||
return
|
||||
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
await self._refresh_sessions_directory(self._ws_ready)
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||
author=str(detail.get("messageAuthor") or ""),
|
||||
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||
)
|
||||
await self._process_inbound_event(session_id, evt, "session")
|
||||
|
||||
# ---- cursor persistence ------------------------------------------------
|
||||
|
||||
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||
return
|
||||
self._session_cursor[session_id] = cursor
|
||||
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||
|
||||
async def _save_cursor_debounced(self) -> None:
|
||||
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||
await self._save_session_cursors()
|
||||
|
||||
async def _load_session_cursors(self) -> None:
|
||||
if not self._cursor_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
for sid, cur in cursors.items():
|
||||
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||
self._session_cursor[sid] = cur
|
||||
|
||||
async def _save_session_cursors(self) -> None:
|
||||
try:
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cursor_path.write_text(json.dumps({
|
||||
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._http:
|
||||
raise RuntimeError("Mochat HTTP client not initialized")
|
||||
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||
response = await self._http.post(url, headers={
|
||||
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||
}, json=payload)
|
||||
if not response.is_success:
|
||||
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||
try:
|
||||
parsed = response.json()
|
||||
except Exception:
|
||||
parsed = response.text
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||
if parsed["code"] != 200:
|
||||
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||
data = parsed.get("data")
|
||||
return data if isinstance(data, dict) else {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||
"""Unified send helper for session and panel messages."""
|
||||
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||
if reply_to:
|
||||
body["replyTo"] = reply_to
|
||||
if group_id:
|
||||
body["groupId"] = group_id
|
||||
return await self._post_json(path, body)
|
||||
|
||||
@staticmethod
|
||||
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
value = metadata.get("group_id") or metadata.get("groupId")
|
||||
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||
@ -1,132 +0,0 @@
|
||||
"""QQ channel implementation using botpy SDK."""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
botpy = None
|
||||
C2CMessage = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botpy.message import C2CMessage
|
||||
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
"""Create a botpy Client subclass bound to the given channel."""
|
||||
intents = botpy.Intents(public_messages=True, direct_message=True)
|
||||
|
||||
class _Bot(botpy.Client):
|
||||
def __init__(self):
|
||||
super().__init__(intents=intents)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
await channel._on_message(message)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
await channel._on_message(message)
|
||||
|
||||
return _Bot
|
||||
|
||||
|
||||
class QQChannel(BaseChannel):
|
||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||
|
||||
name = "qq"
|
||||
|
||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: QQConfig = config
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
if not QQ_AVAILABLE:
|
||||
logger.error("QQ SDK not installed. Run: pip install qq-botpy")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.secret:
|
||||
logger.error("QQ app_id and secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
|
||||
logger.info("QQ bot started (C2C private message)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning("QQ bot error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("QQ bot stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through QQ."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
try:
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=0,
|
||||
content=msg.content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage") -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
try:
|
||||
# Dedup by message ID
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
author = data.author
|
||||
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
|
||||
content = (data.content or "").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=user_id,
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
@ -1,257 +0,0 @@
|
||||
"""Slack channel implementation using Socket Mode."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
name = "slack"
|
||||
|
||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: SlackConfig = config
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
if not self.config.bot_token or not self.config.app_token:
|
||||
logger.error("Slack bot/app token not configured")
|
||||
return
|
||||
if self.config.mode != "socket":
|
||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
self._web_client = AsyncWebClient(token=self.config.bot_token)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=self.config.app_token,
|
||||
web_client=self._web_client,
|
||||
)
|
||||
|
||||
self._socket_client.socket_mode_request_listeners.append(self._on_socket_request)
|
||||
|
||||
# Resolve bot user ID for mention handling
|
||||
try:
|
||||
auth = await self._web_client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id")
|
||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||
except Exception as e:
|
||||
logger.warning("Slack auth_test failed: {}", e)
|
||||
|
||||
logger.info("Starting Slack Socket Mode client...")
|
||||
await self._socket_client.connect()
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Slack client."""
|
||||
self._running = False
|
||||
if self._socket_client:
|
||||
try:
|
||||
await self._socket_client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Slack socket close failed: {}", e)
|
||||
self._socket_client = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Slack."""
|
||||
if not self._web_client:
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||
use_thread = thread_ts and channel_type != "im"
|
||||
thread_ts_param = thread_ts if use_thread else None
|
||||
|
||||
if msg.content:
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content),
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
req: SocketModeRequest,
|
||||
) -> None:
|
||||
"""Handle incoming Socket Mode requests."""
|
||||
if req.type != "events_api":
|
||||
return
|
||||
|
||||
# Acknowledge right away
|
||||
await client.send_socket_mode_response(
|
||||
SocketModeResponse(envelope_id=req.envelope_id)
|
||||
)
|
||||
|
||||
payload = req.payload or {}
|
||||
event = payload.get("event") or {}
|
||||
event_type = event.get("type")
|
||||
|
||||
# Handle app mentions or plain messages
|
||||
if event_type not in ("message", "app_mention"):
|
||||
return
|
||||
|
||||
sender_id = event.get("user")
|
||||
chat_id = event.get("channel")
|
||||
|
||||
# Ignore bot/system messages (any subtype = not a normal user message)
|
||||
if event.get("subtype"):
|
||||
return
|
||||
if self._bot_user_id and sender_id == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Avoid double-processing: Slack sends both `message` and `app_mention`
|
||||
# for mentions in channels. Prefer `app_mention`.
|
||||
text = event.get("text") or ""
|
||||
if event_type == "message" and self._bot_user_id and f"<@{self._bot_user_id}>" in text:
|
||||
return
|
||||
|
||||
# Debug: log basic event shape
|
||||
logger.debug(
|
||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||
event_type,
|
||||
event.get("subtype"),
|
||||
sender_id,
|
||||
chat_id,
|
||||
event.get("channel_type"),
|
||||
text[:80],
|
||||
)
|
||||
if not sender_id or not chat_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type") or ""
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
return
|
||||
|
||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||
return
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
await self._web_client.reactions_add(
|
||||
channel=chat_id,
|
||||
name=self.config.react_emoji,
|
||||
timestamp=event.get("ts"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Slack reactions_add failed: {}", e)
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
},
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
if self.config.dm.policy == "allowlist":
|
||||
return sender_id in self.config.dm.allow_from
|
||||
return True
|
||||
|
||||
# Group / channel messages
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return True
|
||||
|
||||
def _should_respond_in_channel(self, event_type: str, text: str, chat_id: str) -> bool:
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
if self.config.group_policy == "mention":
|
||||
if event_type == "app_mention":
|
||||
return True
|
||||
return self._bot_user_id is not None and f"<@{self._bot_user_id}>" in text
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return False
|
||||
|
||||
def _strip_bot_mention(self, text: str) -> str:
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||
|
||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||
|
||||
@classmethod
|
||||
def _to_mrkdwn(cls, text: str) -> str:
|
||||
"""Convert Markdown to Slack mrkdwn, including tables."""
|
||||
if not text:
|
||||
return ""
|
||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||
return slackify_markdown(text)
|
||||
|
||||
@staticmethod
|
||||
def _convert_table(match: re.Match) -> str:
|
||||
"""Convert a Markdown table to a Slack-readable list."""
|
||||
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return match.group(0)
|
||||
headers = [h.strip() for h in lines[0].strip("|").split("|")]
|
||||
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
|
||||
rows: list[str] = []
|
||||
for line in lines[start:]:
|
||||
cells = [c.strip() for c in line.strip("|").split("|")]
|
||||
cells = (cells + [""] * len(headers))[: len(headers)]
|
||||
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
|
||||
@ -1,457 +0,0 @@
|
||||
"""Telegram channel implementation using python-telegram-bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, Update, ReplyParameters
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
Convert markdown to Telegram-safe HTML.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos == -1:
|
||||
pos = cut.rfind(' ')
|
||||
if pos == -1:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
if self.config.proxy:
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
|
||||
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
|
||||
return "photo"
|
||||
if ext == "ogg":
|
||||
return "voice"
|
||||
if ext in ("mp3", "m4a", "wav", "aac"):
|
||||
return "audio"
|
||||
return "document"
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Telegram."""
|
||||
if not self._app:
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
except ValueError:
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
allow_sending_without_reply=True
|
||||
)
|
||||
|
||||
# Send media files
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
media_type = self._get_media_type(media_path)
|
||||
sender = {
|
||||
"photo": self._app.bot.send_photo,
|
||||
"voice": self._app.bot.send_voice,
|
||||
"audio": self._app.bot.send_audio,
|
||||
}.get(media_type, self._app.bot.send_document)
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
logger.error("Failed to send media {}: {}", media_path, e)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
for chunk in _split_message(msg.content):
|
||||
try:
|
||||
html = _markdown_to_telegram_html(chunk)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=html,
|
||||
parse_mode="HTML",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=chunk,
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
user = update.effective_user
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm Boardware Genius.\n\n"
|
||||
"Send me a message and I'll respond!\n"
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||
if not update.message:
|
||||
return
|
||||
await update.message.reply_text(
|
||||
"Boardware Genius commands:\n"
|
||||
"/new — Start a new conversation\n"
|
||||
"/help — Show available commands"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sender_id(user) -> str:
|
||||
"""Build sender_id with username for allowlist matching."""
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(update.effective_user),
|
||||
chat_id=str(update.message.chat_id),
|
||||
content=update.message.text,
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
elif message.voice:
|
||||
media_file = message.voice
|
||||
media_type = "voice"
|
||||
elif message.audio:
|
||||
media_file = message.audio
|
||||
media_type = "audio"
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||
|
||||
# Save to workspace/media/
|
||||
from pathlib import Path
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private"
|
||||
}
|
||||
)
|
||||
|
||||
def _start_typing(self, chat_id: str) -> None:
|
||||
"""Start sending 'typing...' indicator for a chat."""
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
while self._app:
|
||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
await asyncio.sleep(4)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||
"""Get file extension based on media type."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
"audio/ogg": ".ogg", "audio/mpeg": ".mp3", "audio/mp4": ".m4a",
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
return type_map.get(media_type, "")
|
||||
@ -1,148 +0,0 @@
|
||||
"""WhatsApp channel implementation using Node.js bridge."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import WhatsAppConfig
|
||||
|
||||
|
||||
class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
WhatsApp channel that connects to a Node.js bridge.
|
||||
|
||||
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
|
||||
Communication between Python and Node.js is via WebSocket.
|
||||
"""
|
||||
|
||||
name = "whatsapp"
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig = config
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
"to": msg.chat_id,
|
||||
"text": msg.content
|
||||
}
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
pn = data.get("pn", "")
|
||||
# New LID sytle typically:
|
||||
sender = data.get("sender", "")
|
||||
content = data.get("content", "")
|
||||
|
||||
# Extract just the phone number or lid as chat_id
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info("Sender {}", sender)
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
metadata={
|
||||
"message_id": data.get("id"),
|
||||
"timestamp": data.get("timestamp"),
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
@ -1 +0,0 @@
|
||||
"""CLI module for Boardware Genius."""
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +0,0 @@
|
||||
"""Configuration module for Boardware Genius."""
|
||||
|
||||
from nanobot.config.loader import load_config, get_config_path
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
__all__ = ["Config", "load_config", "get_config_path"]
|
||||
@ -1,97 +0,0 @@
|
||||
"""Configuration loading utilities."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Get the default configuration file path."""
|
||||
# 统一约定配置文件位置:~/.nanobot/config.json
|
||||
# 这样 CLI、Gateway、测试都能复用同一入口,不会出现路径分叉。
|
||||
return Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Get the nanobot data directory."""
|
||||
# 延迟导入(函数内 import)可以减少模块初始化时的依赖耦合。
|
||||
# get_data_path() 内部会确保目录存在。
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
return get_data_path()
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> Config:
|
||||
"""
|
||||
Load configuration from file or create default.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file. Uses default if not provided.
|
||||
|
||||
Returns:
|
||||
Loaded configuration object.
|
||||
"""
|
||||
# 如果调用者没传路径,就走默认路径 ~/.nanobot/config.json
|
||||
path = config_path or get_config_path()
|
||||
|
||||
# 只有文件存在才尝试读取;不存在时直接返回默认 Config。
|
||||
if path.exists():
|
||||
try:
|
||||
# 1) 读取 JSON 原始配置
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# 2) 做向后兼容迁移(旧字段 -> 新字段)
|
||||
data = _migrate_config(data)
|
||||
# 3) 用 Pydantic 做强校验与类型转换
|
||||
# 例如:camelCase/snake_case 映射、默认值补齐、字段类型检查。
|
||||
return Config.model_validate(data)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# 容错策略:配置损坏时不让程序崩溃,而是退回默认配置继续运行。
|
||||
print(f"Warning: Failed to load config from {path}: {e}")
|
||||
print("Using default configuration.")
|
||||
|
||||
# 配置文件不存在,或读取失败 -> 返回 schema 里的默认配置对象。
|
||||
return Config()
|
||||
|
||||
|
||||
def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
"""
|
||||
Save configuration to file.
|
||||
|
||||
Args:
|
||||
config: Configuration to save.
|
||||
config_path: Optional path to save to. Uses default if not provided.
|
||||
"""
|
||||
# 目标路径:优先用调用方传入路径,否则走默认路径。
|
||||
path = config_path or get_config_path()
|
||||
# 先确保父目录存在,避免 open(..., "w") 因目录缺失而失败。
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# model_dump(by_alias=True) 的关键点:
|
||||
# - schema 中很多字段 Python 侧是 snake_case(如 api_key)
|
||||
# - 配置文件对外希望保持 camelCase(如 apiKey)
|
||||
# - by_alias=True 会把字段按 alias 输出,保证写回文件的键名与用户配置习惯一致
|
||||
# (否则会写成 snake_case,和 README 示例不一致)。
|
||||
data = config.model_dump(by_alias=True)
|
||||
|
||||
# ensure_ascii=False: 保留中文等非 ASCII 字符,不转成 \uXXXX
|
||||
# indent=2: 让配置文件更易读、可手工编辑。
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def _migrate_config(data: dict) -> dict:
|
||||
"""Migrate old config formats to current."""
|
||||
# 这个函数专门做“历史配置兼容”:
|
||||
# 旧版字段:tools.exec.restrictToWorkspace
|
||||
# 新版字段:tools.restrictToWorkspace
|
||||
#
|
||||
# 迁移策略:
|
||||
# - 仅当旧字段存在且新字段不存在时才迁移
|
||||
# - 避免覆盖用户在新字段里已经明确设置的值
|
||||
tools = data.get("tools", {})
|
||||
exec_cfg = tools.get("exec", {})
|
||||
if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
|
||||
tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
|
||||
# 返回迁移后的原始 dict,后续再交给 Config.model_validate() 做结构化校验。
|
||||
return data
|
||||
@ -1,19 +0,0 @@
|
||||
"""Path helpers shared by config and channel integrations."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.loader import get_data_dir as _get_data_dir
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Return the global nanobot data directory (~/.nanobot)."""
|
||||
return _get_data_dir()
|
||||
|
||||
|
||||
def get_media_dir(channel: str | None = None) -> Path:
|
||||
"""Return the media directory, optionally namespaced by channel."""
|
||||
base = get_data_dir() / "media"
|
||||
if channel:
|
||||
base = base / str(channel)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base
|
||||
@ -1,539 +0,0 @@
|
||||
"""nanobot 配置 Schema(基于 Pydantic)。
|
||||
|
||||
这份文件是“配置系统的单一结构定义”:
|
||||
1. 定义配置长什么样(字段、默认值、嵌套结构)
|
||||
2. 负责配置的类型校验与兼容(camelCase / snake_case)
|
||||
3. 提供若干读取辅助方法(如 provider 匹配、api_key/api_base 解析)
|
||||
|
||||
你可以把它理解为:
|
||||
- `loader.py` 负责“读写配置文件”
|
||||
- `schema.py` 负责“配置对象的结构和规则”
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic.alias_generators import to_camel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""所有配置模型的基类。
|
||||
|
||||
关键点:
|
||||
- `alias_generator=to_camel`:自动把 `api_key` 这种字段映射到 `apiKey`
|
||||
- `populate_by_name=True`:读取时同时接受 snake_case 和 camelCase
|
||||
|
||||
结果:
|
||||
- Python 代码内部统一使用 snake_case,便于可读性和一致性
|
||||
- 配置文件对外保持 camelCase,贴近 README 和用户习惯
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
|
||||
|
||||
|
||||
class WhatsAppConfig(Base):
|
||||
"""WhatsApp 渠道配置。
|
||||
|
||||
说明:
|
||||
- nanobot 通过单独的 bridge 进程与 WhatsApp 交互
|
||||
- 这里配置的是 bridge 的连接地址和访问控制
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
bridge_url: str = "ws://localhost:3001"
|
||||
bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
|
||||
|
||||
|
||||
class TelegramConfig(Base):
|
||||
"""Telegram 渠道配置。
|
||||
|
||||
常用字段:
|
||||
- token:机器人凭证(必须)
|
||||
- allow_from:白名单(可选,空列表表示不限制)
|
||||
- proxy:在网络受限场景下可配置代理
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
token: str = "" # Bot token from @BotFather
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
||||
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||
|
||||
|
||||
class FeishuConfig(Base):
|
||||
"""飞书/Lark 渠道配置(基于长连接模式)。"""
|
||||
|
||||
enabled: bool = False
|
||||
app_id: str = "" # App ID from Feishu Open Platform
|
||||
app_secret: str = "" # App Secret from Feishu Open Platform
|
||||
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
||||
verification_token: str = "" # Verification Token for event subscription (optional)
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
||||
|
||||
|
||||
class DingTalkConfig(Base):
|
||||
"""钉钉渠道配置(Stream 模式)。"""
|
||||
|
||||
enabled: bool = False
|
||||
client_id: str = "" # AppKey
|
||||
client_secret: str = "" # AppSecret
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
"""Discord 渠道配置。"""
|
||||
|
||||
enabled: bool = False
|
||||
token: str = "" # Bot token from Discord Developer Portal
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
||||
|
||||
|
||||
class MatrixConfig(Base):
|
||||
"""Matrix (Element) 渠道配置。"""
|
||||
|
||||
enabled: bool = False
|
||||
homeserver: str = "https://matrix.org"
|
||||
access_token: str = ""
|
||||
user_id: str = "" # @bot:matrix.org
|
||||
device_id: str = ""
|
||||
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
|
||||
sync_stop_grace_seconds: int = (
|
||||
2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
|
||||
)
|
||||
max_media_bytes: int = (
|
||||
20 * 1024 * 1024
|
||||
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
|
||||
|
||||
class EmailConfig(Base):
|
||||
"""Email 渠道配置(IMAP 收件 + SMTP 发件)。
|
||||
|
||||
设计思路:
|
||||
- IMAP 负责拉取新邮件
|
||||
- SMTP 负责自动回复
|
||||
- 行为参数控制轮询频率、正文截断、标记已读等策略
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
consent_granted: bool = False # Explicit owner permission to access mailbox data
|
||||
|
||||
# IMAP (receive)
|
||||
imap_host: str = ""
|
||||
imap_port: int = 993
|
||||
imap_username: str = ""
|
||||
imap_password: str = ""
|
||||
imap_mailbox: str = "INBOX"
|
||||
imap_use_ssl: bool = True
|
||||
|
||||
# SMTP (send)
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_username: str = ""
|
||||
smtp_password: str = ""
|
||||
smtp_use_tls: bool = True
|
||||
smtp_use_ssl: bool = False
|
||||
from_address: str = ""
|
||||
|
||||
# Behavior
|
||||
auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent
|
||||
poll_interval_seconds: int = 30
|
||||
mark_seen: bool = True
|
||||
max_body_chars: int = 12000
|
||||
subject_prefix: str = "Re: "
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
|
||||
|
||||
|
||||
class MochatMentionConfig(Base):
|
||||
"""Mochat 提及(mention)规则。"""
|
||||
|
||||
require_in_groups: bool = False
|
||||
|
||||
|
||||
class MochatGroupRule(Base):
|
||||
"""Mochat 群组级别规则(可按群单独配置是否必须 @)。"""
|
||||
|
||||
require_mention: bool = False
|
||||
|
||||
|
||||
class MochatConfig(Base):
|
||||
"""Mochat 渠道配置。
|
||||
|
||||
包含三类参数:
|
||||
- 连接参数:base_url / socket_url / socket_path
|
||||
- 重连与轮询参数:各类 *_ms 与 retry 相关字段
|
||||
- 权限与会话参数:allow_from / sessions / panels / mention / groups
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
base_url: str = "https://mochat.io"
|
||||
socket_url: str = ""
|
||||
socket_path: str = "/socket.io"
|
||||
socket_disable_msgpack: bool = False
|
||||
socket_reconnect_delay_ms: int = 1000
|
||||
socket_max_reconnect_delay_ms: int = 10000
|
||||
socket_connect_timeout_ms: int = 10000
|
||||
refresh_interval_ms: int = 30000
|
||||
watch_timeout_ms: int = 25000
|
||||
watch_limit: int = 100
|
||||
retry_delay_ms: int = 500
|
||||
max_retry_attempts: int = 0 # 0 means unlimited retries
|
||||
claw_token: str = ""
|
||||
agent_user_id: str = ""
|
||||
sessions: list[str] = Field(default_factory=list)
|
||||
panels: list[str] = Field(default_factory=list)
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
|
||||
groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
|
||||
reply_delay_mode: str = "non-mention" # off | non-mention
|
||||
reply_delay_ms: int = 120000
|
||||
|
||||
|
||||
class SlackDMConfig(Base):
|
||||
"""Slack 私聊(DM)策略配置。"""
|
||||
|
||||
enabled: bool = True
|
||||
policy: str = "open" # "open" or "allowlist"
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs
|
||||
|
||||
|
||||
class SlackConfig(Base):
|
||||
"""Slack 渠道配置。"""
|
||||
|
||||
enabled: bool = False
|
||||
mode: str = "socket" # "socket" supported
|
||||
webhook_path: str = "/slack/events"
|
||||
bot_token: str = "" # xoxb-...
|
||||
app_token: str = "" # xapp-...
|
||||
user_token_read_only: bool = True
|
||||
reply_in_thread: bool = True
|
||||
react_emoji: str = "eyes"
|
||||
group_policy: str = "mention" # "mention", "open", "allowlist"
|
||||
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||
|
||||
|
||||
class QQConfig(Base):
|
||||
"""QQ 渠道配置(botpy SDK)。"""
|
||||
|
||||
enabled: bool = False
|
||||
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
|
||||
|
||||
|
||||
class ChannelsConfig(Base):
|
||||
"""所有聊天渠道的总配置。
|
||||
|
||||
除了具体渠道参数外,还有两个全局开关:
|
||||
- send_progress:是否把“处理中进度”推送到渠道
|
||||
- send_tool_hints:是否把“工具调用提示”推送到渠道
|
||||
"""
|
||||
|
||||
send_progress: bool = True # stream agent's text progress to the channel
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
||||
feishu: FeishuConfig = Field(default_factory=FeishuConfig)
|
||||
mochat: MochatConfig = Field(default_factory=MochatConfig)
|
||||
dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
|
||||
email: EmailConfig = Field(default_factory=EmailConfig)
|
||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||
qq: QQConfig = Field(default_factory=QQConfig)
|
||||
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
||||
|
||||
|
||||
class AgentDefaults(Base):
|
||||
"""Agent 默认行为配置。
|
||||
|
||||
关键参数建议理解:
|
||||
- model:主模型标识
|
||||
- max_tokens:单次回复上限
|
||||
- max_tool_iterations:一次请求里最多工具循环次数
|
||||
- memory_window:每次送给模型的历史窗口大小
|
||||
"""
|
||||
|
||||
workspace: str = "~/.nanobot/workspace"
|
||||
model: str = "anthropic/claude-opus-4-5"
|
||||
max_tokens: int = 8192
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
memory_window: int = 100
|
||||
|
||||
|
||||
class AgentsConfig(Base):
|
||||
"""Agent 顶层配置(当前主要是 defaults)。"""
|
||||
|
||||
defaults: AgentDefaults = Field(default_factory=AgentDefaults)
|
||||
|
||||
|
||||
class ProviderConfig(Base):
|
||||
"""单个 LLM Provider 的通用配置结构。
|
||||
|
||||
字段说明:
|
||||
- api_key:访问凭证
|
||||
- api_base:可选自定义网关/代理地址
|
||||
- extra_headers:额外 HTTP 头(某些网关会要求)
|
||||
"""
|
||||
|
||||
api_key: str = ""
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||
request_timeout_seconds: int = 600
|
||||
|
||||
|
||||
class ProvidersConfig(Base):
|
||||
"""所有 Provider 的配置集合。
|
||||
|
||||
这里的字段名必须和 `providers/registry.py` 里的 ProviderSpec.name 对齐。
|
||||
这样 `_match_provider()` 才能通过 `getattr(self.providers, spec.name)` 正确取值。
|
||||
"""
|
||||
|
||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
|
||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway
|
||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||
|
||||
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway 服务监听配置。"""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 18790
|
||||
|
||||
|
||||
class WebSearchConfig(Base):
|
||||
"""Web 搜索工具配置(当前主要是 Brave Search)。"""
|
||||
|
||||
api_key: str = "" # Brave Search API key
|
||||
max_results: int = 5
|
||||
|
||||
|
||||
class WebToolsConfig(Base):
|
||||
"""Web 工具总配置。"""
|
||||
|
||||
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||
|
||||
|
||||
class ExecToolConfig(Base):
|
||||
"""Shell 执行工具配置。"""
|
||||
|
||||
timeout: int = 60
|
||||
|
||||
|
||||
class MCPServerConfig(Base):
|
||||
"""单个 MCP 服务器配置(支持 stdio 与 HTTP 两种连接方式)。
|
||||
|
||||
使用方式:
|
||||
- stdio:配置 `command + args + env`
|
||||
- HTTP:配置 `url + headers`
|
||||
"""
|
||||
|
||||
command: str = "" # Stdio: command to run (e.g. "npx")
|
||||
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||
url: str = "" # HTTP: streamable HTTP endpoint URL
|
||||
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
|
||||
auth_mode: str = "none" # none | oauth_backend_token
|
||||
auth_audience: str = ""
|
||||
auth_scopes: list[str] = Field(default_factory=list)
|
||||
tool_timeout: int = 30 # Seconds before a tool call is cancelled
|
||||
sensitive: bool = False # Redact secrets/args from Web views and process events
|
||||
|
||||
|
||||
class A2AConfig(Base):
|
||||
"""A2A agent 委派配置。"""
|
||||
|
||||
# 总开关,预留给未来需要完全禁用远程委派的场景。
|
||||
enabled: bool = True
|
||||
# 单次远程任务的最长等待时间(秒)。
|
||||
timeout_seconds: int = 600
|
||||
# 非流式任务轮询间隔(秒)。
|
||||
poll_interval_seconds: int = 2
|
||||
# agent card 本地缓存 TTL,避免每次委派都重新拉远端元数据。
|
||||
card_cache_ttl_seconds: int = 300
|
||||
# group delegation 并发上限,防止一次性打爆本地或远端资源。
|
||||
max_parallel_agents: int = 4
|
||||
# 是否允许从 skill 元数据里暴露 agent cards。
|
||||
allow_skill_cards: bool = True
|
||||
# 是否允许读取 workspace/agents/registry.json 中的手工登记 agent。
|
||||
allow_workspace_agents: bool = True
|
||||
# 允许访问的远端 host 白名单;为空表示不限制。
|
||||
allowed_hosts: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolsConfig(Base):
|
||||
"""工具层总配置。
|
||||
|
||||
关键安全字段:
|
||||
- restrict_to_workspace:开启后,工具访问将被限制在 workspace 内
|
||||
"""
|
||||
|
||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||
a2a: A2AConfig = Field(default_factory=A2AConfig)
|
||||
|
||||
|
||||
class AuthzConfig(Base):
|
||||
"""外部 AuthZ/OAuth 服务配置。"""
|
||||
|
||||
enabled: bool = False
|
||||
base_url: str = "http://127.0.0.1:19090"
|
||||
request_timeout_seconds: int = 10
|
||||
outlook_mcp_url: str = ""
|
||||
|
||||
|
||||
class BackendIdentityConfig(Base):
|
||||
"""当前 backend 在 AuthZ 服务里的身份配置。"""
|
||||
|
||||
backend_id: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
name: str = "Local Backend"
|
||||
public_base_url: str = ""
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""nanobot 根配置对象。
|
||||
|
||||
这是业务代码中最常使用的配置入口:
|
||||
- `config.agents.defaults.model`
|
||||
- `config.channels.telegram.token`
|
||||
- `config.tools.restrict_to_workspace`
|
||||
等都会从这里往下访问。
|
||||
"""
|
||||
|
||||
agents: AgentsConfig = Field(default_factory=AgentsConfig)
|
||||
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
|
||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
authz: AuthzConfig = Field(default_factory=AuthzConfig)
|
||||
backend_identity: BackendIdentityConfig = Field(default_factory=BackendIdentityConfig)
|
||||
|
||||
@property
|
||||
def workspace_path(self) -> Path:
|
||||
"""返回展开后的 workspace 绝对路径对象。
|
||||
|
||||
`~` 会被替换成用户 home 目录,避免下游代码重复处理路径展开。
|
||||
"""
|
||||
return Path(self.agents.defaults.workspace).expanduser()
|
||||
|
||||
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
|
||||
"""根据模型名与当前配置,匹配最合适的 provider。
|
||||
|
||||
返回值:
|
||||
- ProviderConfig | None:匹配到的配置项(含 api_key/api_base)
|
||||
- str | None:provider 的 registry 名称(例如 openrouter/deepseek)
|
||||
|
||||
匹配优先级(非常重要):
|
||||
1. 显式前缀匹配:`github-copilot/...` 这种明确前缀优先
|
||||
2. 关键字匹配:按 PROVIDERS 顺序匹配关键词
|
||||
3. 兜底匹配:选第一个“已配置 api_key 的非 OAuth provider”
|
||||
"""
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
# 统一做小写与连字符归一化,减少字符串匹配分歧。
|
||||
model_lower = (model or self.agents.defaults.model).lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
|
||||
# 关键字匹配函数:同时兼容 dash/underscore 两种写法。
|
||||
def _kw_matches(kw: str) -> bool:
|
||||
kw = kw.lower()
|
||||
return kw in model_lower or kw.replace("-", "_") in model_normalized
|
||||
|
||||
# 第 1 轮:显式前缀优先
|
||||
# 例如 `github-copilot/gpt-5.3-codex`,必须匹配 github_copilot,
|
||||
# 不能被 `codex` 关键字误匹配成 openai_codex。
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and model_prefix and normalized_prefix == spec.name:
|
||||
if spec.is_oauth or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# 第 2 轮:按关键字匹配(顺序由 PROVIDERS 决定)
|
||||
# 顺序很关键:registry 里前面的 provider 具有更高优先级。
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
||||
if spec.is_oauth or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# 第 3 轮:兜底匹配
|
||||
# 规则:仅考虑“非 OAuth + 有 api_key”的 provider。
|
||||
# 原因:OAuth provider 需要显式模型选择,不能静默兜底。
|
||||
for spec in PROVIDERS:
|
||||
if spec.is_oauth:
|
||||
continue
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and p.api_key:
|
||||
return p, spec.name
|
||||
return None, None
|
||||
|
||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||
"""获取匹配到的 ProviderConfig(含 api_key/api_base/extra_headers)。"""
|
||||
p, _ = self._match_provider(model)
|
||||
return p
|
||||
|
||||
def get_provider_name(self, model: str | None = None) -> str | None:
|
||||
"""获取匹配到的 provider 名称(例如 deepseek/openrouter)。"""
|
||||
_, name = self._match_provider(model)
|
||||
return name
|
||||
|
||||
def get_api_key(self, model: str | None = None) -> str | None:
|
||||
"""获取当前模型对应的 API key(无则返回 None)。"""
|
||||
p = self.get_provider(model)
|
||||
return p.api_key if p else None
|
||||
|
||||
def get_api_base(self, model: str | None = None) -> str | None:
|
||||
"""获取当前模型的 api_base。
|
||||
|
||||
规则:
|
||||
1. 若用户显式配置了 api_base,优先返回用户值
|
||||
2. 否则若匹配到的是 gateway provider,则可回退到 registry 默认 base
|
||||
3. 标准 provider(非 gateway)默认不在这里强制写 api_base
|
||||
"""
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
p, name = self._match_provider(model)
|
||||
if p and p.api_base:
|
||||
return p.api_base
|
||||
# 仅 gateway 在此处应用默认 api_base。
|
||||
# 标准 provider(如 moonshot)通常在 provider 初始化时通过环境变量处理,
|
||||
# 避免污染全局 litellm.api_base。
|
||||
if name:
|
||||
spec = find_by_name(name)
|
||||
if spec and spec.is_gateway and spec.default_api_base:
|
||||
return spec.default_api_base
|
||||
return None
|
||||
|
||||
# BaseSettings 相关:
|
||||
# - env_prefix="NANOBOT_":环境变量前缀,例如 NANOBOT_AGENTS__DEFAULTS__MODEL
|
||||
# - env_nested_delimiter="__":双下划线用于拆分嵌套层级
|
||||
model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__")
|
||||
@ -1,6 +0,0 @@
|
||||
"""Cron service for scheduled agent tasks."""
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronSchedule
|
||||
|
||||
__all__ = ["CronService", "CronJob", "CronSchedule"]
|
||||
@ -1,116 +0,0 @@
|
||||
"""cron 任务运行时辅助逻辑。
|
||||
|
||||
这里负责把已经到点的 `CronJob` 真正翻译成一次可执行动作:
|
||||
1. 纯提醒型任务:直接向目标会话投递消息;
|
||||
2. agent task 型任务:构造自动执行上下文,再交给 `AgentLoop.process_direct()`;
|
||||
3. 额外注入 `cron_action` 工具,让模型可以反向控制后续调度。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.cron_action import CronActionTool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.cron.types import CronExecutionResult, CronJob
|
||||
|
||||
|
||||
async def _deliver_response(
|
||||
bus: MessageBus,
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
content: str | None,
|
||||
) -> None:
|
||||
# cron 统一通过 outbound 消息回到现有渠道层,避免绕开原有发送链路。
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content or "",
|
||||
))
|
||||
|
||||
|
||||
def _describe_schedule(job: CronJob) -> str:
|
||||
"""把调度对象转成面向模型的简短文本。"""
|
||||
if job.schedule.kind == "every":
|
||||
every_ms = job.schedule.every_ms or 0
|
||||
return f"every {every_ms // 1000}s"
|
||||
if job.schedule.kind == "cron":
|
||||
return job.schedule.expr or "cron"
|
||||
return "one-time"
|
||||
|
||||
|
||||
def _resolve_session_key(job: CronJob) -> str:
|
||||
"""为 cron task 选择一个应复用的会话 key。"""
|
||||
# 优先使用显式记录的 session_key,这样任务型 cron 可以延续原短期上下文。
|
||||
if job.payload.session_key:
|
||||
return job.payload.session_key
|
||||
# 如果老数据没有 session_key,但有 channel/to,则退化为路由键。
|
||||
if job.payload.channel and job.payload.to:
|
||||
return f"{job.payload.channel}:{job.payload.to}"
|
||||
# 再兜底到 cron 自己的命名空间,保证始终能生成稳定 key。
|
||||
return f"cron:{job.id}"
|
||||
|
||||
|
||||
def _build_execution_context(job: CronJob, session_key: str) -> str:
|
||||
"""构造注入给 agent 的自动执行上下文说明。"""
|
||||
schedule = _describe_schedule(job)
|
||||
return f"""This turn was triggered automatically by a scheduled cron job.
|
||||
|
||||
Job ID: {job.id}
|
||||
Job Name: {job.name}
|
||||
Schedule: {schedule}
|
||||
Origin Session: {session_key}
|
||||
|
||||
You are in autonomous scheduled-task mode:
|
||||
- This is not an interactive user turn.
|
||||
- Do not ask the user what to do next.
|
||||
- Execute the task, make the necessary tool calls, and report the concrete outcome.
|
||||
- If the task has reached a terminal condition, natural stopping point, or no longer needs future runs, emit a structured cron_action tool call instead of only describing it in text.
|
||||
- Use cron_action(action="complete_today", reason="...") when today's batch is complete and the job should resume next cycle.
|
||||
- Use cron_action(action="remove", reason="...") to delete the current job permanently.
|
||||
- Use cron_action(action="disable", reason="...") to stop the current job without deleting it.
|
||||
- Use cron_action(action="reschedule", ...) to change the current job's schedule deterministically.
|
||||
- Use the regular cron tool only if you truly need to inspect or manage additional jobs beyond the current one.
|
||||
"""
|
||||
|
||||
|
||||
async def run_cron_job(
|
||||
job: CronJob,
|
||||
*,
|
||||
agent: Any,
|
||||
bus: MessageBus,
|
||||
default_channel: str,
|
||||
default_chat_id: str,
|
||||
) -> CronExecutionResult:
|
||||
"""Execute one cron job according to its payload kind."""
|
||||
# deliver 目标允许任务使用自己的渠道配置,否则落回默认 web 会话。
|
||||
channel = job.payload.channel or default_channel
|
||||
chat_id = job.payload.to or default_chat_id
|
||||
|
||||
if job.payload.kind == "system_event":
|
||||
# 提醒模式不需要再过一层 agent 推理,直接把原消息投递给目标会话。
|
||||
message = job.payload.message
|
||||
if job.payload.deliver and job.payload.to:
|
||||
await _deliver_response(bus, channel=channel, chat_id=job.payload.to, content=message)
|
||||
return CronExecutionResult(response=message)
|
||||
|
||||
# task 模式会进入 agent 主循环,因此要准备复用的 session key 和运行说明。
|
||||
session_key = _resolve_session_key(job)
|
||||
execution_context = _build_execution_context(job, session_key)
|
||||
# 把 cron_action 作为“附加工具”注入,仅对当前这次 cron 执行生效。
|
||||
action_tool = CronActionTool(job.id)
|
||||
response = await agent.process_direct(
|
||||
content=job.payload.message,
|
||||
session_key=session_key,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
execution_context=execution_context,
|
||||
extra_tools=[action_tool],
|
||||
)
|
||||
# 若任务要求把最终结果投递出去,则沿用正常 outbound 消息链路。
|
||||
if job.payload.deliver and job.payload.to:
|
||||
await _deliver_response(bus, channel=channel, chat_id=job.payload.to, content=response)
|
||||
# runtime 同时返回文本结果和结构化动作,供 CronService 后续处理。
|
||||
return CronExecutionResult(response=response, action=action_tool.decision)
|
||||
@ -1,583 +0,0 @@
|
||||
"""Cron 调度服务(持久化 + 计算下一次触发 + 定时执行)。
|
||||
|
||||
这个模块是 nanobot 的“计划任务内核”,职责边界如下:
|
||||
1. 数据层:把任务状态持久化到 `jobs.json`,并在内存维护一个 `CronStore` 缓存;
|
||||
2. 调度层:根据 `at / every / cron` 规则计算每个任务的下一次触发时间;
|
||||
3. 执行层:在任务到点时调用 `on_job` 回调(通常由 gateway 注入,转到 agent 执行);
|
||||
4. 管理层:提供增删改查、启停、手动触发等公共 API。
|
||||
|
||||
关键设计点:
|
||||
- 单计时器模型:始终只保留“最近一次触发点”的 `asyncio.Task`,
|
||||
避免“每个任务一个 sleep 协程”导致的资源膨胀;
|
||||
- 懒加载存储:首次访问才读盘,后续以内存对象为准,写操作再落盘;
|
||||
- 容错优先:配置/解析异常尽量降级为空任务或不可调度,不让主服务崩溃。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Literal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.cron.types import (
|
||||
CronAction,
|
||||
CronExecutionResult,
|
||||
CronJob,
|
||||
CronJobState,
|
||||
CronPayload,
|
||||
CronSchedule,
|
||||
CronStore,
|
||||
)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
"""返回当前 Unix 时间戳(毫秒,基于系统墙钟时间)。"""
|
||||
# 这里使用 wall-clock(time.time),因为 cron 语义本身就是“现实时间点”。
|
||||
# 若改用 monotonic,则无法直接表达“今天 9:00”这种绝对时刻。
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
||||
"""计算下一次运行时间(毫秒时间戳)。
|
||||
|
||||
返回 None 表示该任务当前不可运行(如参数非法、时间已过或 cron 解析失败)。
|
||||
"""
|
||||
if schedule.kind == "at":
|
||||
# 一次性定时:仅当目标时间晚于“现在”才有效。
|
||||
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
|
||||
|
||||
if schedule.kind == "every":
|
||||
if not schedule.every_ms or schedule.every_ms <= 0:
|
||||
return None
|
||||
# 固定间隔任务:以“当前时刻 + 间隔”作为下一次触发点。
|
||||
# 注意这里不做“对齐”计算(例如每分钟整点),仅做相对延迟:
|
||||
# - 优点:实现简单、行为稳定;
|
||||
# - 代价:若执行耗时较长,长期看会有“相位漂移”(不保证卡在固定秒位)。
|
||||
return now_ms + schedule.every_ms
|
||||
|
||||
if schedule.kind == "cron" and schedule.expr:
|
||||
try:
|
||||
from croniter import croniter
|
||||
from zoneinfo import ZoneInfo
|
||||
# 使用调用方传入的 now_ms 作为基准,保证在同一输入下行为可预测。
|
||||
base_time = now_ms / 1000
|
||||
# 未指定 tz 时,退回到当前系统本地时区。
|
||||
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||
base_dt = datetime.fromtimestamp(base_time, tz=tz)
|
||||
cron = croniter(schedule.expr, base_dt)
|
||||
next_dt = cron.get_next(datetime)
|
||||
return int(next_dt.timestamp() * 1000)
|
||||
except Exception:
|
||||
# 调度表达式或时区非法时,返回 None 让上层把任务视为不可调度。
|
||||
# 这里吞掉异常是有意设计:单个坏任务不应拖垮整个调度器。
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_schedule_for_add(schedule: CronSchedule) -> None:
|
||||
"""在创建任务前做必要校验,避免写入明显不可执行的调度。"""
|
||||
# 只有 cron 表达式支持时区字段,at/every 传 tz 视为配置错误。
|
||||
if schedule.tz and schedule.kind != "cron":
|
||||
raise ValueError("tz can only be used with cron schedules")
|
||||
|
||||
if schedule.kind == "cron" and schedule.tz:
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
ZoneInfo(schedule.tz)
|
||||
except Exception:
|
||||
raise ValueError(f"unknown timezone '{schedule.tz}'") from None
|
||||
|
||||
|
||||
_DAILY_LIMIT_PATTERNS = [
|
||||
re.compile(r"今日.*已达.*上限"),
|
||||
re.compile(r"已达\d+支上限"),
|
||||
re.compile(r"停止介绍"),
|
||||
re.compile(r"daily (?:cap|limit).*(?:reached|hit)", re.IGNORECASE),
|
||||
re.compile(r"today.*(?:reached|hit).*(?:cap|limit)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _looks_like_daily_limit_reached(response: str | None) -> bool:
|
||||
if not response:
|
||||
return False
|
||||
probe = response.strip()
|
||||
if not probe:
|
||||
return False
|
||||
return any(pattern.search(probe) for pattern in _DAILY_LIMIT_PATTERNS)
|
||||
|
||||
|
||||
def _next_daily_cycle_start_ms(job: CronJob, now_ms: int) -> int:
|
||||
"""Pick the next local-day anchor time for finite daily batch jobs."""
|
||||
tz = datetime.now().astimezone().tzinfo
|
||||
now_dt = datetime.fromtimestamp(now_ms / 1000, tz=tz)
|
||||
anchor_source_ms = job.created_at_ms or now_ms
|
||||
anchor_dt = datetime.fromtimestamp(anchor_source_ms / 1000, tz=tz)
|
||||
candidate = now_dt.replace(
|
||||
hour=anchor_dt.hour,
|
||||
minute=anchor_dt.minute,
|
||||
second=anchor_dt.second,
|
||||
microsecond=anchor_dt.microsecond,
|
||||
) + timedelta(days=1)
|
||||
return int(candidate.timestamp() * 1000)
|
||||
|
||||
|
||||
def _schedule_from_action(action: CronAction) -> CronSchedule:
|
||||
if action.every_seconds is not None:
|
||||
return CronSchedule(kind="every", every_ms=action.every_seconds * 1000)
|
||||
if action.cron_expr:
|
||||
return CronSchedule(kind="cron", expr=action.cron_expr, tz=action.tz)
|
||||
if action.at:
|
||||
dt = datetime.fromisoformat(action.at)
|
||||
return CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
|
||||
raise ValueError("reschedule action requires exactly one schedule field")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionOutcome:
|
||||
removed: bool = False
|
||||
explicit_next_run: bool = False
|
||||
managed_next_run_at_ms: int | None = None
|
||||
|
||||
|
||||
_CronCallbackResult = str | CronExecutionResult | None
|
||||
|
||||
|
||||
class CronService:
|
||||
"""管理并执行定时任务的服务对象。
|
||||
|
||||
运行模型(事件循环内):
|
||||
1. `start()` 时加载 store、重算 next_run、挂载单计时器;
|
||||
2. 计时器唤醒后 `_on_timer()` 找到到期任务并顺序执行;
|
||||
3. 每次状态变化后都 `_save_store()` + `_arm_timer()`,保持数据与调度一致。
|
||||
|
||||
并发假设:
|
||||
- 默认在同一个 asyncio 事件循环线程内被调用;
|
||||
- 代码未显式加锁,不保证跨线程并发安全;
|
||||
- 若要跨线程/多进程共享,应加文件锁或迁移到数据库事务模型。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store_path: Path,
|
||||
on_job: Callable[[CronJob], Coroutine[Any, Any, _CronCallbackResult]] | None = None,
|
||||
):
|
||||
# 任务持久化文件(默认:~/.nanobot/data/cron/jobs.json)。
|
||||
self.store_path = store_path
|
||||
# 任务执行回调:由 gateway 注入,用于真正触发 agent 处理。
|
||||
# CLI 仅做任务管理时可以不传(保持 None)。
|
||||
self.on_job = on_job
|
||||
# `_store` 采用懒加载;首次访问时才读盘。
|
||||
self._store: CronStore | None = None
|
||||
# 全局只维护一个“最近唤醒点”的计时任务,减少无效 wake-up。
|
||||
self._timer_task: asyncio.Task | None = None
|
||||
# 服务开关:只要 stop() 把它置 False,计时器回调会自然短路退出。
|
||||
self._running = False
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""从磁盘加载任务到内存(懒加载 + 内存缓存)。"""
|
||||
if self._store:
|
||||
# 已加载过直接返回内存对象,避免频繁磁盘 IO。
|
||||
return self._store
|
||||
|
||||
if self.store_path.exists():
|
||||
try:
|
||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||
jobs = []
|
||||
for j in data.get("jobs", []):
|
||||
# 反序列化时字段采用“宽松读取”:
|
||||
# - 新老版本缺失字段尽量给默认值;
|
||||
# - 以最大兼容性优先,减少升级时配置爆炸。
|
||||
jobs.append(CronJob(
|
||||
id=j["id"],
|
||||
name=j["name"],
|
||||
enabled=j.get("enabled", True),
|
||||
schedule=CronSchedule(
|
||||
kind=j["schedule"]["kind"],
|
||||
at_ms=j["schedule"].get("atMs"),
|
||||
every_ms=j["schedule"].get("everyMs"),
|
||||
expr=j["schedule"].get("expr"),
|
||||
tz=j["schedule"].get("tz"),
|
||||
),
|
||||
payload=CronPayload(
|
||||
kind=j["payload"].get("kind", "agent_turn"),
|
||||
message=j["payload"].get("message", ""),
|
||||
session_key=j["payload"].get("sessionKey"),
|
||||
deliver=j["payload"].get("deliver", False),
|
||||
channel=j["payload"].get("channel"),
|
||||
to=j["payload"].get("to"),
|
||||
),
|
||||
state=CronJobState(
|
||||
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
||||
last_run_at_ms=j.get("state", {}).get("lastRunAtMs"),
|
||||
last_status=j.get("state", {}).get("lastStatus"),
|
||||
last_error=j.get("state", {}).get("lastError"),
|
||||
),
|
||||
created_at_ms=j.get("createdAtMs", 0),
|
||||
updated_at_ms=j.get("updatedAtMs", 0),
|
||||
delete_after_run=j.get("deleteAfterRun", False),
|
||||
))
|
||||
self._store = CronStore(jobs=jobs)
|
||||
except Exception as e:
|
||||
# 文件损坏或结构异常时,不让服务崩溃,回退为空 store。
|
||||
logger.warning("Failed to load cron store: {}", e)
|
||||
self._store = CronStore()
|
||||
else:
|
||||
# 首次运行尚无文件时,初始化为空 store。
|
||||
self._store = CronStore()
|
||||
|
||||
return self._store
|
||||
|
||||
def _save_store(self) -> None:
|
||||
"""把内存中的任务快照写回磁盘。"""
|
||||
if not self._store:
|
||||
return
|
||||
|
||||
# 首次保存时自动创建上级目录。
|
||||
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"version": self._store.version,
|
||||
"jobs": [
|
||||
{
|
||||
"id": j.id,
|
||||
"name": j.name,
|
||||
"enabled": j.enabled,
|
||||
"schedule": {
|
||||
"kind": j.schedule.kind,
|
||||
"atMs": j.schedule.at_ms,
|
||||
"everyMs": j.schedule.every_ms,
|
||||
"expr": j.schedule.expr,
|
||||
"tz": j.schedule.tz,
|
||||
},
|
||||
"payload": {
|
||||
"kind": j.payload.kind,
|
||||
"message": j.payload.message,
|
||||
"sessionKey": j.payload.session_key,
|
||||
"deliver": j.payload.deliver,
|
||||
"channel": j.payload.channel,
|
||||
"to": j.payload.to,
|
||||
},
|
||||
"state": {
|
||||
"nextRunAtMs": j.state.next_run_at_ms,
|
||||
"lastRunAtMs": j.state.last_run_at_ms,
|
||||
"lastStatus": j.state.last_status,
|
||||
"lastError": j.state.last_error,
|
||||
},
|
||||
"createdAtMs": j.created_at_ms,
|
||||
"updatedAtMs": j.updated_at_ms,
|
||||
"deleteAfterRun": j.delete_after_run,
|
||||
}
|
||||
for j in self._store.jobs
|
||||
]
|
||||
}
|
||||
|
||||
# 这里是“整文件覆盖写”模型,不是事务性写入。
|
||||
# 若未来需要更强一致性,可升级为“临时文件 + 原子 rename”。
|
||||
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动服务并挂载下一次唤醒计时器。"""
|
||||
# 幂等启动语义:重复 start 不抛错,但会重算并重新挂载 timer。
|
||||
self._running = True
|
||||
self._load_store()
|
||||
# 每次启动都重算 next_run,避免沿用过期的历史状态。
|
||||
self._recompute_next_runs()
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止服务并取消当前计时器。"""
|
||||
self._running = False
|
||||
if self._timer_task:
|
||||
# 取消后不等待完成:让调用方快速返回,避免阻塞关停流程。
|
||||
self._timer_task.cancel()
|
||||
self._timer_task = None
|
||||
|
||||
def _recompute_next_runs(self) -> None:
|
||||
"""批量重算启用任务的下一次触发时间。"""
|
||||
if not self._store:
|
||||
return
|
||||
now = _now_ms()
|
||||
for job in self._store.jobs:
|
||||
if job.enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
|
||||
|
||||
def _get_next_wake_ms(self) -> int | None:
|
||||
"""返回所有启用任务中最早的触发时间。"""
|
||||
if not self._store:
|
||||
return None
|
||||
times = [j.state.next_run_at_ms for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms]
|
||||
# 没有任何可触发任务则返回 None,上层据此不挂 timer。
|
||||
return min(times) if times else None
|
||||
|
||||
def _arm_timer(self) -> None:
|
||||
"""按“最近触发点”重置单计时器。"""
|
||||
# 每次状态变化后都重置 timer,保证只等待当前最近的一次触发。
|
||||
if self._timer_task:
|
||||
self._timer_task.cancel()
|
||||
|
||||
next_wake = self._get_next_wake_ms()
|
||||
if not next_wake or not self._running:
|
||||
return
|
||||
|
||||
delay_ms = max(0, next_wake - _now_ms())
|
||||
delay_s = delay_ms / 1000
|
||||
|
||||
async def tick():
|
||||
# sleep 期间若 timer 被 cancel,会抛 CancelledError 并自然结束任务。
|
||||
await asyncio.sleep(delay_s)
|
||||
if self._running:
|
||||
await self._on_timer()
|
||||
|
||||
self._timer_task = asyncio.create_task(tick())
|
||||
|
||||
async def _on_timer(self) -> None:
|
||||
"""计时器触发后执行所有到期任务,并继续调度下一轮。"""
|
||||
if not self._store:
|
||||
return
|
||||
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
|
||||
# 顺序执行,便于日志可读性与状态一致性;若后续有并发需求可在此扩展。
|
||||
# 这里“顺序而非并发”的取舍:
|
||||
# - 优点:状态更新顺序可预测,诊断简单;
|
||||
# - 代价:单个慢任务会延后后续任务执行。
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
|
||||
# 无论是否有 due job,都保存一次状态并重挂 timer,
|
||||
# 保证 next_run 与磁盘快照一致。
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_execution_result(
|
||||
callback_result: _CronCallbackResult,
|
||||
) -> CronExecutionResult:
|
||||
"""Normalize legacy string callbacks into the structured execution result."""
|
||||
if isinstance(callback_result, CronExecutionResult):
|
||||
return callback_result
|
||||
return CronExecutionResult(response=callback_result)
|
||||
|
||||
def _apply_structured_action(self, job: CronJob, action: CronAction) -> _ActionOutcome:
|
||||
"""Apply one structured cron control decision to the current job."""
|
||||
normalized = (action.action or "none").strip().lower()
|
||||
reason = action.reason or "no reason provided"
|
||||
if normalized == "none":
|
||||
return _ActionOutcome()
|
||||
if normalized == "remove":
|
||||
self._store.jobs = [item for item in self._store.jobs if item.id != job.id]
|
||||
logger.info("Cron: removed job '{}' via structured action ({})", job.name, reason)
|
||||
return _ActionOutcome(removed=True)
|
||||
if normalized == "disable":
|
||||
job.enabled = False
|
||||
job.state.next_run_at_ms = None
|
||||
logger.info("Cron: disabled job '{}' via structured action ({})", job.name, reason)
|
||||
return _ActionOutcome(explicit_next_run=True)
|
||||
if normalized == "complete_today":
|
||||
managed_next_run_at_ms = _next_daily_cycle_start_ms(job, _now_ms())
|
||||
logger.info(
|
||||
"Cron: job '{}' completed today's batch via structured action ({}), next cycle at {}",
|
||||
job.name,
|
||||
reason,
|
||||
managed_next_run_at_ms,
|
||||
)
|
||||
return _ActionOutcome(managed_next_run_at_ms=managed_next_run_at_ms)
|
||||
if normalized == "reschedule":
|
||||
schedule = _schedule_from_action(action)
|
||||
_validate_schedule_for_add(schedule)
|
||||
job.schedule = schedule
|
||||
job.enabled = True
|
||||
job.delete_after_run = schedule.kind == "at"
|
||||
job.state.next_run_at_ms = _compute_next_run(schedule, _now_ms())
|
||||
logger.info("Cron: rescheduled job '{}' via structured action ({})", job.name, reason)
|
||||
return _ActionOutcome(explicit_next_run=True)
|
||||
logger.warning("Cron: unknown structured action '{}' for job '{}'", normalized, job.name)
|
||||
return _ActionOutcome()
|
||||
|
||||
async def _execute_job(self, job: CronJob) -> None:
|
||||
"""执行单个任务并更新其运行状态。"""
|
||||
start_ms = _now_ms()
|
||||
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
|
||||
managed_next_run_at_ms: int | None = None
|
||||
removed_by_action = False
|
||||
explicit_next_run = False
|
||||
|
||||
try:
|
||||
result = CronExecutionResult()
|
||||
if self.on_job:
|
||||
# on_job 是业务注入点(如 gateway 中调用 agent.process_direct)。
|
||||
result = self._coerce_execution_result(await self.on_job(job))
|
||||
if result.action is not None:
|
||||
action_outcome = self._apply_structured_action(job, result.action)
|
||||
removed_by_action = action_outcome.removed
|
||||
explicit_next_run = action_outcome.explicit_next_run
|
||||
managed_next_run_at_ms = action_outcome.managed_next_run_at_ms
|
||||
elif job.schedule.kind == "every" and _looks_like_daily_limit_reached(result.response):
|
||||
managed_next_run_at_ms = _next_daily_cycle_start_ms(job, _now_ms())
|
||||
logger.info(
|
||||
"Cron: job '{}' reached daily terminal state, snoozed until {}",
|
||||
job.name,
|
||||
managed_next_run_at_ms,
|
||||
)
|
||||
# 无论回调是否返回内容,只要没有抛异常都视为成功。
|
||||
job.state.last_status = "ok"
|
||||
job.state.last_error = None
|
||||
logger.info("Cron: job '{}' completed", job.name)
|
||||
|
||||
except Exception as e:
|
||||
# 执行失败仅影响当前任务,不中断调度器整体运行。
|
||||
job.state.last_status = "error"
|
||||
job.state.last_error = str(e)
|
||||
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
||||
|
||||
job.state.last_run_at_ms = start_ms
|
||||
job.updated_at_ms = _now_ms()
|
||||
if removed_by_action:
|
||||
return
|
||||
if explicit_next_run:
|
||||
return
|
||||
if managed_next_run_at_ms is not None:
|
||||
# 终态任务:跳过本日剩余频繁触发,等到下一日周期起点再恢复。
|
||||
job.state.next_run_at_ms = managed_next_run_at_ms
|
||||
return
|
||||
|
||||
# 一次性任务:执行后按配置删除或停用,避免重复触发。
|
||||
if job.schedule.kind == "at":
|
||||
if job.delete_after_run:
|
||||
# 一次性且要求删除:直接从 store 移除,后续 list 不再显示。
|
||||
self._store.jobs = [j for j in self._store.jobs if j.id != job.id]
|
||||
else:
|
||||
# 一次性但不删除:仅禁用,便于事后审计/手动重启。
|
||||
job.enabled = False
|
||||
job.state.next_run_at_ms = None
|
||||
else:
|
||||
# 周期任务:立即计算下一次触发时间,供下轮 timer 使用。
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
# ========== Public API ==========
|
||||
|
||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||
"""列出任务,默认仅返回已启用任务。"""
|
||||
store = self._load_store()
|
||||
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
|
||||
# 以 next_run 升序返回,便于直接展示“谁最先执行”。
|
||||
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float("inf"))
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
name: str,
|
||||
schedule: CronSchedule,
|
||||
message: str,
|
||||
payload_kind: Literal["system_event", "agent_turn"] = "agent_turn",
|
||||
session_key: str | None = None,
|
||||
deliver: bool = False,
|
||||
channel: str | None = None,
|
||||
to: str | None = None,
|
||||
delete_after_run: bool = False,
|
||||
) -> CronJob:
|
||||
"""创建并持久化新任务。"""
|
||||
store = self._load_store()
|
||||
# 添加前做参数合法性校验,尽早失败并给上层明确异常。
|
||||
_validate_schedule_for_add(schedule)
|
||||
now = _now_ms()
|
||||
|
||||
job = CronJob(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name=name,
|
||||
enabled=True,
|
||||
schedule=schedule,
|
||||
payload=CronPayload(
|
||||
kind=payload_kind,
|
||||
message=message,
|
||||
session_key=session_key,
|
||||
deliver=deliver,
|
||||
channel=channel,
|
||||
to=to,
|
||||
),
|
||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||
created_at_ms=now,
|
||||
updated_at_ms=now,
|
||||
delete_after_run=delete_after_run,
|
||||
)
|
||||
|
||||
store.jobs.append(job)
|
||||
# 每次变更都立即落盘并重排 timer,避免“内存态/调度态”漂移。
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
|
||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||
return job
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""按 ID 删除任务;存在并删除成功时返回 True。"""
|
||||
store = self._load_store()
|
||||
before = len(store.jobs)
|
||||
store.jobs = [j for j in store.jobs if j.id != job_id]
|
||||
removed = len(store.jobs) < before
|
||||
|
||||
if removed:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
logger.info("Cron: removed job {}", job_id)
|
||||
|
||||
# 返回布尔值给上层决定提示文案(found/not found)。
|
||||
return removed
|
||||
|
||||
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
|
||||
"""启用或停用任务,并同步更新 next_run。"""
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
job.enabled = enabled
|
||||
job.updated_at_ms = _now_ms()
|
||||
if enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
else:
|
||||
job.state.next_run_at_ms = None
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
return job
|
||||
# 没找到任务时返回 None,调用方据此输出“not found”。
|
||||
return None
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""手动触发任务执行。
|
||||
|
||||
默认遵守启用状态;`force=True` 时即使任务被禁用也会执行一次。
|
||||
"""
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
if not force and not job.enabled:
|
||||
# 遵守启用状态:禁用任务默认不执行。
|
||||
return False
|
||||
await self._execute_job(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
return True
|
||||
return False
|
||||
|
||||
def status(self) -> dict:
|
||||
"""返回服务运行状态摘要。"""
|
||||
store = self._load_store()
|
||||
# 这个接口主要用于 status 面板,不暴露详细任务内容。
|
||||
return {
|
||||
"enabled": self._running,
|
||||
"jobs": len(store.jobs),
|
||||
"next_wake_at_ms": self._get_next_wake_ms(),
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
"""cron 模型对象定义。
|
||||
|
||||
这些 dataclass 主要承担两类职责:
|
||||
1. 作为内存中的稳定结构,供 CronService / Web API / Agent 工具共用;
|
||||
2. 作为持久化 JSON 的逻辑模型,尽量保持字段语义直观、兼容性友好。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronSchedule:
|
||||
"""Schedule definition for a cron job."""
|
||||
# `kind` 决定其余字段哪一个生效。
|
||||
kind: Literal["at", "every", "cron"]
|
||||
# `at`:绝对触发时间,毫秒时间戳。
|
||||
at_ms: int | None = None
|
||||
# `every`:固定间隔,毫秒。
|
||||
every_ms: int | None = None
|
||||
# `cron`:标准 5 段 cron 表达式,例如 `0 9 * * *`。
|
||||
expr: str | None = None
|
||||
# cron 表达式使用的时区;其余 kind 不应设置。
|
||||
tz: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronPayload:
|
||||
"""What to do when the job runs."""
|
||||
# system_event: 直接向目标会话投递消息(典型:提醒)
|
||||
# agent_turn: 把 message 当作 prompt 再交给 agent 执行
|
||||
kind: Literal["system_event", "agent_turn"] = "agent_turn"
|
||||
message: str = ""
|
||||
# 任务型 cron 若希望复用原会话短期记忆,可在这里保存 session_key。
|
||||
session_key: str | None = None
|
||||
# 是否把执行结果发回渠道层。
|
||||
deliver: bool = False
|
||||
channel: str | None = None # e.g. "whatsapp"
|
||||
to: str | None = None # e.g. phone number
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronAction:
|
||||
"""Structured cron control decision emitted by the LLM."""
|
||||
# `action` 是唯一必填字段,其余字段只在特定动作下有意义。
|
||||
action: Literal["none", "remove", "disable", "complete_today", "reschedule"] = "none"
|
||||
reason: str | None = None
|
||||
every_seconds: int | None = None
|
||||
cron_expr: str | None = None
|
||||
tz: str | None = None
|
||||
at: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronExecutionResult:
|
||||
"""Structured result of one cron execution."""
|
||||
# 模型最终输出文本。
|
||||
response: str | None = None
|
||||
# 可选结构化调度动作,例如 complete_today / remove / reschedule。
|
||||
action: CronAction | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronJobState:
|
||||
"""Runtime state of a job."""
|
||||
# 调度器计算出的下次执行时间。
|
||||
next_run_at_ms: int | None = None
|
||||
# 最近一次实际执行时间。
|
||||
last_run_at_ms: int | None = None
|
||||
# 最近一次执行结果状态。
|
||||
last_status: Literal["ok", "error", "skipped"] | None = None
|
||||
# 最近一次错误详情,便于 UI 排查。
|
||||
last_error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronJob:
|
||||
"""A scheduled job."""
|
||||
# 稳定主键。
|
||||
id: str
|
||||
# 展示名,主要用于 UI 和日志。
|
||||
name: str
|
||||
enabled: bool = True
|
||||
schedule: CronSchedule = field(default_factory=lambda: CronSchedule(kind="every"))
|
||||
payload: CronPayload = field(default_factory=CronPayload)
|
||||
state: CronJobState = field(default_factory=CronJobState)
|
||||
# 创建 / 更新时间都使用毫秒时间戳,便于直接序列化。
|
||||
created_at_ms: int = 0
|
||||
updated_at_ms: int = 0
|
||||
# 一次性任务执行后是否自动删除。
|
||||
delete_after_run: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronStore:
|
||||
"""Persistent store for cron jobs."""
|
||||
version: int = 1
|
||||
jobs: list[CronJob] = field(default_factory=list)
|
||||
@ -1,5 +0,0 @@
|
||||
"""Heartbeat service for periodic agent wake-ups."""
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
|
||||
__all__ = ["HeartbeatService"]
|
||||
@ -1,137 +0,0 @@
|
||||
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# Default interval: 30 minutes
|
||||
DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60
|
||||
|
||||
# Token the agent replies with when there is nothing to report
|
||||
HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK"
|
||||
|
||||
# The prompt sent to agent during heartbeat
|
||||
HEARTBEAT_PROMPT = (
|
||||
"Read HEARTBEAT.md in your workspace and follow any instructions listed there. "
|
||||
f"If nothing needs attention, reply with exactly: {HEARTBEAT_OK_TOKEN}"
|
||||
)
|
||||
|
||||
|
||||
def _is_heartbeat_empty(content: str | None) -> bool:
|
||||
"""Check if HEARTBEAT.md has no actionable content."""
|
||||
if not content:
|
||||
return True
|
||||
|
||||
# Lines to skip: empty, headers, HTML comments, empty checkboxes
|
||||
skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"}
|
||||
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or line.startswith("<!--") or line in skip_patterns:
|
||||
continue
|
||||
return False # Found actionable content
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class HeartbeatService:
|
||||
"""
|
||||
Periodic heartbeat service that wakes the agent to check for tasks.
|
||||
|
||||
The agent reads HEARTBEAT.md from the workspace and executes any tasks
|
||||
listed there. If it has something to report, the response is forwarded
|
||||
to the user via on_notify. If nothing needs attention, the agent replies
|
||||
HEARTBEAT_OK and the response is silently dropped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
on_heartbeat: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||
interval_s: int = DEFAULT_HEARTBEAT_INTERVAL_S,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.on_heartbeat = on_heartbeat
|
||||
self.on_notify = on_notify
|
||||
self.interval_s = interval_s
|
||||
self.enabled = enabled
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
@property
|
||||
def heartbeat_file(self) -> Path:
|
||||
return self.workspace / "HEARTBEAT.md"
|
||||
|
||||
def _read_heartbeat_file(self) -> str | None:
|
||||
"""Read HEARTBEAT.md content."""
|
||||
if self.heartbeat_file.exists():
|
||||
try:
|
||||
return self.heartbeat_file.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the heartbeat service."""
|
||||
if not self.enabled:
|
||||
logger.info("Heartbeat disabled")
|
||||
return
|
||||
if self._running:
|
||||
logger.warning("Heartbeat already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
logger.info("Heartbeat started (every {}s)", self.interval_s)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the heartbeat service."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
"""Main heartbeat loop."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interval_s)
|
||||
if self._running:
|
||||
await self._tick()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Heartbeat error: {}", e)
|
||||
|
||||
async def _tick(self) -> None:
|
||||
"""Execute a single heartbeat tick."""
|
||||
content = self._read_heartbeat_file()
|
||||
|
||||
# Skip if HEARTBEAT.md is empty or doesn't exist
|
||||
if _is_heartbeat_empty(content):
|
||||
logger.debug("Heartbeat: no tasks (HEARTBEAT.md empty)")
|
||||
return
|
||||
|
||||
logger.info("Heartbeat: checking for tasks...")
|
||||
|
||||
if self.on_heartbeat:
|
||||
try:
|
||||
response = await self.on_heartbeat(HEARTBEAT_PROMPT)
|
||||
if HEARTBEAT_OK_TOKEN in response.upper():
|
||||
logger.info("Heartbeat: OK (nothing to report)")
|
||||
else:
|
||||
logger.info("Heartbeat: completed, delivering response")
|
||||
if self.on_notify:
|
||||
await self.on_notify(response)
|
||||
except Exception:
|
||||
logger.exception("Heartbeat execution failed")
|
||||
|
||||
async def trigger_now(self) -> str | None:
|
||||
"""Manually trigger a heartbeat."""
|
||||
if self.on_heartbeat:
|
||||
return await self.on_heartbeat(HEARTBEAT_PROMPT)
|
||||
return None
|
||||
@ -1,186 +0,0 @@
|
||||
"""Structured LLM audit logging persisted in backend storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import get_logs_path
|
||||
|
||||
_MAX_TEXT_PREVIEW = 1000
|
||||
_MAX_TRACEBACK_PREVIEW = 8000
|
||||
_REDACTED = "***REDACTED***"
|
||||
_SENSITIVE_KEYS = {
|
||||
"api_key",
|
||||
"authorization",
|
||||
"proxy_authorization",
|
||||
"x_api_key",
|
||||
"x-api-key",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"secret",
|
||||
"password",
|
||||
}
|
||||
|
||||
|
||||
def get_llm_audit_log_path() -> Path:
|
||||
"""Return the persisted LLM audit log path."""
|
||||
return get_logs_path() / "llm_audit.jsonl"
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _truncate_text(text: str, limit: int = _MAX_TEXT_PREVIEW) -> str:
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[:limit] + "...(truncated)"
|
||||
|
||||
|
||||
def _redact_value(key: str, value: Any) -> Any:
|
||||
if key.lower() in _SENSITIVE_KEYS and value is not None:
|
||||
return _REDACTED
|
||||
return value
|
||||
|
||||
|
||||
def redact_mapping(mapping: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Redact common secret-like keys in a mapping."""
|
||||
if not mapping:
|
||||
return {}
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, value in mapping.items():
|
||||
if isinstance(value, dict):
|
||||
sanitized[key] = redact_mapping(value)
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
sanitized[key] = [
|
||||
redact_mapping(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
continue
|
||||
sanitized[key] = _redact_value(str(key), value)
|
||||
return sanitized
|
||||
|
||||
|
||||
def summarize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Build a compact audit-safe summary of prompt messages."""
|
||||
summary: list[dict[str, Any]] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
item: dict[str, Any] = {
|
||||
"index": idx,
|
||||
"role": msg.get("role"),
|
||||
}
|
||||
if "name" in msg:
|
||||
item["name"] = msg.get("name")
|
||||
if "tool_call_id" in msg:
|
||||
item["tool_call_id"] = msg.get("tool_call_id")
|
||||
|
||||
content = msg.get("content")
|
||||
if content is None:
|
||||
item["content_kind"] = "none"
|
||||
elif isinstance(content, str):
|
||||
item["content_kind"] = "text"
|
||||
item["content_length"] = len(content)
|
||||
item["content_preview"] = _truncate_text(content)
|
||||
elif isinstance(content, list):
|
||||
item["content_kind"] = "blocks"
|
||||
item["content_blocks"] = len(content)
|
||||
item["content_preview"] = _truncate_text(json.dumps(content, ensure_ascii=False))
|
||||
else:
|
||||
rendered = str(content)
|
||||
item["content_kind"] = type(content).__name__
|
||||
item["content_length"] = len(rendered)
|
||||
item["content_preview"] = _truncate_text(rendered)
|
||||
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
item["tool_calls"] = summarize_tool_calls(tool_calls)
|
||||
|
||||
summary.append(item)
|
||||
return summary
|
||||
|
||||
|
||||
def summarize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Summarize outgoing or incoming tool calls."""
|
||||
summary: list[dict[str, Any]] = []
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
if hasattr(tool_call, "function"):
|
||||
function = getattr(tool_call, "function")
|
||||
arguments = getattr(function, "arguments", None)
|
||||
summary.append({
|
||||
"index": idx,
|
||||
"id": getattr(tool_call, "id", None),
|
||||
"name": getattr(function, "name", None),
|
||||
"arguments_preview": _truncate_text(str(arguments) if arguments is not None else ""),
|
||||
})
|
||||
continue
|
||||
|
||||
if isinstance(tool_call, dict):
|
||||
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
|
||||
summary.append({
|
||||
"index": idx,
|
||||
"id": tool_call.get("id"),
|
||||
"name": fn.get("name"),
|
||||
"arguments_preview": _truncate_text(str(fn.get("arguments", ""))),
|
||||
})
|
||||
continue
|
||||
|
||||
summary.append({
|
||||
"index": idx,
|
||||
"repr": _truncate_text(str(tool_call)),
|
||||
})
|
||||
return summary
|
||||
|
||||
|
||||
def summarize_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
||||
"""Summarize tool definitions sent to the provider."""
|
||||
if not tools:
|
||||
return []
|
||||
summary: list[dict[str, Any]] = []
|
||||
for idx, tool in enumerate(tools):
|
||||
function = tool.get("function") if isinstance(tool, dict) else None
|
||||
entry = {
|
||||
"index": idx,
|
||||
"type": tool.get("type") if isinstance(tool, dict) else None,
|
||||
}
|
||||
if isinstance(function, dict):
|
||||
entry["name"] = function.get("name")
|
||||
params = function.get("parameters")
|
||||
if params is not None:
|
||||
entry["parameters_preview"] = _truncate_text(json.dumps(params, ensure_ascii=False))
|
||||
else:
|
||||
entry["preview"] = _truncate_text(str(tool))
|
||||
summary.append(entry)
|
||||
return summary
|
||||
|
||||
|
||||
def write_llm_audit_event(event: dict[str, Any]) -> None:
|
||||
"""Append one JSONL audit event to backend storage."""
|
||||
payload = {
|
||||
"ts": _utc_now_iso(),
|
||||
**event,
|
||||
}
|
||||
path = get_llm_audit_log_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to persist LLM audit log: {}", exc)
|
||||
|
||||
|
||||
def summarize_exception(exc: BaseException) -> dict[str, str]:
|
||||
return {
|
||||
"type": type(exc).__name__,
|
||||
"message": str(exc),
|
||||
}
|
||||
|
||||
|
||||
def truncate_traceback(text: str) -> str:
|
||||
return _truncate_text(text, _MAX_TRACEBACK_PREVIEW)
|
||||
@ -1,7 +0,0 @@
|
||||
"""LLM provider abstraction module."""
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
|
||||
@ -1,120 +0,0 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.request_timeout_seconds = (
|
||||
max(1.0, float(request_timeout_seconds))
|
||||
if request_timeout_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace empty text content that causes provider 400 errors.
|
||||
|
||||
Empty content can appear when MCP tools return nothing. Most providers
|
||||
reject empty-string content or empty text blocks in list content.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(msg)
|
||||
if filtered:
|
||||
clean["content"] = filtered
|
||||
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
@ -1,61 +0,0 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "no-key",
|
||||
api_base: str = "http://localhost:8000/v1",
|
||||
default_model: str = "default",
|
||||
request_timeout_seconds: float | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self._sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
tool_calls = [
|
||||
ToolCallRequest(id=tc.id, name=tc.function.name,
|
||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
||||
for tc in (msg.tool_calls or [])
|
||||
]
|
||||
u = response.usage
|
||||
return LLMResponse(
|
||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -1,364 +0,0 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.llm_audit import (
|
||||
redact_mapping,
|
||||
summarize_exception,
|
||||
summarize_messages,
|
||||
summarize_tool_calls,
|
||||
summarize_tools,
|
||||
truncate_traceback,
|
||||
write_llm_audit_event,
|
||||
)
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
|
||||
|
||||
# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers.
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using LiteLLM for multi-provider support.
|
||||
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Detect gateway / local deployment.
|
||||
# provider_name (from config key) is the primary signal;
|
||||
# api_key / api_base are fallback for auto-detection.
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
self._setup_env(api_key, api_base, default_model)
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||
litellm.drop_params = True
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if not spec:
|
||||
return
|
||||
if not spec.env_key:
|
||||
# OAuth/provider-only specs (for example: openai_codex)
|
||||
return
|
||||
|
||||
# Gateway/local overrides existing env; standard provider doesn't
|
||||
if self._gateway:
|
||||
os.environ[spec.env_key] = api_key
|
||||
else:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
if self._gateway:
|
||||
# Gateway mode: apply gateway prefix, skip provider-specific prefixes
|
||||
prefix = self._gateway.litellm_prefix
|
||||
if self._gateway.strip_model_prefix:
|
||||
model = model.split("/")[-1]
|
||||
if prefix and not model.startswith(f"{prefix}/"):
|
||||
model = f"{prefix}/{model}"
|
||||
return model
|
||||
|
||||
# Standard mode: auto-prefix for known providers
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
|
||||
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
"""Normalize explicit provider prefixes like `github-copilot/...`."""
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
|
||||
def _supports_cache_control(self, model: str) -> bool:
|
||||
"""Return True when the provider supports cache_control on content blocks."""
|
||||
if self._gateway is not None:
|
||||
return self._gateway.supports_prompt_caching
|
||||
spec = find_by_model(model)
|
||||
return spec is not None and spec.supports_prompt_caching
|
||||
|
||||
def _apply_cache_control(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
"""Return copies of messages and tools with cache_control injected."""
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
content = msg["content"]
|
||||
if isinstance(content, str):
|
||||
new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
else:
|
||||
new_content = list(content)
|
||||
new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}}
|
||||
new_messages.append({**msg, "content": new_content})
|
||||
else:
|
||||
new_messages.append(msg)
|
||||
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}}
|
||||
|
||||
return new_messages, new_tools
|
||||
|
||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||
"""Apply model-specific parameter overrides from the registry."""
|
||||
model_lower = model.lower()
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||
sanitized = []
|
||||
for msg in messages:
|
||||
clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS}
|
||||
# Strict providers require "content" even when assistant only has tool_calls
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
model = self._resolve_model(original_model)
|
||||
request_id = str(uuid.uuid4())
|
||||
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
||||
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": sanitized_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if self.request_timeout_seconds is not None:
|
||||
kwargs["timeout"] = self.request_timeout_seconds
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
request_event = {
|
||||
"event": "llm_request",
|
||||
"request_id": request_id,
|
||||
"provider_impl": type(self).__name__,
|
||||
"gateway": self._gateway.name if self._gateway else None,
|
||||
"original_model": original_model,
|
||||
"resolved_model": model,
|
||||
"api_base": self.api_base,
|
||||
"has_api_key": bool(self.api_key),
|
||||
"temperature": kwargs.get("temperature"),
|
||||
"max_tokens": kwargs.get("max_tokens"),
|
||||
"timeout": kwargs.get("timeout"),
|
||||
"tool_choice": kwargs.get("tool_choice"),
|
||||
"message_count": len(sanitized_messages),
|
||||
"messages": summarize_messages(sanitized_messages),
|
||||
"tools": summarize_tools(tools),
|
||||
"extra_headers": redact_mapping(self.extra_headers),
|
||||
}
|
||||
write_llm_audit_event(request_event)
|
||||
logger.info(
|
||||
"LLM request [{}]: model={} messages={} tools={}",
|
||||
request_id,
|
||||
model,
|
||||
len(sanitized_messages),
|
||||
len(tools or []),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
parsed = self._parse_response(response)
|
||||
write_llm_audit_event({
|
||||
"event": "llm_response",
|
||||
"request_id": request_id,
|
||||
"provider_impl": type(self).__name__,
|
||||
"original_model": original_model,
|
||||
"resolved_model": model,
|
||||
"finish_reason": parsed.finish_reason,
|
||||
"usage": parsed.usage,
|
||||
"content_preview": parsed.content[:1000] if parsed.content else None,
|
||||
"reasoning_preview": parsed.reasoning_content[:1000] if parsed.reasoning_content else None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"arguments_preview": str(tc.arguments)[:1000],
|
||||
}
|
||||
for tc in parsed.tool_calls
|
||||
],
|
||||
"raw_tool_calls": summarize_tool_calls(
|
||||
getattr(response.choices[0].message, "tool_calls", None) or []
|
||||
),
|
||||
})
|
||||
logger.info(
|
||||
"LLM response [{}]: model={} finish_reason={} tool_calls={}",
|
||||
request_id,
|
||||
model,
|
||||
parsed.finish_reason,
|
||||
len(parsed.tool_calls),
|
||||
)
|
||||
return parsed
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
write_llm_audit_event({
|
||||
"event": "llm_error",
|
||||
"request_id": request_id,
|
||||
"provider_impl": type(self).__name__,
|
||||
"gateway": self._gateway.name if self._gateway else None,
|
||||
"original_model": original_model,
|
||||
"resolved_model": model,
|
||||
"api_base": self.api_base,
|
||||
"error": summarize_exception(e),
|
||||
"traceback": truncate_traceback(tb),
|
||||
"message_count": len(sanitized_messages),
|
||||
"messages": summarize_messages(sanitized_messages),
|
||||
"tools": summarize_tools(tools),
|
||||
})
|
||||
logger.exception("LLM error [{}]: model={} provider call failed", request_id, model)
|
||||
# Return error as content for graceful handling
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse LiteLLM response into our standard format."""
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
tool_calls = []
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
usage = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
@ -1,329 +0,0 @@
|
||||
"""OpenAI Codex Responses Provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
|
||||
|
||||
class OpenAICodexProvider(LLMProvider):
|
||||
"""Use Codex OAuth to call the Responses API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_model: str = "openai-codex/gpt-5.1-codex",
|
||||
request_timeout_seconds: float | None = None,
|
||||
):
|
||||
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": _strip_model_prefix(model),
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"instructions": system_prompt,
|
||||
"input": input_items,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"prompt_cache_key": _prompt_cache_key(messages),
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
url = DEFAULT_CODEX_URL
|
||||
|
||||
try:
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
verify=True,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
except Exception as e:
|
||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||
raise
|
||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
verify=False,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Codex: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": DEFAULT_ORIGINATOR,
|
||||
"User-Agent": "nanobot (python)",
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def _request_codex(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
timeout_seconds: float,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Handle text first.
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{idx}",
|
||||
}
|
||||
)
|
||||
# Then handle tool calls.
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
call_id = call_id or f"call_{idx}"
|
||||
item_id = item_id or f"fc_{idx}"
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
content += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
return f"HTTP {status_code}: {raw}"
|
||||
@ -1,462 +0,0 @@
|
||||
"""
|
||||
Provider Registry — single source of truth for LLM provider metadata.
|
||||
|
||||
Adding a new provider:
|
||||
1. Add a ProviderSpec to PROVIDERS below.
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""One LLM provider's metadata. See PROVIDERS below for real examples.
|
||||
|
||||
Placeholders in env_extras values:
|
||||
{api_key} — the user's API key
|
||||
{api_base} — api_base from config, or this spec's default_api_base
|
||||
"""
|
||||
|
||||
# identity
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# model prefixing
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
|
||||
# gateway / local detection
|
||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
|
||||
# gateway behavior
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
|
||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
|
||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||
|
||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||
is_direct: bool = False
|
||||
|
||||
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||
supports_prompt_caching: bool = False
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PROVIDERS — the registry. Order = priority. Copy any entry as template.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
litellm_prefix="",
|
||||
is_direct=True,
|
||||
),
|
||||
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# OpenAI Codex: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex", "codex"),
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="OpenAI Codex",
|
||||
litellm_prefix="", # Not routed through LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
|
||||
# Github Copilot: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||
skip_prefixes=("github_copilot/",),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(
|
||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(
|
||||
("MOONSHOT_API_BASE", "{api_base}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(
|
||||
("kimi-k2.5", {"temperature": 1.0}),
|
||||
),
|
||||
),
|
||||
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# Detected when config key is "vllm" (provider_name="vllm").
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="", # user must provide in config
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local]
|
||||
|
||||
# Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex.
|
||||
for spec in std_specs:
|
||||
if model_prefix and normalized_prefix == spec.name:
|
||||
return spec
|
||||
|
||||
for spec in std_specs:
|
||||
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
"""
|
||||
# 1. Direct match by config key
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user