- 引入AgentTeamOrchestrator支持多agent协同任务执行 - 增加第三方swarms库依赖并配置git协议替换以改善包管理 - 扩展DelegationManager支持团队任务调度和进度跟踪 - 实现中文bigram分词算法提升中文任务检索准确性 - 调整A2AClient和DelegationManager超时时间从30秒增至600秒 - 优化AgentRunResult状态判断逻辑增加有意义摘要检测 - 修改Dockerfile配置npm仓库镜像地址和git协议映射 - 更新CLI命令行接口支持网关端口配置传递 - 调整提供者超时配置机制增强请求稳定性 - 移除过时的support_group字段简化agent描述符结构 - 增强错误处理和进度事件报告机制改进用户体验
1217 lines
47 KiB
Python
1217 lines
47 KiB
Python
"""A2A 客户端实现。
|
||
|
||
目标不是完整覆盖所有厂商变体,而是提供一条足够稳的兼容链路:
|
||
1. 先拉 agent card,解析可用端点和偏好传输;
|
||
2. 优先尝试流式订阅,拿到实时进度;
|
||
3. 流式不可用或中断时,回退到轮询;
|
||
4. 同时兼容 JSON-RPC 和 HTTP+JSON 风格接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import time
|
||
import uuid
|
||
from collections.abc import Awaitable, Callable
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
from urllib.parse import urlparse, urlunparse
|
||
|
||
import httpx
|
||
|
||
from nanobot.agent.agent_registry import AgentDescriptor
|
||
from nanobot.agent.run_result import AgentRunResult, has_meaningful_summary
|
||
|
||
|
||
class A2AError(RuntimeError):
|
||
"""A2A 请求失败时抛出的统一异常。"""
|
||
|
||
|
||
class A2AUnsupportedMethodError(A2AError):
|
||
"""远端端点不支持某个方法时抛出的异常。"""
|
||
|
||
|
||
@dataclass
|
||
class A2AStreamEvent:
|
||
"""A2A 订阅流事件的归一化表示。"""
|
||
|
||
# 事件类型,例如 task / message / status-update / artifact-update。
|
||
kind: str
|
||
# 远端任务 ID;一旦出现,上层就可以登记用于取消或恢复订阅。
|
||
task_id: str | None = None
|
||
# 归一化后的状态文本。
|
||
status: str | None = None
|
||
# 适合展示给用户的增量文本。
|
||
text: str | None = None
|
||
# 是否已到达终态。
|
||
final: bool = False
|
||
# 原始事件体,便于调试和后续扩展。
|
||
raw: dict[str, Any] | None = None
|
||
|
||
|
||
@dataclass
|
||
class _StreamState:
|
||
"""流式任务状态累加器。"""
|
||
|
||
task_id: str | None = None
|
||
status: str = "working"
|
||
artifacts: dict[str, str] = field(default_factory=dict)
|
||
artifact_order: list[str] = field(default_factory=list)
|
||
messages: list[str] = field(default_factory=list)
|
||
status_messages: list[str] = field(default_factory=list)
|
||
latest_result: dict[str, Any] | None = None
|
||
final_seen: bool = False
|
||
|
||
def apply(self, result: dict[str, Any], client: A2AClient) -> A2AStreamEvent:
|
||
"""吸收一条原始结果并产出归一化流事件。"""
|
||
self.latest_result = result
|
||
kind = str(result.get("kind") or result.get("type") or "result").lower()
|
||
task_id = str(result.get("id") or result.get("taskId") or "").strip() or None
|
||
if task_id:
|
||
self.task_id = task_id
|
||
|
||
raw_status = result.get("status")
|
||
normalized_status = client._normalize_status(raw_status)
|
||
# 非 ok 状态会覆盖当前状态;否则在 task/status-update 终态时再更新。
|
||
if normalized_status not in {"", "ok"}:
|
||
self.status = normalized_status
|
||
elif kind in {"task", "status-update"} and client._is_terminal_status(raw_status):
|
||
self.status = client._normalize_status(raw_status)
|
||
|
||
text = ""
|
||
if kind == "artifact-update":
|
||
# artifact-update 需要增量拼接同一个 artifact 的文本内容。
|
||
text = self._apply_artifact_update(result, client)
|
||
elif kind == "status-update":
|
||
# 某些实现把状态消息放在 status 里,有些放在 message 里,这里都兜一遍。
|
||
text = client._extract_text(result.get("status")) or client._extract_text(
|
||
result.get("message")
|
||
)
|
||
self._append_unique(self.status_messages, text)
|
||
elif kind in {"message", "task"}:
|
||
self._apply_task_or_message(result, client)
|
||
text = client._extract_text(result)
|
||
else:
|
||
text = client._extract_text(result)
|
||
|
||
final = bool(result.get("final")) or client._is_terminal_status(raw_status)
|
||
if final:
|
||
self.final_seen = True
|
||
if self.status == "working":
|
||
# 即使没拿到更明确状态,也尽量用终态把 working 覆盖掉。
|
||
self.status = client._normalize_status(raw_status)
|
||
if text and kind not in {"artifact-update", "message", "task"}:
|
||
self._append_unique(self.messages, text)
|
||
|
||
return A2AStreamEvent(
|
||
kind=kind,
|
||
task_id=self.task_id,
|
||
status=self.status,
|
||
text=text or None,
|
||
final=final,
|
||
raw=result,
|
||
)
|
||
|
||
def build_summary(self, client: A2AClient) -> str:
|
||
"""按 artifact -> message -> status 的优先级生成最终摘要。"""
|
||
artifact_text = "\n".join(
|
||
self.artifacts[artifact_id]
|
||
for artifact_id in self.artifact_order
|
||
if self.artifacts.get(artifact_id)
|
||
).strip()
|
||
if artifact_text:
|
||
return artifact_text
|
||
|
||
message_text = "\n".join(text for text in self.messages if text).strip()
|
||
if message_text:
|
||
return message_text
|
||
|
||
status_text = "\n".join(text for text in self.status_messages if text).strip()
|
||
if status_text:
|
||
return status_text
|
||
|
||
if self.latest_result:
|
||
return client._extract_text(self.latest_result)
|
||
return ""
|
||
|
||
def _apply_artifact_update(self, result: dict[str, Any], client: A2AClient) -> str:
|
||
"""把一条 artifact-update 事件并入累积状态。"""
|
||
artifact = result.get("artifact")
|
||
if not isinstance(artifact, dict):
|
||
artifact = result
|
||
artifact_id = str(
|
||
artifact.get("artifactId")
|
||
or artifact.get("id")
|
||
or result.get("artifactId")
|
||
or f"artifact-{len(self.artifact_order) + 1}"
|
||
)
|
||
text = client._extract_text(artifact)
|
||
if not text:
|
||
return ""
|
||
|
||
if artifact_id not in self.artifacts:
|
||
self.artifacts[artifact_id] = ""
|
||
self.artifact_order.append(artifact_id)
|
||
|
||
# append=true 时做增量拼接,否则视为完整覆盖。
|
||
if result.get("append") or artifact.get("append"):
|
||
self.artifacts[artifact_id] += text
|
||
else:
|
||
self.artifacts[artifact_id] = text
|
||
return text
|
||
|
||
def _apply_task_or_message(self, result: dict[str, Any], client: A2AClient) -> None:
|
||
"""把 task/message 类型结果中的 artifact 和文本提取出来。"""
|
||
artifacts = result.get("artifacts")
|
||
if isinstance(artifacts, list):
|
||
for index, artifact in enumerate(artifacts):
|
||
if not isinstance(artifact, dict):
|
||
continue
|
||
artifact_id = str(
|
||
artifact.get("artifactId")
|
||
or artifact.get("id")
|
||
or f"artifact-{len(self.artifact_order) + index + 1}"
|
||
)
|
||
text = client._extract_text(artifact)
|
||
if not text:
|
||
continue
|
||
if artifact_id not in self.artifacts:
|
||
self.artifact_order.append(artifact_id)
|
||
self.artifacts[artifact_id] = text
|
||
|
||
text = client._extract_text(result)
|
||
self._append_unique(self.messages, text)
|
||
|
||
@staticmethod
|
||
def _append_unique(collection: list[str], text: str) -> None:
|
||
"""仅当文本与上一个不同才追加,避免流式重复刷屏。"""
|
||
if text and (not collection or collection[-1] != text):
|
||
collection.append(text)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class _A2ATransportTarget:
|
||
"""解析后的远端传输目标。"""
|
||
|
||
mode: str
|
||
endpoint: str
|
||
|
||
|
||
class A2AClient:
|
||
"""支持 JSON-RPC 与 HTTP+JSON 回退链路的 A2A 客户端。"""
|
||
|
||
def __init__(
|
||
self,
|
||
timeout_seconds: int = 600,
|
||
poll_interval_seconds: int = 2,
|
||
card_cache_ttl_seconds: int = 300,
|
||
allowed_hosts: list[str] | None = None,
|
||
transport: httpx.AsyncBaseTransport | None = None,
|
||
authz_config: Any | None = None,
|
||
backend_identity: Any | None = None,
|
||
):
|
||
# 这些参数决定超时、轮询频率和安全边界。
|
||
self.timeout_seconds = timeout_seconds
|
||
self.poll_interval_seconds = poll_interval_seconds
|
||
self.card_cache_ttl_seconds = card_cache_ttl_seconds
|
||
self.allowed_hosts = {host.lower() for host in (allowed_hosts or []) if host}
|
||
self.transport = transport
|
||
self.authz_config = authz_config
|
||
self.backend_identity = backend_identity
|
||
self._card_cache: dict[str, tuple[float, dict[str, Any]]] = {}
|
||
|
||
async def run_task(
|
||
self,
|
||
agent: AgentDescriptor,
|
||
task: str,
|
||
label: str | None = None,
|
||
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None = None,
|
||
task_callback: Callable[[str], Awaitable[None]] | None = None,
|
||
prefer_streaming: bool = True,
|
||
) -> AgentRunResult:
|
||
"""执行一次远端 A2A 任务。"""
|
||
card = await self.fetch_agent_card(agent)
|
||
params = self._build_message_params(task, label)
|
||
targets = self._resolve_transport_targets(card, agent)
|
||
if not targets:
|
||
raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint")
|
||
|
||
last_unsupported: Exception | None = None
|
||
for target in targets:
|
||
try:
|
||
# 若 card 支持流式,则优先尝试流式以获取中间态。
|
||
if prefer_streaming and self._supports_streaming(card):
|
||
stream_result = await self._run_task_streaming(
|
||
target=target,
|
||
params=params,
|
||
agent=agent,
|
||
event_callback=event_callback,
|
||
task_callback=task_callback,
|
||
)
|
||
if stream_result is not None:
|
||
return stream_result
|
||
|
||
# 流式不可用时回退到普通 send,再视状态决定是否轮询。
|
||
result = await self._send_task(target, params, agent)
|
||
if self._is_task_result(result) and not self._is_terminal_status(result.get("status")):
|
||
task_id = str(result.get("id") or result.get("taskId") or "").strip()
|
||
if task_id:
|
||
if task_callback:
|
||
await task_callback(task_id)
|
||
result = await self._poll_task(target, task_id, agent)
|
||
|
||
return self._build_run_result(agent, result)
|
||
except A2AUnsupportedMethodError as exc:
|
||
last_unsupported = exc
|
||
continue
|
||
|
||
if last_unsupported:
|
||
raise last_unsupported
|
||
raise A2AError(f"Agent '{agent.id}' does not expose a usable A2A endpoint")
|
||
|
||
async def fetch_agent_card(self, agent: AgentDescriptor) -> dict[str, Any]:
|
||
"""拉取远端 agent card,并带本地 TTL 缓存。"""
|
||
_, card = await self.fetch_agent_card_with_url(agent)
|
||
return card
|
||
|
||
async def fetch_agent_card_with_url(self, agent: AgentDescriptor) -> tuple[str, dict[str, Any]]:
|
||
"""拉取远端 agent card,并返回命中的 card URL。"""
|
||
urls = self._candidate_card_urls(agent)
|
||
last_error: Exception | None = None
|
||
for url in urls:
|
||
cache_key = url.lower()
|
||
cached = self._card_cache.get(cache_key)
|
||
if cached and cached[0] > time.monotonic():
|
||
return url, cached[1]
|
||
|
||
try:
|
||
card = await self._fetch_json(url, agent)
|
||
except Exception as exc:
|
||
last_error = exc
|
||
continue
|
||
|
||
if isinstance(card, dict):
|
||
self._card_cache[cache_key] = (
|
||
time.monotonic() + self.card_cache_ttl_seconds,
|
||
card,
|
||
)
|
||
return url, card
|
||
|
||
if last_error:
|
||
raise A2AError(f"Failed to fetch agent card for '{agent.id}': {last_error}")
|
||
raise A2AError(f"Failed to fetch agent card for '{agent.id}'")
|
||
|
||
def invalidate_card(self, agent: AgentDescriptor) -> None:
|
||
"""清空某个 agent 相关的 card 缓存。"""
|
||
for url in self._candidate_card_urls(agent):
|
||
self._card_cache.pop(url.lower(), None)
|
||
|
||
def _candidate_card_urls(self, agent: AgentDescriptor) -> list[str]:
|
||
"""根据 agent 配置推导一组候选 card URL。"""
|
||
urls: list[str] = []
|
||
if agent.card_url:
|
||
urls.append(agent.card_url)
|
||
base_url = str(agent.base_url or agent.endpoint or "").rstrip("/")
|
||
if base_url:
|
||
urls.extend(
|
||
[
|
||
f"{base_url}/.well-known/agent-card",
|
||
f"{base_url}/.well-known/agent-card.json",
|
||
f"{base_url}/.well-known/agent.json",
|
||
]
|
||
)
|
||
deduped: list[str] = []
|
||
seen: set[str] = set()
|
||
for url in urls:
|
||
normalized = url.strip()
|
||
if not normalized or normalized.lower() in seen:
|
||
continue
|
||
seen.add(normalized.lower())
|
||
deduped.append(normalized)
|
||
return deduped
|
||
|
||
async def _run_task_streaming(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
|
||
task_callback: Callable[[str], Awaitable[None]] | None,
|
||
) -> AgentRunResult | None:
|
||
"""优先尝试流式方法执行任务,失败时可回退为 None。"""
|
||
if target.mode == "rest":
|
||
stream_variants = [("message/stream", params)]
|
||
else:
|
||
stream_variants = [
|
||
("tasks/sendSubscribe", {"id": str(uuid.uuid4()), **params}),
|
||
("message/stream", params),
|
||
]
|
||
|
||
saw_supported_stream = False
|
||
last_error: Exception | None = None
|
||
for method, payload in stream_variants:
|
||
state = _StreamState()
|
||
try:
|
||
# 每个流式方法都独立尝试;一旦成功拿到终态结果就直接返回。
|
||
return await self._consume_stream_method(
|
||
target=target,
|
||
method=method,
|
||
params=payload,
|
||
agent=agent,
|
||
event_callback=event_callback,
|
||
task_callback=task_callback,
|
||
state=state,
|
||
allow_resume=True,
|
||
)
|
||
except A2AUnsupportedMethodError as exc:
|
||
last_error = exc
|
||
continue
|
||
except A2AError as exc:
|
||
# 已经跑到一半但中断时,如果拿到了 task_id,就尝试恢复订阅或轮询。
|
||
saw_supported_stream = True
|
||
last_error = exc
|
||
if state.task_id:
|
||
try:
|
||
return await self._resume_or_poll(
|
||
target=target,
|
||
agent=agent,
|
||
task_id=state.task_id,
|
||
state=state,
|
||
event_callback=event_callback,
|
||
task_callback=task_callback,
|
||
)
|
||
except A2AError as resume_exc:
|
||
last_error = resume_exc
|
||
continue
|
||
else:
|
||
saw_supported_stream = True
|
||
|
||
if saw_supported_stream and last_error:
|
||
raise last_error
|
||
return None
|
||
|
||
async def _consume_stream_method(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
method: str,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
|
||
task_callback: Callable[[str], Awaitable[None]] | None,
|
||
state: _StreamState,
|
||
allow_resume: bool,
|
||
) -> AgentRunResult:
|
||
"""消费一个具体流式方法,直到拿到终态结果。"""
|
||
saw_event = False
|
||
seen_task_id: str | None = state.task_id
|
||
try:
|
||
async for body in self._stream_request(target, method, params, agent):
|
||
saw_event = True
|
||
result = self._unwrap_result_object(body)
|
||
event = state.apply(result, self)
|
||
# 首次看到 task_id 时通知上层登记,以便取消或恢复。
|
||
if task_callback and event.task_id and event.task_id != seen_task_id:
|
||
seen_task_id = event.task_id
|
||
await task_callback(event.task_id)
|
||
if event_callback:
|
||
await event_callback(event)
|
||
if event.final:
|
||
return self._build_run_result(agent, result, state)
|
||
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.TimeoutException) as exc:
|
||
if not state.task_id:
|
||
raise A2AError(str(exc)) from exc
|
||
|
||
if state.final_seen:
|
||
return self._build_run_result(agent, state.latest_result or {}, state)
|
||
if allow_resume and state.task_id:
|
||
# 流结束但还没终态时,尝试恢复订阅或退化成轮询。
|
||
return await self._resume_or_poll(
|
||
target=target,
|
||
agent=agent,
|
||
task_id=state.task_id,
|
||
state=state,
|
||
event_callback=event_callback,
|
||
task_callback=task_callback,
|
||
)
|
||
if saw_event:
|
||
raise A2AError("A2A stream ended before a final result was received")
|
||
raise A2AUnsupportedMethodError(method)
|
||
|
||
async def _resume_or_poll(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
agent: AgentDescriptor,
|
||
task_id: str,
|
||
state: _StreamState,
|
||
event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None,
|
||
task_callback: Callable[[str], Awaitable[None]] | None,
|
||
) -> AgentRunResult:
|
||
"""在流式中断后尝试恢复订阅,失败则退化为轮询。"""
|
||
try:
|
||
return await self._consume_stream_method(
|
||
target=target,
|
||
method="tasks/subscribe" if target.mode == "rest" else "tasks/resubscribe",
|
||
params={"id": task_id},
|
||
agent=agent,
|
||
event_callback=event_callback,
|
||
task_callback=task_callback,
|
||
state=state,
|
||
allow_resume=False,
|
||
)
|
||
except A2AUnsupportedMethodError:
|
||
result = await self._poll_task(target, task_id, agent)
|
||
return self._build_run_result(agent, result, state)
|
||
|
||
async def cancel_task(self, agent: AgentDescriptor, task_id: str) -> bool:
|
||
"""尽力取消一个远端 A2A 任务。"""
|
||
task_id = task_id.strip()
|
||
if not task_id:
|
||
return False
|
||
card = await self.fetch_agent_card(agent)
|
||
targets = self._resolve_transport_targets(card, agent)
|
||
if not targets:
|
||
raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint")
|
||
for target in targets:
|
||
try:
|
||
if target.mode == "rest":
|
||
# REST 风格通常使用 `/tasks/{id}:cancel`。
|
||
await self._request_json(
|
||
"POST",
|
||
self._rest_endpoint(target.endpoint, f"/tasks/{task_id}:cancel"),
|
||
agent,
|
||
json_body={"name": f"tasks/{task_id}"},
|
||
)
|
||
else:
|
||
# JSON-RPC 风格使用 `tasks/cancel`。
|
||
await self._rpc_jsonrpc(target.endpoint, "tasks/cancel", {"id": task_id}, agent)
|
||
return True
|
||
except A2AUnsupportedMethodError:
|
||
continue
|
||
return False
|
||
|
||
async def _send_task(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
) -> dict[str, Any]:
|
||
"""发送一次非流式任务请求。"""
|
||
if target.mode == "rest":
|
||
body = await self._request_json(
|
||
"POST",
|
||
self._rest_endpoint(target.endpoint, "/message:send"),
|
||
agent,
|
||
json_body=self._build_rest_payload(params),
|
||
)
|
||
return self._unwrap_result_object(body)
|
||
|
||
send_variants = [
|
||
("tasks/send", {"id": str(uuid.uuid4()), **params}),
|
||
("message/send", params),
|
||
]
|
||
last_error: Exception | None = None
|
||
for method, payload in send_variants:
|
||
try:
|
||
response = await self._rpc_jsonrpc(target.endpoint, method, payload, agent)
|
||
return self._unwrap_result_object(response)
|
||
except A2AUnsupportedMethodError as exc:
|
||
last_error = exc
|
||
continue
|
||
|
||
raise last_error or A2AError("No supported A2A send method found")
|
||
|
||
async def _poll_task(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
task_id: str,
|
||
agent: AgentDescriptor,
|
||
) -> dict[str, Any]:
|
||
"""轮询远端 task,直到进入终态或超时。"""
|
||
deadline = time.monotonic() + self.timeout_seconds
|
||
while time.monotonic() < deadline:
|
||
if target.mode == "rest":
|
||
body = await self._request_json(
|
||
"GET",
|
||
self._rest_endpoint(target.endpoint, f"/tasks/{task_id}"),
|
||
agent,
|
||
)
|
||
result = self._unwrap_result_object(body)
|
||
else:
|
||
response = await self._rpc_jsonrpc(target.endpoint, "tasks/get", {"id": task_id}, agent)
|
||
result = self._unwrap_result_object(response)
|
||
|
||
if self._is_terminal_status(result.get("status")):
|
||
return result
|
||
await asyncio.sleep(self.poll_interval_seconds)
|
||
|
||
raise A2AError(
|
||
f"A2A task '{task_id}' timed out after {self.timeout_seconds} seconds"
|
||
)
|
||
|
||
async def _fetch_json(self, url: str, agent: AgentDescriptor) -> dict[str, Any]:
|
||
"""以 GET 方式拉取 JSON 对象。"""
|
||
self._check_allowed_host(url)
|
||
async with httpx.AsyncClient(
|
||
timeout=self.timeout_seconds,
|
||
transport=self.transport,
|
||
) as client:
|
||
response = await client.get(url, headers=await self._build_headers(agent))
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
if not isinstance(payload, dict):
|
||
raise A2AError("Agent card response must be a JSON object")
|
||
return payload
|
||
|
||
async def _rpc_jsonrpc(
|
||
self,
|
||
endpoint: str,
|
||
method: str,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
) -> dict[str, Any]:
|
||
"""发送一条 JSON-RPC 请求。"""
|
||
self._check_allowed_host(endpoint)
|
||
payload = {
|
||
"jsonrpc": "2.0",
|
||
"id": str(uuid.uuid4()),
|
||
"method": method,
|
||
"params": params,
|
||
}
|
||
|
||
async with httpx.AsyncClient(
|
||
timeout=self.timeout_seconds,
|
||
transport=self.transport,
|
||
) as client:
|
||
try:
|
||
response = await client.post(
|
||
endpoint,
|
||
json=payload,
|
||
headers=await self._build_headers(agent),
|
||
)
|
||
response.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
if exc.response.status_code in {404, 405, 501}:
|
||
raise A2AUnsupportedMethodError(method) from exc
|
||
raise A2AError(str(exc)) from exc
|
||
body = response.json()
|
||
|
||
if not isinstance(body, dict):
|
||
raise A2AError("A2A RPC response must be a JSON object")
|
||
|
||
error = body.get("error")
|
||
if isinstance(error, dict):
|
||
code = error.get("code")
|
||
message = str(error.get("message") or "unknown error")
|
||
if code == -32601 or "not found" in message.lower():
|
||
raise A2AUnsupportedMethodError(message)
|
||
raise A2AError(message)
|
||
|
||
return body
|
||
|
||
async def _stream_request(
|
||
self,
|
||
target: _A2ATransportTarget,
|
||
method: str,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
):
|
||
"""根据 transport mode 选择具体流式实现。"""
|
||
if target.mode == "rest":
|
||
async for body in self._stream_rest(target.endpoint, method, params, agent):
|
||
yield body
|
||
return
|
||
|
||
async for body in self._stream_jsonrpc(target.endpoint, method, params, agent):
|
||
yield body
|
||
|
||
async def _stream_jsonrpc(
|
||
self,
|
||
endpoint: str,
|
||
method: str,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
):
|
||
"""通过 JSON-RPC 流式接口接收事件。"""
|
||
self._check_allowed_host(endpoint)
|
||
payload = {
|
||
"jsonrpc": "2.0",
|
||
"id": str(uuid.uuid4()),
|
||
"method": method,
|
||
"params": params,
|
||
}
|
||
|
||
async with httpx.AsyncClient(
|
||
timeout=self.timeout_seconds,
|
||
transport=self.transport,
|
||
) as client:
|
||
try:
|
||
async with client.stream(
|
||
"POST",
|
||
endpoint,
|
||
json=payload,
|
||
headers=await self._build_headers(agent),
|
||
) as response:
|
||
try:
|
||
response.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
if exc.response.status_code in {404, 405, 501}:
|
||
raise A2AUnsupportedMethodError(method) from exc
|
||
raise A2AError(str(exc)) from exc
|
||
|
||
content_type = response.headers.get("content-type", "").lower()
|
||
if "text/event-stream" in content_type:
|
||
# 标准 SSE 按 event/data 行拼装;这里只消费 data 载荷。
|
||
async for raw_event in self._iter_sse_events(response):
|
||
if raw_event.strip() == "[DONE]":
|
||
break
|
||
yield self._parse_stream_body(raw_event)
|
||
else:
|
||
body = await response.aread()
|
||
if not body:
|
||
return
|
||
yield self._parse_stream_body(body.decode("utf-8"))
|
||
except httpx.HTTPStatusError as exc:
|
||
if exc.response.status_code in {404, 405, 501}:
|
||
raise A2AUnsupportedMethodError(method) from exc
|
||
raise A2AError(str(exc)) from exc
|
||
|
||
async def _stream_rest(
|
||
self,
|
||
endpoint: str,
|
||
method: str,
|
||
params: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
):
|
||
"""通过 REST 风格流式接口接收事件。"""
|
||
if method == "message/stream":
|
||
requests = [
|
||
(
|
||
"POST",
|
||
self._rest_endpoint(endpoint, "/message:stream"),
|
||
self._build_rest_payload(params),
|
||
)
|
||
]
|
||
elif method == "tasks/subscribe":
|
||
task_id = str(params.get("id") or "").strip()
|
||
if not task_id:
|
||
raise A2AError("Missing task id for REST task subscribe")
|
||
subscribe_url = self._rest_endpoint(endpoint, f"/tasks/{task_id}:subscribe")
|
||
requests = [("GET", subscribe_url, None), ("POST", subscribe_url, None)]
|
||
else:
|
||
raise A2AUnsupportedMethodError(method)
|
||
|
||
last_error: Exception | None = None
|
||
for http_method, url, payload in requests:
|
||
self._check_allowed_host(url)
|
||
async with httpx.AsyncClient(
|
||
timeout=self.timeout_seconds,
|
||
transport=self.transport,
|
||
) as client:
|
||
try:
|
||
async with client.stream(
|
||
http_method,
|
||
url,
|
||
json=payload,
|
||
headers=await self._build_headers(agent),
|
||
) as response:
|
||
try:
|
||
response.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
if self._is_unsupported_http_error(exc, allow_validation_errors=True):
|
||
raise A2AUnsupportedMethodError(method) from exc
|
||
raise A2AError(self._http_error_message(exc)) from exc
|
||
|
||
content_type = response.headers.get("content-type", "").lower()
|
||
if "text/event-stream" in content_type:
|
||
async for raw_event in self._iter_sse_events(response):
|
||
if raw_event.strip() == "[DONE]":
|
||
break
|
||
yield self._parse_stream_body(raw_event)
|
||
return
|
||
|
||
body = await response.aread()
|
||
if not body:
|
||
return
|
||
yield self._parse_stream_body(body.decode("utf-8"))
|
||
return
|
||
except A2AUnsupportedMethodError as exc:
|
||
last_error = exc
|
||
continue
|
||
|
||
raise last_error or A2AUnsupportedMethodError(method)
|
||
|
||
async def _request_json(
|
||
self,
|
||
http_method: str,
|
||
url: str,
|
||
agent: AgentDescriptor,
|
||
*,
|
||
json_body: dict[str, Any] | None = None,
|
||
params: dict[str, str] | None = None,
|
||
) -> dict[str, Any]:
|
||
"""发送普通 HTTP JSON 请求。"""
|
||
self._check_allowed_host(url)
|
||
async with httpx.AsyncClient(
|
||
timeout=self.timeout_seconds,
|
||
transport=self.transport,
|
||
) as client:
|
||
try:
|
||
response = await client.request(
|
||
http_method,
|
||
url,
|
||
json=json_body,
|
||
params=params,
|
||
headers=await self._build_headers(agent),
|
||
)
|
||
response.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
if self._is_unsupported_http_error(exc):
|
||
raise A2AUnsupportedMethodError(url) from exc
|
||
raise A2AError(self._http_error_message(exc)) from exc
|
||
|
||
try:
|
||
body = response.json()
|
||
except json.JSONDecodeError as exc:
|
||
raise A2AError(f"Invalid JSON response from {url}") from exc
|
||
|
||
if not isinstance(body, dict):
|
||
raise A2AError("A2A response must be a JSON object")
|
||
return body
|
||
|
||
async def _iter_sse_events(self, response: httpx.Response):
|
||
"""把 SSE 响应流按事件边界还原为 data 文本块。"""
|
||
data_lines: list[str] = []
|
||
async for line in response.aiter_lines():
|
||
if line == "":
|
||
if data_lines:
|
||
yield "\n".join(data_lines)
|
||
data_lines = []
|
||
continue
|
||
if line.startswith(":"):
|
||
continue
|
||
field, _, value = line.partition(":")
|
||
value = value.lstrip()
|
||
if field == "data":
|
||
data_lines.append(value)
|
||
if data_lines:
|
||
yield "\n".join(data_lines)
|
||
|
||
@staticmethod
|
||
def _parse_stream_body(raw_event: str) -> dict[str, Any]:
|
||
"""解析单条流事件 JSON,并统一处理远端 error 对象。"""
|
||
try:
|
||
body = json.loads(raw_event)
|
||
except json.JSONDecodeError as exc:
|
||
raise A2AError(f"Invalid A2A stream payload: {raw_event}") from exc
|
||
if not isinstance(body, dict):
|
||
raise A2AError("A2A stream payload must be a JSON object")
|
||
|
||
error = body.get("error")
|
||
if isinstance(error, dict):
|
||
code = error.get("code")
|
||
message = str(error.get("message") or "unknown error")
|
||
if code == -32601 or "not found" in message.lower():
|
||
raise A2AUnsupportedMethodError(message)
|
||
raise A2AError(message)
|
||
return body
|
||
|
||
def _resolve_transport_targets(
|
||
self,
|
||
card: dict[str, Any],
|
||
agent: AgentDescriptor,
|
||
) -> list[_A2ATransportTarget]:
|
||
"""根据 card 和本地配置解析一组可尝试的传输目标。"""
|
||
default_url = self._resolve_primary_url(card, agent)
|
||
declared = self._collect_declared_interfaces(card)
|
||
preferred = self._normalize_transport(
|
||
card.get("preferred_transport") or card.get("preferredTransport")
|
||
)
|
||
|
||
candidates: list[_A2ATransportTarget] = []
|
||
if preferred in {"jsonrpc", "rest"}:
|
||
preferred_url = declared.get(preferred) or default_url
|
||
if preferred_url:
|
||
candidates.append(self._transport_target(preferred, preferred_url))
|
||
|
||
for mode in ("jsonrpc", "rest"):
|
||
url = declared.get(mode)
|
||
if url:
|
||
candidates.append(self._transport_target(mode, url))
|
||
|
||
if default_url:
|
||
if preferred not in {"jsonrpc", "rest"}:
|
||
candidates.append(self._transport_target("jsonrpc", default_url))
|
||
candidates.append(self._transport_target("rest", default_url))
|
||
elif preferred == "jsonrpc":
|
||
candidates.append(self._transport_target("rest", default_url))
|
||
else:
|
||
candidates.append(self._transport_target("jsonrpc", default_url))
|
||
|
||
deduped: list[_A2ATransportTarget] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
for target in candidates:
|
||
# 同一个 mode + endpoint 只保留一次,避免重复尝试。
|
||
key = (target.mode, target.endpoint.rstrip("/").lower())
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
deduped.append(target)
|
||
return deduped
|
||
|
||
def _collect_declared_interfaces(self, card: dict[str, Any]) -> dict[str, str]:
|
||
"""从 card 的 interfaces 字段里提取声明过的 transport/url。"""
|
||
interfaces = None
|
||
for key in (
|
||
"additional_interfaces",
|
||
"additionalInterfaces",
|
||
"interfaces",
|
||
"supported_interfaces",
|
||
"supportedInterfaces",
|
||
):
|
||
candidate = card.get(key)
|
||
if isinstance(candidate, list):
|
||
interfaces = candidate
|
||
break
|
||
|
||
result: dict[str, str] = {}
|
||
if not isinstance(interfaces, list):
|
||
return result
|
||
|
||
for item in interfaces:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
mode = self._normalize_transport(item.get("transport"))
|
||
url = str(item.get("url") or "").strip()
|
||
if mode in {"jsonrpc", "rest"} and url:
|
||
result.setdefault(mode, url)
|
||
return result
|
||
|
||
def _resolve_primary_url(self, card: dict[str, Any], agent: AgentDescriptor) -> str:
|
||
"""解析 card 的主 URL;当 card 返回 0.0.0.0 时退回本地配置。"""
|
||
card_url = str(card.get("url") or "").strip()
|
||
fallback = str(agent.endpoint or agent.base_url or "").strip()
|
||
if card_url and (urlparse(card_url).hostname or "").strip() not in {"0.0.0.0", "::"}:
|
||
return card_url
|
||
return fallback or card_url
|
||
|
||
def _transport_target(self, mode: str, url: str) -> _A2ATransportTarget:
|
||
"""构造标准化的 transport target。"""
|
||
normalized_url = url.strip()
|
||
if mode == "rest":
|
||
normalized_url = self._normalize_rest_base_url(normalized_url)
|
||
return _A2ATransportTarget(mode=mode, endpoint=normalized_url)
|
||
|
||
@staticmethod
|
||
def _normalize_transport(value: Any) -> str | None:
|
||
"""把不同命名风格的 transport 文本归一化。"""
|
||
text = str(value or "").strip().lower()
|
||
if not text:
|
||
return None
|
||
if text in {"jsonrpc", "json-rpc"}:
|
||
return "jsonrpc"
|
||
if text in {"http+json", "http-json", "http_json", "rest"}:
|
||
return "rest"
|
||
if text == "grpc":
|
||
return "grpc"
|
||
return None
|
||
|
||
@staticmethod
|
||
def _normalize_rest_base_url(url: str) -> str:
|
||
"""把各种 REST 端点变体规整到 `/v1` 根路径。"""
|
||
parsed = urlparse(url)
|
||
path = parsed.path.rstrip("/")
|
||
for suffix in ("/message:send", "/message:stream"):
|
||
if path.endswith(suffix):
|
||
path = path[: -len(suffix)]
|
||
break
|
||
if "/tasks/" in path and (path.endswith(":cancel") or path.endswith(":subscribe")):
|
||
path = path.split("/tasks/", 1)[0]
|
||
if not path.endswith("/v1"):
|
||
path = f"{path}/v1" if path else "/v1"
|
||
return urlunparse(parsed._replace(path=path, params="", query="", fragment="")).rstrip("/")
|
||
|
||
@staticmethod
|
||
def _rest_endpoint(base_url: str, route: str) -> str:
|
||
"""基于 REST 根路径拼接具体路由。"""
|
||
return f"{base_url.rstrip('/')}{route}"
|
||
|
||
@staticmethod
|
||
def _supports_streaming(card: dict[str, Any]) -> bool:
|
||
"""根据 card capability 判断是否支持流式。"""
|
||
capabilities = card.get("capabilities")
|
||
if not isinstance(capabilities, dict) or "streaming" not in capabilities:
|
||
return True
|
||
streaming = capabilities.get("streaming")
|
||
if isinstance(streaming, dict):
|
||
for key in ("enabled", "supported"):
|
||
if key in streaming:
|
||
return bool(streaming.get(key))
|
||
return True
|
||
return bool(streaming)
|
||
|
||
@classmethod
|
||
def _unwrap_result_object(cls, payload: dict[str, Any]) -> dict[str, Any]:
|
||
"""从不同协议变体里提取真正的结果对象。"""
|
||
candidate: Any = payload
|
||
if isinstance(candidate, dict) and isinstance(candidate.get("result"), dict):
|
||
candidate = candidate["result"]
|
||
if not isinstance(candidate, dict):
|
||
raise A2AError("Malformed A2A response")
|
||
|
||
for key, kind in (
|
||
("task", "task"),
|
||
("message", "message"),
|
||
("statusUpdate", "status-update"),
|
||
("status_update", "status-update"),
|
||
("artifactUpdate", "artifact-update"),
|
||
("artifact_update", "artifact-update"),
|
||
):
|
||
value = candidate.get(key)
|
||
if isinstance(value, dict):
|
||
result = dict(value)
|
||
result.setdefault("kind", kind)
|
||
return result
|
||
return candidate
|
||
|
||
@staticmethod
|
||
def _is_unsupported_http_error(
|
||
exc: httpx.HTTPStatusError,
|
||
*,
|
||
allow_validation_errors: bool = False,
|
||
) -> bool:
|
||
"""判断 HTTP 错误是否应被解释为“方法不支持”。"""
|
||
status_code = exc.response.status_code
|
||
if status_code in {404, 405, 501}:
|
||
return True
|
||
if allow_validation_errors and status_code in {400, 422}:
|
||
message = exc.response.text.lower()
|
||
return "not supported" in message or "unsupported" in message
|
||
return False
|
||
|
||
@staticmethod
|
||
def _http_error_message(exc: httpx.HTTPStatusError) -> str:
|
||
"""从 HTTP 错误响应中抽取更可读的错误文本。"""
|
||
try:
|
||
payload = exc.response.json()
|
||
except json.JSONDecodeError:
|
||
payload = None
|
||
if isinstance(payload, dict):
|
||
for key in ("detail", "title", "message", "error"):
|
||
value = payload.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value.strip()
|
||
return str(exc)
|
||
|
||
async def _build_headers(self, agent: AgentDescriptor) -> dict[str, str]:
|
||
"""构造请求头,包括可选的 Bearer Token 和自定义 headers。"""
|
||
headers = {"Accept": "application/json, text/event-stream"}
|
||
auth_mode = (agent.auth_mode or "none").strip().lower()
|
||
if auth_mode == "oauth_backend_token":
|
||
headers["Authorization"] = f"Bearer {await self._issue_backend_token(agent)}"
|
||
else:
|
||
token = os.environ.get(agent.auth_env or "")
|
||
if token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
extra = agent.metadata.get("headers")
|
||
if isinstance(extra, dict):
|
||
for key, value in extra.items():
|
||
if key and isinstance(value, str):
|
||
headers[key] = value
|
||
return headers
|
||
|
||
async def _issue_backend_token(self, agent: AgentDescriptor) -> str:
|
||
from nanobot.authz.client import AuthzClient
|
||
|
||
authz_base_url = str(getattr(self.authz_config, "base_url", "") or "").strip()
|
||
client_id = str(getattr(self.backend_identity, "client_id", "") or "").strip()
|
||
client_secret = str(getattr(self.backend_identity, "client_secret", "") or "").strip()
|
||
if not authz_base_url or not client_id or not client_secret:
|
||
raise A2AError(
|
||
f"A2A agent '{agent.id}' requires AuthZ backend tokens, but authz/backend identity is incomplete"
|
||
)
|
||
|
||
audience = str(agent.auth_audience or agent.id).strip()
|
||
if not audience:
|
||
raise A2AError(f"A2A agent '{agent.id}' is missing auth audience")
|
||
if not audience.startswith("a2a:"):
|
||
audience = f"a2a:{audience}"
|
||
|
||
scopes = [scope for scope in agent.auth_scopes if scope]
|
||
if not scopes:
|
||
scopes = ["run_task"]
|
||
|
||
authz_client = AuthzClient(
|
||
authz_base_url,
|
||
timeout_seconds=int(getattr(self.authz_config, "request_timeout_seconds", 10)),
|
||
)
|
||
token_response = await authz_client.issue_token(
|
||
client_id=client_id,
|
||
client_secret=client_secret,
|
||
audience=audience,
|
||
scopes=scopes,
|
||
)
|
||
access_token = str(token_response.get("access_token") or "").strip()
|
||
if not access_token:
|
||
raise A2AError(f"A2A agent '{agent.id}' did not receive an access token from AuthZ")
|
||
return access_token
|
||
|
||
def _check_allowed_host(self, url: str) -> None:
|
||
"""在配置了白名单时校验远端 host 是否允许访问。"""
|
||
if not self.allowed_hosts:
|
||
return
|
||
host = (urlparse(url).hostname or "").lower()
|
||
if host not in self.allowed_hosts:
|
||
raise A2AError(f"Host '{host}' is not allowed for A2A access")
|
||
|
||
@staticmethod
|
||
def _build_message_params(task: str, label: str | None) -> dict[str, Any]:
|
||
"""把委派任务包装成 A2A 标准 message 参数。"""
|
||
message = {
|
||
"messageId": str(uuid.uuid4()),
|
||
"role": "user",
|
||
"parts": [{"type": "text", "kind": "text", "text": task}],
|
||
}
|
||
if label:
|
||
message["metadata"] = {"label": label}
|
||
return {"message": message}
|
||
|
||
@classmethod
|
||
def _build_rest_payload(cls, params: dict[str, Any]) -> dict[str, Any]:
|
||
"""把通用 message 参数转换成 REST 风格 payload。"""
|
||
payload = json.loads(json.dumps(params))
|
||
message = payload.get("message")
|
||
if not isinstance(message, dict):
|
||
return payload
|
||
|
||
# 某些 REST 实现要求 role 使用枚举字面量,而不是自由字符串。
|
||
role = str(message.get("role") or "").strip().lower()
|
||
if role == "user":
|
||
message["role"] = "ROLE_USER"
|
||
elif role == "agent":
|
||
message["role"] = "ROLE_AGENT"
|
||
|
||
parts = message.pop("parts", None)
|
||
if isinstance(parts, list) and "content" not in message:
|
||
# REST 风格通常把 `parts` 拍平成 `content` 数组。
|
||
content: list[dict[str, Any]] = []
|
||
for part in parts:
|
||
if not isinstance(part, dict):
|
||
continue
|
||
if "text" in part:
|
||
content.append({"text": part["text"]})
|
||
continue
|
||
if "file" in part:
|
||
content.append({"file": part["file"]})
|
||
continue
|
||
if "data" in part:
|
||
content.append({"data": part["data"]})
|
||
message["content"] = content
|
||
|
||
return payload
|
||
|
||
def _build_run_result(
|
||
self,
|
||
agent: AgentDescriptor,
|
||
result: dict[str, Any],
|
||
state: _StreamState | None = None,
|
||
) -> AgentRunResult:
|
||
"""把远端结果对象转换成统一的 AgentRunResult。"""
|
||
summary = ""
|
||
if state:
|
||
summary = state.build_summary(self)
|
||
if not summary:
|
||
summary = self._extract_text(result) or json.dumps(result, ensure_ascii=False)
|
||
status = self._normalize_status(result.get("status"))
|
||
if not has_meaningful_summary(summary):
|
||
status = "error"
|
||
return AgentRunResult(
|
||
agent_id=agent.id,
|
||
agent_name=agent.name,
|
||
status=status,
|
||
summary=summary,
|
||
raw=result,
|
||
)
|
||
|
||
@staticmethod
|
||
def _is_task_result(result: dict[str, Any]) -> bool:
|
||
"""判断返回值是否表示一个 task 对象。"""
|
||
if "status" in result:
|
||
return True
|
||
kind = str(result.get("kind") or "").lower()
|
||
return kind == "task"
|
||
|
||
@staticmethod
|
||
def _is_terminal_status(status: Any) -> bool:
|
||
"""判断状态是否已进入终态。"""
|
||
state = A2AClient._normalized_state_token(status)
|
||
return state in {"completed", "complete", "failed", "error", "cancelled", "canceled"}
|
||
|
||
@staticmethod
|
||
def _normalize_status(status: Any) -> str:
|
||
"""把五花八门的远端状态名归一化。"""
|
||
state = A2AClient._normalized_state_token(status or "ok")
|
||
if state in {"", "completed", "complete", "success", "ok"}:
|
||
return "ok"
|
||
if state in {"working", "running", "in_progress", "submitted", "queued"}:
|
||
return "working"
|
||
if state in {"failed", "error"}:
|
||
return "error"
|
||
if state in {"cancelled", "canceled"}:
|
||
return "cancelled"
|
||
return state
|
||
|
||
@staticmethod
|
||
def _normalized_state_token(status: Any) -> str:
|
||
"""抽取状态里的核心 token,例如去掉 `TASK_STATE_` 前缀。"""
|
||
if isinstance(status, dict):
|
||
state = str(status.get("state") or status.get("status") or "")
|
||
else:
|
||
state = str(status or "")
|
||
state = state.strip().lower()
|
||
if state.startswith("task_state_"):
|
||
state = state[len("task_state_") :]
|
||
return state
|
||
|
||
@classmethod
|
||
def _extract_text(cls, payload: Any) -> str:
|
||
"""从嵌套对象里尽可能提取最有价值的文本内容。"""
|
||
if isinstance(payload, str):
|
||
return payload
|
||
if isinstance(payload, dict):
|
||
for key in ("text", "content", "summary", "finalText", "final_text"):
|
||
value = payload.get(key)
|
||
text = cls._extract_text(value)
|
||
if text:
|
||
return text
|
||
for key in (
|
||
"message",
|
||
"artifact",
|
||
"artifacts",
|
||
"messages",
|
||
"parts",
|
||
"output",
|
||
"toolResults",
|
||
"tool_results",
|
||
"task",
|
||
"statusUpdate",
|
||
"status_update",
|
||
"artifactUpdate",
|
||
"artifact_update",
|
||
"history",
|
||
):
|
||
value = payload.get(key)
|
||
text = cls._extract_text(value)
|
||
if text:
|
||
return text
|
||
if "status" in payload and isinstance(payload["status"], dict):
|
||
text = cls._extract_text(payload["status"])
|
||
if text:
|
||
return text
|
||
return ""
|
||
if isinstance(payload, list):
|
||
parts = [cls._extract_text(item) for item in payload]
|
||
return "\n".join(part for part in parts if part)
|
||
return ""
|