diff --git a/.env.example b/.env.example index 8f023d6..a159380 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,9 @@ CUSTOM_ASR_HOTWORDS= CUSTOM_ASR_ITN= CUSTOM_ASR_CHUNK_MODE= +# LLM backend: openai/openai-compatible or hermes_gateway/openclaw. +CUSTOM_LLM_PROVIDER=openai + # OpenAI-compatible LLM # CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1 # CUSTOM_LLM_MODEL=Qwen3.6-35B @@ -28,6 +31,15 @@ CUSTOM_SAVE_MODEL_IMAGES=false # CUSTOM_TEXT_LLM_MODEL= # CUSTOM_VISION_LLM_MODEL= +# Hermes Agent via OpenClaw Gateway WebSocket, one Gateway session per LiveKit room. +# CUSTOM_LLM_PROVIDER=hermes_gateway +# CUSTOM_HERMES_GATEWAY_URL=ws://localhost:1977/ws +# CUSTOM_HERMES_AGENT_ID= +# CUSTOM_HERMES_API_KEY= +# CUSTOM_HERMES_SESSION_MODE=per_room +# CUSTOM_HERMES_MODEL=hermes-agent +# CUSTOM_HERMES_REQUEST_TIMEOUT=30 + # CUSTOM_LLM_BASE_URL=https://api.deepseek.com # CUSTOM_LLM_MODEL=deepseek-v4-flash # CUSTOM_LLM_API_KEY= diff --git a/custom_agent.py b/custom_agent.py index bb98a1e..d6e9ecd 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from pathlib import Path from dotenv import load_dotenv +from hermes_gateway import GatewaySessionState, HermesGatewayLLM from memory import MemoryRecallClient from tts import BlackboxTTS @@ -638,12 +639,20 @@ async def entrypoint(ctx: JobContext) -> None: LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") + LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower() TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) - if not LLM_API_KEY: + if LLM_PROVIDER not in {"openai", "openai-compatible", "hermes", "hermes_gateway", "openclaw"}: + raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}") + if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") - logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default") + logger.info( + "Using LLM provider=%s model=%s base_url=%s", + LLM_PROVIDER, + LLM_MODEL, + LLM_BASE_URL or "OpenAI default", + ) TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox" @@ -668,38 +677,75 @@ async def entrypoint(ctx: JobContext) -> None: ) stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) - import httpx - from openai import AsyncClient as OpenAIAsyncClient + if LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}: + gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip() + if not gateway_url: + raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}") - # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. - http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) - - if LLM_BASE_URL: - openai_client = OpenAIAsyncClient( - api_key=LLM_API_KEY, - base_url=LLM_BASE_URL, - http_client=http_client, + hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None + hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower() + if hermes_session_mode != "per_room": + raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room") + hermes_token = ( + os.getenv("CUSTOM_HERMES_API_KEY") + or os.getenv("CUSTOM_HERMES_TOKEN") + or LLM_API_KEY + or None + ) + hermes_state = GatewaySessionState( + room_name=ctx.room.name, + agent_id=hermes_agent_id, + session_mode=hermes_session_mode, + ) + base_llm = HermesGatewayLLM( + url=gateway_url, + token=hermes_token, + state=hermes_state, + agent_id=hermes_agent_id, + model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"), + request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0), + ) + text_llm = base_llm + vision_llm = base_llm + logger.info( + "Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s", + gateway_url, + hermes_agent_id or "default", + hermes_state.session_key, ) else: - openai_client = OpenAIAsyncClient( - api_key=LLM_API_KEY, - http_client=http_client, - ) + import httpx + from openai import AsyncClient as OpenAIAsyncClient - base_llm = openai.LLM( - model=LLM_MODEL, - client=openai_client, - ) - text_llm = ( - openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) - if TEXT_LLM_MODEL != LLM_MODEL - else base_llm - ) - vision_llm = ( - openai.LLM(model=VISION_LLM_MODEL, client=openai_client) - if VISION_LLM_MODEL != LLM_MODEL - else base_llm - ) + # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. + http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) + + if LLM_BASE_URL: + openai_client = OpenAIAsyncClient( + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, + http_client=http_client, + ) + else: + openai_client = OpenAIAsyncClient( + api_key=LLM_API_KEY, + http_client=http_client, + ) + + base_llm = openai.LLM( + model=LLM_MODEL, + client=openai_client, + ) + text_llm = ( + openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) + if TEXT_LLM_MODEL != LLM_MODEL + else base_llm + ) + vision_llm = ( + openai.LLM(model=VISION_LLM_MODEL, client=openai_client) + if VISION_LLM_MODEL != LLM_MODEL + else base_llm + ) vision_store = VisionFrameStore( max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0) ) @@ -707,7 +753,7 @@ async def entrypoint(ctx: JobContext) -> None: session: AgentSession = AgentSession( # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, - # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. + # 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway. llm=base_llm, # 3. TTS blackbox tts=BlackboxTTS( diff --git a/hermes_gateway.py b/hermes_gateway.py new file mode 100644 index 0000000..f11b121 --- /dev/null +++ b/hermes_gateway.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import Any + +import aiohttp + +from livekit.agents import llm +from livekit.agents._exceptions import APIConnectionError +from livekit.agents.types import ( + DEFAULT_API_CONNECT_OPTIONS, + NOT_GIVEN, + APIConnectOptions, + NotGivenOr, +) +from livekit.agents.utils import shortuuid + + +@dataclass +class GatewaySessionState: + room_name: str + agent_id: str | None = None + session_key: str | None = None + session_mode: str = "per_room" + + def __post_init__(self) -> None: + if self.session_mode != "per_room": + raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room") + if self.session_key is None: + suffix = self.agent_id or "default" + self.session_key = f"livekit:{self.room_name}:{suffix}" + + +class HermesGatewayLLM(llm.LLM): + def __init__( + self, + *, + url: str, + token: str | None, + state: GatewaySessionState, + agent_id: str | None = None, + model_name: str = "hermes-agent", + request_timeout: float = 30.0, + ) -> None: + super().__init__() + self._url = url + self._token = token + self._state = state + self._agent_id = agent_id + self._model_name = model_name + self._request_timeout = request_timeout + self._http_session: aiohttp.ClientSession | None = None + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "hermes-gateway" + + def chat( + self, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool] | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, + tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN, + extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + ) -> llm.LLMStream: + return HermesGatewayLLMStream( + self, + chat_ctx=chat_ctx, + tools=tools or [], + conn_options=conn_options, + ) + + def _ensure_http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + timeout = aiohttp.ClientTimeout(total=self._request_timeout) + self._http_session = aiohttp.ClientSession(timeout=timeout) + return self._http_session + + async def aclose(self) -> None: + if self._http_session is not None: + await self._http_session.close() + self._http_session = None + + +class HermesGatewayLLMStream(llm.LLMStream): + def __init__( + self, + llm: HermesGatewayLLM, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool], + conn_options: APIConnectOptions, + ) -> None: + super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) + self._llm = llm + + async def _run(self) -> None: + request_id = shortuuid("gwreq_") + async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws: + await _connect(ws, token=self._llm._token) + await _send_rpc( + ws, + method="sessions.create", + params={ + "key": self._llm._state.session_key, + "sessionKey": self._llm._state.session_key, + "agentId": self._llm._agent_id, + "metadata": { + "source": "livekit", + "room": self._llm._state.room_name, + }, + "idempotencyKey": self._llm._state.session_key, + }, + request_id=shortuuid("gwcreate_"), + ) + await _send_rpc( + ws, + method="sessions.send", + params={ + "key": self._llm._state.session_key, + "sessionKey": self._llm._state.session_key, + "agentId": self._llm._agent_id, + "messages": chat_context_to_gateway_messages(self.chat_ctx), + "stream": True, + "idempotencyKey": request_id, + }, + request_id=request_id, + wait_response=False, + ) + + streamed_text = "" + async for frame in _iter_gateway_frames(ws): + if is_error_response(frame, request_id=request_id): + raise APIConnectionError(_gateway_error_message(frame), retryable=False) + text = extract_text_delta(frame) + if text: + if text.startswith(streamed_text): + text = text[len(streamed_text) :] + if text: + streamed_text += text + self._event_ch.send_nowait( + llm.ChatChunk( + id=request_id, + delta=llm.ChoiceDelta(role="assistant", content=text), + ) + ) + if is_terminal_event(frame, request_id=request_id): + return + + +def build_connect_params(*, token: str | None) -> dict[str, Any]: + params: dict[str, Any] = { + "minProtocol": 3, + "maxProtocol": 4, + "client": { + "id": "gateway-client", + "version": "livekit-custom-agent", + "platform": "python", + "mode": "backend", + }, + "role": "operator", + "scopes": ["operator.read", "operator.write"], + "caps": [], + "commands": [], + "permissions": {}, + "locale": "zh-CN", + "userAgent": "livekit-custom-agent", + } + if token: + params["auth"] = {"token": token} + return params + + +def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + for message in chat_ctx.messages(): + content = _message_content_to_gateway_content(message.content) + if content is None: + continue + messages.append({"role": message.role, "content": content}) + return messages + + +def extract_text_delta(frame: dict[str, Any]) -> str: + payload = frame.get("payload") + if not isinstance(payload, dict): + payload = frame + + for path in ( + ("delta", "content"), + ("delta", "text"), + ("message", "delta", "content"), + ("message", "delta", "text"), + ("message", "content"), + ("content",), + ("text",), + ): + value = _get_nested(payload, path) + text = _content_to_text(value) + if text: + return text + + return "" + + +def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool: + if frame.get("type") == "res" and frame.get("id") == request_id: + return True + + event = frame.get("event") + if event in { + "agent.done", + "agent.completed", + "agent.error", + "session.message.completed", + "session.run.completed", + "sessions.run.completed", + "run.completed", + "run.failed", + }: + return True + + payload = frame.get("payload") + if isinstance(payload, dict) and payload.get("done") is True: + return True + + return False + + +def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool: + if ( + frame.get("type") == "res" + and frame.get("id") == request_id + and frame.get("ok") is False + ): + return True + + return frame.get("event") in { + "agent.error", + "session.error", + "session.run.failed", + "sessions.run.failed", + "run.failed", + } + + +def _gateway_error_message(frame: dict[str, Any]) -> str: + error = frame.get("error") + if isinstance(error, str): + return f"OpenClaw gateway request failed: {error}" + if isinstance(error, dict): + message = error.get("message") or error.get("error") + if isinstance(message, str): + return f"OpenClaw gateway request failed: {message}" + + payload = frame.get("payload") + if isinstance(payload, dict): + message = payload.get("message") or payload.get("error") + if isinstance(message, str): + return f"OpenClaw gateway request failed: {message}" + + return f"OpenClaw gateway request failed: {frame!r}" + + +async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None: + first = await _receive_json(ws) + if first.get("event") != "connect.challenge": + raise RuntimeError(f"expected connect.challenge, received {first!r}") + + request_id = shortuuid("gwconnect_") + await _send_rpc( + ws, + method="connect", + params=build_connect_params(token=token), + request_id=request_id, + wait_response=False, + ) + response = await _wait_for_response(ws, request_id=request_id) + if not response.get("ok"): + raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}") + + +async def _send_rpc( + ws: aiohttp.ClientWebSocketResponse, + *, + method: str, + params: dict[str, Any], + request_id: str, + wait_response: bool = True, +) -> dict[str, Any] | None: + await ws.send_str( + json.dumps( + { + "type": "req", + "id": request_id, + "method": method, + "params": _drop_none(params), + } + ) + ) + if not wait_response: + return None + response = await _wait_for_response(ws, request_id=request_id) + if not response.get("ok", False): + raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}") + return response + + +async def _wait_for_response( + ws: aiohttp.ClientWebSocketResponse, *, request_id: str +) -> dict[str, Any]: + async for frame in _iter_gateway_frames(ws): + if frame.get("type") == "res" and frame.get("id") == request_id: + return frame + raise RuntimeError(f"OpenClaw gateway closed before response {request_id}") + + +async def _iter_gateway_frames( + ws: aiohttp.ClientWebSocketResponse, +) -> AsyncIterator[dict[str, Any]]: + async for message in ws: + if message.type == aiohttp.WSMsgType.TEXT: + yield json.loads(message.data) + elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): + return + elif message.type == aiohttp.WSMsgType.ERROR: + raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}") + + +async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]: + message = await ws.receive() + if message.type != aiohttp.WSMsgType.TEXT: + raise RuntimeError(f"expected gateway text frame, received {message.type!r}") + return json.loads(message.data) + + +def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any: + parts: list[dict[str, Any]] = [] + for item in content: + if isinstance(item, str): + if item: + parts.append({"type": "text", "text": item}) + elif isinstance(item, llm.ImageContent) and isinstance(item.image, str): + parts.append({"type": "image_url", "image_url": {"url": item.image}}) + + if not parts: + return None + if len(parts) == 1 and parts[0]["type"] == "text": + return parts[0]["text"] + return parts + + +def _content_to_text(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, list): + text_parts: list[str] = [] + for item in value: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + text_parts.append(text) + return "".join(text_parts) + return "" + + +def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: + current: Any = data + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _drop_none(value: Any) -> Any: + if isinstance(value, dict): + return {key: _drop_none(item) for key, item in value.items() if item is not None} + if isinstance(value, list): + return [_drop_none(item) for item in value] + return value diff --git a/test_hermes_gateway.py b/test_hermes_gateway.py new file mode 100644 index 0000000..739e202 --- /dev/null +++ b/test_hermes_gateway.py @@ -0,0 +1,262 @@ +import json + +import aiohttp +import pytest +from aiohttp import web + +from custom.hermes_gateway import ( + GatewaySessionState, + HermesGatewayLLM, + build_connect_params, + chat_context_to_gateway_messages, + extract_text_delta, + is_error_response, + is_terminal_event, +) +from livekit.agents import ChatContext, llm +from livekit.agents._exceptions import APIConnectionError + + +def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None: + ctx = ChatContext.empty() + ctx.add_message(role="system", content="system prompt") + ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")]) + + messages = chat_context_to_gateway_messages(ctx) + + assert messages == [ + {"role": "system", "content": "system prompt"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "look here"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + }, + ] + + +def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None: + assert ( + extract_text_delta( + {"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}} + ) + == "hi" + ) + assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there" + assert ( + extract_text_delta( + { + "type": "event", + "event": "session.message.delta", + "payload": {"message": {"content": [{"type": "text", "text": "!"}]}}, + } + ) + == "!" + ) + + +def test_per_room_session_state_reuses_stable_session_key() -> None: + state = GatewaySessionState(room_name="kitchen-room", agent_id="helper") + + assert state.session_key == "livekit:kitchen-room:helper" + state.session_key = "gateway-session-123" + assert state.session_key == "gateway-session-123" + + +def test_build_connect_params_uses_backend_operator_defaults() -> None: + params = build_connect_params(token="secret-token") + + assert params["client"] == { + "id": "gateway-client", + "version": "livekit-custom-agent", + "platform": "python", + "mode": "backend", + } + assert params["role"] == "operator" + assert params["scopes"] == ["operator.read", "operator.write"] + assert params["auth"] == {"token": "secret-token"} + assert "device" not in params + + +def test_gateway_response_helpers_match_only_current_send_request() -> None: + assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1") + assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1") + assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1") + assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1") + + +def test_hermes_llm_reports_provider_and_model() -> None: + state = GatewaySessionState(room_name="kitchen", agent_id="helper") + gateway_llm = HermesGatewayLLM( + url="ws://gateway.test/ws", + token="token", + state=state, + agent_id="helper", + model_name="hermes-agent", + ) + + assert gateway_llm.provider == "hermes-gateway" + assert gateway_llm.model == "hermes-agent" + + +def test_gateway_session_state_rejects_non_per_room_mode() -> None: + with pytest.raises(ValueError, match="per_room"): + GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn") + + +async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}}) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + payload = json.loads(message.data) + received.append(payload) + method = payload.get("method") + request_id = payload.get("id") + if method == "connect": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.create": + await ws.send_json( + { + "type": "res", + "id": request_id, + "ok": True, + "result": {"sessionKey": "livekit:kitchen:helper"}, + } + ) + elif method == "sessions.send": + await ws.send_json( + { + "type": "event", + "event": "agent", + "payload": {"delta": {"content": "你好"}}, + } + ) + await ws.send_json( + { + "type": "res", + "id": request_id, + "ok": True, + "result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}}, + } + ) + await ws.close() + + return ws + + app = web.Application() + app.router.add_get("/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + gateway_llm = HermesGatewayLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/ws", + token="secret-token", + state=GatewaySessionState(room_name="kitchen", agent_id="helper"), + agent_id="helper", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="杯子在哪里") + + try: + collected = await gateway_llm.chat(chat_ctx=ctx).collect() + finally: + await gateway_llm.aclose() + await runner.cleanup() + + assert collected.text == "你好" + assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"] + send_request = received[2] + assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper" + assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}] + + +def test_extract_text_delta_reads_final_message_content() -> None: + assert ( + extract_text_delta( + { + "type": "event", + "event": "session.message.completed", + "payload": { + "message": { + "content": [ + {"type": "text", "text": "完整回复"}, + ] + } + }, + } + ) + == "完整回复" + ) + + +def test_is_error_response_accepts_error_events() -> None: + assert is_error_response( + {"type": "event", "event": "agent.error", "payload": {"error": "boom"}}, + request_id="send-1", + ) + assert is_error_response( + {"type": "event", "event": "run.failed", "payload": {"message": "boom"}}, + request_id="send-1", + ) + + +async def test_llm_stream_maps_gateway_error_events_to_api_connection_error( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}}) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + payload = json.loads(message.data) + method = payload.get("method") + request_id = payload.get("id") + if method == "connect": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.create": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.send": + await ws.send_json( + { + "type": "event", + "event": "run.failed", + "payload": {"message": "gateway exploded"}, + } + ) + await ws.close() + + return ws + + app = web.Application() + app.router.add_get("/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + gateway_llm = HermesGatewayLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/ws", + token=None, + state=GatewaySessionState(room_name="kitchen", agent_id="helper"), + agent_id="helper", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="hello") + + try: + with pytest.raises(APIConnectionError, match="gateway exploded"): + await gateway_llm.chat(chat_ctx=ctx).collect() + finally: + await gateway_llm.aclose() + await runner.cleanup()