"""A2A 客户端实现。 目标不是完整覆盖所有厂商变体,而是提供一条足够稳的兼容链路: 1. 先拉 agent card,解析可用端点和偏好传输; 2. 优先尝试流式订阅,拿到实时进度; 3. 流式不可用或中断时,回退到轮询; 4. 同时兼容 JSON-RPC 和 HTTP+JSON 风格接口。 """ from __future__ import annotations import asyncio import json import os import time import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any from urllib.parse import urlparse, urlunparse import httpx from nanobot.agent.agent_registry import AgentDescriptor from nanobot.agent.run_result import AgentRunResult, has_meaningful_summary class A2AError(RuntimeError): """A2A 请求失败时抛出的统一异常。""" class A2AUnsupportedMethodError(A2AError): """远端端点不支持某个方法时抛出的异常。""" @dataclass class A2AStreamEvent: """A2A 订阅流事件的归一化表示。""" # 事件类型,例如 task / message / status-update / artifact-update。 kind: str # 远端任务 ID;一旦出现,上层就可以登记用于取消或恢复订阅。 task_id: str | None = None # 归一化后的状态文本。 status: str | None = None # 适合展示给用户的增量文本。 text: str | None = None # 是否已到达终态。 final: bool = False # 原始事件体,便于调试和后续扩展。 raw: dict[str, Any] | None = None @dataclass class _StreamState: """流式任务状态累加器。""" task_id: str | None = None status: str = "working" artifacts: dict[str, str] = field(default_factory=dict) artifact_order: list[str] = field(default_factory=list) messages: list[str] = field(default_factory=list) status_messages: list[str] = field(default_factory=list) latest_result: dict[str, Any] | None = None final_seen: bool = False def apply(self, result: dict[str, Any], client: A2AClient) -> A2AStreamEvent: """吸收一条原始结果并产出归一化流事件。""" self.latest_result = result kind = str(result.get("kind") or result.get("type") or "result").lower() task_id = str(result.get("id") or result.get("taskId") or "").strip() or None if task_id: self.task_id = task_id raw_status = result.get("status") normalized_status = client._normalize_status(raw_status) # 非 ok 状态会覆盖当前状态;否则在 task/status-update 终态时再更新。 if normalized_status not in {"", "ok"}: self.status = normalized_status elif kind in {"task", "status-update"} and client._is_terminal_status(raw_status): self.status = client._normalize_status(raw_status) text = "" if kind == "artifact-update": # artifact-update 需要增量拼接同一个 artifact 的文本内容。 text = self._apply_artifact_update(result, client) elif kind == "status-update": # 某些实现把状态消息放在 status 里,有些放在 message 里,这里都兜一遍。 text = client._extract_text(result.get("status")) or client._extract_text( result.get("message") ) self._append_unique(self.status_messages, text) elif kind in {"message", "task"}: self._apply_task_or_message(result, client) text = client._extract_text(result) else: text = client._extract_text(result) final = bool(result.get("final")) or client._is_terminal_status(raw_status) if final: self.final_seen = True if self.status == "working": # 即使没拿到更明确状态,也尽量用终态把 working 覆盖掉。 self.status = client._normalize_status(raw_status) if text and kind not in {"artifact-update", "message", "task"}: self._append_unique(self.messages, text) return A2AStreamEvent( kind=kind, task_id=self.task_id, status=self.status, text=text or None, final=final, raw=result, ) def build_summary(self, client: A2AClient) -> str: """按 artifact -> message -> status 的优先级生成最终摘要。""" artifact_text = "\n".join( self.artifacts[artifact_id] for artifact_id in self.artifact_order if self.artifacts.get(artifact_id) ).strip() if artifact_text: return artifact_text message_text = "\n".join(text for text in self.messages if text).strip() if message_text: return message_text status_text = "\n".join(text for text in self.status_messages if text).strip() if status_text: return status_text if self.latest_result: return client._extract_text(self.latest_result) return "" def _apply_artifact_update(self, result: dict[str, Any], client: A2AClient) -> str: """把一条 artifact-update 事件并入累积状态。""" artifact = result.get("artifact") if not isinstance(artifact, dict): artifact = result artifact_id = str( artifact.get("artifactId") or artifact.get("id") or result.get("artifactId") or f"artifact-{len(self.artifact_order) + 1}" ) text = client._extract_text(artifact) if not text: return "" if artifact_id not in self.artifacts: self.artifacts[artifact_id] = "" self.artifact_order.append(artifact_id) # append=true 时做增量拼接,否则视为完整覆盖。 if result.get("append") or artifact.get("append"): self.artifacts[artifact_id] += text else: self.artifacts[artifact_id] = text return text def _apply_task_or_message(self, result: dict[str, Any], client: A2AClient) -> None: """把 task/message 类型结果中的 artifact 和文本提取出来。""" artifacts = result.get("artifacts") if isinstance(artifacts, list): for index, artifact in enumerate(artifacts): if not isinstance(artifact, dict): continue artifact_id = str( artifact.get("artifactId") or artifact.get("id") or f"artifact-{len(self.artifact_order) + index + 1}" ) text = client._extract_text(artifact) if not text: continue if artifact_id not in self.artifacts: self.artifact_order.append(artifact_id) self.artifacts[artifact_id] = text text = client._extract_text(result) self._append_unique(self.messages, text) @staticmethod def _append_unique(collection: list[str], text: str) -> None: """仅当文本与上一个不同才追加,避免流式重复刷屏。""" if text and (not collection or collection[-1] != text): collection.append(text) @dataclass(frozen=True) class _A2ATransportTarget: """解析后的远端传输目标。""" mode: str endpoint: str class A2AClient: """支持 JSON-RPC 与 HTTP+JSON 回退链路的 A2A 客户端。""" def __init__( self, timeout_seconds: int = 600, poll_interval_seconds: int = 2, card_cache_ttl_seconds: int = 300, allowed_hosts: list[str] | None = None, transport: httpx.AsyncBaseTransport | None = None, authz_config: Any | None = None, backend_identity: Any | None = None, ): # 这些参数决定超时、轮询频率和安全边界。 self.timeout_seconds = timeout_seconds self.poll_interval_seconds = poll_interval_seconds self.card_cache_ttl_seconds = card_cache_ttl_seconds self.allowed_hosts = {host.lower() for host in (allowed_hosts or []) if host} self.transport = transport self.authz_config = authz_config self.backend_identity = backend_identity self._card_cache: dict[str, tuple[float, dict[str, Any]]] = {} async def run_task( self, agent: AgentDescriptor, task: str, label: str | None = None, event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None = None, task_callback: Callable[[str], Awaitable[None]] | None = None, prefer_streaming: bool = True, ) -> AgentRunResult: """执行一次远端 A2A 任务。""" card = await self.fetch_agent_card(agent) params = self._build_message_params(task, label) targets = self._resolve_transport_targets(card, agent) if not targets: raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint") last_unsupported: Exception | None = None for target in targets: try: # 若 card 支持流式,则优先尝试流式以获取中间态。 if prefer_streaming and self._supports_streaming(card): stream_result = await self._run_task_streaming( target=target, params=params, agent=agent, event_callback=event_callback, task_callback=task_callback, ) if stream_result is not None: return stream_result # 流式不可用时回退到普通 send,再视状态决定是否轮询。 result = await self._send_task(target, params, agent) if self._is_task_result(result) and not self._is_terminal_status(result.get("status")): task_id = str(result.get("id") or result.get("taskId") or "").strip() if task_id: if task_callback: await task_callback(task_id) result = await self._poll_task(target, task_id, agent) return self._build_run_result(agent, result) except A2AUnsupportedMethodError as exc: last_unsupported = exc continue if last_unsupported: raise last_unsupported raise A2AError(f"Agent '{agent.id}' does not expose a usable A2A endpoint") async def fetch_agent_card(self, agent: AgentDescriptor) -> dict[str, Any]: """拉取远端 agent card,并带本地 TTL 缓存。""" _, card = await self.fetch_agent_card_with_url(agent) return card async def fetch_agent_card_with_url(self, agent: AgentDescriptor) -> tuple[str, dict[str, Any]]: """拉取远端 agent card,并返回命中的 card URL。""" urls = self._candidate_card_urls(agent) last_error: Exception | None = None for url in urls: cache_key = url.lower() cached = self._card_cache.get(cache_key) if cached and cached[0] > time.monotonic(): return url, cached[1] try: card = await self._fetch_json(url, agent) except Exception as exc: last_error = exc continue if isinstance(card, dict): self._card_cache[cache_key] = ( time.monotonic() + self.card_cache_ttl_seconds, card, ) return url, card if last_error: raise A2AError(f"Failed to fetch agent card for '{agent.id}': {last_error}") raise A2AError(f"Failed to fetch agent card for '{agent.id}'") def invalidate_card(self, agent: AgentDescriptor) -> None: """清空某个 agent 相关的 card 缓存。""" for url in self._candidate_card_urls(agent): self._card_cache.pop(url.lower(), None) def _candidate_card_urls(self, agent: AgentDescriptor) -> list[str]: """根据 agent 配置推导一组候选 card URL。""" urls: list[str] = [] if agent.card_url: urls.append(agent.card_url) base_url = str(agent.base_url or agent.endpoint or "").rstrip("/") if base_url: urls.extend( [ f"{base_url}/.well-known/agent-card", f"{base_url}/.well-known/agent-card.json", f"{base_url}/.well-known/agent.json", ] ) deduped: list[str] = [] seen: set[str] = set() for url in urls: normalized = url.strip() if not normalized or normalized.lower() in seen: continue seen.add(normalized.lower()) deduped.append(normalized) return deduped async def _run_task_streaming( self, target: _A2ATransportTarget, params: dict[str, Any], agent: AgentDescriptor, event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None, task_callback: Callable[[str], Awaitable[None]] | None, ) -> AgentRunResult | None: """优先尝试流式方法执行任务,失败时可回退为 None。""" if target.mode == "rest": stream_variants = [("message/stream", params)] else: stream_variants = [ ("tasks/sendSubscribe", {"id": str(uuid.uuid4()), **params}), ("message/stream", params), ] saw_supported_stream = False last_error: Exception | None = None for method, payload in stream_variants: state = _StreamState() try: # 每个流式方法都独立尝试;一旦成功拿到终态结果就直接返回。 return await self._consume_stream_method( target=target, method=method, params=payload, agent=agent, event_callback=event_callback, task_callback=task_callback, state=state, allow_resume=True, ) except A2AUnsupportedMethodError as exc: last_error = exc continue except A2AError as exc: # 已经跑到一半但中断时,如果拿到了 task_id,就尝试恢复订阅或轮询。 saw_supported_stream = True last_error = exc if state.task_id: try: return await self._resume_or_poll( target=target, agent=agent, task_id=state.task_id, state=state, event_callback=event_callback, task_callback=task_callback, ) except A2AError as resume_exc: last_error = resume_exc continue else: saw_supported_stream = True if saw_supported_stream and last_error: raise last_error return None async def _consume_stream_method( self, target: _A2ATransportTarget, method: str, params: dict[str, Any], agent: AgentDescriptor, event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None, task_callback: Callable[[str], Awaitable[None]] | None, state: _StreamState, allow_resume: bool, ) -> AgentRunResult: """消费一个具体流式方法,直到拿到终态结果。""" saw_event = False seen_task_id: str | None = state.task_id try: async for body in self._stream_request(target, method, params, agent): saw_event = True result = self._unwrap_result_object(body) event = state.apply(result, self) # 首次看到 task_id 时通知上层登记,以便取消或恢复。 if task_callback and event.task_id and event.task_id != seen_task_id: seen_task_id = event.task_id await task_callback(event.task_id) if event_callback: await event_callback(event) if event.final: return self._build_run_result(agent, result, state) except (httpx.ReadError, httpx.RemoteProtocolError, httpx.TimeoutException) as exc: if not state.task_id: raise A2AError(str(exc)) from exc if state.final_seen: return self._build_run_result(agent, state.latest_result or {}, state) if allow_resume and state.task_id: # 流结束但还没终态时,尝试恢复订阅或退化成轮询。 return await self._resume_or_poll( target=target, agent=agent, task_id=state.task_id, state=state, event_callback=event_callback, task_callback=task_callback, ) if saw_event: raise A2AError("A2A stream ended before a final result was received") raise A2AUnsupportedMethodError(method) async def _resume_or_poll( self, target: _A2ATransportTarget, agent: AgentDescriptor, task_id: str, state: _StreamState, event_callback: Callable[[A2AStreamEvent], Awaitable[None]] | None, task_callback: Callable[[str], Awaitable[None]] | None, ) -> AgentRunResult: """在流式中断后尝试恢复订阅,失败则退化为轮询。""" try: return await self._consume_stream_method( target=target, method="tasks/subscribe" if target.mode == "rest" else "tasks/resubscribe", params={"id": task_id}, agent=agent, event_callback=event_callback, task_callback=task_callback, state=state, allow_resume=False, ) except A2AUnsupportedMethodError: result = await self._poll_task(target, task_id, agent) return self._build_run_result(agent, result, state) async def cancel_task(self, agent: AgentDescriptor, task_id: str) -> bool: """尽力取消一个远端 A2A 任务。""" task_id = task_id.strip() if not task_id: return False card = await self.fetch_agent_card(agent) targets = self._resolve_transport_targets(card, agent) if not targets: raise A2AError(f"Agent '{agent.id}' does not expose a supported A2A endpoint") for target in targets: try: if target.mode == "rest": # REST 风格通常使用 `/tasks/{id}:cancel`。 await self._request_json( "POST", self._rest_endpoint(target.endpoint, f"/tasks/{task_id}:cancel"), agent, json_body={"name": f"tasks/{task_id}"}, ) else: # JSON-RPC 风格使用 `tasks/cancel`。 await self._rpc_jsonrpc(target.endpoint, "tasks/cancel", {"id": task_id}, agent) return True except A2AUnsupportedMethodError: continue return False async def _send_task( self, target: _A2ATransportTarget, params: dict[str, Any], agent: AgentDescriptor, ) -> dict[str, Any]: """发送一次非流式任务请求。""" if target.mode == "rest": body = await self._request_json( "POST", self._rest_endpoint(target.endpoint, "/message:send"), agent, json_body=self._build_rest_payload(params), ) return self._unwrap_result_object(body) send_variants = [ ("tasks/send", {"id": str(uuid.uuid4()), **params}), ("message/send", params), ] last_error: Exception | None = None for method, payload in send_variants: try: response = await self._rpc_jsonrpc(target.endpoint, method, payload, agent) return self._unwrap_result_object(response) except A2AUnsupportedMethodError as exc: last_error = exc continue raise last_error or A2AError("No supported A2A send method found") async def _poll_task( self, target: _A2ATransportTarget, task_id: str, agent: AgentDescriptor, ) -> dict[str, Any]: """轮询远端 task,直到进入终态或超时。""" deadline = time.monotonic() + self.timeout_seconds while time.monotonic() < deadline: if target.mode == "rest": body = await self._request_json( "GET", self._rest_endpoint(target.endpoint, f"/tasks/{task_id}"), agent, ) result = self._unwrap_result_object(body) else: response = await self._rpc_jsonrpc(target.endpoint, "tasks/get", {"id": task_id}, agent) result = self._unwrap_result_object(response) if self._is_terminal_status(result.get("status")): return result await asyncio.sleep(self.poll_interval_seconds) raise A2AError( f"A2A task '{task_id}' timed out after {self.timeout_seconds} seconds" ) async def _fetch_json(self, url: str, agent: AgentDescriptor) -> dict[str, Any]: """以 GET 方式拉取 JSON 对象。""" self._check_allowed_host(url) async with httpx.AsyncClient( timeout=self.timeout_seconds, transport=self.transport, ) as client: response = await client.get(url, headers=await self._build_headers(agent)) response.raise_for_status() payload = response.json() if not isinstance(payload, dict): raise A2AError("Agent card response must be a JSON object") return payload async def _rpc_jsonrpc( self, endpoint: str, method: str, params: dict[str, Any], agent: AgentDescriptor, ) -> dict[str, Any]: """发送一条 JSON-RPC 请求。""" self._check_allowed_host(endpoint) payload = { "jsonrpc": "2.0", "id": str(uuid.uuid4()), "method": method, "params": params, } async with httpx.AsyncClient( timeout=self.timeout_seconds, transport=self.transport, ) as client: try: response = await client.post( endpoint, json=payload, headers=await self._build_headers(agent), ) response.raise_for_status() except httpx.HTTPStatusError as exc: if exc.response.status_code in {404, 405, 501}: raise A2AUnsupportedMethodError(method) from exc raise A2AError(str(exc)) from exc body = response.json() if not isinstance(body, dict): raise A2AError("A2A RPC response must be a JSON object") error = body.get("error") if isinstance(error, dict): code = error.get("code") message = str(error.get("message") or "unknown error") if code == -32601 or "not found" in message.lower(): raise A2AUnsupportedMethodError(message) raise A2AError(message) return body async def _stream_request( self, target: _A2ATransportTarget, method: str, params: dict[str, Any], agent: AgentDescriptor, ): """根据 transport mode 选择具体流式实现。""" if target.mode == "rest": async for body in self._stream_rest(target.endpoint, method, params, agent): yield body return async for body in self._stream_jsonrpc(target.endpoint, method, params, agent): yield body async def _stream_jsonrpc( self, endpoint: str, method: str, params: dict[str, Any], agent: AgentDescriptor, ): """通过 JSON-RPC 流式接口接收事件。""" self._check_allowed_host(endpoint) payload = { "jsonrpc": "2.0", "id": str(uuid.uuid4()), "method": method, "params": params, } async with httpx.AsyncClient( timeout=self.timeout_seconds, transport=self.transport, ) as client: try: async with client.stream( "POST", endpoint, json=payload, headers=await self._build_headers(agent), ) as response: try: response.raise_for_status() except httpx.HTTPStatusError as exc: if exc.response.status_code in {404, 405, 501}: raise A2AUnsupportedMethodError(method) from exc raise A2AError(str(exc)) from exc content_type = response.headers.get("content-type", "").lower() if "text/event-stream" in content_type: # 标准 SSE 按 event/data 行拼装;这里只消费 data 载荷。 async for raw_event in self._iter_sse_events(response): if raw_event.strip() == "[DONE]": break yield self._parse_stream_body(raw_event) else: body = await response.aread() if not body: return yield self._parse_stream_body(body.decode("utf-8")) except httpx.HTTPStatusError as exc: if exc.response.status_code in {404, 405, 501}: raise A2AUnsupportedMethodError(method) from exc raise A2AError(str(exc)) from exc async def _stream_rest( self, endpoint: str, method: str, params: dict[str, Any], agent: AgentDescriptor, ): """通过 REST 风格流式接口接收事件。""" if method == "message/stream": requests = [ ( "POST", self._rest_endpoint(endpoint, "/message:stream"), self._build_rest_payload(params), ) ] elif method == "tasks/subscribe": task_id = str(params.get("id") or "").strip() if not task_id: raise A2AError("Missing task id for REST task subscribe") subscribe_url = self._rest_endpoint(endpoint, f"/tasks/{task_id}:subscribe") requests = [("GET", subscribe_url, None), ("POST", subscribe_url, None)] else: raise A2AUnsupportedMethodError(method) last_error: Exception | None = None for http_method, url, payload in requests: self._check_allowed_host(url) async with httpx.AsyncClient( timeout=self.timeout_seconds, transport=self.transport, ) as client: try: async with client.stream( http_method, url, json=payload, headers=await self._build_headers(agent), ) as response: try: response.raise_for_status() except httpx.HTTPStatusError as exc: if self._is_unsupported_http_error(exc, allow_validation_errors=True): raise A2AUnsupportedMethodError(method) from exc raise A2AError(self._http_error_message(exc)) from exc content_type = response.headers.get("content-type", "").lower() if "text/event-stream" in content_type: async for raw_event in self._iter_sse_events(response): if raw_event.strip() == "[DONE]": break yield self._parse_stream_body(raw_event) return body = await response.aread() if not body: return yield self._parse_stream_body(body.decode("utf-8")) return except A2AUnsupportedMethodError as exc: last_error = exc continue raise last_error or A2AUnsupportedMethodError(method) async def _request_json( self, http_method: str, url: str, agent: AgentDescriptor, *, json_body: dict[str, Any] | None = None, params: dict[str, str] | None = None, ) -> dict[str, Any]: """发送普通 HTTP JSON 请求。""" self._check_allowed_host(url) async with httpx.AsyncClient( timeout=self.timeout_seconds, transport=self.transport, ) as client: try: response = await client.request( http_method, url, json=json_body, params=params, headers=await self._build_headers(agent), ) response.raise_for_status() except httpx.HTTPStatusError as exc: if self._is_unsupported_http_error(exc): raise A2AUnsupportedMethodError(url) from exc raise A2AError(self._http_error_message(exc)) from exc try: body = response.json() except json.JSONDecodeError as exc: raise A2AError(f"Invalid JSON response from {url}") from exc if not isinstance(body, dict): raise A2AError("A2A response must be a JSON object") return body async def _iter_sse_events(self, response: httpx.Response): """把 SSE 响应流按事件边界还原为 data 文本块。""" data_lines: list[str] = [] async for line in response.aiter_lines(): if line == "": if data_lines: yield "\n".join(data_lines) data_lines = [] continue if line.startswith(":"): continue field, _, value = line.partition(":") value = value.lstrip() if field == "data": data_lines.append(value) if data_lines: yield "\n".join(data_lines) @staticmethod def _parse_stream_body(raw_event: str) -> dict[str, Any]: """解析单条流事件 JSON,并统一处理远端 error 对象。""" try: body = json.loads(raw_event) except json.JSONDecodeError as exc: raise A2AError(f"Invalid A2A stream payload: {raw_event}") from exc if not isinstance(body, dict): raise A2AError("A2A stream payload must be a JSON object") error = body.get("error") if isinstance(error, dict): code = error.get("code") message = str(error.get("message") or "unknown error") if code == -32601 or "not found" in message.lower(): raise A2AUnsupportedMethodError(message) raise A2AError(message) return body def _resolve_transport_targets( self, card: dict[str, Any], agent: AgentDescriptor, ) -> list[_A2ATransportTarget]: """根据 card 和本地配置解析一组可尝试的传输目标。""" default_url = self._resolve_primary_url(card, agent) declared = self._collect_declared_interfaces(card) preferred = self._normalize_transport( card.get("preferred_transport") or card.get("preferredTransport") ) candidates: list[_A2ATransportTarget] = [] if preferred in {"jsonrpc", "rest"}: preferred_url = declared.get(preferred) or default_url if preferred_url: candidates.append(self._transport_target(preferred, preferred_url)) for mode in ("jsonrpc", "rest"): url = declared.get(mode) if url: candidates.append(self._transport_target(mode, url)) if default_url: if preferred not in {"jsonrpc", "rest"}: candidates.append(self._transport_target("jsonrpc", default_url)) candidates.append(self._transport_target("rest", default_url)) elif preferred == "jsonrpc": candidates.append(self._transport_target("rest", default_url)) else: candidates.append(self._transport_target("jsonrpc", default_url)) deduped: list[_A2ATransportTarget] = [] seen: set[tuple[str, str]] = set() for target in candidates: # 同一个 mode + endpoint 只保留一次,避免重复尝试。 key = (target.mode, target.endpoint.rstrip("/").lower()) if key in seen: continue seen.add(key) deduped.append(target) return deduped def _collect_declared_interfaces(self, card: dict[str, Any]) -> dict[str, str]: """从 card 的 interfaces 字段里提取声明过的 transport/url。""" interfaces = None for key in ( "additional_interfaces", "additionalInterfaces", "interfaces", "supported_interfaces", "supportedInterfaces", ): candidate = card.get(key) if isinstance(candidate, list): interfaces = candidate break result: dict[str, str] = {} if not isinstance(interfaces, list): return result for item in interfaces: if not isinstance(item, dict): continue mode = self._normalize_transport(item.get("transport")) url = str(item.get("url") or "").strip() if mode in {"jsonrpc", "rest"} and url: result.setdefault(mode, url) return result def _resolve_primary_url(self, card: dict[str, Any], agent: AgentDescriptor) -> str: """解析 card 的主 URL;当 card 返回 0.0.0.0 时退回本地配置。""" card_url = str(card.get("url") or "").strip() fallback = str(agent.endpoint or agent.base_url or "").strip() if card_url and (urlparse(card_url).hostname or "").strip() not in {"0.0.0.0", "::"}: return card_url return fallback or card_url def _transport_target(self, mode: str, url: str) -> _A2ATransportTarget: """构造标准化的 transport target。""" normalized_url = url.strip() if mode == "rest": normalized_url = self._normalize_rest_base_url(normalized_url) return _A2ATransportTarget(mode=mode, endpoint=normalized_url) @staticmethod def _normalize_transport(value: Any) -> str | None: """把不同命名风格的 transport 文本归一化。""" text = str(value or "").strip().lower() if not text: return None if text in {"jsonrpc", "json-rpc"}: return "jsonrpc" if text in {"http+json", "http-json", "http_json", "rest"}: return "rest" if text == "grpc": return "grpc" return None @staticmethod def _normalize_rest_base_url(url: str) -> str: """把各种 REST 端点变体规整到 `/v1` 根路径。""" parsed = urlparse(url) path = parsed.path.rstrip("/") for suffix in ("/message:send", "/message:stream"): if path.endswith(suffix): path = path[: -len(suffix)] break if "/tasks/" in path and (path.endswith(":cancel") or path.endswith(":subscribe")): path = path.split("/tasks/", 1)[0] if not path.endswith("/v1"): path = f"{path}/v1" if path else "/v1" return urlunparse(parsed._replace(path=path, params="", query="", fragment="")).rstrip("/") @staticmethod def _rest_endpoint(base_url: str, route: str) -> str: """基于 REST 根路径拼接具体路由。""" return f"{base_url.rstrip('/')}{route}" @staticmethod def _supports_streaming(card: dict[str, Any]) -> bool: """根据 card capability 判断是否支持流式。""" capabilities = card.get("capabilities") if not isinstance(capabilities, dict) or "streaming" not in capabilities: return True streaming = capabilities.get("streaming") if isinstance(streaming, dict): for key in ("enabled", "supported"): if key in streaming: return bool(streaming.get(key)) return True return bool(streaming) @classmethod def _unwrap_result_object(cls, payload: dict[str, Any]) -> dict[str, Any]: """从不同协议变体里提取真正的结果对象。""" candidate: Any = payload if isinstance(candidate, dict) and isinstance(candidate.get("result"), dict): candidate = candidate["result"] if not isinstance(candidate, dict): raise A2AError("Malformed A2A response") for key, kind in ( ("task", "task"), ("message", "message"), ("statusUpdate", "status-update"), ("status_update", "status-update"), ("artifactUpdate", "artifact-update"), ("artifact_update", "artifact-update"), ): value = candidate.get(key) if isinstance(value, dict): result = dict(value) result.setdefault("kind", kind) return result return candidate @staticmethod def _is_unsupported_http_error( exc: httpx.HTTPStatusError, *, allow_validation_errors: bool = False, ) -> bool: """判断 HTTP 错误是否应被解释为“方法不支持”。""" status_code = exc.response.status_code if status_code in {404, 405, 501}: return True if allow_validation_errors and status_code in {400, 422}: message = exc.response.text.lower() return "not supported" in message or "unsupported" in message return False @staticmethod def _http_error_message(exc: httpx.HTTPStatusError) -> str: """从 HTTP 错误响应中抽取更可读的错误文本。""" try: payload = exc.response.json() except json.JSONDecodeError: payload = None if isinstance(payload, dict): for key in ("detail", "title", "message", "error"): value = payload.get(key) if isinstance(value, str) and value.strip(): return value.strip() return str(exc) async def _build_headers(self, agent: AgentDescriptor) -> dict[str, str]: """构造请求头,包括可选的 Bearer Token 和自定义 headers。""" headers = {"Accept": "application/json, text/event-stream"} auth_mode = (agent.auth_mode or "none").strip().lower() if auth_mode == "oauth_backend_token": headers["Authorization"] = f"Bearer {await self._issue_backend_token(agent)}" else: token = os.environ.get(agent.auth_env or "") if token: headers["Authorization"] = f"Bearer {token}" extra = agent.metadata.get("headers") if isinstance(extra, dict): for key, value in extra.items(): if key and isinstance(value, str): headers[key] = value return headers async def _issue_backend_token(self, agent: AgentDescriptor) -> str: from nanobot.authz.client import AuthzClient authz_base_url = str(getattr(self.authz_config, "base_url", "") or "").strip() client_id = str(getattr(self.backend_identity, "client_id", "") or "").strip() client_secret = str(getattr(self.backend_identity, "client_secret", "") or "").strip() if not authz_base_url or not client_id or not client_secret: raise A2AError( f"A2A agent '{agent.id}' requires AuthZ backend tokens, but authz/backend identity is incomplete" ) audience = str(agent.auth_audience or agent.id).strip() if not audience: raise A2AError(f"A2A agent '{agent.id}' is missing auth audience") if not audience.startswith("a2a:"): audience = f"a2a:{audience}" scopes = [scope for scope in agent.auth_scopes if scope] if not scopes: scopes = ["run_task"] authz_client = AuthzClient( authz_base_url, timeout_seconds=int(getattr(self.authz_config, "request_timeout_seconds", 10)), ) token_response = await authz_client.issue_token( client_id=client_id, client_secret=client_secret, audience=audience, scopes=scopes, ) access_token = str(token_response.get("access_token") or "").strip() if not access_token: raise A2AError(f"A2A agent '{agent.id}' did not receive an access token from AuthZ") return access_token def _check_allowed_host(self, url: str) -> None: """在配置了白名单时校验远端 host 是否允许访问。""" if not self.allowed_hosts: return host = (urlparse(url).hostname or "").lower() if host not in self.allowed_hosts: raise A2AError(f"Host '{host}' is not allowed for A2A access") @staticmethod def _build_message_params(task: str, label: str | None) -> dict[str, Any]: """把委派任务包装成 A2A 标准 message 参数。""" message = { "messageId": str(uuid.uuid4()), "role": "user", "parts": [{"type": "text", "kind": "text", "text": task}], } if label: message["metadata"] = {"label": label} return {"message": message} @classmethod def _build_rest_payload(cls, params: dict[str, Any]) -> dict[str, Any]: """把通用 message 参数转换成 REST 风格 payload。""" payload = json.loads(json.dumps(params)) message = payload.get("message") if not isinstance(message, dict): return payload # 某些 REST 实现要求 role 使用枚举字面量,而不是自由字符串。 role = str(message.get("role") or "").strip().lower() if role == "user": message["role"] = "ROLE_USER" elif role == "agent": message["role"] = "ROLE_AGENT" parts = message.pop("parts", None) if isinstance(parts, list) and "content" not in message: # REST 风格通常把 `parts` 拍平成 `content` 数组。 content: list[dict[str, Any]] = [] for part in parts: if not isinstance(part, dict): continue if "text" in part: content.append({"text": part["text"]}) continue if "file" in part: content.append({"file": part["file"]}) continue if "data" in part: content.append({"data": part["data"]}) message["content"] = content return payload def _build_run_result( self, agent: AgentDescriptor, result: dict[str, Any], state: _StreamState | None = None, ) -> AgentRunResult: """把远端结果对象转换成统一的 AgentRunResult。""" summary = "" if state: summary = state.build_summary(self) if not summary: summary = self._extract_text(result) or json.dumps(result, ensure_ascii=False) status = self._normalize_status(result.get("status")) if not has_meaningful_summary(summary): status = "error" return AgentRunResult( agent_id=agent.id, agent_name=agent.name, status=status, summary=summary, raw=result, ) @staticmethod def _is_task_result(result: dict[str, Any]) -> bool: """判断返回值是否表示一个 task 对象。""" if "status" in result: return True kind = str(result.get("kind") or "").lower() return kind == "task" @staticmethod def _is_terminal_status(status: Any) -> bool: """判断状态是否已进入终态。""" state = A2AClient._normalized_state_token(status) return state in {"completed", "complete", "failed", "error", "cancelled", "canceled"} @staticmethod def _normalize_status(status: Any) -> str: """把五花八门的远端状态名归一化。""" state = A2AClient._normalized_state_token(status or "ok") if state in {"", "completed", "complete", "success", "ok"}: return "ok" if state in {"working", "running", "in_progress", "submitted", "queued"}: return "working" if state in {"failed", "error"}: return "error" if state in {"cancelled", "canceled"}: return "cancelled" return state @staticmethod def _normalized_state_token(status: Any) -> str: """抽取状态里的核心 token,例如去掉 `TASK_STATE_` 前缀。""" if isinstance(status, dict): state = str(status.get("state") or status.get("status") or "") else: state = str(status or "") state = state.strip().lower() if state.startswith("task_state_"): state = state[len("task_state_") :] return state @classmethod def _extract_text(cls, payload: Any) -> str: """从嵌套对象里尽可能提取最有价值的文本内容。""" if isinstance(payload, str): return payload if isinstance(payload, dict): for key in ("text", "content", "summary", "finalText", "final_text"): value = payload.get(key) text = cls._extract_text(value) if text: return text for key in ( "message", "artifact", "artifacts", "messages", "parts", "output", "toolResults", "tool_results", "task", "statusUpdate", "status_update", "artifactUpdate", "artifact_update", "history", ): value = payload.get(key) text = cls._extract_text(value) if text: return text if "status" in payload and isinstance(payload["status"], dict): text = cls._extract_text(payload["status"]) if text: return text return "" if isinstance(payload, list): parts = [cls._extract_text(item) for item in payload] return "\n".join(part for part in parts if part) return ""