from __future__ import annotations import json from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any import aiohttp from livekit.agents import llm from livekit.agents._exceptions import APIConnectionError from livekit.agents.types import ( DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectOptions, NotGivenOr, ) from livekit.agents.utils import shortuuid @dataclass class GatewaySessionState: room_name: str agent_id: str | None = None session_key: str | None = None session_mode: str = "per_room" def __post_init__(self) -> None: if self.session_mode != "per_room": raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room") if self.session_key is None: suffix = self.agent_id or "default" self.session_key = f"livekit:{self.room_name}:{suffix}" class HermesGatewayLLM(llm.LLM): def __init__( self, *, url: str, token: str | None, state: GatewaySessionState, agent_id: str | None = None, model_name: str = "hermes-agent", request_timeout: float = 30.0, ) -> None: super().__init__() self._url = url self._token = token self._state = state self._agent_id = agent_id self._model_name = model_name self._request_timeout = request_timeout self._http_session: aiohttp.ClientSession | None = None @property def model(self) -> str: return self._model_name @property def provider(self) -> str: return "hermes-gateway" def chat( self, *, chat_ctx: llm.ChatContext, tools: list[llm.Tool] | None = None, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN, extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, ) -> llm.LLMStream: return HermesGatewayLLMStream( self, chat_ctx=chat_ctx, tools=tools or [], conn_options=conn_options, ) def _ensure_http_session(self) -> aiohttp.ClientSession: if self._http_session is None: timeout = aiohttp.ClientTimeout(total=self._request_timeout) self._http_session = aiohttp.ClientSession(timeout=timeout) return self._http_session async def aclose(self) -> None: if self._http_session is not None: await self._http_session.close() self._http_session = None class HermesGatewayLLMStream(llm.LLMStream): def __init__( self, llm: HermesGatewayLLM, *, chat_ctx: llm.ChatContext, tools: list[llm.Tool], conn_options: APIConnectOptions, ) -> None: super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) self._llm = llm async def _run(self) -> None: request_id = shortuuid("gwreq_") async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws: await _connect(ws, token=self._llm._token) await _send_rpc( ws, method="sessions.create", params={ "key": self._llm._state.session_key, "sessionKey": self._llm._state.session_key, "agentId": self._llm._agent_id, "metadata": { "source": "livekit", "room": self._llm._state.room_name, }, "idempotencyKey": self._llm._state.session_key, }, request_id=shortuuid("gwcreate_"), ) await _send_rpc( ws, method="sessions.send", params={ "key": self._llm._state.session_key, "sessionKey": self._llm._state.session_key, "agentId": self._llm._agent_id, "messages": chat_context_to_gateway_messages(self.chat_ctx), "stream": True, "idempotencyKey": request_id, }, request_id=request_id, wait_response=False, ) streamed_text = "" async for frame in _iter_gateway_frames(ws): if is_error_response(frame, request_id=request_id): raise APIConnectionError(_gateway_error_message(frame), retryable=False) text = extract_text_delta(frame) if text: if text.startswith(streamed_text): text = text[len(streamed_text) :] if text: streamed_text += text self._event_ch.send_nowait( llm.ChatChunk( id=request_id, delta=llm.ChoiceDelta(role="assistant", content=text), ) ) if is_terminal_event(frame, request_id=request_id): return def build_connect_params(*, token: str | None) -> dict[str, Any]: params: dict[str, Any] = { "minProtocol": 3, "maxProtocol": 4, "client": { "id": "gateway-client", "version": "livekit-custom-agent", "platform": "python", "mode": "backend", }, "role": "operator", "scopes": ["operator.read", "operator.write"], "caps": [], "commands": [], "permissions": {}, "locale": "zh-CN", "userAgent": "livekit-custom-agent", } if token: params["auth"] = {"token": token} return params def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] for message in chat_ctx.messages(): content = _message_content_to_gateway_content(message.content) if content is None: continue messages.append({"role": message.role, "content": content}) return messages def extract_text_delta(frame: dict[str, Any]) -> str: payload = frame.get("payload") if not isinstance(payload, dict): payload = frame for path in ( ("delta", "content"), ("delta", "text"), ("message", "delta", "content"), ("message", "delta", "text"), ("message", "content"), ("content",), ("text",), ): value = _get_nested(payload, path) text = _content_to_text(value) if text: return text return "" def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool: if frame.get("type") == "res" and frame.get("id") == request_id: return True event = frame.get("event") if event in { "agent.done", "agent.completed", "agent.error", "session.message.completed", "session.run.completed", "sessions.run.completed", "run.completed", "run.failed", }: return True payload = frame.get("payload") if isinstance(payload, dict) and payload.get("done") is True: return True return False def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool: if ( frame.get("type") == "res" and frame.get("id") == request_id and frame.get("ok") is False ): return True return frame.get("event") in { "agent.error", "session.error", "session.run.failed", "sessions.run.failed", "run.failed", } def _gateway_error_message(frame: dict[str, Any]) -> str: error = frame.get("error") if isinstance(error, str): return f"OpenClaw gateway request failed: {error}" if isinstance(error, dict): message = error.get("message") or error.get("error") if isinstance(message, str): return f"OpenClaw gateway request failed: {message}" payload = frame.get("payload") if isinstance(payload, dict): message = payload.get("message") or payload.get("error") if isinstance(message, str): return f"OpenClaw gateway request failed: {message}" return f"OpenClaw gateway request failed: {frame!r}" async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None: first = await _receive_json(ws) if first.get("event") != "connect.challenge": raise RuntimeError(f"expected connect.challenge, received {first!r}") request_id = shortuuid("gwconnect_") await _send_rpc( ws, method="connect", params=build_connect_params(token=token), request_id=request_id, wait_response=False, ) response = await _wait_for_response(ws, request_id=request_id) if not response.get("ok"): raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}") async def _send_rpc( ws: aiohttp.ClientWebSocketResponse, *, method: str, params: dict[str, Any], request_id: str, wait_response: bool = True, ) -> dict[str, Any] | None: await ws.send_str( json.dumps( { "type": "req", "id": request_id, "method": method, "params": _drop_none(params), } ) ) if not wait_response: return None response = await _wait_for_response(ws, request_id=request_id) if not response.get("ok", False): raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}") return response async def _wait_for_response( ws: aiohttp.ClientWebSocketResponse, *, request_id: str ) -> dict[str, Any]: async for frame in _iter_gateway_frames(ws): if frame.get("type") == "res" and frame.get("id") == request_id: return frame raise RuntimeError(f"OpenClaw gateway closed before response {request_id}") async def _iter_gateway_frames( ws: aiohttp.ClientWebSocketResponse, ) -> AsyncIterator[dict[str, Any]]: async for message in ws: if message.type == aiohttp.WSMsgType.TEXT: yield json.loads(message.data) elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): return elif message.type == aiohttp.WSMsgType.ERROR: raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}") async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]: message = await ws.receive() if message.type != aiohttp.WSMsgType.TEXT: raise RuntimeError(f"expected gateway text frame, received {message.type!r}") return json.loads(message.data) def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any: parts: list[dict[str, Any]] = [] for item in content: if isinstance(item, str): if item: parts.append({"type": "text", "text": item}) elif isinstance(item, llm.ImageContent) and isinstance(item.image, str): parts.append({"type": "image_url", "image_url": {"url": item.image}}) if not parts: return None if len(parts) == 1 and parts[0]["type"] == "text": return parts[0]["text"] return parts def _content_to_text(value: Any) -> str: if isinstance(value, str): return value if isinstance(value, list): text_parts: list[str] = [] for item in value: if isinstance(item, str): text_parts.append(item) elif isinstance(item, dict): text = item.get("text") if isinstance(text, str): text_parts.append(text) return "".join(text_parts) return "" def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: current: Any = data for key in path: if not isinstance(current, dict): return None current = current.get(key) return current def _drop_none(value: Any) -> Any: if isinstance(value, dict): return {key: _drop_none(item) for key, item in value.items() if item is not None} if isinstance(value, list): return [_drop_none(item) for item in value] return value