第一次提交
This commit is contained in:
346
app-instance/backend/nanobot/agent/tools/mcp.py
Normal file
346
app-instance/backend/nanobot/agent/tools/mcp.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""MCP 客户端封装。
|
||||
|
||||
职责分两层:
|
||||
1. `connect_mcp_servers()` 负责建立与 MCP server 的连接,并把远端工具注册成 nanobot 本地工具;
|
||||
2. `MCPToolWrapper` 负责把单个远端 MCP tool 包装成可供 LLM 调用的 `Tool`,同时发出结构化过程事件。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.process_events import current_process_run_id, emit_process_event, new_run_id
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""把单个 MCP server tool 包装成 nanobot Tool。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
server_name: str,
|
||||
tool_def,
|
||||
*,
|
||||
call_tool: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
tool_timeout: int = 30,
|
||||
sensitive: bool = False,
|
||||
):
|
||||
self._session = session
|
||||
self._call_tool = call_tool or self._default_call_tool
|
||||
# 记录来源服务名,便于日志、事件流和最终导出的工具名保持可追踪。
|
||||
self._server_name = server_name
|
||||
self._original_name = tool_def.name
|
||||
# 在 nanobot 内部为 MCP 工具统一加 `mcp_<server>_` 前缀,避免同名冲突。
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._tool_timeout = tool_timeout
|
||||
self._sensitive = sensitive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
# 每次 MCP 调用都分配独立 run_id,前端可以把它显示成树状子步骤。
|
||||
run_id = new_run_id("mcp")
|
||||
args_json = json.dumps(kwargs, ensure_ascii=False) if kwargs else "{}"
|
||||
await emit_process_event(
|
||||
"process_run_started",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name}",
|
||||
status="running",
|
||||
metadata={
|
||||
"tool_name": self._original_name,
|
||||
"tool_args": None if self._sensitive else kwargs,
|
||||
"tool_timeout": self._tool_timeout,
|
||||
"sensitive": self._sensitive,
|
||||
},
|
||||
)
|
||||
# 在真正请求远端前先发一条 progress,方便 UI 及时显示“正在调用哪个工具”。
|
||||
await emit_process_event(
|
||||
"process_run_progress",
|
||||
run_id=run_id,
|
||||
parent_run_id=current_process_run_id(),
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
text=(
|
||||
f"Calling {self._original_name}"
|
||||
if self._sensitive
|
||||
else f"Calling {self._original_name} with {args_json}"
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._call_tool(self._original_name, kwargs),
|
||||
timeout=self._tool_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时被视为业务失败,但不抛异常给上层 agent 循环,而是返回可读错误文本。
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
summary = f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
await emit_process_event(
|
||||
"process_run_status",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
text=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="error",
|
||||
summary=summary,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return summary
|
||||
|
||||
# MCP SDK 返回的是结构化 content block 列表,这里统一摊平成文本。
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
output = "\n".join(parts) or "(no output)"
|
||||
artifact_type = "text"
|
||||
artifact_data: Any | None = None
|
||||
stripped = output.strip()
|
||||
# 如果看起来像 JSON,则额外解析成结构化 artifact,方便前端做更丰富展示。
|
||||
if stripped.startswith("{") or stripped.startswith("["):
|
||||
try:
|
||||
artifact_data = json.loads(stripped)
|
||||
artifact_type = "json"
|
||||
except json.JSONDecodeError:
|
||||
artifact_data = None
|
||||
await emit_process_event(
|
||||
"process_run_artifact",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
title=f"{self._server_name}.{self._original_name} result",
|
||||
artifact_type="redacted" if self._sensitive else artifact_type,
|
||||
content=None if self._sensitive or artifact_data is not None else output,
|
||||
data=None if self._sensitive else artifact_data,
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
await emit_process_event(
|
||||
"process_run_finished",
|
||||
run_id=run_id,
|
||||
actor_type="mcp",
|
||||
actor_id=self._server_name,
|
||||
actor_name=self._server_name,
|
||||
status="done",
|
||||
summary=(
|
||||
f"{self._original_name} completed"
|
||||
if self._sensitive
|
||||
else output[:1000]
|
||||
),
|
||||
metadata={"tool_name": self._original_name, "sensitive": self._sensitive},
|
||||
)
|
||||
return output
|
||||
|
||||
async def _default_call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
return await self._session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict,
|
||||
registry: ToolRegistry,
|
||||
stack: AsyncExitStack,
|
||||
*,
|
||||
authz_config: Any | None = None,
|
||||
backend_identity: Any | None = None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""连接所有配置中的 MCP server,并把工具注册到 registry。"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from nanobot.authz.client import AuthzClient
|
||||
|
||||
async def _build_http_headers(server_name: str, cfg: Any) -> dict[str, str]:
|
||||
headers = dict(getattr(cfg, "headers", {}) or {})
|
||||
if getattr(cfg, "auth_mode", "none") != "oauth_backend_token":
|
||||
return headers
|
||||
|
||||
if not (
|
||||
authz_config
|
||||
and getattr(authz_config, "base_url", "").strip()
|
||||
and backend_identity
|
||||
and getattr(backend_identity, "client_id", "").strip()
|
||||
and getattr(backend_identity, "client_secret", "").strip()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"MCP server '{server_name}' requires AuthZ backend token, but authz/backend identity is incomplete"
|
||||
)
|
||||
|
||||
authz_client = AuthzClient(
|
||||
getattr(authz_config, "base_url"),
|
||||
timeout_seconds=int(getattr(authz_config, "request_timeout_seconds", 10)),
|
||||
)
|
||||
raw_audience = str(getattr(cfg, "auth_audience", "") or "").strip()
|
||||
# Older managed Outlook configs stored `auth_audience="mcp"`, but AuthZ
|
||||
# permissions are issued against `mcp:<server_id>`.
|
||||
if not raw_audience or raw_audience == "mcp":
|
||||
audience = f"mcp:{server_name}"
|
||||
elif raw_audience.startswith("mcp:"):
|
||||
audience = raw_audience
|
||||
else:
|
||||
audience = f"mcp:{raw_audience}"
|
||||
token_response = await authz_client.issue_token(
|
||||
client_id=getattr(backend_identity, "client_id"),
|
||||
client_secret=getattr(backend_identity, "client_secret"),
|
||||
audience=audience,
|
||||
scopes=[str(item) for item in list(getattr(cfg, "auth_scopes", []) or [])],
|
||||
)
|
||||
access_token = str(token_response.get("access_token") or "").strip()
|
||||
if not access_token:
|
||||
raise RuntimeError(f"MCP server '{server_name}' did not receive an access token from AuthZ")
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
async def _open_http_session(
|
||||
session_stack: AsyncExitStack,
|
||||
cfg: Any,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
):
|
||||
http_client = await session_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=headers or None,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
)
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
return session
|
||||
|
||||
async def _list_http_tools(server_name: str, cfg: Any):
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
return tools.tools
|
||||
|
||||
def _make_http_call_tool(server_name: str, cfg: Any) -> Callable[[str, dict[str, Any]], Awaitable[Any]]:
|
||||
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async with AsyncExitStack() as session_stack:
|
||||
headers = await _build_http_headers(server_name, cfg)
|
||||
session = await _open_http_session(session_stack, cfg, headers=headers)
|
||||
return await session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
return _call_tool
|
||||
|
||||
# `report` 会返回给调用方,用于 Web UI 展示连接状态和已发现工具。
|
||||
report: dict[str, dict[str, Any]] = {}
|
||||
for name, cfg in mcp_servers.items():
|
||||
report[name] = {
|
||||
"status": "disconnected",
|
||||
"last_error": None,
|
||||
"tool_names": [],
|
||||
"tool_count": 0,
|
||||
"transport": "stdio" if getattr(cfg, "command", "") else "http",
|
||||
}
|
||||
try:
|
||||
if cfg.command:
|
||||
# stdio 模式:本地拉起一个子进程,通过 stdin/stdout 与 MCP server 通信。
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
elif cfg.url:
|
||||
if getattr(cfg, "auth_mode", "none") == "oauth_backend_token":
|
||||
tools_defs = await _list_http_tools(name, cfg)
|
||||
call_tool = _make_http_call_tool(name, cfg)
|
||||
for tool_def in tools_defs:
|
||||
wrapper = MCPToolWrapper(
|
||||
None,
|
||||
name,
|
||||
tool_def,
|
||||
call_tool=call_tool,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
headers = await _build_http_headers(name, cfg)
|
||||
session = await _open_http_session(stack, cfg, headers=headers)
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(
|
||||
session,
|
||||
name,
|
||||
tool_def,
|
||||
tool_timeout=cfg.tool_timeout,
|
||||
sensitive=bool(getattr(cfg, "sensitive", False)),
|
||||
)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
report[name]["tool_names"].append(wrapper.name)
|
||||
else:
|
||||
# 没有 command 也没有 url 的条目视为无效配置,跳过但不抛异常。
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
|
||||
report[name]["tool_count"] = len(report[name]["tool_names"])
|
||||
report[name]["status"] = "connected"
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} tools registered",
|
||||
name,
|
||||
len(report[name]["tool_names"]),
|
||||
)
|
||||
except Exception as e:
|
||||
# 单个 server 失败不影响其他 server 继续连;错误写进 report 供 UI 展示。
|
||||
report[name]["status"] = "error"
|
||||
report[name]["last_error"] = str(e)
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
return report
|
||||
Reference in New Issue
Block a user