beaver #2

Open
verachen wants to merge 9 commits from beaver into main
11 changed files with 2558 additions and 69 deletions

View File

@ -2,7 +2,15 @@
LIVEKIT_URL=ws://localhost:7880
LIVEKIT_API_KEY=
LIVEKIT_API_SECRET=
CUSTOM_AGENT_NAME=my-agent
CUSTOM_AGENT_PROFILE=normal
# CUSTOM_AGENT_NAME=normal-agent
CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver
# Beaver terminal text WebSocket
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
TERMINAL_PEER_ID=device-001
TERMINAL_DEVICE_NAME=desk-terminal
# ASR blackbox
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
@ -12,25 +20,34 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
CUSTOM_ASR_MAX_SPEECH_DURATION=12
# Keep false if forced ASR turns should reply even while input audio continues.
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
# CUSTOM_LLM_PROVIDER=beaver
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
# OpenAI-compatible LLM
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
# CUSTOM_LLM_MODEL=Qwen3.6-35B
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_API_KEY=sk-
# CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_LLM_BASE_URL=http://localhost/v1
CUSTOM_LLM_MODEL=Qwen-VL
CUSTOM_LLM_API_KEY=
CUSTOM_LLM_BASE_URL=http:/localhost/v1
CUSTOM_LLM_MODEL=Mistral-Medium-3.5-128B
CUSTOM_LLM_API_KEY=sk-
CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_SAVE_MODEL_IMAGES=false
CUSTOM_SAVE_MODEL_IMAGES=true
# CUSTOM_TEXT_LLM_MODEL=
# CUSTOM_VISION_LLM_MODEL=
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_API_KEY=sk-
# CUSTOM_LLM_VERIFY_SSL=false
@ -71,4 +88,4 @@ CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
CUSTOM_MEMORY_TIMEOUT=2
CUSTOM_MEMORY_MAX_CHARS=2000
CUSTOM_MEMORY_API_KEY=
CUSTOM_PREEMPTIVE_GENERATION=true
CUSTOM_PREEMPTIVE_GENERATION=false

260
asr.py
View File

@ -1,11 +1,14 @@
import asyncio
import logging
from typing import Any, Optional, Union
from collections import deque
from collections.abc import AsyncIterable
from typing import Any
import aiohttp
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
@ -17,9 +20,15 @@ from livekit.agents import (
utils,
)
from livekit.agents.utils import is_given
from livekit.agents.vad import VAD, VADEventType
logger = logging.getLogger("blackbox-asr")
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
STT_CAPABILITIES = stt.STTCapabilities
class BlackboxSTT(stt.STT):
def __init__(
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
language: str | None = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
hotwords: str | None = None,
itn: bool | str | None = None,
chunk_mode: bool | str | None = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
def _form_value(value: bool | str) -> str:
if isinstance(value, bool):
return str(value).lower()
return value
class BoundedStreamAdapter(stt.STT):
def __init__(
self,
*,
stt: stt.STT,
vad: VAD,
max_speech_duration: float | None = 12.0,
pre_speech_duration: float = 0.5,
) -> None:
super().__init__(
capabilities=STT_CAPABILITIES(
streaming=True,
interim_results=False,
diarization=False,
)
)
self._vad = vad
self._stt = stt
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
self._stt.on("metrics_collected", self._on_metrics_collected)
@property
def wrapped_stt(self) -> stt.STT:
return self._stt
@property
def model(self) -> str:
return self._stt.model
@property
def provider(self) -> str:
return self._stt.provider
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
return _BoundedStreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
max_speech_duration=self._max_speech_duration,
pre_speech_duration=self._pre_speech_duration,
)
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
self.emit("metrics_collected", *args, **kwargs)
async def aclose(self) -> None:
self._stt.off("metrics_collected", self._on_metrics_collected)
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
def __init__(
self,
adapter: BoundedStreamAdapter,
*,
vad: VAD,
wrapped_stt: stt.STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
max_speech_duration: float | None,
pre_speech_duration: float,
) -> None:
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
async for _ in event_aiter:
pass
async def _run(self) -> None:
vad_stream = self._vad.stream()
lock = asyncio.Lock()
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
speech_active = False
segment_frames: list[rtc.AudioFrame] = []
segment_duration = 0.0
pre_roll_frames: deque[rtc.AudioFrame] = deque()
pre_roll_duration = 0.0
def _frame_duration(frame: rtc.AudioFrame) -> float:
return frame.samples_per_channel / frame.sample_rate
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
nonlocal pre_roll_duration
pre_roll_frames.append(frame)
pre_roll_duration += _frame_duration(frame)
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
if not frames:
return
if forced:
logger.info(
"Forcing ASR segment after %.2fs of continuous speech",
self._max_speech_duration,
)
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
await recognize_queue.put(frames)
async def _recognize_worker() -> None:
while True:
frames = await recognize_queue.get()
if frames is None:
return
merged_frames = utils.merge_frames(frames)
try:
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)
except Exception:
logger.exception("ASR segment recognition failed")
continue
if not t_event.alternatives or not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
async def _forward_input() -> None:
nonlocal segment_duration, segment_frames, speech_active
async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
vad_stream.flush()
continue
vad_stream.push_frame(input_frame)
forced_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active:
segment_frames.append(input_frame)
segment_duration += _frame_duration(input_frame)
if (
self._max_speech_duration is not None
and segment_duration >= self._max_speech_duration
):
forced_frames = segment_frames
segment_frames = []
segment_duration = 0.0
else:
_append_pre_roll(input_frame)
if forced_frames:
await _enqueue_segment(forced_frames, forced=True)
vad_stream.end_input()
final_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active and segment_frames:
final_frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
if final_frames:
await _enqueue_segment(final_frames)
async def _recognize_from_vad() -> None:
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
)
async with lock:
if not speech_active:
speech_active = True
segment_frames = list(pre_roll_frames)
segment_duration = sum(_frame_duration(f) for f in segment_frames)
pre_roll_frames.clear()
pre_roll_duration = 0.0
continue
if event.type != VADEventType.END_OF_SPEECH:
continue
async with lock:
frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
await _enqueue_segment(frames)
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
tasks = [
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
]
try:
await asyncio.gather(*tasks)
await recognize_queue.put(None)
await worker_task
finally:
await utils.aio.cancel_and_wait(*tasks, worker_task)
await vad_stream.aclose()

