383 lines
16 KiB
Python
383 lines
16 KiB
Python
"""MCP 客户端封装。
|
||
|
||
职责分两层:
|
||
1. `connect_mcp_servers()` 负责建立与 MCP server 的连接,并把远端工具注册成 nanobot 本地工具;
|
||
2. `MCPToolWrapper` 负责把单个远端 MCP tool 包装成可供 LLM 调用的 `Tool`,同时发出结构化过程事件。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
from collections.abc import Awaitable, Callable
|
||
from contextlib import AsyncExitStack
|
||
from typing import Any
|
||
|
||
import httpx
|
||
from loguru import logger
|
||
|
||
from nanobot.agent.process_events import current_process_run_id, emit_process_event, new_run_id
|
||
from nanobot.agent.tools.base import Tool
|
||
from nanobot.agent.tools.registry import ToolRegistry
|
||
|
||
|
||
def _iter_leaf_exceptions(exc: BaseException) -> list[BaseException]:
|
||
if isinstance(exc, BaseExceptionGroup):
|
||
leaves: list[BaseException] = []
|
||
for sub_exc in exc.exceptions:
|
||
leaves.extend(_iter_leaf_exceptions(sub_exc))
|
||
return leaves
|
||
return [exc]
|
||
|
||
|
||
def _describe_mcp_exception(exc: BaseException, *, server_name: str, url: str | None = None) -> str:
|
||
leaves = _iter_leaf_exceptions(exc)
|
||
target = f" ({url})" if url else ""
|
||
|
||
for leaf in leaves:
|
||
if isinstance(leaf, httpx.TimeoutException):
|
||
return f"MCP server '{server_name}' timed out while waiting for a response{target}"
|
||
if isinstance(leaf, httpx.ConnectError):
|
||
return f"MCP server '{server_name}' is unreachable{target}"
|
||
if isinstance(leaf, httpx.HTTPStatusError):
|
||
return f"MCP server '{server_name}' returned HTTP {leaf.response.status_code}{target}"
|
||
if isinstance(leaf, httpx.HTTPError):
|
||
detail = str(leaf).strip() or leaf.__class__.__name__
|
||
return f"MCP server '{server_name}' HTTP error{target}: {detail}"
|
||
|
||
detail_source = leaves[0] if leaves else exc
|
||
detail = str(detail_source).strip() or detail_source.__class__.__name__
|
||
if isinstance(exc, BaseExceptionGroup):
|
||
return f"MCP server '{server_name}' failed: {detail_source.__class__.__name__}: {detail}"
|
||
return detail
|
||
|
||
|
||
class MCPToolWrapper(Tool):
|
||
"""把单个 MCP server tool 包装成 nanobot Tool。"""
|
||
|
||
def __init__(
|
||
self,
|
||
session,
|
||
server_name: str,
|
||
tool_def,
|
||
*,
|
||
call_tool: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||
tool_timeout: int = 30,
|
||
sensitive: bool = False,
|
||
):
|
||
self._session = session
|
||
self._call_tool = call_tool or self._default_call_tool
|
||
# 记录来源服务名,便于日志、事件流和最终导出的工具名保持可追踪。
|
||
self._server_name = server_name
|
||
self._original_name = tool_def.name
|
||
# 在 nanobot 内部为 MCP 工具统一加 `mcp_<server>_` 前缀,避免同名冲突。
|
||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||
self._description = tool_def.description or tool_def.name
|
||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||
self._tool_timeout = tool_timeout
|
||
self._sensitive = sensitive
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def description(self) -> str:
|
||
return self._description
|
||
|
||
@property
|
||
def parameters(self) -> dict[str, Any]:
|
||
return self._parameters
|
||
|
||
async def execute(self, **kwargs: Any) -> str:
|
||
from mcp import types
|
||
# 每次 MCP 调用都分配独立 run_id,前端可以把它显示成树状子步骤。
|
||
run_id = new_run_id("mcp")
|
||
args_json = json.dumps(kwargs, ensure_ascii=False) if kwargs else "{}"
|
||
await emit_process_event(
|
||
"process_run_started",
|
||
run_id=run_id,
|
||
parent_run_id=current_process_run_id(),
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
title=f"{self._server_name}.{self._original_name}",
|
||
status="running",
|
||
metadata={
|
||
"tool_name": self._original_name,
|
||
"tool_args": None if self._sensitive else kwargs,
|
||
"tool_timeout": self._tool_timeout,
|
||
"sensitive": self._sensitive,
|
||
},
|
||
)
|
||
# 在真正请求远端前先发一条 progress,方便 UI 及时显示“正在调用哪个工具”。
|
||
await emit_process_event(
|
||
"process_run_progress",
|
||
run_id=run_id,
|
||
parent_run_id=current_process_run_id(),
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
text=(
|
||
f"Calling {self._original_name}"
|
||
if self._sensitive
|
||
else f"Calling {self._original_name} with {args_json}"
|
||
),
|
||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||
)
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._call_tool(self._original_name, kwargs),
|
||
timeout=self._tool_timeout,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
# 超时被视为业务失败,但不抛异常给上层 agent 循环,而是返回可读错误文本。
|
||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||
summary = f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||
await emit_process_event(
|
||
"process_run_status",
|
||
run_id=run_id,
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
status="error",
|
||
text=summary,
|
||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||
)
|
||
await emit_process_event(
|
||
"process_run_finished",
|
||
run_id=run_id,
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
status="error",
|
||
summary=summary,
|
||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||
)
|
||
return summary
|
||
|
||
# MCP SDK 返回的是结构化 content block 列表,这里统一摊平成文本。
|
||
parts = []
|
||
for block in result.content:
|
||
if isinstance(block, types.TextContent):
|
||
parts.append(block.text)
|
||
else:
|
||
parts.append(str(block))
|
||
output = "\n".join(parts) or "(no output)"
|
||
artifact_type = "text"
|
||
artifact_data: Any | None = None
|
||
stripped = output.strip()
|
||
# 如果看起来像 JSON,则额外解析成结构化 artifact,方便前端做更丰富展示。
|
||
if stripped.startswith("{") or stripped.startswith("["):
|
||
try:
|
||
artifact_data = json.loads(stripped)
|
||
artifact_type = "json"
|
||
except json.JSONDecodeError:
|
||
artifact_data = None
|
||
await emit_process_event(
|
||
"process_run_artifact",
|
||
run_id=run_id,
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
title=f"{self._server_name}.{self._original_name} result",
|
||
artifact_type="redacted" if self._sensitive else artifact_type,
|
||
content=None if self._sensitive or artifact_data is not None else output,
|
||
data=None if self._sensitive else artifact_data,
|
||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||
)
|
||
await emit_process_event(
|
||
"process_run_finished",
|
||
run_id=run_id,
|
||
actor_type="mcp",
|
||
actor_id=self._server_name,
|
||
actor_name=self._server_name,
|
||
status="done",
|
||
summary=(
|
||
f"{self._original_name} completed"
|
||
if self._sensitive
|
||
else output[:1000]
|
||
),
|
||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||
)
|
||
return output
|
||
|
||
async def _default_call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
return await self._session.call_tool(tool_name, arguments=arguments)
|
||
|
||
|
||
async def connect_mcp_servers(
|
||
mcp_servers: dict,
|
||
registry: ToolRegistry,
|
||
stack: AsyncExitStack,
|
||
*,
|
||
authz_config: Any | None = None,
|
||
backend_identity: Any | None = None,
|
||
) -> dict[str, dict[str, Any]]:
|
||
"""连接所有配置中的 MCP server,并把工具注册到 registry。"""
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.stdio import stdio_client
|
||
from mcp.client.streamable_http import streamable_http_client
|
||
from nanobot.authz.client import AuthzClient
|
||
|
||
async def _build_http_headers(server_name: str, cfg: Any) -> dict[str, str]:
|
||
headers = dict(getattr(cfg, "headers", {}) or {})
|
||
if getattr(cfg, "auth_mode", "none") != "oauth_backend_token":
|
||
return headers
|
||
|
||
if not (
|
||
authz_config
|
||
and getattr(authz_config, "base_url", "").strip()
|
||
and backend_identity
|
||
and getattr(backend_identity, "client_id", "").strip()
|
||
and getattr(backend_identity, "client_secret", "").strip()
|
||
):
|
||
raise RuntimeError(
|
||
f"MCP server '{server_name}' requires AuthZ backend token, but authz/backend identity is incomplete"
|
||
)
|
||
|
||
authz_client = AuthzClient(
|
||
getattr(authz_config, "base_url"),
|
||
timeout_seconds=int(getattr(authz_config, "request_timeout_seconds", 10)),
|
||
)
|
||
raw_audience = str(getattr(cfg, "auth_audience", "") or "").strip()
|
||
# Older managed Outlook configs stored `auth_audience="mcp"`, but AuthZ
|
||
# permissions are issued against `mcp:<server_id>`.
|
||
if not raw_audience or raw_audience == "mcp":
|
||
audience = f"mcp:{server_name}"
|
||
elif raw_audience.startswith("mcp:"):
|
||
audience = raw_audience
|
||
else:
|
||
audience = f"mcp:{raw_audience}"
|
||
token_response = await authz_client.issue_token(
|
||
client_id=getattr(backend_identity, "client_id"),
|
||
client_secret=getattr(backend_identity, "client_secret"),
|
||
audience=audience,
|
||
scopes=[str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
|
||
)
|
||
access_token = str(token_response.get("access_token") or "").strip()
|
||
if not access_token:
|
||
raise RuntimeError(f"MCP server '{server_name}' did not receive an access token from AuthZ")
|
||
headers["Authorization"] = f"Bearer {access_token}"
|
||
return headers
|
||
|
||
async def _open_http_session(
|
||
session_stack: AsyncExitStack,
|
||
cfg: Any,
|
||
*,
|
||
headers: dict[str, str],
|
||
):
|
||
http_client = await session_stack.enter_async_context(
|
||
httpx.AsyncClient(
|
||
headers=headers or None,
|
||
follow_redirects=True,
|
||
trust_env=False,
|
||
)
|
||
)
|
||
read, write, _ = await session_stack.enter_async_context(
|
||
streamable_http_client(cfg.url, http_client=http_client)
|
||
)
|
||
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||
await session.initialize()
|
||
return session
|
||
|
||
async def _list_http_tools(server_name: str, cfg: Any):
|
||
async with AsyncExitStack() as session_stack:
|
||
headers = await _build_http_headers(server_name, cfg)
|
||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||
tools = await session.list_tools()
|
||
return tools.tools
|
||
|
||
def _make_http_call_tool(server_name: str, cfg: Any) -> Callable[[str, dict[str, Any]], Awaitable[Any]]:
|
||
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
async with AsyncExitStack() as session_stack:
|
||
headers = await _build_http_headers(server_name, cfg)
|
||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||
return await session.call_tool(tool_name, arguments=arguments)
|
||
|
||
return _call_tool
|
||
|
||
# `report` 会返回给调用方,用于 Web UI 展示连接状态和已发现工具。
|
||
report: dict[str, dict[str, Any]] = {}
|
||
for name, cfg in mcp_servers.items():
|
||
report[name] = {
|
||
"status": "disconnected",
|
||
"last_error": None,
|
||
"tool_names": [],
|
||
"tool_count": 0,
|
||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||
}
|
||
try:
|
||
if cfg.command:
|
||
# stdio 模式:本地拉起一个子进程,通过 stdin/stdout 与 MCP server 通信。
|
||
params = StdioServerParameters(
|
||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||
)
|
||
read, write = await stack.enter_async_context(stdio_client(params))
|
||
session = await stack.enter_async_context(ClientSession(read, write))
|
||
await session.initialize()
|
||
tools = await session.list_tools()
|
||
for tool_def in tools.tools:
|
||
wrapper = MCPToolWrapper(
|
||
session,
|
||
name,
|
||
tool_def,
|
||
tool_timeout=cfg.tool_timeout,
|
||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||
)
|
||
registry.register(wrapper)
|
||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||
report[name]["tool_names"].append(wrapper.name)
|
||
elif cfg.url:
|
||
if getattr(cfg, "auth_mode", "none") == "oauth_backend_token":
|
||
tools_defs = await _list_http_tools(name, cfg)
|
||
call_tool = _make_http_call_tool(name, cfg)
|
||
for tool_def in tools_defs:
|
||
wrapper = MCPToolWrapper(
|
||
None,
|
||
name,
|
||
tool_def,
|
||
call_tool=call_tool,
|
||
tool_timeout=cfg.tool_timeout,
|
||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||
)
|
||
registry.register(wrapper)
|
||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||
report[name]["tool_names"].append(wrapper.name)
|
||
else:
|
||
headers = await _build_http_headers(name, cfg)
|
||
session = await _open_http_session(stack, cfg, headers=headers)
|
||
tools = await session.list_tools()
|
||
for tool_def in tools.tools:
|
||
wrapper = MCPToolWrapper(
|
||
session,
|
||
name,
|
||
tool_def,
|
||
tool_timeout=cfg.tool_timeout,
|
||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||
)
|
||
registry.register(wrapper)
|
||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||
report[name]["tool_names"].append(wrapper.name)
|
||
else:
|
||
# 没有 command 也没有 url 的条目视为无效配置,跳过但不抛异常。
|
||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||
continue
|
||
|
||
report[name]["tool_count"] = len(report[name]["tool_names"])
|
||
report[name]["status"] = "connected"
|
||
logger.info(
|
||
"MCP server '{}': connected, {} tools registered",
|
||
name,
|
||
len(report[name]["tool_names"]),
|
||
)
|
||
except Exception as e:
|
||
# 单个 server 失败不影响其他 server 继续连;错误写进 report 供 UI 展示。
|
||
error_detail = _describe_mcp_exception(
|
||
e,
|
||
server_name=name,
|
||
url=str(getattr(cfg, "url", "") or "").strip() or None,
|
||
)
|
||
report[name]["status"] = "error"
|
||
report[name]["last_error"] = error_detail
|
||
logger.error("MCP server '{}': failed to connect: {}", name, error_detail)
|
||
return report
|