beaver #2
12
.env.example
12
.env.example
@ -13,6 +13,9 @@ CUSTOM_ASR_HOTWORDS=
|
||||
CUSTOM_ASR_ITN=
|
||||
CUSTOM_ASR_CHUNK_MODE=
|
||||
|
||||
# LLM backend: openai/openai-compatible or hermes_gateway/openclaw.
|
||||
CUSTOM_LLM_PROVIDER=openai
|
||||
|
||||
# OpenAI-compatible LLM
|
||||
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
|
||||
# CUSTOM_LLM_MODEL=Qwen3.6-35B
|
||||
@ -28,6 +31,15 @@ CUSTOM_SAVE_MODEL_IMAGES=false
|
||||
# CUSTOM_TEXT_LLM_MODEL=
|
||||
# CUSTOM_VISION_LLM_MODEL=
|
||||
|
||||
# Hermes Agent via OpenClaw Gateway WebSocket, one Gateway session per LiveKit room.
|
||||
# CUSTOM_LLM_PROVIDER=hermes_gateway
|
||||
# CUSTOM_HERMES_GATEWAY_URL=ws://localhost:1977/ws
|
||||
# CUSTOM_HERMES_AGENT_ID=
|
||||
# CUSTOM_HERMES_API_KEY=
|
||||
# CUSTOM_HERMES_SESSION_MODE=per_room
|
||||
# CUSTOM_HERMES_MODEL=hermes-agent
|
||||
# CUSTOM_HERMES_REQUEST_TIMEOUT=30
|
||||
|
||||
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
|
||||
# CUSTOM_LLM_MODEL=deepseek-v4-flash
|
||||
# CUSTOM_LLM_API_KEY=
|
||||
|
||||
108
custom_agent.py
108
custom_agent.py
@ -9,6 +9,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||||
from memory import MemoryRecallClient
|
||||
from tts import BlackboxTTS
|
||||
|
||||
@ -638,12 +639,20 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
|
||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
||||
if not LLM_API_KEY:
|
||||
if LLM_PROVIDER not in {"openai", "openai-compatible", "hermes", "hermes_gateway", "openclaw"}:
|
||||
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
|
||||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
|
||||
logger.info(
|
||||
"Using LLM provider=%s model=%s base_url=%s",
|
||||
LLM_PROVIDER,
|
||||
LLM_MODEL,
|
||||
LLM_BASE_URL or "OpenAI default",
|
||||
)
|
||||
|
||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
||||
@ -668,38 +677,75 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
)
|
||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||
|
||||
import httpx
|
||||
from openai import AsyncClient as OpenAIAsyncClient
|
||||
if LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||||
if not gateway_url:
|
||||
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
|
||||
|
||||
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
|
||||
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
|
||||
|
||||
if LLM_BASE_URL:
|
||||
openai_client = OpenAIAsyncClient(
|
||||
api_key=LLM_API_KEY,
|
||||
base_url=LLM_BASE_URL,
|
||||
http_client=http_client,
|
||||
hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None
|
||||
hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower()
|
||||
if hermes_session_mode != "per_room":
|
||||
raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room")
|
||||
hermes_token = (
|
||||
os.getenv("CUSTOM_HERMES_API_KEY")
|
||||
or os.getenv("CUSTOM_HERMES_TOKEN")
|
||||
or LLM_API_KEY
|
||||
or None
|
||||
)
|
||||
hermes_state = GatewaySessionState(
|
||||
room_name=ctx.room.name,
|
||||
agent_id=hermes_agent_id,
|
||||
session_mode=hermes_session_mode,
|
||||
)
|
||||
base_llm = HermesGatewayLLM(
|
||||
url=gateway_url,
|
||||
token=hermes_token,
|
||||
state=hermes_state,
|
||||
agent_id=hermes_agent_id,
|
||||
model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"),
|
||||
request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0),
|
||||
)
|
||||
text_llm = base_llm
|
||||
vision_llm = base_llm
|
||||
logger.info(
|
||||
"Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s",
|
||||
gateway_url,
|
||||
hermes_agent_id or "default",
|
||||
hermes_state.session_key,
|
||||
)
|
||||
else:
|
||||
openai_client = OpenAIAsyncClient(
|
||||
api_key=LLM_API_KEY,
|
||||
http_client=http_client,
|
||||
)
|
||||
import httpx
|
||||
from openai import AsyncClient as OpenAIAsyncClient
|
||||
|
||||
base_llm = openai.LLM(
|
||||
model=LLM_MODEL,
|
||||
client=openai_client,
|
||||
)
|
||||
text_llm = (
|
||||
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
|
||||
if TEXT_LLM_MODEL != LLM_MODEL
|
||||
else base_llm
|
||||
)
|
||||
vision_llm = (
|
||||
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
|
||||
if VISION_LLM_MODEL != LLM_MODEL
|
||||
else base_llm
|
||||
)
|
||||
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
|
||||
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
|
||||
|
||||
if LLM_BASE_URL:
|
||||
openai_client = OpenAIAsyncClient(
|
||||
api_key=LLM_API_KEY,
|
||||
base_url=LLM_BASE_URL,
|
||||
http_client=http_client,
|
||||
)
|
||||
else:
|
||||
openai_client = OpenAIAsyncClient(
|
||||
api_key=LLM_API_KEY,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
base_llm = openai.LLM(
|
||||
model=LLM_MODEL,
|
||||
client=openai_client,
|
||||
)
|
||||
text_llm = (
|
||||
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
|
||||
if TEXT_LLM_MODEL != LLM_MODEL
|
||||
else base_llm
|
||||
)
|
||||
vision_llm = (
|
||||
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
|
||||
if VISION_LLM_MODEL != LLM_MODEL
|
||||
else base_llm
|
||||
)
|
||||
vision_store = VisionFrameStore(
|
||||
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
|
||||
)
|
||||
@ -707,7 +753,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
session: AgentSession = AgentSession(
|
||||
# 1. Custom ASR blackbox with StreamAdapter
|
||||
stt=stt_stream,
|
||||
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
|
||||
# 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway.
|
||||
llm=base_llm,
|
||||
# 3. TTS blackbox
|
||||
tts=BlackboxTTS(
|
||||
|
||||
391
hermes_gateway.py
Normal file
391
hermes_gateway.py
Normal file
@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from livekit.agents import llm
|
||||
from livekit.agents._exceptions import APIConnectionError
|
||||
from livekit.agents.types import (
|
||||
DEFAULT_API_CONNECT_OPTIONS,
|
||||
NOT_GIVEN,
|
||||
APIConnectOptions,
|
||||
NotGivenOr,
|
||||
)
|
||||
from livekit.agents.utils import shortuuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewaySessionState:
|
||||
room_name: str
|
||||
agent_id: str | None = None
|
||||
session_key: str | None = None
|
||||
session_mode: str = "per_room"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.session_mode != "per_room":
|
||||
raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room")
|
||||
if self.session_key is None:
|
||||
suffix = self.agent_id or "default"
|
||||
self.session_key = f"livekit:{self.room_name}:{suffix}"
|
||||
|
||||
|
||||
class HermesGatewayLLM(llm.LLM):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
token: str | None,
|
||||
state: GatewaySessionState,
|
||||
agent_id: str | None = None,
|
||||
model_name: str = "hermes-agent",
|
||||
request_timeout: float = 30.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._url = url
|
||||
self._token = token
|
||||
self._state = state
|
||||
self._agent_id = agent_id
|
||||
self._model_name = model_name
|
||||
self._request_timeout = request_timeout
|
||||
self._http_session: aiohttp.ClientSession | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return "hermes-gateway"
|
||||
|
||||
def chat(
|
||||
self,
|
||||
*,
|
||||
chat_ctx: llm.ChatContext,
|
||||
tools: list[llm.Tool] | None = None,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
||||
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
||||
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
||||
) -> llm.LLMStream:
|
||||
return HermesGatewayLLMStream(
|
||||
self,
|
||||
chat_ctx=chat_ctx,
|
||||
tools=tools or [],
|
||||
conn_options=conn_options,
|
||||
)
|
||||
|
||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
timeout = aiohttp.ClientTimeout(total=self._request_timeout)
|
||||
self._http_session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._http_session
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
|
||||
class HermesGatewayLLMStream(llm.LLMStream):
|
||||
def __init__(
|
||||
self,
|
||||
llm: HermesGatewayLLM,
|
||||
*,
|
||||
chat_ctx: llm.ChatContext,
|
||||
tools: list[llm.Tool],
|
||||
conn_options: APIConnectOptions,
|
||||
) -> None:
|
||||
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||
self._llm = llm
|
||||
|
||||
async def _run(self) -> None:
|
||||
request_id = shortuuid("gwreq_")
|
||||
async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws:
|
||||
await _connect(ws, token=self._llm._token)
|
||||
await _send_rpc(
|
||||
ws,
|
||||
method="sessions.create",
|
||||
params={
|
||||
"key": self._llm._state.session_key,
|
||||
"sessionKey": self._llm._state.session_key,
|
||||
"agentId": self._llm._agent_id,
|
||||
"metadata": {
|
||||
"source": "livekit",
|
||||
"room": self._llm._state.room_name,
|
||||
},
|
||||
"idempotencyKey": self._llm._state.session_key,
|
||||
},
|
||||
request_id=shortuuid("gwcreate_"),
|
||||
)
|
||||
await _send_rpc(
|
||||
ws,
|
||||
method="sessions.send",
|
||||
params={
|
||||
"key": self._llm._state.session_key,
|
||||
"sessionKey": self._llm._state.session_key,
|
||||
"agentId": self._llm._agent_id,
|
||||
"messages": chat_context_to_gateway_messages(self.chat_ctx),
|
||||
"stream": True,
|
||||
"idempotencyKey": request_id,
|
||||
},
|
||||
request_id=request_id,
|
||||
wait_response=False,
|
||||
)
|
||||
|
||||
streamed_text = ""
|
||||
async for frame in _iter_gateway_frames(ws):
|
||||
if is_error_response(frame, request_id=request_id):
|
||||
raise APIConnectionError(_gateway_error_message(frame), retryable=False)
|
||||
text = extract_text_delta(frame)
|
||||
if text:
|
||||
if text.startswith(streamed_text):
|
||||
text = text[len(streamed_text) :]
|
||||
if text:
|
||||
streamed_text += text
|
||||
self._event_ch.send_nowait(
|
||||
llm.ChatChunk(
|
||||
id=request_id,
|
||||
delta=llm.ChoiceDelta(role="assistant", content=text),
|
||||
)
|
||||
)
|
||||
if is_terminal_event(frame, request_id=request_id):
|
||||
return
|
||||
|
||||
|
||||
def build_connect_params(*, token: str | None) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {
|
||||
"minProtocol": 3,
|
||||
"maxProtocol": 4,
|
||||
"client": {
|
||||
"id": "gateway-client",
|
||||
"version": "livekit-custom-agent",
|
||||
"platform": "python",
|
||||
"mode": "backend",
|
||||
},
|
||||
"role": "operator",
|
||||
"scopes": ["operator.read", "operator.write"],
|
||||
"caps": [],
|
||||
"commands": [],
|
||||
"permissions": {},
|
||||
"locale": "zh-CN",
|
||||
"userAgent": "livekit-custom-agent",
|
||||
}
|
||||
if token:
|
||||
params["auth"] = {"token": token}
|
||||
return params
|
||||
|
||||
|
||||
def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
for message in chat_ctx.messages():
|
||||
content = _message_content_to_gateway_content(message.content)
|
||||
if content is None:
|
||||
continue
|
||||
messages.append({"role": message.role, "content": content})
|
||||
return messages
|
||||
|
||||
|
||||
def extract_text_delta(frame: dict[str, Any]) -> str:
|
||||
payload = frame.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
payload = frame
|
||||
|
||||
for path in (
|
||||
("delta", "content"),
|
||||
("delta", "text"),
|
||||
("message", "delta", "content"),
|
||||
("message", "delta", "text"),
|
||||
("message", "content"),
|
||||
("content",),
|
||||
("text",),
|
||||
):
|
||||
value = _get_nested(payload, path)
|
||||
text = _content_to_text(value)
|
||||
if text:
|
||||
return text
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool:
|
||||
if frame.get("type") == "res" and frame.get("id") == request_id:
|
||||
return True
|
||||
|
||||
event = frame.get("event")
|
||||
if event in {
|
||||
"agent.done",
|
||||
"agent.completed",
|
||||
"agent.error",
|
||||
"session.message.completed",
|
||||
"session.run.completed",
|
||||
"sessions.run.completed",
|
||||
"run.completed",
|
||||
"run.failed",
|
||||
}:
|
||||
return True
|
||||
|
||||
payload = frame.get("payload")
|
||||
if isinstance(payload, dict) and payload.get("done") is True:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool:
|
||||
if (
|
||||
frame.get("type") == "res"
|
||||
and frame.get("id") == request_id
|
||||
and frame.get("ok") is False
|
||||
):
|
||||
return True
|
||||
|
||||
return frame.get("event") in {
|
||||
"agent.error",
|
||||
"session.error",
|
||||
"session.run.failed",
|
||||
"sessions.run.failed",
|
||||
"run.failed",
|
||||
}
|
||||
|
||||
|
||||
def _gateway_error_message(frame: dict[str, Any]) -> str:
|
||||
error = frame.get("error")
|
||||
if isinstance(error, str):
|
||||
return f"OpenClaw gateway request failed: {error}"
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message") or error.get("error")
|
||||
if isinstance(message, str):
|
||||
return f"OpenClaw gateway request failed: {message}"
|
||||
|
||||
payload = frame.get("payload")
|
||||
if isinstance(payload, dict):
|
||||
message = payload.get("message") or payload.get("error")
|
||||
if isinstance(message, str):
|
||||
return f"OpenClaw gateway request failed: {message}"
|
||||
|
||||
return f"OpenClaw gateway request failed: {frame!r}"
|
||||
|
||||
|
||||
async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None:
|
||||
first = await _receive_json(ws)
|
||||
if first.get("event") != "connect.challenge":
|
||||
raise RuntimeError(f"expected connect.challenge, received {first!r}")
|
||||
|
||||
request_id = shortuuid("gwconnect_")
|
||||
await _send_rpc(
|
||||
ws,
|
||||
method="connect",
|
||||
params=build_connect_params(token=token),
|
||||
request_id=request_id,
|
||||
wait_response=False,
|
||||
)
|
||||
response = await _wait_for_response(ws, request_id=request_id)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}")
|
||||
|
||||
|
||||
async def _send_rpc(
|
||||
ws: aiohttp.ClientWebSocketResponse,
|
||||
*,
|
||||
method: str,
|
||||
params: dict[str, Any],
|
||||
request_id: str,
|
||||
wait_response: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
await ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "req",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": _drop_none(params),
|
||||
}
|
||||
)
|
||||
)
|
||||
if not wait_response:
|
||||
return None
|
||||
response = await _wait_for_response(ws, request_id=request_id)
|
||||
if not response.get("ok", False):
|
||||
raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}")
|
||||
return response
|
||||
|
||||
|
||||
async def _wait_for_response(
|
||||
ws: aiohttp.ClientWebSocketResponse, *, request_id: str
|
||||
) -> dict[str, Any]:
|
||||
async for frame in _iter_gateway_frames(ws):
|
||||
if frame.get("type") == "res" and frame.get("id") == request_id:
|
||||
return frame
|
||||
raise RuntimeError(f"OpenClaw gateway closed before response {request_id}")
|
||||
|
||||
|
||||
async def _iter_gateway_frames(
|
||||
ws: aiohttp.ClientWebSocketResponse,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for message in ws:
|
||||
if message.type == aiohttp.WSMsgType.TEXT:
|
||||
yield json.loads(message.data)
|
||||
elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
|
||||
return
|
||||
elif message.type == aiohttp.WSMsgType.ERROR:
|
||||
raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}")
|
||||
|
||||
|
||||
async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]:
|
||||
message = await ws.receive()
|
||||
if message.type != aiohttp.WSMsgType.TEXT:
|
||||
raise RuntimeError(f"expected gateway text frame, received {message.type!r}")
|
||||
return json.loads(message.data)
|
||||
|
||||
|
||||
def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any:
|
||||
parts: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
if item:
|
||||
parts.append({"type": "text", "text": item})
|
||||
elif isinstance(item, llm.ImageContent) and isinstance(item.image, str):
|
||||
parts.append({"type": "image_url", "image_url": {"url": item.image}})
|
||||
|
||||
if not parts:
|
||||
return None
|
||||
if len(parts) == 1 and parts[0]["type"] == "text":
|
||||
return parts[0]["text"]
|
||||
return parts
|
||||
|
||||
|
||||
def _content_to_text(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
text_parts: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
text_parts.append(text)
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any:
|
||||
current: Any = data
|
||||
for key in path:
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(key)
|
||||
return current
|
||||
|
||||
|
||||
def _drop_none(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {key: _drop_none(item) for key, item in value.items() if item is not None}
|
||||
if isinstance(value, list):
|
||||
return [_drop_none(item) for item in value]
|
||||
return value
|
||||
262
test_hermes_gateway.py
Normal file
262
test_hermes_gateway.py
Normal file
@ -0,0 +1,262 @@
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from custom.hermes_gateway import (
|
||||
GatewaySessionState,
|
||||
HermesGatewayLLM,
|
||||
build_connect_params,
|
||||
chat_context_to_gateway_messages,
|
||||
extract_text_delta,
|
||||
is_error_response,
|
||||
is_terminal_event,
|
||||
)
|
||||
from livekit.agents import ChatContext, llm
|
||||
from livekit.agents._exceptions import APIConnectionError
|
||||
|
||||
|
||||
def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None:
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="system", content="system prompt")
|
||||
ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")])
|
||||
|
||||
messages = chat_context_to_gateway_messages(ctx)
|
||||
|
||||
assert messages == [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look here"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None:
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}}
|
||||
)
|
||||
== "hi"
|
||||
)
|
||||
assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there"
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "session.message.delta",
|
||||
"payload": {"message": {"content": [{"type": "text", "text": "!"}]}},
|
||||
}
|
||||
)
|
||||
== "!"
|
||||
)
|
||||
|
||||
|
||||
def test_per_room_session_state_reuses_stable_session_key() -> None:
|
||||
state = GatewaySessionState(room_name="kitchen-room", agent_id="helper")
|
||||
|
||||
assert state.session_key == "livekit:kitchen-room:helper"
|
||||
state.session_key = "gateway-session-123"
|
||||
assert state.session_key == "gateway-session-123"
|
||||
|
||||
|
||||
def test_build_connect_params_uses_backend_operator_defaults() -> None:
|
||||
params = build_connect_params(token="secret-token")
|
||||
|
||||
assert params["client"] == {
|
||||
"id": "gateway-client",
|
||||
"version": "livekit-custom-agent",
|
||||
"platform": "python",
|
||||
"mode": "backend",
|
||||
}
|
||||
assert params["role"] == "operator"
|
||||
assert params["scopes"] == ["operator.read", "operator.write"]
|
||||
assert params["auth"] == {"token": "secret-token"}
|
||||
assert "device" not in params
|
||||
|
||||
|
||||
def test_gateway_response_helpers_match_only_current_send_request() -> None:
|
||||
assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1")
|
||||
assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1")
|
||||
assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1")
|
||||
assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1")
|
||||
|
||||
|
||||
def test_hermes_llm_reports_provider_and_model() -> None:
|
||||
state = GatewaySessionState(room_name="kitchen", agent_id="helper")
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url="ws://gateway.test/ws",
|
||||
token="token",
|
||||
state=state,
|
||||
agent_id="helper",
|
||||
model_name="hermes-agent",
|
||||
)
|
||||
|
||||
assert gateway_llm.provider == "hermes-gateway"
|
||||
assert gateway_llm.model == "hermes-agent"
|
||||
|
||||
|
||||
def test_gateway_session_state_rejects_non_per_room_mode() -> None:
|
||||
with pytest.raises(ValueError, match="per_room"):
|
||||
GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn")
|
||||
|
||||
|
||||
async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
payload = json.loads(message.data)
|
||||
received.append(payload)
|
||||
method = payload.get("method")
|
||||
request_id = payload.get("id")
|
||||
if method == "connect":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.create":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "res",
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"result": {"sessionKey": "livekit:kitchen:helper"},
|
||||
}
|
||||
)
|
||||
elif method == "sessions.send":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "agent",
|
||||
"payload": {"delta": {"content": "你好"}},
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "res",
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}},
|
||||
}
|
||||
)
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||
token="secret-token",
|
||||
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||
agent_id="helper",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="杯子在哪里")
|
||||
|
||||
try:
|
||||
collected = await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await gateway_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert collected.text == "你好"
|
||||
assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"]
|
||||
send_request = received[2]
|
||||
assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper"
|
||||
assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}]
|
||||
|
||||
|
||||
def test_extract_text_delta_reads_final_message_content() -> None:
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "session.message.completed",
|
||||
"payload": {
|
||||
"message": {
|
||||
"content": [
|
||||
{"type": "text", "text": "完整回复"},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
== "完整回复"
|
||||
)
|
||||
|
||||
|
||||
def test_is_error_response_accepts_error_events() -> None:
|
||||
assert is_error_response(
|
||||
{"type": "event", "event": "agent.error", "payload": {"error": "boom"}},
|
||||
request_id="send-1",
|
||||
)
|
||||
assert is_error_response(
|
||||
{"type": "event", "event": "run.failed", "payload": {"message": "boom"}},
|
||||
request_id="send-1",
|
||||
)
|
||||
|
||||
|
||||
async def test_llm_stream_maps_gateway_error_events_to_api_connection_error(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
payload = json.loads(message.data)
|
||||
method = payload.get("method")
|
||||
request_id = payload.get("id")
|
||||
if method == "connect":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.create":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.send":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "run.failed",
|
||||
"payload": {"message": "gateway exploded"},
|
||||
}
|
||||
)
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||
token=None,
|
||||
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||
agent_id="helper",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="hello")
|
||||
|
||||
try:
|
||||
with pytest.raises(APIConnectionError, match="gateway exploded"):
|
||||
await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await gateway_llm.aclose()
|
||||
await runner.cleanup()
|
||||
Reference in New Issue
Block a user