128
beaver_llm.py Normal file
View File

@ -0,0 +1,128 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Sequence
from typing import Any
try:
from beaver_terminal_client import BeaverTerminalClient
except ModuleNotFoundError:
from custom.beaver_terminal_client import BeaverTerminalClient
from livekit.agents import llm
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
NotGivenOr,
)
from livekit.agents.utils import shortuuid
logger = logging.getLogger("beaver-llm")
def latest_user_text(chat_ctx: llm.ChatContext) -> str:
for message in reversed(chat_ctx.messages()):
if message.role != "user":
continue
return _content_to_text(message.content)
return ""
def _content_to_text(content: Sequence[llm.ChatContent]) -> str:
text_parts = [item for item in content if isinstance(item, str)]
return "\n".join(text_parts)
class BeaverLLM(llm.LLM):
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
model_name: str = "beaver-terminal",
) -> None:
super().__init__()
self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name)
self._model_name = model_name
self._lock = asyncio.Lock()
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "beaver"
@property
def session_id(self) -> str | None:
return self._client.session_id
async def connect(self, *, warmup_text: str | None = None) -> str | None:
warmup_reply: str | None = None
async with self._lock:
await self._client.connect()
if warmup_text and warmup_text.strip():
warmup_reply = await self._client.send_text(warmup_text.strip())
if warmup_reply is None:
logger.info("Beaver handshake completed session_id=%s", self.session_id)
else:
logger.info(
"Beaver handshake warmup completed session_id=%s reply_len=%s",
self.session_id,
len(warmup_reply),
)
return warmup_reply
def chat(
self,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool] | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> llm.LLMStream:
return BeaverLLMStream(
self,
chat_ctx=chat_ctx,
tools=tools or [],
conn_options=conn_options,
)
async def aclose(self) -> None:
await self._client.close()
class BeaverLLMStream(llm.LLMStream):
def __init__(
self,
beaver_llm: BeaverLLM,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool],
conn_options: APIConnectOptions,
) -> None:
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._beaver_llm = beaver_llm
self._request_id = shortuuid("beaver_")
async def _run(self) -> None:
user_text = latest_user_text(self.chat_ctx)
async with self._beaver_llm._lock:
reply = await self._beaver_llm._client.send_text(user_text)
if reply:
self._send_text_chunk(reply)
def _send_text_chunk(self, text: str) -> None:
self._event_ch.send_nowait(
llm.ChatChunk(
id=self._request_id,
delta=llm.ChoiceDelta(role="assistant", content=text),
)
)

238
beaver_terminal_client.py Normal file
View File

