feat: add hermes gateway llm adapter
This commit is contained in:
262
test_hermes_gateway.py
Normal file
262
test_hermes_gateway.py
Normal file
@ -0,0 +1,262 @@
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from custom.hermes_gateway import (
|
||||
GatewaySessionState,
|
||||
HermesGatewayLLM,
|
||||
build_connect_params,
|
||||
chat_context_to_gateway_messages,
|
||||
extract_text_delta,
|
||||
is_error_response,
|
||||
is_terminal_event,
|
||||
)
|
||||
from livekit.agents import ChatContext, llm
|
||||
from livekit.agents._exceptions import APIConnectionError
|
||||
|
||||
|
||||
def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None:
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="system", content="system prompt")
|
||||
ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")])
|
||||
|
||||
messages = chat_context_to_gateway_messages(ctx)
|
||||
|
||||
assert messages == [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look here"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None:
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}}
|
||||
)
|
||||
== "hi"
|
||||
)
|
||||
assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there"
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "session.message.delta",
|
||||
"payload": {"message": {"content": [{"type": "text", "text": "!"}]}},
|
||||
}
|
||||
)
|
||||
== "!"
|
||||
)
|
||||
|
||||
|
||||
def test_per_room_session_state_reuses_stable_session_key() -> None:
|
||||
state = GatewaySessionState(room_name="kitchen-room", agent_id="helper")
|
||||
|
||||
assert state.session_key == "livekit:kitchen-room:helper"
|
||||
state.session_key = "gateway-session-123"
|
||||
assert state.session_key == "gateway-session-123"
|
||||
|
||||
|
||||
def test_build_connect_params_uses_backend_operator_defaults() -> None:
|
||||
params = build_connect_params(token="secret-token")
|
||||
|
||||
assert params["client"] == {
|
||||
"id": "gateway-client",
|
||||
"version": "livekit-custom-agent",
|
||||
"platform": "python",
|
||||
"mode": "backend",
|
||||
}
|
||||
assert params["role"] == "operator"
|
||||
assert params["scopes"] == ["operator.read", "operator.write"]
|
||||
assert params["auth"] == {"token": "secret-token"}
|
||||
assert "device" not in params
|
||||
|
||||
|
||||
def test_gateway_response_helpers_match_only_current_send_request() -> None:
|
||||
assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1")
|
||||
assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1")
|
||||
assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1")
|
||||
assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1")
|
||||
|
||||
|
||||
def test_hermes_llm_reports_provider_and_model() -> None:
|
||||
state = GatewaySessionState(room_name="kitchen", agent_id="helper")
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url="ws://gateway.test/ws",
|
||||
token="token",
|
||||
state=state,
|
||||
agent_id="helper",
|
||||
model_name="hermes-agent",
|
||||
)
|
||||
|
||||
assert gateway_llm.provider == "hermes-gateway"
|
||||
assert gateway_llm.model == "hermes-agent"
|
||||
|
||||
|
||||
def test_gateway_session_state_rejects_non_per_room_mode() -> None:
|
||||
with pytest.raises(ValueError, match="per_room"):
|
||||
GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn")
|
||||
|
||||
|
||||
async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
payload = json.loads(message.data)
|
||||
received.append(payload)
|
||||
method = payload.get("method")
|
||||
request_id = payload.get("id")
|
||||
if method == "connect":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.create":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "res",
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"result": {"sessionKey": "livekit:kitchen:helper"},
|
||||
}
|
||||
)
|
||||
elif method == "sessions.send":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "agent",
|
||||
"payload": {"delta": {"content": "你好"}},
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "res",
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}},
|
||||
}
|
||||
)
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||
token="secret-token",
|
||||
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||
agent_id="helper",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="杯子在哪里")
|
||||
|
||||
try:
|
||||
collected = await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await gateway_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert collected.text == "你好"
|
||||
assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"]
|
||||
send_request = received[2]
|
||||
assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper"
|
||||
assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}]
|
||||
|
||||
|
||||
def test_extract_text_delta_reads_final_message_content() -> None:
|
||||
assert (
|
||||
extract_text_delta(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "session.message.completed",
|
||||
"payload": {
|
||||
"message": {
|
||||
"content": [
|
||||
{"type": "text", "text": "完整回复"},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
== "完整回复"
|
||||
)
|
||||
|
||||
|
||||
def test_is_error_response_accepts_error_events() -> None:
|
||||
assert is_error_response(
|
||||
{"type": "event", "event": "agent.error", "payload": {"error": "boom"}},
|
||||
request_id="send-1",
|
||||
)
|
||||
assert is_error_response(
|
||||
{"type": "event", "event": "run.failed", "payload": {"message": "boom"}},
|
||||
request_id="send-1",
|
||||
)
|
||||
|
||||
|
||||
async def test_llm_stream_maps_gateway_error_events_to_api_connection_error(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
payload = json.loads(message.data)
|
||||
method = payload.get("method")
|
||||
request_id = payload.get("id")
|
||||
if method == "connect":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.create":
|
||||
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||
elif method == "sessions.send":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "event",
|
||||
"event": "run.failed",
|
||||
"payload": {"message": "gateway exploded"},
|
||||
}
|
||||
)
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
gateway_llm = HermesGatewayLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||
token=None,
|
||||
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||
agent_id="helper",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="hello")
|
||||
|
||||
try:
|
||||
with pytest.raises(APIConnectionError, match="gateway exploded"):
|
||||
await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await gateway_llm.aclose()
|
||||
await runner.cleanup()
|
||||
Reference in New Issue
Block a user