Files
beaver_project/app-instance/backend/nanobot/a2a/client.py
steven_li cdfc222c9f feat: 添加swarms团队编排功能并优化agent委派系统
- 引入AgentTeamOrchestrator支持多agent协同任务执行
- 增加第三方swarms库依赖并配置git协议替换以改善包管理
- 扩展DelegationManager支持团队任务调度和进度跟踪
- 实现中文bigram分词算法提升中文任务检索准确性
- 调整A2AClient和DelegationManager超时时间从30秒增至600秒
- 优化AgentRunResult状态判断逻辑增加有意义摘要检测
- 修改Dockerfile配置npm仓库镜像地址和git协议映射
- 更新CLI命令行接口支持网关端口配置传递
- 调整提供者超时配置机制增强请求稳定性
- 移除过时的support_group字段简化agent描述符结构
- 增强错误处理和进度事件报告机制改进用户体验
2026-04-14 14:34:23 +08:00

1217 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""A2A 客户端实现。
目标不是完整覆盖所有厂商变体,而是提供一条足够稳的兼容链路:
1. 先拉 agent card解析可用端点和偏好传输
2. 优先尝试流式订阅,拿到实时进度;
3. 流式不可用或中断时,回退到轮询;
4. 同时兼容 JSON-RPC 和 HTTP+JSON 风格接口。
"""
from __future__ import annotations
import asyncio
import json
import os
import time
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from urllib.parse import urlparse, urlunparse
import httpx
from nanobot.agent.agent_registry import AgentDescriptor
from nanobot.agent.run_result import AgentRunResult, has_meaningful_summary
class A2AError(RuntimeError):
"""A2A 请求失败时抛出的统一异常。"""
class A2AUnsupportedMethodError(A2AError):
"""远端端点不支持某个方法时抛出的异常。"""
@dataclass
class A2AStreamEvent:
"""A2A 订阅流事件的归一化表示。"""
# 事件类型,例如 task / message / status-update / artifact-update。
kind: str
# 远端任务 ID一旦出现上层就可以登记用于取消或恢复订阅。
task_id: str | None = None
# 归一化后的状态文本。
status: str | None = None
# 适合展示给用户的增量文本。
text: str | None = None
# 是否已到达终态。
final: bool = False
# 原始事件体,便于调试和后续扩展。
raw: dict[str, Any] | None = None
@dataclass
class _StreamState:
"""流式任务状态累加器。"""
task_id: str | None = None
status: str = "working"
artifacts: dict[str, str] = field(default_factory=dict)
artifact_order: list[str] = field(default_factory=list)
messages: list[str] = field(default_factory=list)
status_messages: list[str] = field(default_factory=list)
latest_result: dict[str, Any] | None = None
final_seen: bool = False
def apply(self, result: dict[str, Any], client: A2AClient) -> A2AStreamEvent:
"""吸收一条原始结果并产出归一化流事件。"""
self.latest_result = result
kind = str(result.get("kind") or result.get("type") or "result").lower()
task_id = str(result.get("id") or result.get("taskId") or "").strip() or None
if task_id:
self.task_id = task_id
raw_status = result.get("status")
normalized_status = client._normalize_status(raw_status)
# 非 ok 状态会覆盖当前状态;否则在 task/status-update 终态时再更新。
if normalized_status not in {"", "ok"}:
self.status = normalized_status
elif kind in {"task", "status-update"} and client._is_terminal_status(raw_status):
self.status = client._normalize_status(raw_status)
text = ""
if kind == "artifact-update":
# artifact-update 需要增量拼接同一个 artifact 的文本内容。
text = self._apply_artifact_update(result, client)
elif kind == "status-update":
# 某些实现把状态消息放在 status 里,有些放在 message 里,这里都兜一遍。
text = client._extract_text(result.get("status")) or client._extract_text(
result.get("message")
)
self._append_unique(self.status_messages, text)
elif kind in {"message", "task"}:
self._apply_task_or_message(result, client)
text = client._extract_text(result)
else:
text = client._extract_text(result)
final = bool(result.get("final")) or client._is_terminal_status(raw_status)
if final:
self.final_seen = True
if self.status == "working":
# 即使没拿到更明确状态,也尽量用终态把 working 覆盖掉。
self.status = client._normalize_status(raw_status)
if text and kind not in {"artifact-update", "message", "task"}:
self._append_unique(self.messages, text)
return A2AStreamEvent(
kind=kind,
task_id=self.task_id,
status=self.status,
text=text or None,
final=final,
raw=result,
)
def build_summary(self, client: A2AClient) -> str:
"""按 artifact -> message -> status 的优先级生成最终摘要。"""
artifact_text = "\n".join(
self.artifacts[artifact_id]
for artifact_id in self.artifact_order
if self.artifacts.get(artifact_id)
).strip()
if artifact_text:
return artifact_text
message_text = "\n".join(text for text in self.messages if text).strip()
if message_text:
return message_text
status_text = "\n".join(text for text in self.status_messages if text).strip()
if status_text:
return status_text
if self.latest_result:
return client._extract_text(self.latest_result)
return ""
def _apply_artifact_update(self, result: dict[str, Any], client: A2AClient) -> str:
"""把一条 artifact-update 事件并入累积状态。"""
artifact = result.get("artifact")
if not isinstance(artifact, dict):
artifact = result
artifact_id = str(
artifact.get("artifactId")
or artifact.get("id")
or result.get("artifactId")
or f"artifact-{len(self.artifact_order) + 1}"
)
text = client._extract_text(artifact)
if not text:
return ""
if artifact_id not in self.artifacts:
self.artifacts[artifact_id] = ""
self.artifact_order.append(artifact_id)
# append=true 时做增量拼接,否则视为完整覆盖。
if result.get("append") or artifact.get("append"):
self.artifacts[artifact_id] += text
else:
self.artifacts[artifact_id] = text
return text
def _apply_task_or_message(self, result: dict[str, Any], client: A2AClient) -> None:
"""把 task/message 类型结果中的 artifact 和文本提取出来。"""
artifacts = result.get("artifacts")
if isinstance(artifacts, list):
for index, artifact in enumerate(artifacts):
if not isinstance(artifact, dict):
continue
artifact_id = str(
artifact.get("artifactId")
or artifact.get("id")
or f"artifact-{len(self.artifact_order) + index + 1}"
)
text = client._extract_text(artifact)
if not text:
continue
if artifact_id not in self.artifacts:
self.artifact_order.append(artifact_id)
self.artifacts[artifact_id] = text
text = client._extract_text(result)
self._append_unique(self.messages, text)
@staticmethod
def _append_unique(collection: list[str], text: str) -> None:
"""仅当文本与上一个不同才追加,避免流式重复刷屏。"""
if text and (not collection or collection[-1] != text):
collection.append(text)
@dataclass(frozen=True)
class _A2ATransportTarget:
"""解析后的远端传输目标。"""
mode: str
endpoint: str
class A2AClient:
"""支持 JSON-RPC 与 HTTP+JSON 回退链路的 A2A 客户端。"""
def __init__(
self,
timeout_seconds: int = 600,
poll_interval_seconds: int = 2,
card_cache_ttl_seconds: int = 300,
allowed_hosts: list[str] | None = None,
transport: httpx.AsyncBaseTransport | None = None,
authz_config: Any | None = None,
backend_identity: Any | None = None,
):
# 这些参数决定超时、轮询频率和安全边界。
self.timeout_seconds = timeout_seconds
self.poll_interval_seconds = poll_interval_seconds
self.card_cache_ttl_seconds = card_cache_ttl_seconds
self.allowed_hosts = {host.lower() for host in (allowed_hosts or []) if host}
self.transport = transport
self.authz_config = authz_config
self.backend_identity = backend_identity
self._card_cache: dict[str, tuple[float, dict[str, Any]]] = {}
async def run_task(
self,
agent: AgentDescriptor,
task: str,
label: str | None = None,
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None = None,
task_callback: Callable[[str], Awaitable[None]] | None = None,
prefer_streaming: bool = True,
) -> AgentRunResult:
"""执行一次远端 A2A 任务。"""
card = await self.fetch_agent_card(agent)
params = self._build_message_params(task, label)
targets = self._resolve_transport_targets(card, agent)
if not targets:
raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint")
last_unsupported: Exception | None = None
for target in targets:
try:
# 若 card 支持流式,则优先尝试流式以获取中间态。
if prefer_streaming and self._supports_streaming(card):
stream_result = await self._run_task_streaming(
target=target,
params=params,
agent=agent,
event_callback=event_callback,
task_callback=task_callback,
)
if stream_result is not None:
return stream_result
# 流式不可用时回退到普通 send再视状态决定是否轮询。
result = await self._send_task(target, params, agent)
if self._is_task_result(result) and not self._is_terminal_status(result.get("status")):
task_id = str(result.get("id") or result.get("taskId") or "").strip()
if task_id:
if task_callback:
await task_callback(task_id)
result = await self._poll_task(target, task_id, agent)
return self._build_run_result(agent, result)
except A2AUnsupportedMethodError as exc:
last_unsupported = exc
continue
if last_unsupported:
raise last_unsupported
raise A2AError(f"Agent '{agent.id}' does not expose a usable A2A endpoint")
async def fetch_agent_card(self, agent: AgentDescriptor) -> dict[str, Any]:
"""拉取远端 agent card并带本地 TTL 缓存。"""
_, card = await self.fetch_agent_card_with_url(agent)
return card
async def fetch_agent_card_with_url(self, agent: AgentDescriptor) -> tuple[str, dict[str, Any]]:
"""拉取远端 agent card并返回命中的 card URL。"""
urls = self._candidate_card_urls(agent)
last_error: Exception | None = None
for url in urls:
cache_key = url.lower()
cached = self._card_cache.get(cache_key)
if cached and cached[0] > time.monotonic():
return url, cached[1]
try:
card = await self._fetch_json(url, agent)
except Exception as exc:
last_error = exc
continue
if isinstance(card, dict):
self._card_cache[cache_key] = (
time.monotonic() + self.card_cache_ttl_seconds,
card,
)
return url, card
if last_error:
raise A2AError(f"Failed to fetch agent card for '{agent.id}': {last_error}")
raise A2AError(f"Failed to fetch agent card for '{agent.id}'")
def invalidate_card(self, agent: AgentDescriptor) -> None:
"""清空某个 agent 相关的 card 缓存。"""
for url in self._candidate_card_urls(agent):
self._card_cache.pop(url.lower(), None)
def _candidate_card_urls(self, agent: AgentDescriptor) -> list[str]:
"""根据 agent 配置推导一组候选 card URL。"""
urls: list[str] = []
if agent.card_url:
urls.append(agent.card_url)
base_url = str(agent.base_url or agent.endpoint or "").rstrip("/")
if base_url:
urls.extend(
[
f"{base_url}/.well-known/agent-card",
f"{base_url}/.well-known/agent-card.json",
f"{base_url}/.well-known/agent.json",
]
)
deduped: list[str] = []
seen: set[str] = set()
for url in urls:
normalized = url.strip()
if not normalized or normalized.lower() in seen:
continue
seen.add(normalized.lower())
deduped.append(normalized)
return deduped
async def _run_task_streaming(
self,
target: _A2ATransportTarget,
params: dict[str, Any],
agent: AgentDescriptor,
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
task_callback: Callable[[str], Awaitable[None]] | None,
) -> AgentRunResult | None:
"""优先尝试流式方法执行任务,失败时可回退为 None。"""
if target.mode == "rest":
stream_variants = [("message/stream", params)]
else:
stream_variants = [
("tasks/sendSubscribe", {"id": str(uuid.uuid4()), **params}),
("message/stream", params),
]
saw_supported_stream = False
last_error: Exception | None = None
for method, payload in stream_variants:
state = _StreamState()
try:
# 每个流式方法都独立尝试;一旦成功拿到终态结果就直接返回。
return await self._consume_stream_method(
target=target,
method=method,
params=payload,
agent=agent,
event_callback=event_callback,
task_callback=task_callback,
state=state,
allow_resume=True,
)
except A2AUnsupportedMethodError as exc:
last_error = exc
continue
except A2AError as exc:
# 已经跑到一半但中断时,如果拿到了 task_id就尝试恢复订阅或轮询。
saw_supported_stream = True
last_error = exc
if state.task_id:
try:
return await self._resume_or_poll(
target=target,
agent=agent,
task_id=state.task_id,
state=state,
event_callback=event_callback,
task_callback=task_callback,
)
except A2AError as resume_exc:
last_error = resume_exc
continue
else:
saw_supported_stream = True
if saw_supported_stream and last_error:
raise last_error
return None
async def _consume_stream_method(
self,
target: _A2ATransportTarget,
method: str,
params: dict[str, Any],
agent: AgentDescriptor,
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
task_callback: Callable[[str], Awaitable[None]] | None,
state: _StreamState,
allow_resume: bool,
) -> AgentRunResult:
"""消费一个具体流式方法,直到拿到终态结果。"""
saw_event = False
seen_task_id: str | None = state.task_id
try:
async for body in self._stream_request(target, method, params, agent):
saw_event = True
result = self._unwrap_result_object(body)
event = state.apply(result, self)
# 首次看到 task_id 时通知上层登记,以便取消或恢复。
if task_callback and event.task_id and event.task_id != seen_task_id:
seen_task_id = event.task_id
await task_callback(event.task_id)
if event_callback:
await event_callback(event)
if event.final:
return self._build_run_result(agent, result, state)
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.TimeoutException) as exc:
if not state.task_id:
raise A2AError(str(exc)) from exc
if state.final_seen:
return self._build_run_result(agent, state.latest_result or {}, state)
if allow_resume and state.task_id:
# 流结束但还没终态时,尝试恢复订阅或退化成轮询。
return await self._resume_or_poll(
target=target,
agent=agent,
task_id=state.task_id,
state=state,
event_callback=event_callback,
task_callback=task_callback,
)
if saw_event:
raise A2AError("A2A stream ended before a final result was received")
raise A2AUnsupportedMethodError(method)
async def _resume_or_poll(
self,
target: _A2ATransportTarget,
agent: AgentDescriptor,
task_id: str,
state: _StreamState,
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
task_callback: Callable[[str], Awaitable[None]] | None,
) -> AgentRunResult:
"""在流式中断后尝试恢复订阅,失败则退化为轮询。"""
try:
return await self._consume_stream_method(
target=target,
method="tasks/subscribe" if target.mode == "rest" else "tasks/resubscribe",
params={"id": task_id},
agent=agent,
event_callback=event_callback,
task_callback=task_callback,
state=state,
allow_resume=False,
)
except A2AUnsupportedMethodError:
result = await self._poll_task(target, task_id, agent)
return self._build_run_result(agent, result, state)
async def cancel_task(self, agent: AgentDescriptor, task_id: str) -> bool:
"""尽力取消一个远端 A2A 任务。"""
task_id = task_id.strip()
if not task_id:
return False
card = await self.fetch_agent_card(agent)
targets = self._resolve_transport_targets(card, agent)
if not targets:
raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint")
for target in targets:
try:
if target.mode == "rest":
# REST 风格通常使用 `/tasks/{id}:cancel`。
await self._request_json(
"POST",
self._rest_endpoint(target.endpoint, f"/tasks/{task_id}:cancel"),
agent,
json_body={"name": f"tasks/{task_id}"},
)
else:
# JSON-RPC 风格使用 `tasks/cancel`。
await self._rpc_jsonrpc(target.endpoint, "tasks/cancel", {"id": task_id}, agent)
return True
except A2AUnsupportedMethodError:
continue
return False
async def _send_task(
self,
target: _A2ATransportTarget,
params: dict[str, Any],
agent: AgentDescriptor,
) -> dict[str, Any]:
"""发送一次非流式任务请求。"""
if target.mode == "rest":
body = await self._request_json(
"POST",
self._rest_endpoint(target.endpoint, "/message:send"),
agent,
json_body=self._build_rest_payload(params),
)
return self._unwrap_result_object(body)
send_variants = [
("tasks/send", {"id": str(uuid.uuid4()), **params}),
("message/send", params),
]
last_error: Exception | None = None
for method, payload in send_variants:
try:
response = await self._rpc_jsonrpc(target.endpoint, method, payload, agent)
return self._unwrap_result_object(response)
except A2AUnsupportedMethodError as exc:
last_error = exc
continue
raise last_error or A2AError("No supported A2A send method found")
async def _poll_task(
self,
target: _A2ATransportTarget,
task_id: str,
agent: AgentDescriptor,
) -> dict[str, Any]:
"""轮询远端 task直到进入终态或超时。"""
deadline = time.monotonic() + self.timeout_seconds
while time.monotonic() < deadline:
if target.mode == "rest":
body = await self._request_json(
"GET",
self._rest_endpoint(target.endpoint, f"/tasks/{task_id}"),
agent,
)
result = self._unwrap_result_object(body)
else:
response = await self._rpc_jsonrpc(target.endpoint, "tasks/get", {"id": task_id}, agent)
result = self._unwrap_result_object(response)
if self._is_terminal_status(result.get("status")):
return result
await asyncio.sleep(self.poll_interval_seconds)
raise A2AError(
f"A2A task '{task_id}' timed out after {self.timeout_seconds} seconds"
)
async def _fetch_json(self, url: str, agent: AgentDescriptor) -> dict[str, Any]:
"""以 GET 方式拉取 JSON 对象。"""
self._check_allowed_host(url)
async with httpx.AsyncClient(
timeout=self.timeout_seconds,
transport=self.transport,
) as client:
response = await client.get(url, headers=await self._build_headers(agent))
response.raise_for_status()
payload = response.json()
if not isinstance(payload, dict):
raise A2AError("Agent card response must be a JSON object")
return payload
async def _rpc_jsonrpc(
self,
endpoint: str,
method: str,
params: dict[str, Any],
agent: AgentDescriptor,
) -> dict[str, Any]:
"""发送一条 JSON-RPC 请求。"""
self._check_allowed_host(endpoint)
payload = {
"jsonrpc": "2.0",
"id": str(uuid.uuid4()),
"method": method,
"params": params,
}
async with httpx.AsyncClient(
timeout=self.timeout_seconds,
transport=self.transport,
) as client:
try:
response = await client.post(
endpoint,
json=payload,
headers=await self._build_headers(agent),
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if exc.response.status_code in {404, 405, 501}:
raise A2AUnsupportedMethodError(method) from exc
raise A2AError(str(exc)) from exc
body = response.json()
if not isinstance(body, dict):
raise A2AError("A2A RPC response must be a JSON object")
error = body.get("error")
if isinstance(error, dict):
code = error.get("code")
message = str(error.get("message") or "unknown error")
if code == -32601 or "not found" in message.lower():
raise A2AUnsupportedMethodError(message)
raise A2AError(message)
return body
async def _stream_request(
self,
target: _A2ATransportTarget,
method: str,
params: dict[str, Any],
agent: AgentDescriptor,
):
"""根据 transport mode 选择具体流式实现。"""
if target.mode == "rest":
async for body in self._stream_rest(target.endpoint, method, params, agent):
yield body
return
async for body in self._stream_jsonrpc(target.endpoint, method, params, agent):
yield body
async def _stream_jsonrpc(
self,
endpoint: str,
method: str,
params: dict[str, Any],
agent: AgentDescriptor,
):
"""通过 JSON-RPC 流式接口接收事件。"""
self._check_allowed_host(endpoint)
payload = {
"jsonrpc": "2.0",
"id": str(uuid.uuid4()),
"method": method,
"params": params,
}
async with httpx.AsyncClient(
timeout=self.timeout_seconds,
transport=self.transport,
) as client:
try:
async with client.stream(
"POST",
endpoint,
json=payload,
headers=await self._build_headers(agent),
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if exc.response.status_code in {404, 405, 501}:
raise A2AUnsupportedMethodError(method) from exc
raise A2AError(str(exc)) from exc
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" in content_type:
# 标准 SSE 按 event/data 行拼装;这里只消费 data 载荷。
async for raw_event in self._iter_sse_events(response):
if raw_event.strip() == "[DONE]":
break
yield self._parse_stream_body(raw_event)
else:
body = await response.aread()
if not body:
return
yield self._parse_stream_body(body.decode("utf-8"))
except httpx.HTTPStatusError as exc:
if exc.response.status_code in {404, 405, 501}:
raise A2AUnsupportedMethodError(method) from exc
raise A2AError(str(exc)) from exc
async def _stream_rest(
self,
endpoint: str,
method: str,
params: dict[str, Any],
agent: AgentDescriptor,
):
"""通过 REST 风格流式接口接收事件。"""
if method == "message/stream":
requests = [
(
"POST",
self._rest_endpoint(endpoint, "/message:stream"),
self._build_rest_payload(params),
)
]
elif method == "tasks/subscribe":
task_id = str(params.get("id") or "").strip()
if not task_id:
raise A2AError("Missing task id for REST task subscribe")
subscribe_url = self._rest_endpoint(endpoint, f"/tasks/{task_id}:subscribe")
requests = [("GET", subscribe_url, None), ("POST", subscribe_url, None)]
else:
raise A2AUnsupportedMethodError(method)
last_error: Exception | None = None
for http_method, url, payload in requests:
self._check_allowed_host(url)
async with httpx.AsyncClient(
timeout=self.timeout_seconds,
transport=self.transport,
) as client:
try:
async with client.stream(
http_method,
url,
json=payload,
headers=await self._build_headers(agent),
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if self._is_unsupported_http_error(exc, allow_validation_errors=True):
raise A2AUnsupportedMethodError(method) from exc
raise A2AError(self._http_error_message(exc)) from exc
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" in content_type:
async for raw_event in self._iter_sse_events(response):
if raw_event.strip() == "[DONE]":
break
yield self._parse_stream_body(raw_event)
return
body = await response.aread()
if not body:
return
yield self._parse_stream_body(body.decode("utf-8"))
return
except A2AUnsupportedMethodError as exc:
last_error = exc
continue
raise last_error or A2AUnsupportedMethodError(method)
async def _request_json(
self,
http_method: str,
url: str,
agent: AgentDescriptor,
*,
json_body: dict[str, Any] | None = None,
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""发送普通 HTTP JSON 请求。"""
self._check_allowed_host(url)
async with httpx.AsyncClient(
timeout=self.timeout_seconds,
transport=self.transport,
) as client:
try:
response = await client.request(
http_method,
url,
json=json_body,
params=params,
headers=await self._build_headers(agent),
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if self._is_unsupported_http_error(exc):
raise A2AUnsupportedMethodError(url) from exc
raise A2AError(self._http_error_message(exc)) from exc
try:
body = response.json()
except json.JSONDecodeError as exc:
raise A2AError(f"Invalid JSON response from {url}") from exc
if not isinstance(body, dict):
raise A2AError("A2A response must be a JSON object")
return body
async def _iter_sse_events(self, response: httpx.Response):
"""把 SSE 响应流按事件边界还原为 data 文本块。"""
data_lines: list[str] = []
async for line in response.aiter_lines():
if line == "":
if data_lines:
yield "\n".join(data_lines)
data_lines = []
continue
if line.startswith(":"):
continue
field, _, value = line.partition(":")
value = value.lstrip()
if field == "data":
data_lines.append(value)
if data_lines:
yield "\n".join(data_lines)
@staticmethod
def _parse_stream_body(raw_event: str) -> dict[str, Any]:
"""解析单条流事件 JSON并统一处理远端 error 对象。"""
try:
body = json.loads(raw_event)
except json.JSONDecodeError as exc:
raise A2AError(f"Invalid A2A stream payload: {raw_event}") from exc
if not isinstance(body, dict):
raise A2AError("A2A stream payload must be a JSON object")
error = body.get("error")
if isinstance(error, dict):
code = error.get("code")
message = str(error.get("message") or "unknown error")
if code == -32601 or "not found" in message.lower():
raise A2AUnsupportedMethodError(message)
raise A2AError(message)
return body
def _resolve_transport_targets(
self,
card: dict[str, Any],
agent: AgentDescriptor,
) -> list[_A2ATransportTarget]:
"""根据 card 和本地配置解析一组可尝试的传输目标。"""
default_url = self._resolve_primary_url(card, agent)
declared = self._collect_declared_interfaces(card)
preferred = self._normalize_transport(
card.get("preferred_transport") or card.get("preferredTransport")
)
candidates: list[_A2ATransportTarget] = []
if preferred in {"jsonrpc", "rest"}:
preferred_url = declared.get(preferred) or default_url
if preferred_url:
candidates.append(self._transport_target(preferred, preferred_url))
for mode in ("jsonrpc", "rest"):
url = declared.get(mode)
if url:
candidates.append(self._transport_target(mode, url))
if default_url:
if preferred not in {"jsonrpc", "rest"}:
candidates.append(self._transport_target("jsonrpc", default_url))
candidates.append(self._transport_target("rest", default_url))
elif preferred == "jsonrpc":
candidates.append(self._transport_target("rest", default_url))
else:
candidates.append(self._transport_target("jsonrpc", default_url))
deduped: list[_A2ATransportTarget] = []
seen: set[tuple[str, str]] = set()
for target in candidates:
# 同一个 mode + endpoint 只保留一次,避免重复尝试。
key = (target.mode, target.endpoint.rstrip("/").lower())
if key in seen:
continue
seen.add(key)
deduped.append(target)
return deduped
def _collect_declared_interfaces(self, card: dict[str, Any]) -> dict[str, str]:
"""从 card 的 interfaces 字段里提取声明过的 transport/url。"""
interfaces = None
for key in (
"additional_interfaces",
"additionalInterfaces",
"interfaces",
"supported_interfaces",
"supportedInterfaces",
):
candidate = card.get(key)
if isinstance(candidate, list):
interfaces = candidate
break
result: dict[str, str] = {}
if not isinstance(interfaces, list):
return result
for item in interfaces:
if not isinstance(item, dict):
continue
mode = self._normalize_transport(item.get("transport"))
url = str(item.get("url") or "").strip()
if mode in {"jsonrpc", "rest"} and url:
result.setdefault(mode, url)
return result
def _resolve_primary_url(self, card: dict[str, Any], agent: AgentDescriptor) -> str:
"""解析 card 的主 URL当 card 返回 0.0.0.0 时退回本地配置。"""
card_url = str(card.get("url") or "").strip()
fallback = str(agent.endpoint or agent.base_url or "").strip()
if card_url and (urlparse(card_url).hostname or "").strip() not in {"0.0.0.0", "::"}:
return card_url
return fallback or card_url
def _transport_target(self, mode: str, url: str) -> _A2ATransportTarget:
"""构造标准化的 transport target。"""
normalized_url = url.strip()
if mode == "rest":
normalized_url = self._normalize_rest_base_url(normalized_url)
return _A2ATransportTarget(mode=mode, endpoint=normalized_url)
@staticmethod
def _normalize_transport(value: Any) -> str | None:
"""把不同命名风格的 transport 文本归一化。"""
text = str(value or "").strip().lower()
if not text:
return None
if text in {"jsonrpc", "json-rpc"}:
return "jsonrpc"
if text in {"http+json", "http-json", "http_json", "rest"}:
return "rest"
if text == "grpc":
return "grpc"
return None
@staticmethod
def _normalize_rest_base_url(url: str) -> str:
"""把各种 REST 端点变体规整到 `/v1` 根路径。"""
parsed = urlparse(url)
path = parsed.path.rstrip("/")
for suffix in ("/message:send", "/message:stream"):
if path.endswith(suffix):
path = path[: -len(suffix)]
break
if "/tasks/" in path and (path.endswith(":cancel") or path.endswith(":subscribe")):
path = path.split("/tasks/", 1)[0]
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path, params="", query="", fragment="")).rstrip("/")
@staticmethod
def _rest_endpoint(base_url: str, route: str) -> str:
"""基于 REST 根路径拼接具体路由。"""
return f"{base_url.rstrip('/')}{route}"
@staticmethod
def _supports_streaming(card: dict[str, Any]) -> bool:
"""根据 card capability 判断是否支持流式。"""
capabilities = card.get("capabilities")
if not isinstance(capabilities, dict) or "streaming" not in capabilities:
return True
streaming = capabilities.get("streaming")
if isinstance(streaming, dict):
for key in ("enabled", "supported"):
if key in streaming:
return bool(streaming.get(key))
return True
return bool(streaming)
@classmethod
def _unwrap_result_object(cls, payload: dict[str, Any]) -> dict[str, Any]:
"""从不同协议变体里提取真正的结果对象。"""
candidate: Any = payload
if isinstance(candidate, dict) and isinstance(candidate.get("result"), dict):
candidate = candidate["result"]
if not isinstance(candidate, dict):
raise A2AError("Malformed A2A response")
for key, kind in (
("task", "task"),
("message", "message"),
("statusUpdate", "status-update"),
("status_update", "status-update"),
("artifactUpdate", "artifact-update"),
("artifact_update", "artifact-update"),
):
value = candidate.get(key)
if isinstance(value, dict):
result = dict(value)
result.setdefault("kind", kind)
return result
return candidate
@staticmethod
def _is_unsupported_http_error(
exc: httpx.HTTPStatusError,
*,
allow_validation_errors: bool = False,
) -> bool:
"""判断 HTTP 错误是否应被解释为“方法不支持”。"""
status_code = exc.response.status_code
if status_code in {404, 405, 501}:
return True
if allow_validation_errors and status_code in {400, 422}:
message = exc.response.text.lower()
return "not supported" in message or "unsupported" in message
return False
@staticmethod
def _http_error_message(exc: httpx.HTTPStatusError) -> str:
"""从 HTTP 错误响应中抽取更可读的错误文本。"""
try:
payload = exc.response.json()
except json.JSONDecodeError:
payload = None
if isinstance(payload, dict):
for key in ("detail", "title", "message", "error"):
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return str(exc)
async def _build_headers(self, agent: AgentDescriptor) -> dict[str, str]:
"""构造请求头,包括可选的 Bearer Token 和自定义 headers。"""
headers = {"Accept": "application/json, text/event-stream"}
auth_mode = (agent.auth_mode or "none").strip().lower()
if auth_mode == "oauth_backend_token":
headers["Authorization"] = f"Bearer {await self._issue_backend_token(agent)}"
else:
token = os.environ.get(agent.auth_env or "")
if token:
headers["Authorization"] = f"Bearer {token}"
extra = agent.metadata.get("headers")
if isinstance(extra, dict):
for key, value in extra.items():
if key and isinstance(value, str):
headers[key] = value
return headers
async def _issue_backend_token(self, agent: AgentDescriptor) -> str:
from nanobot.authz.client import AuthzClient
authz_base_url = str(getattr(self.authz_config, "base_url", "") or "").strip()
client_id = str(getattr(self.backend_identity, "client_id", "") or "").strip()
client_secret = str(getattr(self.backend_identity, "client_secret", "") or "").strip()
if not authz_base_url or not client_id or not client_secret:
raise A2AError(
f"A2A agent '{agent.id}' requires AuthZ backend tokens, but authz/backend identity is incomplete"
)
audience = str(agent.auth_audience or agent.id).strip()
if not audience:
raise A2AError(f"A2A agent '{agent.id}' is missing auth audience")
if not audience.startswith("a2a:"):
audience = f"a2a:{audience}"
scopes = [scope for scope in agent.auth_scopes if scope]
if not scopes:
scopes = ["run_task"]
authz_client = AuthzClient(
authz_base_url,
timeout_seconds=int(getattr(self.authz_config, "request_timeout_seconds", 10)),
)
token_response = await authz_client.issue_token(
client_id=client_id,
client_secret=client_secret,
audience=audience,
scopes=scopes,
)
access_token = str(token_response.get("access_token") or "").strip()
if not access_token:
raise A2AError(f"A2A agent '{agent.id}' did not receive an access token from AuthZ")
return access_token
def _check_allowed_host(self, url: str) -> None:
"""在配置了白名单时校验远端 host 是否允许访问。"""
if not self.allowed_hosts:
return
host = (urlparse(url).hostname or "").lower()
if host not in self.allowed_hosts:
raise A2AError(f"Host '{host}' is not allowed for A2A access")
@staticmethod
def _build_message_params(task: str, label: str | None) -> dict[str, Any]:
"""把委派任务包装成 A2A 标准 message 参数。"""
message = {
"messageId": str(uuid.uuid4()),
"role": "user",
"parts": [{"type": "text", "kind": "text", "text": task}],
}
if label:
message["metadata"] = {"label": label}
return {"message": message}
@classmethod
def _build_rest_payload(cls, params: dict[str, Any]) -> dict[str, Any]:
"""把通用 message 参数转换成 REST 风格 payload。"""
payload = json.loads(json.dumps(params))
message = payload.get("message")
if not isinstance(message, dict):
return payload
# 某些 REST 实现要求 role 使用枚举字面量,而不是自由字符串。
role = str(message.get("role") or "").strip().lower()
if role == "user":
message["role"] = "ROLE_USER"
elif role == "agent":
message["role"] = "ROLE_AGENT"
parts = message.pop("parts", None)
if isinstance(parts, list) and "content" not in message:
# REST 风格通常把 `parts` 拍平成 `content` 数组。
content: list[dict[str, Any]] = []
for part in parts:
if not isinstance(part, dict):
continue
if "text" in part:
content.append({"text": part["text"]})
continue
if "file" in part:
content.append({"file": part["file"]})
continue
if "data" in part:
content.append({"data": part["data"]})
message["content"] = content
return payload
def _build_run_result(
self,
agent: AgentDescriptor,
result: dict[str, Any],
state: _StreamState | None = None,
) -> AgentRunResult:
"""把远端结果对象转换成统一的 AgentRunResult。"""
summary = ""
if state:
summary = state.build_summary(self)
if not summary:
summary = self._extract_text(result) or json.dumps(result, ensure_ascii=False)
status = self._normalize_status(result.get("status"))
if not has_meaningful_summary(summary):
status = "error"
return AgentRunResult(
agent_id=agent.id,
agent_name=agent.name,
status=status,
summary=summary,
raw=result,
)
@staticmethod
def _is_task_result(result: dict[str, Any]) -> bool:
"""判断返回值是否表示一个 task 对象。"""
if "status" in result:
return True
kind = str(result.get("kind") or "").lower()
return kind == "task"
@staticmethod
def _is_terminal_status(status: Any) -> bool:
"""判断状态是否已进入终态。"""
state = A2AClient._normalized_state_token(status)
return state in {"completed", "complete", "failed", "error", "cancelled", "canceled"}
@staticmethod
def _normalize_status(status: Any) -> str:
"""把五花八门的远端状态名归一化。"""
state = A2AClient._normalized_state_token(status or "ok")
if state in {"", "completed", "complete", "success", "ok"}:
return "ok"
if state in {"working", "running", "in_progress", "submitted", "queued"}:
return "working"
if state in {"failed", "error"}:
return "error"
if state in {"cancelled", "canceled"}:
return "cancelled"
return state
@staticmethod
def _normalized_state_token(status: Any) -> str:
"""抽取状态里的核心 token例如去掉 `TASK_STATE_` 前缀。"""
if isinstance(status, dict):
state = str(status.get("state") or status.get("status") or "")
else:
state = str(status or "")
state = state.strip().lower()
if state.startswith("task_state_"):
state = state[len("task_state_") :]
return state
@classmethod
def _extract_text(cls, payload: Any) -> str:
"""从嵌套对象里尽可能提取最有价值的文本内容。"""
if isinstance(payload, str):
return payload
if isinstance(payload, dict):
for key in ("text", "content", "summary", "finalText", "final_text"):
value = payload.get(key)
text = cls._extract_text(value)
if text:
return text
for key in (
"message",
"artifact",
"artifacts",
"messages",
"parts",
"output",
"toolResults",
"tool_results",
"task",
"statusUpdate",
"status_update",
"artifactUpdate",
"artifact_update",
"history",
):
value = payload.get(key)
text = cls._extract_text(value)
if text:
return text
if "status" in payload and isinstance(payload["status"], dict):
text = cls._extract_text(payload["status"])
if text:
return text
return ""
if isinstance(payload, list):
parts = [cls._extract_text(item) for item in payload]
return "\n".join(part for part in parts if part)
return ""