- 集成MCP连接管理器,支持MCP服务器连接 - 添加多种内置工具:ClarifyTool、CronTool、DelegateTool、ExecuteCodeTool、 PatchFileTool、ProcessTool、SendMessageTool、SpawnTool、TerminalTool、 TodoTool、WebFetchTool、WebSearchTool、WriteFileTool等 - 实现工具注册和装配功能 - 添加技能选择上下文参数 - 支持思考模式控制参数thinking_enabled feat(coordinator): 重构任务执行计划器参数命名 - 将learning_candidate_enabled重命名为allow_candidate_generation - 更新TeamGraphScheduler中的参数传递 - 修改LocalAgentRunner中的相关参数处理 - 更新README文档中的相应描述 refactor(context): 标准化工具调用参数格式 - 添加_json导入用于参数序列化 - 实现_provider_tool_calls方法标准化OpenAI兼容的工具调用载荷 - 修复工具调用中参数非字符串类型的序列化问题 refactor(session): 优化消息历史记录过滤逻辑 - 修改get_messages_as_conversation为基于运行状态过滤消息 - 排除未完成、失败或错误结束的运行记录 - 改进对话历史的可见性控制机制 fix(store): 修复FTS索引重建逻辑 - 添加异常处理防止FTS索引创建失败 - 实现_rebuild_fts_index方法重新构建全文搜索索引 - 优化索引触发器和表的维护流程
193 lines
8.0 KiB
Python
193 lines
8.0 KiB
Python
"""MCP connection manager."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import AsyncExitStack
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from beaver.foundation.config import AuthzConfig, BackendIdentityConfig, MCPServerConfig
|
|
from beaver.integrations.authz import AuthzClient
|
|
from beaver.tools.mcp.wrapper import MCPToolWrapper
|
|
from beaver.tools.registry import ToolRegistry
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class MCPConnectionReport:
|
|
status: str = "disconnected"
|
|
last_error: str | None = None
|
|
tool_names: list[str] = field(default_factory=list)
|
|
tool_count: int = 0
|
|
transport: str = "http"
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"status": self.status,
|
|
"last_error": self.last_error,
|
|
"tool_names": list(self.tool_names),
|
|
"tool_count": self.tool_count,
|
|
"transport": self.transport,
|
|
}
|
|
|
|
|
|
class MCPConnectionManager:
|
|
def __init__(
|
|
self,
|
|
servers: dict[str, MCPServerConfig],
|
|
*,
|
|
authz_config: AuthzConfig | None = None,
|
|
backend_identity: BackendIdentityConfig | None = None,
|
|
) -> None:
|
|
self.servers = servers
|
|
self.authz_config = authz_config
|
|
self.backend_identity = backend_identity
|
|
self.stack = AsyncExitStack()
|
|
self.connected = False
|
|
self._connect_lock = asyncio.Lock()
|
|
self.report: dict[str, MCPConnectionReport] = {}
|
|
|
|
async def connect_all(self, registry: ToolRegistry) -> dict[str, dict[str, Any]]:
|
|
async with self._connect_lock:
|
|
if self.connected:
|
|
return {key: value.to_dict() for key, value in self.report.items()}
|
|
self.report = {}
|
|
for server_id, cfg in self.servers.items():
|
|
self.report[server_id] = MCPConnectionReport(transport=cfg.transport)
|
|
try:
|
|
if cfg.command:
|
|
await self._connect_stdio(server_id, cfg, registry)
|
|
elif cfg.url:
|
|
await self._connect_http(server_id, cfg, registry)
|
|
else:
|
|
raise ValueError("MCP server requires command or url")
|
|
self.report[server_id].status = "connected"
|
|
self.report[server_id].tool_count = len(self.report[server_id].tool_names)
|
|
except Exception as exc:
|
|
self.report[server_id].status = "error"
|
|
self.report[server_id].last_error = _describe_exception(exc, server_id=server_id, url=cfg.url or None)
|
|
self.connected = True
|
|
return {key: value.to_dict() for key, value in self.report.items()}
|
|
|
|
async def close(self) -> None:
|
|
await self.stack.aclose()
|
|
self.connected = False
|
|
|
|
async def _headers(self, server_id: str, cfg: MCPServerConfig) -> dict[str, str]:
|
|
headers = dict(cfg.headers or {})
|
|
if cfg.auth_mode != "oauth_backend_token":
|
|
return headers
|
|
if not (
|
|
self.authz_config
|
|
and self.authz_config.enabled
|
|
and self.authz_config.base_url
|
|
and self.backend_identity
|
|
and self.backend_identity.client_id
|
|
and self.backend_identity.client_secret
|
|
):
|
|
raise RuntimeError("oauth_backend_token requires AuthZ and backend identity")
|
|
audience = cfg.auth_audience or f"mcp:{server_id}"
|
|
client = AuthzClient(self.authz_config.base_url, timeout_seconds=self.authz_config.request_timeout_seconds)
|
|
token = await client.issue_token(
|
|
client_id=self.backend_identity.client_id,
|
|
client_secret=self.backend_identity.client_secret,
|
|
audience=audience,
|
|
scopes=list(cfg.auth_scopes),
|
|
)
|
|
access_token = str(token.get("access_token") or "").strip()
|
|
if not access_token:
|
|
raise RuntimeError("AuthZ did not return an access token")
|
|
headers["Authorization"] = f"Bearer {access_token}"
|
|
return headers
|
|
|
|
async def _open_http_session(self, cfg: MCPServerConfig, headers: dict[str, str]):
|
|
from mcp import ClientSession
|
|
from mcp.client.streamable_http import streamable_http_client
|
|
|
|
http_client = await self.stack.enter_async_context(
|
|
httpx.AsyncClient(headers=headers or None, follow_redirects=True, trust_env=False)
|
|
)
|
|
read, write, _ = await self.stack.enter_async_context(streamable_http_client(cfg.url, http_client=http_client))
|
|
session = await self.stack.enter_async_context(ClientSession(read, write))
|
|
await session.initialize()
|
|
return session
|
|
|
|
async def _connect_http(self, server_id: str, cfg: MCPServerConfig, registry: ToolRegistry) -> None:
|
|
headers = await self._headers(server_id, cfg)
|
|
session = await self._open_http_session(cfg, headers)
|
|
tools = await session.list_tools()
|
|
for tool_def in tools.tools:
|
|
async def call_tool(tool_name: str, args: dict[str, Any], *, _session=session) -> Any:
|
|
return await _session.call_tool(tool_name, arguments=args)
|
|
|
|
wrapper = MCPToolWrapper(
|
|
server_id,
|
|
tool_def,
|
|
call_tool,
|
|
cfg.tool_timeout,
|
|
cfg.sensitive,
|
|
cfg.kind,
|
|
cfg.category,
|
|
cfg.display_name,
|
|
)
|
|
registry.register(wrapper, replace=True)
|
|
if wrapper.spec.name not in self.report[server_id].tool_names:
|
|
self.report[server_id].tool_names.append(wrapper.spec.name)
|
|
|
|
async def _connect_stdio(self, server_id: str, cfg: MCPServerConfig, registry: ToolRegistry) -> None:
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
params = StdioServerParameters(command=cfg.command, args=list(cfg.args), env=dict(cfg.env) or None)
|
|
read, write = await self.stack.enter_async_context(stdio_client(params))
|
|
session = await self.stack.enter_async_context(ClientSession(read, write))
|
|
await session.initialize()
|
|
tools = await session.list_tools()
|
|
for tool_def in tools.tools:
|
|
async def call_tool(tool_name: str, args: dict[str, Any], *, _session=session) -> Any:
|
|
return await _session.call_tool(tool_name, arguments=args)
|
|
|
|
wrapper = MCPToolWrapper(
|
|
server_id,
|
|
tool_def,
|
|
call_tool,
|
|
cfg.tool_timeout,
|
|
cfg.sensitive,
|
|
cfg.kind,
|
|
cfg.category,
|
|
cfg.display_name,
|
|
)
|
|
registry.register(wrapper, replace=True)
|
|
if wrapper.spec.name not in self.report[server_id].tool_names:
|
|
self.report[server_id].tool_names.append(wrapper.spec.name)
|
|
|
|
|
|
async def test_mcp_server(
|
|
server_id: str,
|
|
cfg: MCPServerConfig,
|
|
*,
|
|
authz_config: AuthzConfig | None = None,
|
|
backend_identity: BackendIdentityConfig | None = None,
|
|
) -> dict[str, Any]:
|
|
registry = ToolRegistry()
|
|
manager = MCPConnectionManager({server_id: cfg}, authz_config=authz_config, backend_identity=backend_identity)
|
|
try:
|
|
report = await manager.connect_all(registry)
|
|
return {"ok": report.get(server_id, {}).get("status") == "connected", "server": server_id, **report.get(server_id, {})}
|
|
finally:
|
|
await manager.close()
|
|
|
|
|
|
def _describe_exception(exc: BaseException, *, server_id: str, url: str | None = None) -> str:
|
|
target = f" ({url})" if url else ""
|
|
if isinstance(exc, httpx.TimeoutException):
|
|
return f"MCP server '{server_id}' timed out{target}"
|
|
if isinstance(exc, httpx.ConnectError):
|
|
return f"MCP server '{server_id}' is unreachable{target}"
|
|
if isinstance(exc, httpx.HTTPStatusError):
|
|
return f"MCP server '{server_id}' returned HTTP {exc.response.status_code}{target}"
|
|
detail = str(exc).strip() or exc.__class__.__name__
|
|
return f"MCP server '{server_id}' failed{target}: {detail}"
|