@ -0,0 +1,238 @@
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import aiohttp
from dotenv import load_dotenv
logger = logging.getLogger("beaver-terminal-client")
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
DEFAULT_TERMINAL_PEER_ID = "device-001"
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
nonce: str | None = None
initial_counter: int = 0
def __post_init__(self) -> None:
self.counter = self.initial_counter
def next_id(self) -> str:
self.counter += 1
if self.nonce:
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
return {
"type": "connect",
"peer_id": peer_id,
"device_name": device_name,
"capabilities": ["text"],
}
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
return {
"type": "message",
"message_id": message_id,
"text": text,
}
class BeaverTerminalClient:
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
http_session: aiohttp.ClientSession | None = None,
message_ids: MessageIdGenerator | None = None,
) -> None:
self._url = url
self._peer_id = peer_id
self._device_name = device_name
self._owned_session = http_session is None
self._http_session = http_session
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
try:
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
await self._cleanup_failed_connection()
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
except Exception:
await self._cleanup_failed_connection()
raise
async def send_text(self, text: str) -> str:
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
message_id = self._message_ids.next_id()
message_frame = build_message_frame(message_id=message_id, text=text)
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
return reply
continue
if (
frame_type == "message"
and frame.get("role") == "assistant"
and frame.get("message_id") == message_id
):
text = frame.get("text")
if frame.get("finish_reason") == "error":
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
return text if isinstance(text, str) else ""
if frame_type == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def ping(self) -> bool:
await self._send_json({"type": "ping"})
while True:
frame = await self._receive_json()
if frame.get("type") == "pong":
return True
if frame.get("type") == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
async def _cleanup_failed_connection(self) -> None:
await self._close_websocket()
self.session_id = None
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
return self._http_session
async def _send_json(self, frame: dict[str, Any]) -> None:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
await self._ws.send_json(frame)
async def _receive_json(self) -> dict[str, Any]:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
if not isinstance(data, dict):
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
return data
def client_from_env() -> BeaverTerminalClient:
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
return BeaverTerminalClient(
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
)
async def run_console() -> None:
logging.basicConfig(level=logging.INFO)
client = client_from_env()
try:
await client.connect()
logger.info("Connected to Beaver session_id=%s", client.session_id)
while True:
text = await asyncio.to_thread(input, "> ")
text = text.strip()
if not text:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(run_console())

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import json
import logging
@ -8,11 +9,14 @@ from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path
from beaver_llm import BeaverLLM
from beaver_terminal_client import BeaverTerminalError
from dotenv import load_dotenv
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from asr import BlackboxSTT, BoundedStreamAdapter
from livekit.agents import (
Agent,
AgentServer,
@ -30,7 +34,6 @@ from livekit.agents import (
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
@ -40,7 +43,6 @@ logger = logging.getLogger("custom-agent")
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
ROOM_LOCATOR_INSTRUCTIONS = """
你是一个房间物品定位助手。
@ -60,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
@ -69,18 +71,121 @@ GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_AGENT_PROFILE = "normal"
@dataclass(frozen=True)
class AgentProfile:
agent_name: str
llm_provider: str
input_mode: str
AGENT_PROFILES = {
"normal": AgentProfile(
agent_name="normal-agent",
llm_provider="openai-compatible",
input_mode=VOICE_INPUT_MODE,
),
"beaver": AgentProfile(
agent_name="beaver-agent",
llm_provider="beaver",
input_mode=VOICE_INPUT_MODE,
),
"vision-normal": AgentProfile(
agent_name="vision-normal-agent",
llm_provider="openai-compatible",
input_mode=VISION_VOICE_INPUT_MODE,
),
"vision-beaver": AgentProfile(
agent_name="vision-beaver-agent",
llm_provider="beaver",
input_mode=VISION_VOICE_INPUT_MODE,
),
}
AGENT_PROFILE_ALIASES = {
"default": "normal",
"openai": "normal",
"openai-compatible": "normal",
"llm": "normal",
"text": "normal",
"voice": "normal",
"vision": "vision-normal",
"vision-llm": "vision-normal",
"vision-openai": "vision-normal",
"vision-openai-compatible": "vision-normal",
}
def _normalize_agent_profile(value: str | None) -> str:
if not value or not value.strip():
return DEFAULT_AGENT_PROFILE
normalized = value.strip().lower().replace("_", "-")
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
if profile in AGENT_PROFILES:
return profile
logger.warning(
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
value,
DEFAULT_AGENT_PROFILE,
)
return DEFAULT_AGENT_PROFILE
def _agent_profile_from_name(agent_name: str | None) -> str | None:
if not agent_name or not agent_name.strip():
return None
normalized = agent_name.strip().lower().replace("_", "-")
for profile_name, profile in AGENT_PROFILES.items():
if normalized == profile.agent_name:
return profile_name
return None
def _selected_agent_profile_name() -> str:
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
if configured_profile and configured_profile.strip():
return _normalize_agent_profile(configured_profile)
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
if inferred_profile is not None:
return inferred_profile
return DEFAULT_AGENT_PROFILE
AGENT_PROFILE_NAME = _selected_agent_profile_name()
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"confident",
"confused",
"cool",
"crying",
"delicious",
"embarrassed",
"funny",
"happy",
"kissy",
"laughing",
"loving",
"neutral",
"relaxed",
"sad",
"shocked",
"silly",
"sleepy",
"surprised",
"fearful",
"calm",
"concerned",
"thinking",
"winking",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
@ -290,6 +395,7 @@ class CustomAgent(Agent):
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
@ -545,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
def _with_vision_as_latest_user_message(
chat_ctx: ChatContext, vision_frame: VisionFrame
) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
@ -612,7 +720,25 @@ def _model_image_save_dir_from_env() -> Path | None:
return Path(__file__).with_name("model_images")
server = AgentServer()
def _agent_server_from_env() -> AgentServer:
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
if configured_port is None:
return AgentServer()
try:
port = int(configured_port)
except ValueError:
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
port = 0
if port < 0 or port > 65535:
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
port = 0
return AgentServer(port=port)
server = _agent_server_from_env()
def prewarm(proc: JobProcess) -> None:
@ -634,16 +760,37 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if not LLM_API_KEY:
INPUT_MODE = _normalize_input_mode(
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
)
if LLM_PROVIDER not in {
"openai",
"openai-compatible",
"hermes",
"hermes_gateway",
"openclaw",
"beaver",
}:
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
logger.info(
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
AGENT_PROFILE_NAME,
AGENT_NAME or "<automatic>",
INPUT_MODE,
LLM_PROVIDER,
LLM_MODEL,
LLM_BASE_URL or "OpenAI default",
)
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
@ -656,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
beaver_warmup_text: str | None = None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -666,40 +814,117 @@ async def entrypoint(ctx: JobContext) -> None:
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
stt_stream = BoundedStreamAdapter(
stt=blackbox_stt,
vad=ctx.proc.userdata["vad"],
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
)
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
allow_interruptions = _env_bool(
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
ASR_MAX_SPEECH_DURATION <= 0,
)
import httpx
from openai import AsyncClient as OpenAIAsyncClient
if LLM_PROVIDER == "beaver":
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
if not beaver_url:
raise RuntimeError(
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
)
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
beaver_peer_id = (
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
)
beaver_device_name = (
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
or "livekit-custom-agent"
)
base_llm = BeaverLLM(
url=beaver_url,
peer_id=beaver_peer_id,
device_name=beaver_device_name,
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
)
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
beaver_url,
beaver_peer_id,
beaver_device_name,
ctx.room.name,
bool(beaver_warmup_text and beaver_warmup_text.strip()),
)
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
if not gateway_url:
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None
hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower()
if hermes_session_mode != "per_room":
raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room")
hermes_token = (
os.getenv("CUSTOM_HERMES_API_KEY")
or os.getenv("CUSTOM_HERMES_TOKEN")
or LLM_API_KEY
or None
)
hermes_state = GatewaySessionState(
room_name=ctx.room.name,
agent_id=hermes_agent_id,
session_mode=hermes_session_mode,
)
base_llm = HermesGatewayLLM(
url=gateway_url,
token=hermes_token,
state=hermes_state,
agent_id=hermes_agent_id,
model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"),
request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0),
)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s",
gateway_url,
hermes_agent_id or "default",
hermes_state.session_key,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
import httpx
from openai import AsyncClient as OpenAIAsyncClient
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_store = VisionFrameStore(
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
)
@ -707,7 +932,7 @@ async def entrypoint(ctx: JobContext) -> None:
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
# 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway.
llm=base_llm,
# 3. TTS blackbox
tts=BlackboxTTS(
@ -721,13 +946,14 @@ async def entrypoint(ctx: JobContext) -> None:
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
turn_detection=turn_detection,
interruption={
"enabled": allow_interruptions,
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", LLM_PROVIDER != "beaver"),
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
@ -753,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
def _interrupt_done(fut: asyncio.Future[None]) -> None:
try:
fut.result()
except Exception:
logger.exception("Bridge interrupt failed")
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
reason = None if payload is None else payload.get("reason")
logger.info(
"Received bridge interrupt: topic=%s reason=%s",
packet_topic or "<default>",
reason if isinstance(reason, str) and reason else "<unspecified>",
)
try:
interrupt_fut = session.interrupt(force=True)
except RuntimeError:
logger.exception("Bridge interrupt received before AgentSession was running")
return
interrupt_fut.add_done_callback(_interrupt_done)
if packet_topic == INTERRUPT_TOPIC:
payload: dict[str, object] | None = None
try:
decoded = json.loads(data_packet.data.decode("utf-8"))
if isinstance(decoded, dict):
payload = decoded
except Exception:
logger.exception("Failed to decode interrupt payload")
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
logger.exception("Failed to decode data payload")
return
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
@ -820,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
),
record=_recording_options_from_env(),
)
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
_start_beaver_background_warmup(
ctx=ctx,
beaver_llm=base_llm,
warmup_text=beaver_warmup_text,
)
def _start_beaver_background_warmup(
*,
ctx: JobContext,
beaver_llm: BeaverLLM,
warmup_text: str | None,
) -> None:
async def _warmup() -> None:
try:
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
except BeaverTerminalError:
logger.warning(
"Beaver background handshake failed; will retry on first user turn",
exc_info=True,
)
return
except Exception:
logger.exception("Unexpected Beaver background handshake failure")
return
logger.info(
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
ctx.room.name,
beaver_llm.session_id,
bool(warmup_text and warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
async def _cancel_warmup() -> None:
if warmup_task.done():
return
warmup_task.cancel()
try:
await warmup_task
except asyncio.CancelledError:
pass
ctx.add_shutdown_callback(_cancel_warmup)
def _tts_params_from_env(model_name: str) -> dict[str, str]:
@ -917,6 +1226,14 @@ def _env_bool(name: str, default: bool) -> bool:
return default
def _first_env(*names: str) -> str | None:
for name in names:
value = os.getenv(name)
if value and value.strip():
return value.strip()
return None
def _recording_options_from_env() -> RecordingOptions:
return RecordingOptions(
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),

391
hermes_gateway.py Normal file
View File

@ -0,0 +1,391 @@
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
import aiohttp
from livekit.agents import llm
from livekit.agents._exceptions import APIConnectionError
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
NotGivenOr,
)
from livekit.agents.utils import shortuuid
@dataclass
class GatewaySessionState:
room_name: str
agent_id: str | None = None
session_key: str | None = None
session_mode: str = "per_room"
def __post_init__(self) -> None:
if self.session_mode != "per_room":
raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room")
if self.session_key is None:
suffix = self.agent_id or "default"
self.session_key = f"livekit:{self.room_name}:{suffix}"
class HermesGatewayLLM(llm.LLM):
def __init__(
self,
*,
url: str,
token: str | None,
state: GatewaySessionState,
agent_id: str | None = None,
model_name: str = "hermes-agent",
request_timeout: float = 30.0,
) -> None:
super().__init__()
self._url = url
self._token = token
self._state = state
self._agent_id = agent_id
self._model_name = model_name
self._request_timeout = request_timeout
self._http_session: aiohttp.ClientSession | None = None
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "hermes-gateway"
def chat(
self,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool] | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> llm.LLMStream:
return HermesGatewayLLMStream(
self,
chat_ctx=chat_ctx,
tools=tools or [],
conn_options=conn_options,
)
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
timeout = aiohttp.ClientTimeout(total=self._request_timeout)
self._http_session = aiohttp.ClientSession(timeout=timeout)
return self._http_session
async def aclose(self) -> None:
if self._http_session is not None:
await self._http_session.close()
self._http_session = None
class HermesGatewayLLMStream(llm.LLMStream):
def __init__(
self,
llm: HermesGatewayLLM,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool],
conn_options: APIConnectOptions,
) -> None:
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._llm = llm
async def _run(self) -> None:
request_id = shortuuid("gwreq_")
async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws:
await _connect(ws, token=self._llm._token)
await _send_rpc(
ws,
method="sessions.create",
params={
"key": self._llm._state.session_key,
"sessionKey": self._llm._state.session_key,
"agentId": self._llm._agent_id,
"metadata": {
"source": "livekit",
"room": self._llm._state.room_name,
},
"idempotencyKey": self._llm._state.session_key,
},
request_id=shortuuid("gwcreate_"),
)
await _send_rpc(
ws,
method="sessions.send",
params={
"key": self._llm._state.session_key,
"sessionKey": self._llm._state.session_key,
"agentId": self._llm._agent_id,
"messages": chat_context_to_gateway_messages(self.chat_ctx),
"stream": True,
"idempotencyKey": request_id,
},
request_id=request_id,
wait_response=False,
)
streamed_text = ""
async for frame in _iter_gateway_frames(ws):
if is_error_response(frame, request_id=request_id):
raise APIConnectionError(_gateway_error_message(frame), retryable=False)
text = extract_text_delta(frame)
if text:
if text.startswith(streamed_text):
text = text[len(streamed_text) :]
if text:
streamed_text += text
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(role="assistant", content=text),
)
)
if is_terminal_event(frame, request_id=request_id):
return
def build_connect_params(*, token: str | None) -> dict[str, Any]:
params: dict[str, Any] = {
"minProtocol": 3,
"maxProtocol": 4,
"client": {
"id": "gateway-client",
"version": "livekit-custom-agent",
"platform": "python",
"mode": "backend",
},
"role": "operator",
"scopes": ["operator.read", "operator.write"],
"caps": [],
"commands": [],
"permissions": {},
"locale": "zh-CN",
"userAgent": "livekit-custom-agent",
}
if token:
params["auth"] = {"token": token}
return params
def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for message in chat_ctx.messages():
content = _message_content_to_gateway_content(message.content)
if content is None:
continue
messages.append({"role": message.role, "content": content})
return messages
def extract_text_delta(frame: dict[str, Any]) -> str:
payload = frame.get("payload")
if not isinstance(payload, dict):
payload = frame
for path in (
("delta", "content"),
("delta", "text"),
("message", "delta", "content"),
("message", "delta", "text"),
("message", "content"),
("content",),
("text",),
):
value = _get_nested(payload, path)
text = _content_to_text(value)
if text:
return text
return ""
def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool:
if frame.get("type") == "res" and frame.get("id") == request_id:
return True
event = frame.get("event")
if event in {
"agent.done",
"agent.completed",
"agent.error",
"session.message.completed",
"session.run.completed",
"sessions.run.completed",
"run.completed",
"run.failed",
}:
return True
payload = frame.get("payload")
if isinstance(payload, dict) and payload.get("done") is True:
return True
return False
def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool:
if (
frame.get("type") == "res"
and frame.get("id") == request_id
and frame.get("ok") is False
):
return True
return frame.get("event") in {
"agent.error",
"session.error",
"session.run.failed",
"sessions.run.failed",
"run.failed",
}
def _gateway_error_message(frame: dict[str, Any]) -> str:
error = frame.get("error")
if isinstance(error, str):
return f"OpenClaw gateway request failed: {error}"
if isinstance(error, dict):
message = error.get("message") or error.get("error")
if isinstance(message, str):
return f"OpenClaw gateway request failed: {message}"
payload = frame.get("payload")
if isinstance(payload, dict):
message = payload.get("message") or payload.get("error")
if isinstance(message, str):
return f"OpenClaw gateway request failed: {message}"
return f"OpenClaw gateway request failed: {frame!r}"
async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None:
first = await _receive_json(ws)
if first.get("event") != "connect.challenge":
raise RuntimeError(f"expected connect.challenge, received {first!r}")
request_id = shortuuid("gwconnect_")
await _send_rpc(
ws,
method="connect",
params=build_connect_params(token=token),
request_id=request_id,
wait_response=False,
)
response = await _wait_for_response(ws, request_id=request_id)
if not response.get("ok"):
raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}")
async def _send_rpc(
ws: aiohttp.ClientWebSocketResponse,
*,
method: str,
params: dict[str, Any],
request_id: str,
wait_response: bool = True,
) -> dict[str, Any] | None:
await ws.send_str(
json.dumps(
{
"type": "req",
"id": request_id,
"method": method,
"params": _drop_none(params),
}
)
)
if not wait_response:
return None
response = await _wait_for_response(ws, request_id=request_id)
if not response.get("ok", False):
raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}")
return response
async def _wait_for_response(
ws: aiohttp.ClientWebSocketResponse, *, request_id: str
) -> dict[str, Any]:
async for frame in _iter_gateway_frames(ws):
if frame.get("type") == "res" and frame.get("id") == request_id:
return frame
raise RuntimeError(f"OpenClaw gateway closed before response {request_id}")
async def _iter_gateway_frames(
ws: aiohttp.ClientWebSocketResponse,
) -> AsyncIterator[dict[str, Any]]:
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
yield json.loads(message.data)
elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
return
elif message.type == aiohttp.WSMsgType.ERROR:
raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}")
async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]:
message = await ws.receive()
if message.type != aiohttp.WSMsgType.TEXT:
raise RuntimeError(f"expected gateway text frame, received {message.type!r}")
return json.loads(message.data)
def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any:
parts: list[dict[str, Any]] = []
for item in content:
if isinstance(item, str):
if item:
parts.append({"type": "text", "text": item})
elif isinstance(item, llm.ImageContent) and isinstance(item.image, str):
parts.append({"type": "image_url", "image_url": {"url": item.image}})
if not parts:
return None
if len(parts) == 1 and parts[0]["type"] == "text":
return parts[0]["text"]
return parts
def _content_to_text(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, list):
text_parts: list[str] = []
for item in value:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
text_parts.append(text)
return "".join(text_parts)
return ""
def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any:
current: Any = data
for key in path:
if not isinstance(current, dict):
return None
current = current.get(key)
return current
def _drop_none(value: Any) -> Any:
if isinstance(value, dict):
return {key: _drop_none(item) for key, item in value.items() if item is not None}
if isinstance(value, list):
return [_drop_none(item) for item in value]
return value

264
start_agent_profiles.py Normal file
View File

@ -0,0 +1,264 @@
from __future__ import annotations
import argparse
import asyncio
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from dotenv import load_dotenv
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
@dataclass(frozen=True)
class AgentProfile:
agent_name: str
llm_provider: str
input_mode: str
AGENT_PROFILES = {
"normal": AgentProfile(
agent_name="normal-agent",
llm_provider="openai-compatible",
input_mode="voice",
),
"beaver": AgentProfile(
agent_name="beaver-agent",
llm_provider="beaver",
input_mode="voice",
),
"vision-normal": AgentProfile(
agent_name="vision-normal-agent",
llm_provider="openai-compatible",
input_mode="vision_voice",
),
"vision-beaver": AgentProfile(
agent_name="vision-beaver-agent",
llm_provider="beaver",
input_mode="vision_voice",
),
}
DEFAULT_PROFILES = ("normal", "beaver")
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def _parse_profiles(value: str) -> list[str]:
if value.strip().lower() == "all":
return list(AGENT_PROFILES)
profiles: list[str] = []
for raw_profile in value.split(","):
profile = raw_profile.strip().lower().replace("_", "-")
if not profile:
continue
if profile not in AGENT_PROFILES:
valid = ", ".join([*AGENT_PROFILES, "all"])
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
profiles.append(profile)
if not profiles:
raise ValueError("at least one profile is required")
return list(dict.fromkeys(profiles))
def _parse_http_port_base(value: str | None) -> int:
if value is None or not value.strip():
return 0
try:
port = int(value)
except ValueError as exc:
raise ValueError(f"invalid HTTP port base {value!r}") from exc
if port < 0 or port > 65535:
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
return port
def _profile_http_port(http_port_base: int, index: int) -> int:
if http_port_base == 0:
return 0
port = http_port_base + index
if port > 65535:
raise ValueError("HTTP port range exceeds 65535")
return port
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
profile = AGENT_PROFILES[profile_name]
env = os.environ.copy()
env.update(
{
"CUSTOM_AGENT_PROFILE": profile_name,
"CUSTOM_AGENT_NAME": profile.agent_name,
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
}
)
return env
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
while line := await stream.readline():
text = line.decode("utf-8", errors="replace").rstrip()
print(f"[{prefix}] {text}", flush=True)
async def _start_profile(
profile_name: str,
*,
mode: str,
http_port: int,
reload: bool,
) -> asyncio.subprocess.Process:
profile = AGENT_PROFILES[profile_name]
script_path = Path(__file__).with_name("custom_agent.py")
args = [sys.executable, str(script_path), mode]
if mode == "dev" and not reload:
args.append("--no-reload")
process = await asyncio.create_subprocess_exec(
*args,
env=_child_env(profile_name, http_port=http_port),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
print(
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
flush=True,
)
if process.stdout is not None:
asyncio.create_task(_pipe_output(profile_name, process.stdout))
if process.stderr is not None:
asyncio.create_task(_pipe_output(profile_name, process.stderr))
return process
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
for process in processes:
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(
asyncio.gather(*(process.wait() for process in processes)),
timeout=10.0,
)
except asyncio.TimeoutError:
for process in processes:
if process.returncode is None:
process.kill()
await asyncio.gather(*(process.wait() for process in processes))
async def _run(
profiles: list[str],
*,
mode: str,
http_port_base: int,
reload: bool,
) -> int:
processes = [
await _start_profile(
profile,
mode=mode,
http_port=_profile_http_port(http_port_base, index),
reload=reload,
)
for index, profile in enumerate(profiles)
]
stop_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, stop_event.set)
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
stop_task = asyncio.create_task(stop_event.wait())
done, _ = await asyncio.wait(
[*wait_tasks, stop_task],
return_when=asyncio.FIRST_COMPLETED,
)
exit_code = 0
if stop_task not in done:
for task in done:
process = wait_tasks[task]
exit_code = task.result() or 0
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
await _terminate(processes)
return exit_code
def main() -> None:
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
parser.add_argument(
"mode",
nargs="?",
default="dev",
choices=("console", "dev", "start", "connect"),
help="custom_agent.py CLI mode to run for each profile",
)
parser.add_argument(
"--profiles",
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
help="comma-separated profiles to start, or 'all'",
)
parser.add_argument(
"--http-port-base",
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
help=(
"base HTTP health-check port for profile workers; "
"0 lets the OS assign free ports"
),
)
parser.add_argument(
"--reload",
action="store_true",
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
help="enable auto-reload in dev mode",
)
args = parser.parse_args()
try:
profiles = _parse_profiles(args.profiles)
http_port_base = _parse_http_port_base(args.http_port_base)
except ValueError as exc:
parser.error(str(exc))
raise SystemExit(
asyncio.run(
_run(
profiles,
mode=args.mode,
http_port_base=http_port_base,
reload=args.reload,
)
)
)
if __name__ == "__main__":
main()

169
test_beaver_llm.py Normal file
View File

@ -0,0 +1,169 @@
import asyncio
import json
import aiohttp
from aiohttp import web
try:
from custom.beaver_llm import BeaverLLM, latest_user_text
except ModuleNotFoundError:
from beaver_llm import BeaverLLM, latest_user_text
from livekit.agents import ChatContext
def test_latest_user_text_uses_most_recent_user_message() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="user", content="first")
ctx.add_message(role="assistant", content="ignored")
ctx.add_message(role="user", content=["second", "line"])
assert latest_user_text(ctx) == "second\nline"
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="system", content="ignored instructions")
ctx.add_message(role="user", content="hello beaver")
try:
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await asyncio.sleep(0.05)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello beaver")
try:
chunks: list[str] = []
async with beaver_llm.chat(chat_ctx=ctx) as stream:
async for chunk in stream:
if chunk.delta and chunk.delta.content:
chunks.append(chunk.delta.content)
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert chunks == ["beaver reply"]

View File

@ -0,0 +1,458 @@
import asyncio
import json
import sys
from pathlib import Path
import aiohttp
import pytest
from aiohttp import web
if __name__ == "__main__":
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
raise SystemExit(pytest.main([__file__]))
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
def test_build_connect_frame_uses_stable_peer_id() -> None:
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
assert frame == {
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
}
def test_build_message_frame_uses_message_id_and_text() -> None:
frame = build_message_frame(message_id="device-001-000001", text="hello")
assert frame == {
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
}
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
assert generator.next_id() == "device-001-000008"
assert generator.next_id() == "device-001-000009"
assert generator.counter == 9
def test_message_id_generator_can_include_nonce() -> None:
generator = MessageIdGenerator(peer_id="device-001", nonce="run12345")
assert generator.next_id() == "device-001-run12345-000001"
assert generator.next_id() == "device-001-run12345-000002"
async def test_client_connects_sends_text_and_returns_assistant_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "assistant reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert client.session_id == "terminal-dev:local:device-001"
assert reply == "assistant reply"
assert received == [
{
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
},
{
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
},
]
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": False,
"duplicate": True,
"pending": False,
"reply": "cached assistant reply",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "cached assistant reply"
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json({"type": "error", "error": "text is required"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="text is required"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_cleans_up_owned_session_when_connect_fails(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.Response:
return web.Response(status=200, text="not a websocket")
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
await client.connect()
assert client._http_session is None
assert client._ws is None
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "failed turn",
"finish_reason": "error",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="failed turn"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "ping":
await ws.send_json({"type": "pong"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
assert await client.ping()
finally:
await client.close()
await runner.cleanup()
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
unused_tcp_port: int,
) -> None:
connect_peer_ids: list[str] = []
message_ids: list[str] = []
connection_count = 0
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
nonlocal connection_count
connection_count += 1
current_connection = connection_count
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
connect_peer_ids.append(frame["peer_id"])
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
message_ids.append(frame["message_id"])
if current_connection == 1:
await ws.close()
continue
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-2",
"text": "reply after reconnect",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
await asyncio.sleep(0.01)
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "reply after reconnect"
assert connect_peer_ids == ["device-001", "device-001"]
assert message_ids == ["device-001-000001", "device-001-000002"]

262
test_hermes_gateway.py Normal file
View File

@ -0,0 +1,262 @@
import json
import aiohttp
import pytest
from aiohttp import web
from custom.hermes_gateway import (
GatewaySessionState,
HermesGatewayLLM,
build_connect_params,
chat_context_to_gateway_messages,
extract_text_delta,
is_error_response,
is_terminal_event,
)
from livekit.agents import ChatContext, llm
from livekit.agents._exceptions import APIConnectionError
def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="system", content="system prompt")
ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")])
messages = chat_context_to_gateway_messages(ctx)
assert messages == [
{"role": "system", "content": "system prompt"},
{
"role": "user",
"content": [
{"type": "text", "text": "look here"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
},
]
def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None:
assert (
extract_text_delta(
{"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}}
)
== "hi"
)
assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there"
assert (
extract_text_delta(
{
"type": "event",
"event": "session.message.delta",
"payload": {"message": {"content": [{"type": "text", "text": "!"}]}},
}
)
== "!"
)
def test_per_room_session_state_reuses_stable_session_key() -> None:
state = GatewaySessionState(room_name="kitchen-room", agent_id="helper")
assert state.session_key == "livekit:kitchen-room:helper"
state.session_key = "gateway-session-123"
assert state.session_key == "gateway-session-123"
def test_build_connect_params_uses_backend_operator_defaults() -> None:
params = build_connect_params(token="secret-token")
assert params["client"] == {
"id": "gateway-client",
"version": "livekit-custom-agent",
"platform": "python",
"mode": "backend",
}
assert params["role"] == "operator"
assert params["scopes"] == ["operator.read", "operator.write"]
assert params["auth"] == {"token": "secret-token"}
assert "device" not in params
def test_gateway_response_helpers_match_only_current_send_request() -> None:
assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1")
assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1")
assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1")
assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1")
def test_hermes_llm_reports_provider_and_model() -> None:
state = GatewaySessionState(room_name="kitchen", agent_id="helper")
gateway_llm = HermesGatewayLLM(
url="ws://gateway.test/ws",
token="token",
state=state,
agent_id="helper",
model_name="hermes-agent",
)
assert gateway_llm.provider == "hermes-gateway"
assert gateway_llm.model == "hermes-agent"
def test_gateway_session_state_rejects_non_per_room_mode() -> None:
with pytest.raises(ValueError, match="per_room"):
GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn")
async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
payload = json.loads(message.data)
received.append(payload)
method = payload.get("method")
request_id = payload.get("id")
if method == "connect":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.create":
await ws.send_json(
{
"type": "res",
"id": request_id,
"ok": True,
"result": {"sessionKey": "livekit:kitchen:helper"},
}
)
elif method == "sessions.send":
await ws.send_json(
{
"type": "event",
"event": "agent",
"payload": {"delta": {"content": "你好"}},
}
)
await ws.send_json(
{
"type": "res",
"id": request_id,
"ok": True,
"result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}},
}
)
await ws.close()
return ws
app = web.Application()
app.router.add_get("/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
gateway_llm = HermesGatewayLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
token="secret-token",
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
agent_id="helper",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="杯子在哪里")
try:
collected = await gateway_llm.chat(chat_ctx=ctx).collect()
finally:
await gateway_llm.aclose()
await runner.cleanup()
assert collected.text == "你好"
assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"]
send_request = received[2]
assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper"
assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}]
def test_extract_text_delta_reads_final_message_content() -> None:
assert (
extract_text_delta(
{
"type": "event",
"event": "session.message.completed",
"payload": {
"message": {
"content": [
{"type": "text", "text": "完整回复"},
]
}
},
}
)
== "完整回复"
)
def test_is_error_response_accepts_error_events() -> None:
assert is_error_response(
{"type": "event", "event": "agent.error", "payload": {"error": "boom"}},
request_id="send-1",
)
assert is_error_response(
{"type": "event", "event": "run.failed", "payload": {"message": "boom"}},
request_id="send-1",
)
async def test_llm_stream_maps_gateway_error_events_to_api_connection_error(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
payload = json.loads(message.data)
method = payload.get("method")
request_id = payload.get("id")
if method == "connect":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.create":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.send":
await ws.send_json(
{
"type": "event",
"event": "run.failed",
"payload": {"message": "gateway exploded"},
}
)
await ws.close()
return ws
app = web.Application()
app.router.add_get("/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
gateway_llm = HermesGatewayLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
token=None,
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
agent_id="helper",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello")
try:
with pytest.raises(APIConnectionError, match="gateway exploded"):
await gateway_llm.chat(chat_ctx=ctx).collect()
finally:
await gateway_llm.aclose()
await runner.cleanup()

7
tts.py
View File

@ -7,7 +7,6 @@ import time
import wave
from collections.abc import Mapping
from io import BytesIO
from typing import Optional
import aiohttp
@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS):
*,
url: str,
model_name: str = "voxcpmtts",
params: Optional[Mapping[str, object]] = None,
prompt_wav_path: Optional[str] = None,
params: Mapping[str, object] | None = None,
prompt_wav_path: str | None = None,
prompt_wav_field: str = "prompt_wav",
sample_rate: int = 16000,
num_channels: int = 1,
timeout: float = 60.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),