392 lines
12 KiB
Python
392 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncIterator
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
|
|
from livekit.agents import llm
|
|
from livekit.agents._exceptions import APIConnectionError
|
|
from livekit.agents.types import (
|
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
NOT_GIVEN,
|
|
APIConnectOptions,
|
|
NotGivenOr,
|
|
)
|
|
from livekit.agents.utils import shortuuid
|
|
|
|
|
|
@dataclass
|
|
class GatewaySessionState:
|
|
room_name: str
|
|
agent_id: str | None = None
|
|
session_key: str | None = None
|
|
session_mode: str = "per_room"
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.session_mode != "per_room":
|
|
raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room")
|
|
if self.session_key is None:
|
|
suffix = self.agent_id or "default"
|
|
self.session_key = f"livekit:{self.room_name}:{suffix}"
|
|
|
|
|
|
class HermesGatewayLLM(llm.LLM):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
url: str,
|
|
token: str | None,
|
|
state: GatewaySessionState,
|
|
agent_id: str | None = None,
|
|
model_name: str = "hermes-agent",
|
|
request_timeout: float = 30.0,
|
|
) -> None:
|
|
super().__init__()
|
|
self._url = url
|
|
self._token = token
|
|
self._state = state
|
|
self._agent_id = agent_id
|
|
self._model_name = model_name
|
|
self._request_timeout = request_timeout
|
|
self._http_session: aiohttp.ClientSession | None = None
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model_name
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return "hermes-gateway"
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
chat_ctx: llm.ChatContext,
|
|
tools: list[llm.Tool] | None = None,
|
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
|
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
) -> llm.LLMStream:
|
|
return HermesGatewayLLMStream(
|
|
self,
|
|
chat_ctx=chat_ctx,
|
|
tools=tools or [],
|
|
conn_options=conn_options,
|
|
)
|
|
|
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
|
if self._http_session is None:
|
|
timeout = aiohttp.ClientTimeout(total=self._request_timeout)
|
|
self._http_session = aiohttp.ClientSession(timeout=timeout)
|
|
return self._http_session
|
|
|
|
async def aclose(self) -> None:
|
|
if self._http_session is not None:
|
|
await self._http_session.close()
|
|
self._http_session = None
|
|
|
|
|
|
class HermesGatewayLLMStream(llm.LLMStream):
|
|
def __init__(
|
|
self,
|
|
llm: HermesGatewayLLM,
|
|
*,
|
|
chat_ctx: llm.ChatContext,
|
|
tools: list[llm.Tool],
|
|
conn_options: APIConnectOptions,
|
|
) -> None:
|
|
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
self._llm = llm
|
|
|
|
async def _run(self) -> None:
|
|
request_id = shortuuid("gwreq_")
|
|
async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws:
|
|
await _connect(ws, token=self._llm._token)
|
|
await _send_rpc(
|
|
ws,
|
|
method="sessions.create",
|
|
params={
|
|
"key": self._llm._state.session_key,
|
|
"sessionKey": self._llm._state.session_key,
|
|
"agentId": self._llm._agent_id,
|
|
"metadata": {
|
|
"source": "livekit",
|
|
"room": self._llm._state.room_name,
|
|
},
|
|
"idempotencyKey": self._llm._state.session_key,
|
|
},
|
|
request_id=shortuuid("gwcreate_"),
|
|
)
|
|
await _send_rpc(
|
|
ws,
|
|
method="sessions.send",
|
|
params={
|
|
"key": self._llm._state.session_key,
|
|
"sessionKey": self._llm._state.session_key,
|
|
"agentId": self._llm._agent_id,
|
|
"messages": chat_context_to_gateway_messages(self.chat_ctx),
|
|
"stream": True,
|
|
"idempotencyKey": request_id,
|
|
},
|
|
request_id=request_id,
|
|
wait_response=False,
|
|
)
|
|
|
|
streamed_text = ""
|
|
async for frame in _iter_gateway_frames(ws):
|
|
if is_error_response(frame, request_id=request_id):
|
|
raise APIConnectionError(_gateway_error_message(frame), retryable=False)
|
|
text = extract_text_delta(frame)
|
|
if text:
|
|
if text.startswith(streamed_text):
|
|
text = text[len(streamed_text) :]
|
|
if text:
|
|
streamed_text += text
|
|
self._event_ch.send_nowait(
|
|
llm.ChatChunk(
|
|
id=request_id,
|
|
delta=llm.ChoiceDelta(role="assistant", content=text),
|
|
)
|
|
)
|
|
if is_terminal_event(frame, request_id=request_id):
|
|
return
|
|
|
|
|
|
def build_connect_params(*, token: str | None) -> dict[str, Any]:
|
|
params: dict[str, Any] = {
|
|
"minProtocol": 3,
|
|
"maxProtocol": 4,
|
|
"client": {
|
|
"id": "gateway-client",
|
|
"version": "livekit-custom-agent",
|
|
"platform": "python",
|
|
"mode": "backend",
|
|
},
|
|
"role": "operator",
|
|
"scopes": ["operator.read", "operator.write"],
|
|
"caps": [],
|
|
"commands": [],
|
|
"permissions": {},
|
|
"locale": "zh-CN",
|
|
"userAgent": "livekit-custom-agent",
|
|
}
|
|
if token:
|
|
params["auth"] = {"token": token}
|
|
return params
|
|
|
|
|
|
def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]:
|
|
messages: list[dict[str, Any]] = []
|
|
for message in chat_ctx.messages():
|
|
content = _message_content_to_gateway_content(message.content)
|
|
if content is None:
|
|
continue
|
|
messages.append({"role": message.role, "content": content})
|
|
return messages
|
|
|
|
|
|
def extract_text_delta(frame: dict[str, Any]) -> str:
|
|
payload = frame.get("payload")
|
|
if not isinstance(payload, dict):
|
|
payload = frame
|
|
|
|
for path in (
|
|
("delta", "content"),
|
|
("delta", "text"),
|
|
("message", "delta", "content"),
|
|
("message", "delta", "text"),
|
|
("message", "content"),
|
|
("content",),
|
|
("text",),
|
|
):
|
|
value = _get_nested(payload, path)
|
|
text = _content_to_text(value)
|
|
if text:
|
|
return text
|
|
|
|
return ""
|
|
|
|
|
|
def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool:
|
|
if frame.get("type") == "res" and frame.get("id") == request_id:
|
|
return True
|
|
|
|
event = frame.get("event")
|
|
if event in {
|
|
"agent.done",
|
|
"agent.completed",
|
|
"agent.error",
|
|
"session.message.completed",
|
|
"session.run.completed",
|
|
"sessions.run.completed",
|
|
"run.completed",
|
|
"run.failed",
|
|
}:
|
|
return True
|
|
|
|
payload = frame.get("payload")
|
|
if isinstance(payload, dict) and payload.get("done") is True:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool:
|
|
if (
|
|
frame.get("type") == "res"
|
|
and frame.get("id") == request_id
|
|
and frame.get("ok") is False
|
|
):
|
|
return True
|
|
|
|
return frame.get("event") in {
|
|
"agent.error",
|
|
"session.error",
|
|
"session.run.failed",
|
|
"sessions.run.failed",
|
|
"run.failed",
|
|
}
|
|
|
|
|
|
def _gateway_error_message(frame: dict[str, Any]) -> str:
|
|
error = frame.get("error")
|
|
if isinstance(error, str):
|
|
return f"OpenClaw gateway request failed: {error}"
|
|
if isinstance(error, dict):
|
|
message = error.get("message") or error.get("error")
|
|
if isinstance(message, str):
|
|
return f"OpenClaw gateway request failed: {message}"
|
|
|
|
payload = frame.get("payload")
|
|
if isinstance(payload, dict):
|
|
message = payload.get("message") or payload.get("error")
|
|
if isinstance(message, str):
|
|
return f"OpenClaw gateway request failed: {message}"
|
|
|
|
return f"OpenClaw gateway request failed: {frame!r}"
|
|
|
|
|
|
async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None:
|
|
first = await _receive_json(ws)
|
|
if first.get("event") != "connect.challenge":
|
|
raise RuntimeError(f"expected connect.challenge, received {first!r}")
|
|
|
|
request_id = shortuuid("gwconnect_")
|
|
await _send_rpc(
|
|
ws,
|
|
method="connect",
|
|
params=build_connect_params(token=token),
|
|
request_id=request_id,
|
|
wait_response=False,
|
|
)
|
|
response = await _wait_for_response(ws, request_id=request_id)
|
|
if not response.get("ok"):
|
|
raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}")
|
|
|
|
|
|
async def _send_rpc(
|
|
ws: aiohttp.ClientWebSocketResponse,
|
|
*,
|
|
method: str,
|
|
params: dict[str, Any],
|
|
request_id: str,
|
|
wait_response: bool = True,
|
|
) -> dict[str, Any] | None:
|
|
await ws.send_str(
|
|
json.dumps(
|
|
{
|
|
"type": "req",
|
|
"id": request_id,
|
|
"method": method,
|
|
"params": _drop_none(params),
|
|
}
|
|
)
|
|
)
|
|
if not wait_response:
|
|
return None
|
|
response = await _wait_for_response(ws, request_id=request_id)
|
|
if not response.get("ok", False):
|
|
raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}")
|
|
return response
|
|
|
|
|
|
async def _wait_for_response(
|
|
ws: aiohttp.ClientWebSocketResponse, *, request_id: str
|
|
) -> dict[str, Any]:
|
|
async for frame in _iter_gateway_frames(ws):
|
|
if frame.get("type") == "res" and frame.get("id") == request_id:
|
|
return frame
|
|
raise RuntimeError(f"OpenClaw gateway closed before response {request_id}")
|
|
|
|
|
|
async def _iter_gateway_frames(
|
|
ws: aiohttp.ClientWebSocketResponse,
|
|
) -> AsyncIterator[dict[str, Any]]:
|
|
async for message in ws:
|
|
if message.type == aiohttp.WSMsgType.TEXT:
|
|
yield json.loads(message.data)
|
|
elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
|
|
return
|
|
elif message.type == aiohttp.WSMsgType.ERROR:
|
|
raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}")
|
|
|
|
|
|
async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]:
|
|
message = await ws.receive()
|
|
if message.type != aiohttp.WSMsgType.TEXT:
|
|
raise RuntimeError(f"expected gateway text frame, received {message.type!r}")
|
|
return json.loads(message.data)
|
|
|
|
|
|
def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any:
|
|
parts: list[dict[str, Any]] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
if item:
|
|
parts.append({"type": "text", "text": item})
|
|
elif isinstance(item, llm.ImageContent) and isinstance(item.image, str):
|
|
parts.append({"type": "image_url", "image_url": {"url": item.image}})
|
|
|
|
if not parts:
|
|
return None
|
|
if len(parts) == 1 and parts[0]["type"] == "text":
|
|
return parts[0]["text"]
|
|
return parts
|
|
|
|
|
|
def _content_to_text(value: Any) -> str:
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, list):
|
|
text_parts: list[str] = []
|
|
for item in value:
|
|
if isinstance(item, str):
|
|
text_parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
text_parts.append(text)
|
|
return "".join(text_parts)
|
|
return ""
|
|
|
|
|
|
def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any:
|
|
current: Any = data
|
|
for key in path:
|
|
if not isinstance(current, dict):
|
|
return None
|
|
current = current.get(key)
|
|
return current
|
|
|
|
|
|
def _drop_none(value: Any) -> Any:
|
|
if isinstance(value, dict):
|
|
return {key: _drop_none(item) for key, item in value.items() if item is not None}
|
|
if isinstance(value, list):
|
|
return [_drop_none(item) for item in value]
|
|
return value
|