97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
"""工具注册中心。
|
||
|
||
职责很单一:
|
||
1. 保存当前可用工具实例;
|
||
2. 向 LLM 暴露 function schema;
|
||
3. 在执行前做基础参数校验,并把异常统一转成文本结果。
|
||
"""
|
||
|
||
from typing import Any
|
||
|
||
from nanobot.agent.tools.base import Tool
|
||
|
||
|
||
class ToolRegistry:
|
||
"""
|
||
Registry for agent tools.
|
||
|
||
Allows dynamic registration and execution of tools.
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 工具名到实例的映射表;工具名在整个 registry 内必须唯一。
|
||
self._tools: dict[str, Tool] = {}
|
||
|
||
def register(self, tool: Tool) -> None:
|
||
"""注册一个工具实例。"""
|
||
self._tools[tool.name] = tool
|
||
|
||
def clone(self) -> "ToolRegistry":
|
||
"""创建一个浅拷贝,复用同一批工具实例。"""
|
||
# 这里不深拷贝工具对象,因为很多工具本身持有运行时状态或外部连接。
|
||
# 当前需求只是“在一个请求里临时附加额外工具”,复用实例即可。
|
||
other = ToolRegistry()
|
||
other._tools = dict(self._tools)
|
||
return other
|
||
|
||
def unregister(self, name: str) -> None:
|
||
"""Unregister a tool by name."""
|
||
self._tools.pop(name, None)
|
||
|
||
def get(self, name: str) -> Tool | None:
|
||
"""Get a tool by name."""
|
||
return self._tools.get(name)
|
||
|
||
def has(self, name: str) -> bool:
|
||
"""Check if a tool is registered."""
|
||
return name in self._tools
|
||
|
||
def get_definitions(self) -> list[dict[str, Any]]:
|
||
"""Get all tool definitions in OpenAI format."""
|
||
return [tool.to_schema() for tool in self._tools.values()]
|
||
|
||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||
"""
|
||
Execute a tool by name with given parameters.
|
||
|
||
Args:
|
||
name: Tool name.
|
||
params: Tool parameters.
|
||
|
||
Returns:
|
||
Tool execution result as string.
|
||
|
||
Raises:
|
||
KeyError: If tool not found.
|
||
"""
|
||
_hint = "\n\n[Analyze the error above and try a different approach.]"
|
||
|
||
tool = self._tools.get(name)
|
||
if not tool:
|
||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||
|
||
try:
|
||
# schema 级参数校验放在真正调用前做,尽量把错误反馈成模型能自修复的文本。
|
||
errors = tool.validate_params(params)
|
||
if errors:
|
||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _hint
|
||
result = await tool.execute(**params)
|
||
# 约定:工具若返回以 Error 开头的文本,说明是业务失败而非程序崩溃。
|
||
if isinstance(result, str) and result.startswith("Error"):
|
||
return result + _hint
|
||
return result
|
||
except Exception as e:
|
||
# 保持“不抛异常到模型层”的接口语义,统一回成可读文本。
|
||
return f"Error executing {name}: {str(e)}" + _hint
|
||
|
||
@property
|
||
def tool_names(self) -> list[str]:
|
||
"""Get list of registered tool names."""
|
||
return list(self._tools.keys())
|
||
|
||
def __len__(self) -> int:
|
||
return len(self._tools)
|
||
|
||
def __contains__(self, name: str) -> bool:
|
||
return name in self._tools
|