第一次提交
This commit is contained in:
6
app-instance/backend/nanobot/agent/tools/__init__.py
Normal file
6
app-instance/backend/nanobot/agent/tools/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Agent tools module."""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
__all__ = ["Tool", "ToolRegistry"]
|
||||
102
app-instance/backend/nanobot/agent/tools/base.py
Normal file
102
app-instance/backend/nanobot/agent/tools/base.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""Base class for agent tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for agent tools.
|
||||
|
||||
Tools are capabilities that the agent can use to interact with
|
||||
the environment, such as reading files, executing commands, etc.
|
||||
"""
|
||||
|
||||
_TYPE_MAP = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""
|
||||
Execute the tool with given parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
||||
return self._validate(params, {**schema, "type": "object"}, "")
|
||||
|
||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||
t, label = schema.get("type"), path or "parameter"
|
||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
||||
return [f"{label} should be {t}"]
|
||||
|
||||
errors = []
|
||||
if "enum" in schema and val not in schema["enum"]:
|
||||
errors.append(f"{label} must be one of {schema['enum']}")
|
||||
if t in ("integer", "number"):
|
||||
if "minimum" in schema and val < schema["minimum"]:
|
||||
errors.append(f"{label} must be >= {schema['minimum']}")
|
||||
if "maximum" in schema and val > schema["maximum"]:
|
||||
errors.append(f"{label} must be <= {schema['maximum']}")
|
||||
if t == "string":
|
||||
if "minLength" in schema and len(val) < schema["minLength"]:
|
||||
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
||||
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
||||
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
||||
if t == "object":
|
||||
props = schema.get("properties", {})
|
||||
for k in schema.get("required", []):
|
||||
if k not in val:
|
||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||
for k, v in val.items():
|
||||
if k in props:
|
||||
errors.extend(self._validate(v, props[k], path + '.' + k if path else k))
|
||||
if t == "array" and "items" in schema:
|
||||
for i, item in enumerate(val):
|
||||
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]"))
|
||||
return errors
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to OpenAI function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
}
|
||||
246
app-instance/backend/nanobot/agent/tools/cron.py
Normal file
246
app-instance/backend/nanobot/agent/tools/cron.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""cron 工具:给 Agent 提供“定时任务管理”能力。
|
||||
|
||||
这个工具是 LLM 在对话中可调用的 function tool,主要负责三件事:
|
||||
1. `add`:创建一个定时任务(周期/cron/一次性);
|
||||
2. `list`:列出现有任务;
|
||||
3. `remove`:删除指定任务。
|
||||
|
||||
设计定位说明:
|
||||
- 本工具只做“任务管理面”,不直接负责“定时器循环”;
|
||||
- 真正的调度与执行由 `CronService` 统一负责(start/stop/on_job);
|
||||
- 工具层通过 `set_context(channel, chat_id)` 注入当前会话路由,
|
||||
从而让定时任务在触发后把结果回投到正确会话。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
class CronTool(Tool):
|
||||
"""对话可调用的 cron 管理工具。
|
||||
|
||||
调用来源:
|
||||
- 主 agent 在工具调用回合中发起 `cron(...)`。
|
||||
|
||||
关键约束:
|
||||
- action 仅支持 `add/list/remove` 三种;
|
||||
- `add` 必须带 message,并且必须先注入 session 上下文(channel/chat_id);
|
||||
- 时间相关参数三选一:`every_seconds` / `cron_expr` / `at`。
|
||||
"""
|
||||
|
||||
def __init__(self, cron_service: CronService):
|
||||
# 持有同一个 CronService 实例,保证:
|
||||
# 1) CLI 命令与 agent 工具看到同一份 jobs.json;
|
||||
# 2) 任务状态(next_run、enabled)在进程内一致。
|
||||
self._cron = cron_service
|
||||
# 路由上下文由 AgentLoop 每轮注入。
|
||||
# 任务触发时将按该路由把结果投递回原会话。
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
self._session_key = ""
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, session_key: str | None = None) -> None:
|
||||
"""设置当前会话路由上下文。
|
||||
|
||||
为什么需要它:
|
||||
- 用户在 A 会话里让 agent“每天提醒我”,
|
||||
任务未来触发时应回到 A,而不是误发到其他会话。
|
||||
- 因此 channel/chat_id 不依赖模型每次显式传参,
|
||||
而是由运行时在调用前预注入默认目标。
|
||||
"""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._session_key = session_key or f"{channel}:{chat_id}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# 暴露给模型的工具名。模型会以 `cron(...)` 发起 function call。
|
||||
return "cron"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# 给模型看的简要能力描述,尽量短而明确。
|
||||
return "Schedule reminders and recurring tasks. Actions: add, list, remove. Use mode=reminder or task."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# OpenAI function schema:
|
||||
# - 定义参数结构与类型;
|
||||
# - 由 ToolRegistry 在调用前做基础参数校验。
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
# add 时的任务文本:
|
||||
# - 既可做“纯提醒文案”,也可做“交给 agent 执行的提示”。
|
||||
"description": "Reminder message (for add)"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["reminder", "task"],
|
||||
"description": "Execution mode: reminder sends message directly; task re-enters agent"
|
||||
},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
# 固定间隔调度(单位秒),内部会转换为毫秒。
|
||||
"description": "Interval in seconds (for recurring tasks)"
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
# 标准 cron 表达式(5 段),例如每天 9 点:0 9 * * *
|
||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
# 仅与 cron_expr 搭配使用的 IANA 时区。
|
||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')"
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
# 一次性触发时间,ISO 格式(本地/带偏移都可由 fromisoformat 解析)。
|
||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
||||
},
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
message: str = "",
|
||||
mode: str | None = None,
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
"""工具主入口:按 action 分发到具体处理函数。
|
||||
|
||||
注意:
|
||||
- 这里不直接抛异常给上层;尽量返回可读错误字符串。
|
||||
- 真正未捕获异常(如非法日期解析)会被 ToolRegistry 包装成 Error 文本。
|
||||
"""
|
||||
# add:创建任务(并立即持久化),返回任务 ID。
|
||||
if action == "add":
|
||||
return self._add_job(message, mode, every_seconds, cron_expr, tz, at)
|
||||
# list:只读取并格式化输出,不改状态。
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
# remove:按 ID 删除任务并重置调度器。
|
||||
elif action == "remove":
|
||||
return self._remove_job(job_id)
|
||||
# schema 已限制枚举,这里是兜底防御。
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
def _add_job(
|
||||
self,
|
||||
message: str,
|
||||
mode: str | None,
|
||||
every_seconds: int | None,
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
) -> str:
|
||||
"""创建任务并写入 CronService。
|
||||
|
||||
参数优先级(互斥选择):
|
||||
1. `every_seconds` -> 固定间隔任务
|
||||
2. `cron_expr` -> cron 表达式任务
|
||||
3. `at` -> 一次性任务(执行后自动删除)
|
||||
"""
|
||||
# message 是 add 的必填语义字段:没有内容就无法定义“要做什么”。
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
# channel/chat_id 由 AgentLoop 注入;
|
||||
# 若缺失,说明当前调用上下文不完整,无法保证结果回投目标正确。
|
||||
if not self._channel or not self._chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
# 时区仅对 cron 表达式有意义;避免用户误把 tz 用在 every/at 上。
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
# 尽早校验时区,提前给出明确错误,避免把非法数据写入存储。
|
||||
if tz:
|
||||
from zoneinfo import ZoneInfo
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except (KeyError, Exception):
|
||||
return f"Error: unknown timezone '{tz}'"
|
||||
|
||||
# mode 缺省时默认按“提醒”处理:
|
||||
# - 与 cron skill 的说明一致;
|
||||
# - 避免把原始建任务指令再次送回 agent,造成任务自复制。
|
||||
normalized_mode = (mode or "reminder").strip().lower()
|
||||
if normalized_mode not in {"reminder", "task"}:
|
||||
return "Error: mode must be 'reminder' or 'task'"
|
||||
payload_kind = "system_event" if normalized_mode == "reminder" else "agent_turn"
|
||||
|
||||
# 构建调度对象:
|
||||
# - CronService 内部统一使用毫秒时间戳;
|
||||
# - `at` 任务默认 delete_after_run=True,执行一次后自动移除。
|
||||
delete_after = False
|
||||
if every_seconds:
|
||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||
elif cron_expr:
|
||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||
elif at:
|
||||
from datetime import datetime
|
||||
# fromisoformat 解析失败会抛 ValueError,
|
||||
# 该异常会由 ToolRegistry 统一转换为错误字符串返回给模型。
|
||||
dt = datetime.fromisoformat(at)
|
||||
at_ms = int(dt.timestamp() * 1000)
|
||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||
delete_after = True
|
||||
else:
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
# 创建任务并持久化:
|
||||
# - name 使用 message 前 30 字符做简短标题,便于列表展示;
|
||||
# - deliver=True:任务触发后默认向当前会话投递结果;
|
||||
# - channel/to 使用注入上下文,确保消息路由一致。
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
payload_kind=payload_kind,
|
||||
session_key=self._session_key or None,
|
||||
deliver=True,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
# 返回简明确认文本,便于模型后续引用 job_id 做删除或说明。
|
||||
return f"Created {normalized_mode} job '{job.name}' (id: {job.id})"
|
||||
|
||||
def _list_jobs(self) -> str:
|
||||
"""列出当前可见任务(默认仅启用任务)。"""
|
||||
jobs = self._cron.list_jobs()
|
||||
if not jobs:
|
||||
return "No scheduled jobs."
|
||||
# 输出格式保持轻量,避免把过多状态塞给模型。
|
||||
# 详细状态(next_run/last_error)可在 CLI 的 `nanobot cron list` 查看。
|
||||
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
||||
return "Scheduled jobs:\n" + "\n".join(lines)
|
||||
|
||||
def _remove_job(self, job_id: str | None) -> str:
|
||||
"""按 ID 删除任务。"""
|
||||
if not job_id:
|
||||
return "Error: job_id is required for remove"
|
||||
# remove_job 返回 bool,工具层负责转换成对话友好的文案。
|
||||
if self._cron.remove_job(job_id):
|
||||
return f"Removed job {job_id}"
|
||||
return f"Job {job_id} not found"
|
||||
116
app-instance/backend/nanobot/agent/tools/cron_action.py
Normal file
116
app-instance/backend/nanobot/agent/tools/cron_action.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""结构化 cron 生命周期控制工具。
|
||||
|
||||
cron 任务不是普通用户对话,它经常需要在运行完成后主动告诉调度器:
|
||||
- 这个任务已经可以删掉;
|
||||
- 今天这一轮先结束,下一天再继续;
|
||||
- 下次应该改成新的时间表。
|
||||
|
||||
这个工具就是让模型把这些决策显式写成结构化数据,而不是只留在自然语言里。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.types import CronAction
|
||||
|
||||
|
||||
class CronActionTool(Tool):
|
||||
"""捕获模型输出的机器可读 cron 控制决策。"""
|
||||
|
||||
def __init__(self, job_id: str):
|
||||
# `job_id` 仅用于回显和审计,不参与决策本身。
|
||||
self.job_id = job_id
|
||||
# `_decision` 在本轮 agent 执行期间最多被写一次,外部在结束后读取。
|
||||
self._decision: CronAction | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cron_action"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Record a structured lifecycle action for the currently running cron job."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["none", "remove", "disable", "complete_today", "reschedule"],
|
||||
"description": "Lifecycle action for the current cron job",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Short reason for audit logs",
|
||||
},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Required when action=reschedule and using fixed interval",
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
"description": "Required when action=reschedule and using cron expression",
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone for cron_expr reschedules",
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
"description": "Required when action=reschedule and using one-time ISO datetime",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@property
|
||||
def decision(self) -> CronAction | None:
|
||||
# 暴露最终结构化决策给 cron runtime,便于后处理调度状态。
|
||||
return self._decision
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
reason: str | None = None,
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> str:
|
||||
# 统一做小写规范化,避免模型传入 `Remove` / `REMOVE` 之类大小写变体。
|
||||
normalized = (action or "").strip().lower()
|
||||
allowed_actions = {"none", "remove", "disable", "complete_today", "reschedule"}
|
||||
if normalized not in allowed_actions:
|
||||
return f"Error: unsupported cron action '{action}'"
|
||||
# 非重排任务不允许额外携带调度字段,避免出现“说 remove 但又传 cron_expr”的脏数据。
|
||||
if normalized != "reschedule" and any(value is not None for value in (every_seconds, cron_expr, tz, at)):
|
||||
return "Error: schedule fields can only be used when action='reschedule'"
|
||||
|
||||
if normalized == "reschedule":
|
||||
# 重新排期必须在三种时间表达方式里三选一,不能都不传,也不能混传。
|
||||
options = int(every_seconds is not None) + int(bool(cron_expr)) + int(bool(at))
|
||||
if options != 1:
|
||||
return "Error: reschedule requires exactly one of every_seconds, cron_expr, or at"
|
||||
# 时区只有 cron 表达式才有意义。
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
|
||||
# 校验通过后,把本轮决策固化为 dataclass,交给 runtime 在执行后统一消费。
|
||||
self._decision = CronAction(
|
||||
action=normalized or "none",
|
||||
reason=(reason or "").strip() or None,
|
||||
every_seconds=every_seconds,
|
||||
cron_expr=cron_expr,
|
||||
tz=tz,
|
||||
at=at,
|
||||
)
|
||||
# 返回给模型/日志的是一条可读确认文本,方便工具调用结果出现在上下文里。
|
||||
detail = f" for job {self.job_id}"
|
||||
if self._decision.reason:
|
||||
detail += f" ({self._decision.reason})"
|
||||
return f"Recorded cron_action={self._decision.action}{detail}"
|
||||
275
app-instance/backend/nanobot/agent/tools/filesystem.py
Normal file
275
app-instance/backend/nanobot/agent/tools/filesystem.py
Normal file
@ -0,0 +1,275 @@
|
||||
"""File system tools: read, write, edit."""
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path:
|
||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||
p = Path(path).expanduser()
|
||||
if not p.is_absolute() and workspace:
|
||||
p = workspace / p
|
||||
resolved = p.resolve()
|
||||
if allowed_dir:
|
||||
try:
|
||||
resolved.relative_to(allowed_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_relative_to(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root.resolve())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _protected_write_error() -> str:
|
||||
return (
|
||||
"Error: Direct writes to workspace skills are blocked. "
|
||||
"Stage the skill for review and require explicit user approval before installation."
|
||||
)
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Tool to read file contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file at the given path."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to read"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
return content
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Tool to write content to a file."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
self._protected_paths = [p.expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to write to"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if any(_is_relative_to(file_path, protected) for protected in self._protected_paths):
|
||||
return _protected_write_error()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
|
||||
|
||||
class EditFileTool(Tool):
|
||||
"""Tool to edit a file by replacing text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
self._protected_paths = [p.expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to edit"
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace"
|
||||
},
|
||||
"new_text": {
|
||||
"type": "string",
|
||||
"description": "The text to replace with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if any(_is_relative_to(file_path, protected) for protected in self._protected_paths):
|
||||
return _protected_write_error()
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
if old_text not in content:
|
||||
return self._not_found_message(old_text, content, path)
|
||||
|
||||
# Count occurrences
|
||||
count = content.count(old_text)
|
||||
if count > 1:
|
||||
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
||||
|
||||
new_content = content.replace(old_text, new_text, 1)
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
return f"Successfully edited {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _not_found_message(old_text: str, content: str, path: str) -> str:
|
||||
"""Build a helpful error when old_text is not found."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
class ListDirTool(Tool):
|
||||
"""Tool to list directory contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_dir"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List the contents of a directory."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not dir_path.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
items = []
|
||||
for item in sorted(dir_path.iterdir()):
|
||||
prefix = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{prefix}{item.name}")
|
||||
|
||||
if not items:
|
||||
return f"Directory {path} is empty"
|
||||
|
||||
return "\n".join(items)
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
346
app-instance/backend/nanobot/agent/tools/mcp.py
Normal file
346
app-instance/backend/nanobot/agent/tools/mcp.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""MCP 客户端封装。
|
||||
|
||||
职责分两层:
|
||||
1. `connect_mcp_servers()` 负责建立与 MCP server 的连接,并把远端工具注册成 nanobot 本地工具;
|
||||
2. `MCPToolWrapper` 负责把单个远端 MCP tool 包装成可供 LLM 调用的 `Tool`,同时发出结构化过程事件。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.process_events import current_process_run_id, emit_process_event, new_run_id
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""把单个 MCP server tool 包装成 nanobot Tool。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
server_name: str,
|
||||
tool_def,
|
||||
*,
|
||||
call_tool: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
tool_timeout: int = 30,
|
||||
sensitive: bool = False,
|
||||
):
|
||||
self._session = session
|
||||
self._call_tool = call_tool or self._default_call_tool
|
||||
# 记录来源服务名,便于日志、事件流和最终导出的工具名保持可追踪。
|
||||
self._server_name = server_name
|
||||
self._original_name = tool_def.name
|
||||
# 在 nanobot 内部为 MCP 工具统一加 `mcp_<server>_` 前缀,避免同名冲突。
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._tool_timeout = tool_timeout
|
||||
self._sensitive = sensitive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
# 每次 MCP 调用都分配独立 run_id,前端可以把它显示成树状子步骤。
|
||||
run_id = new_run_id("mcp")
|
||||
args_json = json.dumps(kwargs, ensure_ascii=False) if kwargs else "{}"
|
||||
await emit_process_event(
|
||||
"process_run_started",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name}",
|
||||
status="running",
|
||||
metadata={
|
||||
"tool_name": self._original_name,
|
||||
"tool_args": None if self._sensitive else kwargs,
|
||||
"tool_timeout": self._tool_timeout,
|
||||
"sensitive": self._sensitive,
|
||||
},
|
||||
)
|
||||
# 在真正请求远端前先发一条 progress,方便 UI 及时显示“正在调用哪个工具”。
|
||||
await emit_process_event(
|
||||
"process_run_progress",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
text=(
|
||||
f"Calling {self._original_name}"
|
||||
if self._sensitive
|
||||
else f"Calling {self._original_name} with {args_json}"
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._call_tool(self._original_name, kwargs),
|
||||
timeout=self._tool_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时被视为业务失败,但不抛异常给上层 agent 循环,而是返回可读错误文本。
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
summary = f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
await emit_process_event(
|
||||
"process_run_status",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
text=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
summary=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return summary
|
||||
|
||||
# MCP SDK 返回的是结构化 content block 列表,这里统一摊平成文本。
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
output = "\n".join(parts) or "(no output)"
|
||||
artifact_type = "text"
|
||||
artifact_data: Any | None = None
|
||||
stripped = output.strip()
|
||||
# 如果看起来像 JSON,则额外解析成结构化 artifact,方便前端做更丰富展示。
|
||||
if stripped.startswith("{") or stripped.startswith("["):
|
||||
try:
|
||||
artifact_data = json.loads(stripped)
|
||||
artifact_type = "json"
|
||||
except json.JSONDecodeError:
|
||||
artifact_data = None
|
||||
await emit_process_event(
|
||||
"process_run_artifact",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name} result",
|
||||
artifact_type="redacted" if self._sensitive else artifact_type,
|
||||
content=None if self._sensitive or artifact_data is not None else output,
|
||||
data=None if self._sensitive else artifact_data,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="done",
|
||||
summary=(
|
||||
f"{self._original_name} completed"
|
||||
if self._sensitive
|
||||
else output[:1000]
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return output
|
||||
|
||||
async def _default_call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
return await self._session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict,
|
||||
registry: ToolRegistry,
|
||||
stack: AsyncExitStack,
|
||||
*,
|
||||
authz_config: Any | None = None,
|
||||
backend_identity: Any | None = None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""连接所有配置中的 MCP server,并把工具注册到 registry。"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from nanobot.authz.client import AuthzClient
|
||||
|
||||
async def _build_http_headers(server_name: str, cfg: Any) -> dict[str, str]:
|
||||
headers = dict(getattr(cfg, "headers", {}) or {})
|
||||
if getattr(cfg, "auth_mode", "none") != "oauth_backend_token":
|
||||
return headers
|
||||
|
||||
if not (
|
||||
authz_config
|
||||
and getattr(authz_config, "base_url", "").strip()
|
||||
and backend_identity
|
||||
and getattr(backend_identity, "client_id", "").strip()
|
||||
and getattr(backend_identity, "client_secret", "").strip()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"MCP server '{server_name}' requires AuthZ backend token, but authz/backend identity is incomplete"
|
||||
)
|
||||
|
||||
authz_client = AuthzClient(
|
||||
getattr(authz_config, "base_url"),
|
||||
timeout_seconds=int(getattr(authz_config, "request_timeout_seconds", 10)),
|
||||
)
|
||||
raw_audience = str(getattr(cfg, "auth_audience", "") or "").strip()
|
||||
# Older managed Outlook configs stored `auth_audience="mcp"`, but AuthZ
|
||||
# permissions are issued against `mcp:<server_id>`.
|
||||
if not raw_audience or raw_audience == "mcp":
|
||||
audience = f"mcp:{server_name}"
|
||||
elif raw_audience.startswith("mcp:"):
|
||||
audience = raw_audience
|
||||
else:
|
||||
audience = f"mcp:{raw_audience}"
|
||||
token_response = await authz_client.issue_token(
|
||||
client_id=getattr(backend_identity, "client_id"),
|
||||
client_secret=getattr(backend_identity, "client_secret"),
|
||||
audience=audience,
|
||||
scopes=[str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
|
||||
)
|
||||
access_token = str(token_response.get("access_token") or "").strip()
|
||||
if not access_token:
|
||||
raise RuntimeError(f"MCP server '{server_name}' did not receive an access token from AuthZ")
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
async def _open_http_session(
|
||||
session_stack: AsyncExitStack,
|
||||
cfg: Any,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
):
|
||||
http_client = await session_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=headers or None,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
)
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
return session
|
||||
|
||||
async def _list_http_tools(server_name: str, cfg: Any):
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
return tools.tools
|
||||
|
||||
def _make_http_call_tool(server_name: str, cfg: Any) -> Callable[[str, dict[str, Any]], Awaitable[Any]]:
|
||||
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
return await session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
return _call_tool
|
||||
|
||||
# `report` 会返回给调用方,用于 Web UI 展示连接状态和已发现工具。
|
||||
report: dict[str, dict[str, Any]] = {}
|
||||
for name, cfg in mcp_servers.items():
|
||||
report[name] = {
|
||||
"status": "disconnected",
|
||||
"last_error": None,
|
||||
"tool_names": [],
|
||||
"tool_count": 0,
|
||||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||||
}
|
||||
try:
|
||||
if cfg.command:
|
||||
# stdio 模式:本地拉起一个子进程,通过 stdin/stdout 与 MCP server 通信。
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
elif cfg.url:
|
||||
if getattr(cfg, "auth_mode", "none") == "oauth_backend_token":
|
||||
tools_defs = await _list_http_tools(name, cfg)
|
||||
call_tool = _make_http_call_tool(name, cfg)
|
||||
for tool_def in tools_defs:
|
||||
wrapper = MCPToolWrapper(
|
||||
None,
|
||||
name,
|
||||
tool_def,
|
||||
call_tool=call_tool,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
headers = await _build_http_headers(name, cfg)
|
||||
session = await _open_http_session(stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
# 没有 command 也没有 url 的条目视为无效配置,跳过但不抛异常。
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
|
||||
report[name]["tool_count"] = len(report[name]["tool_names"])
|
||||
report[name]["status"] = "connected"
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} tools registered",
|
||||
name,
|
||||
len(report[name]["tool_names"]),
|
||||
)
|
||||
except Exception as e:
|
||||
# 单个 server 失败不影响其他 server 继续连;错误写进 report 供 UI 展示。
|
||||
report[name]["status"] = "error"
|
||||
report[name]["last_error"] = str(e)
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
return report
|
||||
108
app-instance/backend/nanobot/agent/tools/message.py
Normal file
108
app-instance/backend/nanobot/agent/tools/message.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
class MessageTool(Tool):
|
||||
"""Tool to send messages to users on chat channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||
default_channel: str = "",
|
||||
default_chat_id: str = "",
|
||||
default_message_id: str | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._default_channel = default_channel
|
||||
self._default_chat_id = default_chat_id
|
||||
self._default_message_id = default_message_id
|
||||
self._sent_in_turn: bool = False
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel = channel
|
||||
self._default_chat_id = chat_id
|
||||
self._default_message_id = message_id
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
self._send_callback = callback
|
||||
|
||||
def start_turn(self) -> None:
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Send a message to the user. Use this when you want to communicate something."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, discord, etc.)"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID"
|
||||
},
|
||||
"media": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional: list of file paths to attach (images, audio, documents)"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
content: str,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
media: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
self._sent_in_turn = True
|
||||
media_info = f" with {len(media)} attachments" if media else ""
|
||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||
except Exception as e:
|
||||
return f"Error sending message: {str(e)}"
|
||||
96
app-instance/backend/nanobot/agent/tools/registry.py
Normal file
96
app-instance/backend/nanobot/agent/tools/registry.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""工具注册中心。
|
||||
|
||||
职责很单一:
|
||||
1. 保存当前可用工具实例;
|
||||
2. 向 LLM 暴露 function schema;
|
||||
3. 在执行前做基础参数校验,并把异常统一转成文本结果。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
Registry for agent tools.
|
||||
|
||||
Allows dynamic registration and execution of tools.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 工具名到实例的映射表;工具名在整个 registry 内必须唯一。
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""注册一个工具实例。"""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def clone(self) -> "ToolRegistry":
|
||||
"""创建一个浅拷贝,复用同一批工具实例。"""
|
||||
# 这里不深拷贝工具对象,因为很多工具本身持有运行时状态或外部连接。
|
||||
# 当前需求只是“在一个请求里临时附加额外工具”,复用实例即可。
|
||||
other = ToolRegistry()
|
||||
other._tools = dict(self._tools)
|
||||
return other
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool by name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""
|
||||
Execute a tool by name with given parameters.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
params: Tool parameters.
|
||||
|
||||
Returns:
|
||||
Tool execution result as string.
|
||||
|
||||
Raises:
|
||||
KeyError: If tool not found.
|
||||
"""
|
||||
_hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# schema 级参数校验放在真正调用前做,尽量把错误反馈成模型能自修复的文本。
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _hint
|
||||
result = await tool.execute(**params)
|
||||
# 约定:工具若返回以 Error 开头的文本,说明是业务失败而非程序崩溃。
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _hint
|
||||
return result
|
||||
except Exception as e:
|
||||
# 保持“不抛异常到模型层”的接口语义,统一回成可读文本。
|
||||
return f"Error executing {name}: {str(e)}" + _hint
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
284
app-instance/backend/nanobot/agent/tools/shell.py
Normal file
284
app-instance/backend/nanobot/agent/tools/shell.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""Shell execution tool."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ExecTool(Tool):
|
||||
"""Tool to execute shell commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
working_dir: str | None = None,
|
||||
deny_patterns: list[str] | None = None,
|
||||
allow_patterns: list[str] | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
protected_paths: list[Path] | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
self.deny_patterns = deny_patterns or [
|
||||
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||
r"\brmdir\s+/s\b", # rmdir /s
|
||||
r"(?:^|[;&|]\s*)format\b", # format (as standalone command only)
|
||||
r"\b(mkfs|diskpart)\b", # disk operations
|
||||
r"\bdd\s+if=", # dd
|
||||
r">\s*/dev/sd", # write to disk
|
||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||
]
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.protected_paths = [Path(p).expanduser().resolve() for p in protected_paths or []]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "exec"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
# Wait for the process to fully terminate so pipes are
|
||||
# drained and file descriptors are released.
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
return f"Error: Command timed out after {self.timeout} seconds"
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
# Truncate very long output
|
||||
max_len = 10000
|
||||
if len(result) > max_len:
|
||||
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
cmd = command.strip()
|
||||
lower = cmd.lower()
|
||||
|
||||
for pattern in self.deny_patterns:
|
||||
if re.search(pattern, lower):
|
||||
return "Error: Command blocked by safety guard (dangerous pattern detected)"
|
||||
|
||||
if self.allow_patterns:
|
||||
if not any(re.search(p, lower) for p in self.allow_patterns):
|
||||
return "Error: Command blocked by safety guard (not in allowlist)"
|
||||
|
||||
if self.restrict_to_workspace:
|
||||
if "..\\" in cmd or "../" in cmd:
|
||||
return "Error: Command blocked by safety guard (path traversal detected)"
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
||||
# Only match absolute paths — avoid false positives on relative
|
||||
# paths like ".venv/bin/python" where "/bin/python" would be
|
||||
# incorrectly extracted by the old pattern.
|
||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
|
||||
|
||||
for raw in win_paths + posix_paths:
|
||||
try:
|
||||
p = Path(raw.strip()).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
protected_error = self._guard_protected_paths(command, cwd)
|
||||
if protected_error:
|
||||
return protected_error
|
||||
|
||||
return None
|
||||
|
||||
def _guard_protected_paths(self, command: str, cwd: str) -> str | None:
|
||||
if not self.protected_paths:
|
||||
return None
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
if self._is_blocked_clawhub_install(command, cwd_path):
|
||||
return self._protected_write_error()
|
||||
|
||||
if not self._looks_like_write(command):
|
||||
return None
|
||||
|
||||
for raw in self._extract_path_tokens(command):
|
||||
resolved = self._resolve_command_path(raw, cwd_path)
|
||||
if resolved and any(self._is_relative_to(resolved, root) for root in self.protected_paths):
|
||||
return self._protected_write_error()
|
||||
|
||||
return None
|
||||
|
||||
def _is_blocked_clawhub_install(self, command: str, cwd_path: Path) -> bool:
|
||||
lower = command.lower()
|
||||
if "clawhub" not in lower or not re.search(r"\b(install|update)\b", lower):
|
||||
return False
|
||||
|
||||
workdir = self._extract_flag_value(command, "--workdir")
|
||||
if workdir:
|
||||
resolved = self._resolve_command_path(workdir, cwd_path)
|
||||
return any(
|
||||
resolved == root.parent or self._is_relative_to(root, resolved)
|
||||
for root in self.protected_paths
|
||||
)
|
||||
|
||||
return any(cwd_path == root.parent for root in self.protected_paths)
|
||||
|
||||
@staticmethod
|
||||
def _protected_write_error() -> str:
|
||||
return (
|
||||
"Error: Direct writes to workspace skills are blocked. "
|
||||
"Stage the skill for review and require explicit user approval before installation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_relative_to(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_flag_value(command: str, flag: str) -> str | None:
|
||||
tokens = ExecTool._tokenize(command)
|
||||
for i, token in enumerate(tokens):
|
||||
if token == flag and i + 1 < len(tokens):
|
||||
return tokens[i + 1]
|
||||
if token.startswith(flag + "="):
|
||||
return token.split("=", 1)[1]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_write(command: str) -> bool:
|
||||
lower = command.lower()
|
||||
if re.search(r"(^|[^<])>>?\s*\S+", command):
|
||||
return True
|
||||
if re.search(r"\bsed\s+-i(?:\s|$)", lower):
|
||||
return True
|
||||
return bool(re.search(
|
||||
r"\b(cp|mv|rm|mkdir|touch|install|tee|tar|unzip|zip|chmod|chown|git|python|python3|node|npx|bash|sh|zsh|pwsh|powershell)\b",
|
||||
lower,
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def _extract_path_tokens(command: str) -> list[str]:
|
||||
tokens = ExecTool._tokenize(command)
|
||||
path_tokens: list[str] = []
|
||||
skip_next = False
|
||||
for i, token in enumerate(tokens):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if token in {"--workdir", "-C"}:
|
||||
if i + 1 < len(tokens):
|
||||
path_tokens.append(tokens[i + 1])
|
||||
skip_next = True
|
||||
continue
|
||||
if "=" in token:
|
||||
key, value = token.split("=", 1)
|
||||
if key in {"--workdir"}:
|
||||
path_tokens.append(value)
|
||||
continue
|
||||
cleaned = token.strip("\"'")
|
||||
if ExecTool._looks_like_path_token(cleaned):
|
||||
path_tokens.append(cleaned)
|
||||
return path_tokens
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_path_token(token: str) -> bool:
|
||||
if not token or token in {".", ".."}:
|
||||
return True
|
||||
if token.startswith(("~", "/", "./", "../")):
|
||||
return True
|
||||
if re.match(r"^[A-Za-z]:\\", token):
|
||||
return True
|
||||
return "/" in token or "\\" in token
|
||||
|
||||
@staticmethod
|
||||
def _resolve_command_path(raw: str, cwd_path: Path) -> Path | None:
|
||||
token = raw.strip().strip("\"'")
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
path = Path(token).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (cwd_path / path).resolve()
|
||||
else:
|
||||
path = path.resolve()
|
||||
return path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(command: str) -> list[str]:
|
||||
try:
|
||||
return shlex.split(command, posix=os.name != "nt")
|
||||
except ValueError:
|
||||
return command.split()
|
||||
105
app-instance/backend/nanobot/agent/tools/spawn.py
Normal file
105
app-instance/backend/nanobot/agent/tools/spawn.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""spawn 工具:用于把任务委派给后台 agent。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.delegation import DelegationManager
|
||||
|
||||
|
||||
class SpawnTool(Tool):
|
||||
"""
|
||||
后台委派工具。
|
||||
|
||||
作用:
|
||||
1. 把耗时/可并行的任务委派给 DelegationManager;
|
||||
2. 目标可以是本地 agent、A2A 远端 agent 或 agent group;
|
||||
3. 后台任务异步执行,不阻塞当前对话回合。
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "DelegationManager"):
|
||||
# manager 负责真正创建 asyncio 后台任务并管理生命周期。
|
||||
self._manager = manager
|
||||
# 默认来源会话(CLI 直连场景)。实际会在每轮由 loop._set_tool_context 覆盖。
|
||||
self._origin_channel = "cli"
|
||||
self._origin_chat_id = "direct"
|
||||
self._announce_via_bus = True
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, announce_via_bus: bool = True) -> None:
|
||||
"""设置后台委派结果回传的目标会话。"""
|
||||
# 委派任务完成后并不会直接给用户发消息,
|
||||
# 而是把结果发回这里记录的 origin(channel/chat_id)对应会话。
|
||||
self._origin_channel = channel
|
||||
self._origin_chat_id = chat_id
|
||||
self._announce_via_bus = announce_via_bus
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# 暴露给 LLM 的工具名;模型会用这个名字发起 function call。
|
||||
return "spawn"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# 给模型看的能力描述,强调“后台执行 + 完成后回报”语义。
|
||||
return (
|
||||
"Delegate a task to a background agent. "
|
||||
"Use this for complex or time-consuming work that can run independently. "
|
||||
"You can target a specific agent, a group of agents, or let the system choose. "
|
||||
"The delegated agent(s) will report back when done."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# OpenAI function schema:定义模型可传入的参数结构。
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task for the delegated agent to complete",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Optional agent ID or name for a single target",
|
||||
},
|
||||
"targets": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of agent IDs/names for a group task",
|
||||
},
|
||||
"strategy": {
|
||||
"type": "string",
|
||||
"enum": ["auto", "local", "plugin", "a2a", "group"],
|
||||
"description": "Routing strategy. Default is auto.",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
target: str | None = None,
|
||||
targets: list[str] | None = None,
|
||||
strategy: str = "auto",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""创建并启动一个后台委派任务。"""
|
||||
# 这里仅负责转发请求,不在本工具内执行实际任务逻辑。
|
||||
# 返回值是“已启动”状态文本,真正结果稍后通过主消息总线回传。
|
||||
return await self._manager.dispatch(
|
||||
task=task,
|
||||
label=label,
|
||||
target=target,
|
||||
targets=targets,
|
||||
strategy=strategy,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
announce_via_bus=self._announce_via_bus,
|
||||
)
|
||||
163
app-instance/backend/nanobot/agent/tools/web.py
Normal file
163
app-instance/backend/nanobot/agent/tools/web.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""Web tools: web_search and web_fetch."""
|
||||
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Shared constants
|
||||
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_7_2) AppleWebKit/537.36"
|
||||
MAX_REDIRECTS = 5 # Limit redirects to prevent DoS attacks
|
||||
|
||||
|
||||
def _strip_tags(text: str) -> str:
|
||||
"""Remove HTML tags and decode entities."""
|
||||
text = re.sub(r'<script[\s\S]*?</script>', '', text, flags=re.I)
|
||||
text = re.sub(r'<style[\s\S]*?</style>', '', text, flags=re.I)
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
return html.unescape(text).strip()
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize whitespace."""
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
||||
|
||||
|
||||
def _validate_url(url: str) -> tuple[bool, str]:
|
||||
"""Validate URL: must be http(s) with valid domain."""
|
||||
try:
|
||||
p = urlparse(url)
|
||||
if p.scheme not in ('http', 'https'):
|
||||
return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
|
||||
if not p.netloc:
|
||||
return False, "Missing domain"
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web using Brave Search API."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Results (1-10)", "minimum": 1, "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
self.max_results = max_results
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
if not self.api_key:
|
||||
return "Error: BRAVE_API_KEY not configured"
|
||||
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": n},
|
||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||
timeout=10.0
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
results = r.json().get("web", {}).get("results", [])
|
||||
if not results:
|
||||
return f"No results for: {query}"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results[:n], 1):
|
||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||
if desc := item.get("description"):
|
||||
lines.append(f" {desc}")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL using Readability."""
|
||||
|
||||
name = "web_fetch"
|
||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to fetch"},
|
||||
"extractMode": {"type": "string", "enum": ["markdown", "text"], "default": "markdown"},
|
||||
"maxChars": {"type": "integer", "minimum": 100}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, max_chars: int = 50000):
|
||||
self.max_chars = max_chars
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||
from readability import Document
|
||||
|
||||
max_chars = maxChars or self.max_chars
|
||||
|
||||
# Validate URL before fetching
|
||||
is_valid, error_msg = _validate_url(url)
|
||||
if not is_valid:
|
||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
max_redirects=MAX_REDIRECTS,
|
||||
timeout=30.0
|
||||
) as client:
|
||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||
r.raise_for_status()
|
||||
|
||||
ctype = r.headers.get("content-type", "")
|
||||
|
||||
# JSON
|
||||
if "application/json" in ctype:
|
||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||
# HTML
|
||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||
doc = Document(r.text)
|
||||
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
||||
text = f"# {doc.title()}\n\n{content}" if doc.title() else content
|
||||
extractor = "readability"
|
||||
else:
|
||||
text, extractor = r.text, "raw"
|
||||
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||
|
||||
def _to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to markdown."""
|
||||
# Convert links, headings, lists before stripping tags
|
||||
text = re.sub(r'<a\s+[^>]*href=["\']([^"\']+)["\'][^>]*>([\s\S]*?)</a>',
|
||||
lambda m: f'[{_strip_tags(m[2])}]({m[1]})', html, flags=re.I)
|
||||
text = re.sub(r'<h([1-6])[^>]*>([\s\S]*?)</h\1>',
|
||||
lambda m: f'\n{"#" * int(m[1])} {_strip_tags(m[2])}\n', text, flags=re.I)
|
||||
text = re.sub(r'<li[^>]*>([\s\S]*?)</li>', lambda m: f'\n- {_strip_tags(m[1])}', text, flags=re.I)
|
||||
text = re.sub(r'</(p|div|section|article)>', '\n\n', text, flags=re.I)
|
||||
text = re.sub(r'<(br|hr)\s*/?>', '\n', text, flags=re.I)
|
||||
return _normalize(_strip_tags(text))
|
||||
Reference in New Issue
Block a user