Files
beaver_project/app-instance/backend-old/nanobot/agent/tools/mcp.py

383 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MCP 客户端封装。
职责分两层:
1. `connect_mcp_servers()` 负责建立与 MCP server 的连接,并把远端工具注册成 nanobot 本地工具;
2. `MCPToolWrapper` 负责把单个远端 MCP tool 包装成可供 LLM 调用的 `Tool`,同时发出结构化过程事件。
"""
import asyncio
import json
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from typing import Any
import httpx
from loguru import logger
from nanobot.agent.process_events import current_process_run_id, emit_process_event, new_run_id
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
def _iter_leaf_exceptions(exc: BaseException) -> list[BaseException]:
if isinstance(exc, BaseExceptionGroup):
leaves: list[BaseException] = []
for sub_exc in exc.exceptions:
leaves.extend(_iter_leaf_exceptions(sub_exc))
return leaves
return [exc]
def _describe_mcp_exception(exc: BaseException, *, server_name: str, url: str | None = None) -> str:
leaves = _iter_leaf_exceptions(exc)
target = f" ({url})" if url else ""
for leaf in leaves:
if isinstance(leaf, httpx.TimeoutException):
return f"MCP server '{server_name}' timed out while waiting for a response{target}"
if isinstance(leaf, httpx.ConnectError):
return f"MCP server '{server_name}' is unreachable{target}"
if isinstance(leaf, httpx.HTTPStatusError):
return f"MCP server '{server_name}' returned HTTP {leaf.response.status_code}{target}"
if isinstance(leaf, httpx.HTTPError):
detail = str(leaf).strip() or leaf.__class__.__name__
return f"MCP server '{server_name}' HTTP error{target}: {detail}"
detail_source = leaves[0] if leaves else exc
detail = str(detail_source).strip() or detail_source.__class__.__name__
if isinstance(exc, BaseExceptionGroup):
return f"MCP server '{server_name}' failed: {detail_source.__class__.__name__}: {detail}"
return detail
class MCPToolWrapper(Tool):
"""把单个 MCP server tool 包装成 nanobot Tool。"""
def __init__(
self,
session,
server_name: str,
tool_def,
*,
call_tool: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
tool_timeout: int = 30,
sensitive: bool = False,
):
self._session = session
self._call_tool = call_tool or self._default_call_tool
# 记录来源服务名,便于日志、事件流和最终导出的工具名保持可追踪。
self._server_name = server_name
self._original_name = tool_def.name
# 在 nanobot 内部为 MCP 工具统一加 `mcp_<server>_` 前缀,避免同名冲突。
self._name = f"mcp_{server_name}_{tool_def.name}"
self._description = tool_def.description or tool_def.name
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
self._tool_timeout = tool_timeout
self._sensitive = sensitive
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> dict[str, Any]:
return self._parameters
async def execute(self, **kwargs: Any) -> str:
from mcp import types
# 每次 MCP 调用都分配独立 run_id前端可以把它显示成树状子步骤。
run_id = new_run_id("mcp")
args_json = json.dumps(kwargs, ensure_ascii=False) if kwargs else "{}"
await emit_process_event(
"process_run_started",
run_id=run_id,
parent_run_id=current_process_run_id(),
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
title=f"{self._server_name}.{self._original_name}",
status="running",
metadata={
"tool_name": self._original_name,
"tool_args": None if self._sensitive else kwargs,
"tool_timeout": self._tool_timeout,
"sensitive": self._sensitive,
},
)
# 在真正请求远端前先发一条 progress方便 UI 及时显示“正在调用哪个工具”。
await emit_process_event(
"process_run_progress",
run_id=run_id,
parent_run_id=current_process_run_id(),
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
text=(
f"Calling {self._original_name}"
if self._sensitive
else f"Calling {self._original_name} with {args_json}"
),
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
)
try:
result = await asyncio.wait_for(
self._call_tool(self._original_name, kwargs),
timeout=self._tool_timeout,
)
except asyncio.TimeoutError:
# 超时被视为业务失败,但不抛异常给上层 agent 循环,而是返回可读错误文本。
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
summary = f"(MCP tool call timed out after {self._tool_timeout}s)"
await emit_process_event(
"process_run_status",
run_id=run_id,
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
status="error",
text=summary,
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
)
await emit_process_event(
"process_run_finished",
run_id=run_id,
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
status="error",
summary=summary,
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
)
return summary
# MCP SDK 返回的是结构化 content block 列表,这里统一摊平成文本。
parts = []
for block in result.content:
if isinstance(block, types.TextContent):
parts.append(block.text)
else:
parts.append(str(block))
output = "\n".join(parts) or "(no output)"
artifact_type = "text"
artifact_data: Any | None = None
stripped = output.strip()
# 如果看起来像 JSON则额外解析成结构化 artifact方便前端做更丰富展示。
if stripped.startswith("{") or stripped.startswith("["):
try:
artifact_data = json.loads(stripped)
artifact_type = "json"
except json.JSONDecodeError:
artifact_data = None
await emit_process_event(
"process_run_artifact",
run_id=run_id,
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
title=f"{self._server_name}.{self._original_name} result",
artifact_type="redacted" if self._sensitive else artifact_type,
content=None if self._sensitive or artifact_data is not None else output,
data=None if self._sensitive else artifact_data,
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
)
await emit_process_event(
"process_run_finished",
run_id=run_id,
actor_type="mcp",
actor_id=self._server_name,
actor_name=self._server_name,
status="done",
summary=(
f"{self._original_name} completed"
if self._sensitive
else output[:1000]
),
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
)
return output
async def _default_call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
return await self._session.call_tool(tool_name, arguments=arguments)
async def connect_mcp_servers(
mcp_servers: dict,
registry: ToolRegistry,
stack: AsyncExitStack,
*,
authz_config: Any | None = None,
backend_identity: Any | None = None,
) -> dict[str, dict[str, Any]]:
"""连接所有配置中的 MCP server并把工具注册到 registry。"""
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
from nanobot.authz.client import AuthzClient
async def _build_http_headers(server_name: str, cfg: Any) -> dict[str, str]:
headers = dict(getattr(cfg, "headers", {}) or {})
if getattr(cfg, "auth_mode", "none") != "oauth_backend_token":
return headers
if not (
authz_config
and getattr(authz_config, "base_url", "").strip()
and backend_identity
and getattr(backend_identity, "client_id", "").strip()
and getattr(backend_identity, "client_secret", "").strip()
):
raise RuntimeError(
f"MCP server '{server_name}' requires AuthZ backend token, but authz/backend identity is incomplete"
)
authz_client = AuthzClient(
getattr(authz_config, "base_url"),
timeout_seconds=int(getattr(authz_config, "request_timeout_seconds", 10)),
)
raw_audience = str(getattr(cfg, "auth_audience", "") or "").strip()
# Older managed Outlook configs stored `auth_audience="mcp"`, but AuthZ
# permissions are issued against `mcp:<server_id>`.
if not raw_audience or raw_audience == "mcp":
audience = f"mcp:{server_name}"
elif raw_audience.startswith("mcp:"):
audience = raw_audience
else:
audience = f"mcp:{raw_audience}"
token_response = await authz_client.issue_token(
client_id=getattr(backend_identity, "client_id"),
client_secret=getattr(backend_identity, "client_secret"),
audience=audience,
scopes=[str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
)
access_token = str(token_response.get("access_token") or "").strip()
if not access_token:
raise RuntimeError(f"MCP server '{server_name}' did not receive an access token from AuthZ")
headers["Authorization"] = f"Bearer {access_token}"
return headers
async def _open_http_session(
session_stack: AsyncExitStack,
cfg: Any,
*,
headers: dict[str, str],
):
http_client = await session_stack.enter_async_context(
httpx.AsyncClient(
headers=headers or None,
follow_redirects=True,
trust_env=False,
)
)
read, write, _ = await session_stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client)
)
session = await session_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
return session
async def _list_http_tools(server_name: str, cfg: Any):
async with AsyncExitStack() as session_stack:
headers = await _build_http_headers(server_name, cfg)
session = await _open_http_session(session_stack, cfg, headers=headers)
tools = await session.list_tools()
return tools.tools
def _make_http_call_tool(server_name: str, cfg: Any) -> Callable[[str, dict[str, Any]], Awaitable[Any]]:
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
async with AsyncExitStack() as session_stack:
headers = await _build_http_headers(server_name, cfg)
session = await _open_http_session(session_stack, cfg, headers=headers)
return await session.call_tool(tool_name, arguments=arguments)
return _call_tool
# `report` 会返回给调用方,用于 Web UI 展示连接状态和已发现工具。
report: dict[str, dict[str, Any]] = {}
for name, cfg in mcp_servers.items():
report[name] = {
"status": "disconnected",
"last_error": None,
"tool_names": [],
"tool_count": 0,
"transport": "stdio" if getattr(cfg, "command", "") else "http",
}
try:
if cfg.command:
# stdio 模式:本地拉起一个子进程,通过 stdin/stdout 与 MCP server 通信。
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await stack.enter_async_context(stdio_client(params))
session = await stack.enter_async_context(ClientSession(read, write))
await session.initialize()
tools = await session.list_tools()
for tool_def in tools.tools:
wrapper = MCPToolWrapper(
session,
name,
tool_def,
tool_timeout=cfg.tool_timeout,
sensitive=bool(getattr(cfg, "sensitive", False)),
)
registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
report[name]["tool_names"].append(wrapper.name)
elif cfg.url:
if getattr(cfg, "auth_mode", "none") == "oauth_backend_token":
tools_defs = await _list_http_tools(name, cfg)
call_tool = _make_http_call_tool(name, cfg)
for tool_def in tools_defs:
wrapper = MCPToolWrapper(
None,
name,
tool_def,
call_tool=call_tool,
tool_timeout=cfg.tool_timeout,
sensitive=bool(getattr(cfg, "sensitive", False)),
)
registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
report[name]["tool_names"].append(wrapper.name)
else:
headers = await _build_http_headers(name, cfg)
session = await _open_http_session(stack, cfg, headers=headers)
tools = await session.list_tools()
for tool_def in tools.tools:
wrapper = MCPToolWrapper(
session,
name,
tool_def,
tool_timeout=cfg.tool_timeout,
sensitive=bool(getattr(cfg, "sensitive", False)),
)
registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
report[name]["tool_names"].append(wrapper.name)
else:
# 没有 command 也没有 url 的条目视为无效配置,跳过但不抛异常。
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
report[name]["tool_count"] = len(report[name]["tool_names"])
report[name]["status"] = "connected"
logger.info(
"MCP server '{}': connected, {} tools registered",
name,
len(report[name]["tool_names"]),
)
except Exception as e:
# 单个 server 失败不影响其他 server 继续连;错误写进 report 供 UI 展示。
error_detail = _describe_mcp_exception(
e,
server_name=name,
url=str(getattr(cfg, "url", "") or "").strip() or None,
)
report[name]["status"] = "error"
report[name]["last_error"] = error_detail
logger.error("MCP server '{}': failed to connect: {}", name, error_detail)
return report