diff --git a/.env.example b/.env.example index 8f023d6..001ecc9 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,15 @@ LIVEKIT_URL=ws://localhost:7880 LIVEKIT_API_KEY= LIVEKIT_API_SECRET= -CUSTOM_AGENT_NAME=my-agent + +CUSTOM_AGENT_PROFILE=normal +# CUSTOM_AGENT_NAME=normal-agent +CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver + +# Beaver terminal text WebSocket +BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws +TERMINAL_PEER_ID=device-001 +TERMINAL_DEVICE_NAME=desk-terminal # ASR blackbox CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox @@ -12,25 +20,34 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh CUSTOM_ASR_HOTWORDS= CUSTOM_ASR_ITN= CUSTOM_ASR_CHUNK_MODE= +# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable. +CUSTOM_ASR_MAX_SPEECH_DURATION=12 +# Keep false if forced ASR turns should reply even while input audio continues. +CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false + +# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver. +# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override. +# CUSTOM_LLM_PROVIDER=beaver +CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready # OpenAI-compatible LLM # CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1 # CUSTOM_LLM_MODEL=Qwen3.6-35B -# CUSTOM_LLM_API_KEY= +# CUSTOM_LLM_API_KEY=sk- # CUSTOM_LLM_VERIFY_SSL=false -CUSTOM_LLM_BASE_URL=http://localhost/v1 -CUSTOM_LLM_MODEL=Qwen-VL -CUSTOM_LLM_API_KEY= +CUSTOM_LLM_BASE_URL=http:/localhost/v1 +CUSTOM_LLM_MODEL=Mistral-Medium-3.5-128B +CUSTOM_LLM_API_KEY=sk- CUSTOM_LLM_VERIFY_SSL=false -CUSTOM_SAVE_MODEL_IMAGES=false +CUSTOM_SAVE_MODEL_IMAGES=true # CUSTOM_TEXT_LLM_MODEL= # CUSTOM_VISION_LLM_MODEL= # CUSTOM_LLM_BASE_URL=https://api.deepseek.com # CUSTOM_LLM_MODEL=deepseek-v4-flash -# CUSTOM_LLM_API_KEY= +# CUSTOM_LLM_API_KEY=sk- # CUSTOM_LLM_VERIFY_SSL=false @@ -71,4 +88,4 @@ CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph CUSTOM_MEMORY_TIMEOUT=2 CUSTOM_MEMORY_MAX_CHARS=2000 CUSTOM_MEMORY_API_KEY= -CUSTOM_PREEMPTIVE_GENERATION=true +CUSTOM_PREEMPTIVE_GENERATION=false diff --git a/asr.py b/asr.py index 4052321..b8c6d76 100644 --- a/asr.py +++ b/asr.py @@ -1,11 +1,14 @@ import asyncio import logging -from typing import Any, Optional, Union +from collections import deque +from collections.abc import AsyncIterable +from typing import Any import aiohttp from livekit import rtc from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectionError, APIConnectOptions, @@ -17,9 +20,15 @@ from livekit.agents import ( utils, ) from livekit.agents.utils import is_given +from livekit.agents.vad import VAD, VADEventType logger = logging.getLogger("blackbox-asr") +DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions( + max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout +) +STT_CAPABILITIES = stt.STTCapabilities + class BlackboxSTT(stt.STT): def __init__( @@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT): url: str, *, model_name: str = "sensevoice", - language: Optional[str] = "auto", + language: str | None = "auto", output_language: str = "zh", - hotwords: Optional[str] = None, - itn: Optional[Union[bool, str]] = None, - chunk_mode: Optional[Union[bool, str]] = None, + hotwords: str | None = None, + itn: bool | str | None = None, + chunk_mode: bool | str | None = None, timeout: float = 30.0, - http_session: Optional[aiohttp.ClientSession] = None, + http_session: aiohttp.ClientSession | None = None, ) -> None: super().__init__( capabilities=stt.STTCapabilities( @@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str: raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}") -def _form_value(value: Union[bool, str]) -> str: +def _form_value(value: bool | str) -> str: if isinstance(value, bool): return str(value).lower() return value + + +class BoundedStreamAdapter(stt.STT): + def __init__( + self, + *, + stt: stt.STT, + vad: VAD, + max_speech_duration: float | None = 12.0, + pre_speech_duration: float = 0.5, + ) -> None: + super().__init__( + capabilities=STT_CAPABILITIES( + streaming=True, + interim_results=False, + diarization=False, + ) + ) + self._vad = vad + self._stt = stt + self._max_speech_duration = max_speech_duration + self._pre_speech_duration = pre_speech_duration + self._stt.on("metrics_collected", self._on_metrics_collected) + + @property + def wrapped_stt(self) -> stt.STT: + return self._stt + + @property + def model(self) -> str: + return self._stt.model + + @property + def provider(self) -> str: + return self._stt.provider + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> stt.SpeechEvent: + return await self._stt.recognize( + buffer=buffer, language=language, conn_options=conn_options + ) + + def stream( + self, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> stt.RecognizeStream: + return _BoundedStreamAdapterWrapper( + self, + vad=self._vad, + wrapped_stt=self._stt, + language=language, + conn_options=conn_options, + max_speech_duration=self._max_speech_duration, + pre_speech_duration=self._pre_speech_duration, + ) + + def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: + self.emit("metrics_collected", *args, **kwargs) + + async def aclose(self) -> None: + self._stt.off("metrics_collected", self._on_metrics_collected) + + +class _BoundedStreamAdapterWrapper(stt.RecognizeStream): + def __init__( + self, + adapter: BoundedStreamAdapter, + *, + vad: VAD, + wrapped_stt: stt.STT, + language: NotGivenOr[str], + conn_options: APIConnectOptions, + max_speech_duration: float | None, + pre_speech_duration: float, + ) -> None: + super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) + self._vad = vad + self._wrapped_stt = wrapped_stt + self._wrapped_stt_conn_options = conn_options + self._language = language + self._max_speech_duration = max_speech_duration + self._pre_speech_duration = pre_speech_duration + + async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None: + async for _ in event_aiter: + pass + + async def _run(self) -> None: + vad_stream = self._vad.stream() + lock = asyncio.Lock() + recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue() + + speech_active = False + segment_frames: list[rtc.AudioFrame] = [] + segment_duration = 0.0 + pre_roll_frames: deque[rtc.AudioFrame] = deque() + pre_roll_duration = 0.0 + + def _frame_duration(frame: rtc.AudioFrame) -> float: + return frame.samples_per_channel / frame.sample_rate + + def _append_pre_roll(frame: rtc.AudioFrame) -> None: + nonlocal pre_roll_duration + pre_roll_frames.append(frame) + pre_roll_duration += _frame_duration(frame) + + while pre_roll_duration > self._pre_speech_duration and pre_roll_frames: + pre_roll_duration -= _frame_duration(pre_roll_frames.popleft()) + + async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None: + if not frames: + return + + if forced: + logger.info( + "Forcing ASR segment after %.2fs of continuous speech", + self._max_speech_duration, + ) + + self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)) + await recognize_queue.put(frames) + + async def _recognize_worker() -> None: + while True: + frames = await recognize_queue.get() + if frames is None: + return + + merged_frames = utils.merge_frames(frames) + try: + t_event = await self._wrapped_stt.recognize( + buffer=merged_frames, + language=self._language, + conn_options=self._wrapped_stt_conn_options, + ) + except Exception: + logger.exception("ASR segment recognition failed") + continue + + if not t_event.alternatives or not t_event.alternatives[0].text: + continue + + self._event_ch.send_nowait( + stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[t_event.alternatives[0]], + ) + ) + + async def _forward_input() -> None: + nonlocal segment_duration, segment_frames, speech_active + + async for input_frame in self._input_ch: + if isinstance(input_frame, self._FlushSentinel): + vad_stream.flush() + continue + + vad_stream.push_frame(input_frame) + forced_frames: list[rtc.AudioFrame] = [] + + async with lock: + if speech_active: + segment_frames.append(input_frame) + segment_duration += _frame_duration(input_frame) + if ( + self._max_speech_duration is not None + and segment_duration >= self._max_speech_duration + ): + forced_frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + else: + _append_pre_roll(input_frame) + + if forced_frames: + await _enqueue_segment(forced_frames, forced=True) + + vad_stream.end_input() + + final_frames: list[rtc.AudioFrame] = [] + async with lock: + if speech_active and segment_frames: + final_frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + speech_active = False + + if final_frames: + await _enqueue_segment(final_frames) + + async def _recognize_from_vad() -> None: + nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active + + async for event in vad_stream: + if event.type == VADEventType.START_OF_SPEECH: + self._event_ch.send_nowait( + stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH) + ) + async with lock: + if not speech_active: + speech_active = True + segment_frames = list(pre_roll_frames) + segment_duration = sum(_frame_duration(f) for f in segment_frames) + pre_roll_frames.clear() + pre_roll_duration = 0.0 + continue + + if event.type != VADEventType.END_OF_SPEECH: + continue + + async with lock: + frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + speech_active = False + + await _enqueue_segment(frames) + + worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize") + tasks = [ + asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"), + asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"), + ] + try: + await asyncio.gather(*tasks) + await recognize_queue.put(None) + await worker_task + finally: + await utils.aio.cancel_and_wait(*tasks, worker_task) + await vad_stream.aclose() diff --git a/beaver_llm.py b/beaver_llm.py new file mode 100644 index 0000000..1050f3e --- /dev/null +++ b/beaver_llm.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Sequence +from typing import Any + +try: + from beaver_terminal_client import BeaverTerminalClient +except ModuleNotFoundError: + from custom.beaver_terminal_client import BeaverTerminalClient +from livekit.agents import llm +from livekit.agents.types import ( + DEFAULT_API_CONNECT_OPTIONS, + NOT_GIVEN, + APIConnectOptions, + NotGivenOr, +) +from livekit.agents.utils import shortuuid + +logger = logging.getLogger("beaver-llm") + + +def latest_user_text(chat_ctx: llm.ChatContext) -> str: + for message in reversed(chat_ctx.messages()): + if message.role != "user": + continue + return _content_to_text(message.content) + return "" + + +def _content_to_text(content: Sequence[llm.ChatContent]) -> str: + text_parts = [item for item in content if isinstance(item, str)] + return "\n".join(text_parts) + + +class BeaverLLM(llm.LLM): + def __init__( + self, + *, + url: str, + peer_id: str, + device_name: str, + model_name: str = "beaver-terminal", + ) -> None: + super().__init__() + self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name) + self._model_name = model_name + self._lock = asyncio.Lock() + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "beaver" + + @property + def session_id(self) -> str | None: + return self._client.session_id + + async def connect(self, *, warmup_text: str | None = None) -> str | None: + warmup_reply: str | None = None + async with self._lock: + await self._client.connect() + if warmup_text and warmup_text.strip(): + warmup_reply = await self._client.send_text(warmup_text.strip()) + + if warmup_reply is None: + logger.info("Beaver handshake completed session_id=%s", self.session_id) + else: + logger.info( + "Beaver handshake warmup completed session_id=%s reply_len=%s", + self.session_id, + len(warmup_reply), + ) + return warmup_reply + + def chat( + self, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool] | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, + tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN, + extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + ) -> llm.LLMStream: + return BeaverLLMStream( + self, + chat_ctx=chat_ctx, + tools=tools or [], + conn_options=conn_options, + ) + + async def aclose(self) -> None: + await self._client.close() + + +class BeaverLLMStream(llm.LLMStream): + def __init__( + self, + beaver_llm: BeaverLLM, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool], + conn_options: APIConnectOptions, + ) -> None: + super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) + self._beaver_llm = beaver_llm + self._request_id = shortuuid("beaver_") + + async def _run(self) -> None: + user_text = latest_user_text(self.chat_ctx) + async with self._beaver_llm._lock: + reply = await self._beaver_llm._client.send_text(user_text) + + if reply: + self._send_text_chunk(reply) + + def _send_text_chunk(self, text: str) -> None: + self._event_ch.send_nowait( + llm.ChatChunk( + id=self._request_id, + delta=llm.ChoiceDelta(role="assistant", content=text), + ) + ) diff --git a/beaver_terminal_client.py b/beaver_terminal_client.py new file mode 100644 index 0000000..b209da6 --- /dev/null +++ b/beaver_terminal_client.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import aiohttp +from dotenv import load_dotenv + +logger = logging.getLogger("beaver-terminal-client") +DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws" +DEFAULT_TERMINAL_PEER_ID = "device-001" +DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal" +CUSTOM_ENV_PATH = Path(__file__).with_name(".env") + + +class BeaverTerminalError(RuntimeError): + pass + + +class BeaverTerminalConnectionClosed(BeaverTerminalError): + pass + + +@dataclass +class MessageIdGenerator: + peer_id: str + nonce: str | None = None + initial_counter: int = 0 + + def __post_init__(self) -> None: + self.counter = self.initial_counter + + def next_id(self) -> str: + self.counter += 1 + if self.nonce: + return f"{self.peer_id}-{self.nonce}-{self.counter:06d}" + return f"{self.peer_id}-{self.counter:06d}" + + +def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]: + return { + "type": "connect", + "peer_id": peer_id, + "device_name": device_name, + "capabilities": ["text"], + } + + +def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]: + return { + "type": "message", + "message_id": message_id, + "text": text, + } + + +class BeaverTerminalClient: + def __init__( + self, + *, + url: str, + peer_id: str, + device_name: str, + http_session: aiohttp.ClientSession | None = None, + message_ids: MessageIdGenerator | None = None, + ) -> None: + self._url = url + self._peer_id = peer_id + self._device_name = device_name + self._owned_session = http_session is None + self._http_session = http_session + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id) + self.session_id: str | None = None + + async def connect(self) -> None: + await self._close_websocket() + session = self._ensure_http_session() + try: + self._ws = await session.ws_connect(self._url) + await self._send_json( + build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) + ) + frame = await self._receive_json() + if frame.get("type") != "connected": + raise BeaverTerminalError(f"expected connected frame, received {frame!r}") + session_id = frame.get("session_id") + self.session_id = session_id if isinstance(session_id, str) else None + except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc: + await self._cleanup_failed_connection() + raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc + except Exception: + await self._cleanup_failed_connection() + raise + + async def send_text(self, text: str) -> str: + for attempt in range(2): + if not self._websocket_is_open(): + await self.connect() + + message_id = self._message_ids.next_id() + message_frame = build_message_frame(message_id=message_id, text=text) + + try: + await self._send_json(message_frame) + return await self._wait_for_reply(message_id) + except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc: + if attempt == 1: + raise BeaverTerminalConnectionClosed( + "Beaver websocket closed before assistant reply" + ) from exc + logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id") + await self.connect() + + raise BeaverTerminalError("unreachable Beaver send state") + + async def _wait_for_reply(self, message_id: str) -> str: + while True: + frame = await self._receive_json() + frame_type = frame.get("type") + if frame_type == "ack" and frame.get("message_id") == message_id: + reply = frame.get("reply") + if isinstance(reply, str): + return reply + continue + + if ( + frame_type == "message" + and frame.get("role") == "assistant" + and frame.get("message_id") == message_id + ): + text = frame.get("text") + if frame.get("finish_reason") == "error": + raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed") + return text if isinstance(text, str) else "" + + if frame_type == "error": + error = frame.get("error") + raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") + + async def ping(self) -> bool: + await self._send_json({"type": "ping"}) + while True: + frame = await self._receive_json() + if frame.get("type") == "pong": + return True + if frame.get("type") == "error": + error = frame.get("error") + raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") + + async def close(self) -> None: + await self._close_websocket() + if self._owned_session and self._http_session is not None: + await self._http_session.close() + self._http_session = None + + async def _close_websocket(self) -> None: + if self._ws is not None: + await self._ws.close() + self._ws = None + + async def _cleanup_failed_connection(self) -> None: + await self._close_websocket() + self.session_id = None + if self._owned_session and self._http_session is not None: + await self._http_session.close() + self._http_session = None + + def _websocket_is_open(self) -> bool: + return self._ws is not None and not self._ws.closed + + def _ensure_http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def _send_json(self, frame: dict[str, Any]) -> None: + if self._ws is None: + raise BeaverTerminalError("Beaver websocket is not connected") + await self._ws.send_json(frame) + + async def _receive_json(self) -> dict[str, Any]: + if self._ws is None: + raise BeaverTerminalError("Beaver websocket is not connected") + + message = await self._ws.receive() + if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): + raise BeaverTerminalConnectionClosed("Beaver websocket closed") + if message.type == aiohttp.WSMsgType.ERROR: + raise BeaverTerminalConnectionClosed( + f"Beaver websocket error: {self._ws.exception()!r}" + ) + if message.type != aiohttp.WSMsgType.TEXT: + raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}") + data = message.json() + if not isinstance(data, dict): + raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}") + return data + + +def client_from_env() -> BeaverTerminalClient: + load_dotenv(dotenv_path=CUSTOM_ENV_PATH) + return BeaverTerminalClient( + url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL), + peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID), + device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME), + ) + + +async def run_console() -> None: + logging.basicConfig(level=logging.INFO) + client = client_from_env() + try: + await client.connect() + logger.info("Connected to Beaver session_id=%s", client.session_id) + while True: + text = await asyncio.to_thread(input, "> ") + text = text.strip() + if not text: + continue + if text in {"quit", "exit"}: + return + try: + reply = await client.send_text(text) + except BeaverTerminalError as exc: + logger.error("Beaver turn failed: %s", exc) + continue + print(reply) + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(run_console()) diff --git a/custom_agent.py b/custom_agent.py index bb98a1e..840efe1 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import logging @@ -8,11 +9,14 @@ from collections.abc import AsyncIterable from dataclasses import dataclass from pathlib import Path +from beaver_llm import BeaverLLM +from beaver_terminal_client import BeaverTerminalError from dotenv import load_dotenv +from hermes_gateway import GatewaySessionState, HermesGatewayLLM from memory import MemoryRecallClient from tts import BlackboxTTS -from asr import BlackboxSTT +from asr import BlackboxSTT, BoundedStreamAdapter from livekit.agents import ( Agent, AgentServer, @@ -30,7 +34,6 @@ from livekit.agents import ( llm, metrics, room_io, - stt, ) from livekit.agents.voice.generation import update_instructions as update_chat_instructions from livekit.plugins import openai, silero @@ -40,7 +43,6 @@ logger = logging.getLogger("custom-agent") CUSTOM_ENV_PATH = Path(__file__).with_name(".env") load_dotenv(dotenv_path=CUSTOM_ENV_PATH) -AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "") ROOM_LOCATOR_INSTRUCTIONS = """ 你是一个房间物品定位助手。 @@ -60,7 +62,7 @@ GENERAL_INSTRUCTIONS = """ EMOTION_INSTRUCTIONS = """ 每次回复必须先输出一个情绪标签,格式严格为: -emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。 +emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。 情绪标签之后直接输出给用户的正常回复,不要解释标签。 """.strip() @@ -69,18 +71,121 @@ GENERAL_MODE = "general" VOICE_INPUT_MODE = "voice" VISION_VOICE_INPUT_MODE = "vision_voice" AUTO_INPUT_MODE = "auto" +INTERRUPT_TOPIC = "lk.interrupt" VISION_FRAME_TOPIC = "vision.frame" +DEFAULT_AGENT_PROFILE = "normal" + + +@dataclass(frozen=True) +class AgentProfile: + agent_name: str + llm_provider: str + input_mode: str + + +AGENT_PROFILES = { + "normal": AgentProfile( + agent_name="normal-agent", + llm_provider="openai-compatible", + input_mode=VOICE_INPUT_MODE, + ), + "beaver": AgentProfile( + agent_name="beaver-agent", + llm_provider="beaver", + input_mode=VOICE_INPUT_MODE, + ), + "vision-normal": AgentProfile( + agent_name="vision-normal-agent", + llm_provider="openai-compatible", + input_mode=VISION_VOICE_INPUT_MODE, + ), + "vision-beaver": AgentProfile( + agent_name="vision-beaver-agent", + llm_provider="beaver", + input_mode=VISION_VOICE_INPUT_MODE, + ), +} +AGENT_PROFILE_ALIASES = { + "default": "normal", + "openai": "normal", + "openai-compatible": "normal", + "llm": "normal", + "text": "normal", + "voice": "normal", + "vision": "vision-normal", + "vision-llm": "vision-normal", + "vision-openai": "vision-normal", + "vision-openai-compatible": "vision-normal", +} + + +def _normalize_agent_profile(value: str | None) -> str: + if not value or not value.strip(): + return DEFAULT_AGENT_PROFILE + + normalized = value.strip().lower().replace("_", "-") + profile = AGENT_PROFILE_ALIASES.get(normalized, normalized) + if profile in AGENT_PROFILES: + return profile + + logger.warning( + "Invalid CUSTOM_AGENT_PROFILE=%r, using %s", + value, + DEFAULT_AGENT_PROFILE, + ) + return DEFAULT_AGENT_PROFILE + + +def _agent_profile_from_name(agent_name: str | None) -> str | None: + if not agent_name or not agent_name.strip(): + return None + + normalized = agent_name.strip().lower().replace("_", "-") + for profile_name, profile in AGENT_PROFILES.items(): + if normalized == profile.agent_name: + return profile_name + return None + + +def _selected_agent_profile_name() -> str: + configured_profile = os.getenv("CUSTOM_AGENT_PROFILE") + if configured_profile and configured_profile.strip(): + return _normalize_agent_profile(configured_profile) + + inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME")) + if inferred_profile is not None: + return inferred_profile + + return DEFAULT_AGENT_PROFILE + + +AGENT_PROFILE_NAME = _selected_agent_profile_name() +AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME] +AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name DEFAULT_EMOTION = "neutral" EMOTION_LABELS = { - "neutral", - "happy", - "sad", "angry", + "confident", + "confused", + "cool", + "crying", + "delicious", + "embarrassed", + "funny", + "happy", + "kissy", + "laughing", + "loving", + "neutral", + "relaxed", + "sad", + "shocked", + "silly", + "sleepy", "surprised", - "fearful", - "calm", - "concerned", + "thinking", + "winking", } EMOTION_PREFIX_RE = re.compile(r"^\s*\s*", re.IGNORECASE) TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE) @@ -290,6 +395,7 @@ class CustomAgent(Agent): yield chunk return _stream() + async def _observe_emotion_prefix( self, chunk: llm.ChatChunk | str | FlushSentinel ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: @@ -545,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s return chat_ctx -def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext: +def _with_vision_as_latest_user_message( + chat_ctx: ChatContext, vision_frame: VisionFrame +) -> ChatContext: chat_ctx = chat_ctx.copy() image_content = llm.ImageContent( image=vision_frame.image_data_url, @@ -612,7 +720,25 @@ def _model_image_save_dir_from_env() -> Path | None: return Path(__file__).with_name("model_images") -server = AgentServer() +def _agent_server_from_env() -> AgentServer: + configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT") + if configured_port is None: + return AgentServer() + + try: + port = int(configured_port) + except ValueError: + logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port) + port = 0 + + if port < 0 or port > 65535: + logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port) + port = 0 + + return AgentServer(port=port) + + +server = _agent_server_from_env() def prewarm(proc: JobProcess) -> None: @@ -634,16 +760,37 @@ async def entrypoint(ctx: JobContext) -> None: ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice") ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto") ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh") + ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0) LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") + LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower() TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) - INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) - if not LLM_API_KEY: + INPUT_MODE = _normalize_input_mode( + os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode) + ) + if LLM_PROVIDER not in { + "openai", + "openai-compatible", + "hermes", + "hermes_gateway", + "openclaw", + "beaver", + }: + raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}") + if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") - logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default") + logger.info( + "Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s", + AGENT_PROFILE_NAME, + AGENT_NAME or "", + INPUT_MODE, + LLM_PROVIDER, + LLM_MODEL, + LLM_BASE_URL or "OpenAI default", + ) TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox" @@ -656,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None: MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0) MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000) MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None + beaver_warmup_text: str | None = None blackbox_stt = BlackboxSTT( url=ASR_URL, @@ -666,40 +814,117 @@ async def entrypoint(ctx: JobContext) -> None: itn=os.getenv("CUSTOM_ASR_ITN"), chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"), ) - stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) + stt_stream = BoundedStreamAdapter( + stt=blackbox_stt, + vad=ctx.proc.userdata["vad"], + max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None, + ) + turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel() + allow_interruptions = _env_bool( + "CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR", + ASR_MAX_SPEECH_DURATION <= 0, + ) - import httpx - from openai import AsyncClient as OpenAIAsyncClient + if LLM_PROVIDER == "beaver": + beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL") + if not beaver_url: + raise RuntimeError( + f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}" + ) - # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. - http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) + beaver_peer_id = ( + _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}" + ) + beaver_device_name = ( + _first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME") + or "livekit-custom-agent" + ) + base_llm = BeaverLLM( + url=beaver_url, + peer_id=beaver_peer_id, + device_name=beaver_device_name, + model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"), + ) + beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT") + text_llm = base_llm + vision_llm = base_llm + logger.info( + "Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s", + beaver_url, + beaver_peer_id, + beaver_device_name, + ctx.room.name, + bool(beaver_warmup_text and beaver_warmup_text.strip()), + ) + elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}: + gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip() + if not gateway_url: + raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}") - if LLM_BASE_URL: - openai_client = OpenAIAsyncClient( - api_key=LLM_API_KEY, - base_url=LLM_BASE_URL, - http_client=http_client, + hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None + hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower() + if hermes_session_mode != "per_room": + raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room") + hermes_token = ( + os.getenv("CUSTOM_HERMES_API_KEY") + or os.getenv("CUSTOM_HERMES_TOKEN") + or LLM_API_KEY + or None + ) + hermes_state = GatewaySessionState( + room_name=ctx.room.name, + agent_id=hermes_agent_id, + session_mode=hermes_session_mode, + ) + base_llm = HermesGatewayLLM( + url=gateway_url, + token=hermes_token, + state=hermes_state, + agent_id=hermes_agent_id, + model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"), + request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0), + ) + text_llm = base_llm + vision_llm = base_llm + logger.info( + "Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s", + gateway_url, + hermes_agent_id or "default", + hermes_state.session_key, ) else: - openai_client = OpenAIAsyncClient( - api_key=LLM_API_KEY, - http_client=http_client, - ) + import httpx + from openai import AsyncClient as OpenAIAsyncClient - base_llm = openai.LLM( - model=LLM_MODEL, - client=openai_client, - ) - text_llm = ( - openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) - if TEXT_LLM_MODEL != LLM_MODEL - else base_llm - ) - vision_llm = ( - openai.LLM(model=VISION_LLM_MODEL, client=openai_client) - if VISION_LLM_MODEL != LLM_MODEL - else base_llm - ) + # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. + http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) + + if LLM_BASE_URL: + openai_client = OpenAIAsyncClient( + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, + http_client=http_client, + ) + else: + openai_client = OpenAIAsyncClient( + api_key=LLM_API_KEY, + http_client=http_client, + ) + + base_llm = openai.LLM( + model=LLM_MODEL, + client=openai_client, + ) + text_llm = ( + openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) + if TEXT_LLM_MODEL != LLM_MODEL + else base_llm + ) + vision_llm = ( + openai.LLM(model=VISION_LLM_MODEL, client=openai_client) + if VISION_LLM_MODEL != LLM_MODEL + else base_llm + ) vision_store = VisionFrameStore( max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0) ) @@ -707,7 +932,7 @@ async def entrypoint(ctx: JobContext) -> None: session: AgentSession = AgentSession( # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, - # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. + # 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway. llm=base_llm, # 3. TTS blackbox tts=BlackboxTTS( @@ -721,13 +946,14 @@ async def entrypoint(ctx: JobContext) -> None: # 4. Silero VAD vad=ctx.proc.userdata["vad"], turn_handling=TurnHandlingOptions( - turn_detection=MultilingualModel(), + turn_detection=turn_detection, interruption={ + "enabled": allow_interruptions, "resume_false_interruption": True, "false_interruption_timeout": 1.0, }, ), - preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True), + preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", LLM_PROVIDER != "beaver"), aec_warmup_duration=3.0, tts_text_transforms=[ "filter_emoji", @@ -753,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None: @ctx.room.on("data_received") def _on_data_received(data_packet) -> None: packet_topic = getattr(data_packet, "topic", None) - if packet_topic not in {None, "", VISION_FRAME_TOPIC}: + + def _interrupt_done(fut: asyncio.Future[None]) -> None: + try: + fut.result() + except Exception: + logger.exception("Bridge interrupt failed") + + def _handle_interrupt(payload: dict[str, object] | None = None) -> None: + reason = None if payload is None else payload.get("reason") + logger.info( + "Received bridge interrupt: topic=%s reason=%s", + packet_topic or "", + reason if isinstance(reason, str) and reason else "", + ) + try: + interrupt_fut = session.interrupt(force=True) + except RuntimeError: + logger.exception("Bridge interrupt received before AgentSession was running") + return + interrupt_fut.add_done_callback(_interrupt_done) + + if packet_topic == INTERRUPT_TOPIC: + payload: dict[str, object] | None = None + try: + decoded = json.loads(data_packet.data.decode("utf-8")) + if isinstance(decoded, dict): + payload = decoded + except Exception: + logger.exception("Failed to decode interrupt payload") + _handle_interrupt(payload) return - if INPUT_MODE == VOICE_INPUT_MODE: - logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) + if packet_topic not in {None, "", VISION_FRAME_TOPIC}: return try: payload = json.loads(data_packet.data.decode("utf-8")) except Exception: - logger.exception("Failed to decode vision frame payload") + logger.exception("Failed to decode data payload") + return + + if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC: + _handle_interrupt(payload) + return + + if INPUT_MODE == VOICE_INPUT_MODE: + logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) return if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC: @@ -820,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None: ), record=_recording_options_from_env(), ) + if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM): + _start_beaver_background_warmup( + ctx=ctx, + beaver_llm=base_llm, + warmup_text=beaver_warmup_text, + ) + + +def _start_beaver_background_warmup( + *, + ctx: JobContext, + beaver_llm: BeaverLLM, + warmup_text: str | None, +) -> None: + async def _warmup() -> None: + try: + warmup_reply = await beaver_llm.connect(warmup_text=warmup_text) + except BeaverTerminalError: + logger.warning( + "Beaver background handshake failed; will retry on first user turn", + exc_info=True, + ) + return + except Exception: + logger.exception("Unexpected Beaver background handshake failure") + return + + logger.info( + "Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s", + ctx.room.name, + beaver_llm.session_id, + bool(warmup_text and warmup_text.strip()), + len(warmup_reply) if warmup_reply is not None else 0, + ) + + warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup") + + async def _cancel_warmup() -> None: + if warmup_task.done(): + return + warmup_task.cancel() + try: + await warmup_task + except asyncio.CancelledError: + pass + + ctx.add_shutdown_callback(_cancel_warmup) def _tts_params_from_env(model_name: str) -> dict[str, str]: @@ -917,6 +1226,14 @@ def _env_bool(name: str, default: bool) -> bool: return default +def _first_env(*names: str) -> str | None: + for name in names: + value = os.getenv(name) + if value and value.strip(): + return value.strip() + return None + + def _recording_options_from_env() -> RecordingOptions: return RecordingOptions( audio=_env_bool("CUSTOM_RECORD_AUDIO", False), diff --git a/hermes_gateway.py b/hermes_gateway.py new file mode 100644 index 0000000..f11b121 --- /dev/null +++ b/hermes_gateway.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import Any + +import aiohttp + +from livekit.agents import llm +from livekit.agents._exceptions import APIConnectionError +from livekit.agents.types import ( + DEFAULT_API_CONNECT_OPTIONS, + NOT_GIVEN, + APIConnectOptions, + NotGivenOr, +) +from livekit.agents.utils import shortuuid + + +@dataclass +class GatewaySessionState: + room_name: str + agent_id: str | None = None + session_key: str | None = None + session_mode: str = "per_room" + + def __post_init__(self) -> None: + if self.session_mode != "per_room": + raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room") + if self.session_key is None: + suffix = self.agent_id or "default" + self.session_key = f"livekit:{self.room_name}:{suffix}" + + +class HermesGatewayLLM(llm.LLM): + def __init__( + self, + *, + url: str, + token: str | None, + state: GatewaySessionState, + agent_id: str | None = None, + model_name: str = "hermes-agent", + request_timeout: float = 30.0, + ) -> None: + super().__init__() + self._url = url + self._token = token + self._state = state + self._agent_id = agent_id + self._model_name = model_name + self._request_timeout = request_timeout + self._http_session: aiohttp.ClientSession | None = None + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "hermes-gateway" + + def chat( + self, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool] | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, + tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN, + extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + ) -> llm.LLMStream: + return HermesGatewayLLMStream( + self, + chat_ctx=chat_ctx, + tools=tools or [], + conn_options=conn_options, + ) + + def _ensure_http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + timeout = aiohttp.ClientTimeout(total=self._request_timeout) + self._http_session = aiohttp.ClientSession(timeout=timeout) + return self._http_session + + async def aclose(self) -> None: + if self._http_session is not None: + await self._http_session.close() + self._http_session = None + + +class HermesGatewayLLMStream(llm.LLMStream): + def __init__( + self, + llm: HermesGatewayLLM, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool], + conn_options: APIConnectOptions, + ) -> None: + super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) + self._llm = llm + + async def _run(self) -> None: + request_id = shortuuid("gwreq_") + async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws: + await _connect(ws, token=self._llm._token) + await _send_rpc( + ws, + method="sessions.create", + params={ + "key": self._llm._state.session_key, + "sessionKey": self._llm._state.session_key, + "agentId": self._llm._agent_id, + "metadata": { + "source": "livekit", + "room": self._llm._state.room_name, + }, + "idempotencyKey": self._llm._state.session_key, + }, + request_id=shortuuid("gwcreate_"), + ) + await _send_rpc( + ws, + method="sessions.send", + params={ + "key": self._llm._state.session_key, + "sessionKey": self._llm._state.session_key, + "agentId": self._llm._agent_id, + "messages": chat_context_to_gateway_messages(self.chat_ctx), + "stream": True, + "idempotencyKey": request_id, + }, + request_id=request_id, + wait_response=False, + ) + + streamed_text = "" + async for frame in _iter_gateway_frames(ws): + if is_error_response(frame, request_id=request_id): + raise APIConnectionError(_gateway_error_message(frame), retryable=False) + text = extract_text_delta(frame) + if text: + if text.startswith(streamed_text): + text = text[len(streamed_text) :] + if text: + streamed_text += text + self._event_ch.send_nowait( + llm.ChatChunk( + id=request_id, + delta=llm.ChoiceDelta(role="assistant", content=text), + ) + ) + if is_terminal_event(frame, request_id=request_id): + return + + +def build_connect_params(*, token: str | None) -> dict[str, Any]: + params: dict[str, Any] = { + "minProtocol": 3, + "maxProtocol": 4, + "client": { + "id": "gateway-client", + "version": "livekit-custom-agent", + "platform": "python", + "mode": "backend", + }, + "role": "operator", + "scopes": ["operator.read", "operator.write"], + "caps": [], + "commands": [], + "permissions": {}, + "locale": "zh-CN", + "userAgent": "livekit-custom-agent", + } + if token: + params["auth"] = {"token": token} + return params + + +def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + for message in chat_ctx.messages(): + content = _message_content_to_gateway_content(message.content) + if content is None: + continue + messages.append({"role": message.role, "content": content}) + return messages + + +def extract_text_delta(frame: dict[str, Any]) -> str: + payload = frame.get("payload") + if not isinstance(payload, dict): + payload = frame + + for path in ( + ("delta", "content"), + ("delta", "text"), + ("message", "delta", "content"), + ("message", "delta", "text"), + ("message", "content"), + ("content",), + ("text",), + ): + value = _get_nested(payload, path) + text = _content_to_text(value) + if text: + return text + + return "" + + +def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool: + if frame.get("type") == "res" and frame.get("id") == request_id: + return True + + event = frame.get("event") + if event in { + "agent.done", + "agent.completed", + "agent.error", + "session.message.completed", + "session.run.completed", + "sessions.run.completed", + "run.completed", + "run.failed", + }: + return True + + payload = frame.get("payload") + if isinstance(payload, dict) and payload.get("done") is True: + return True + + return False + + +def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool: + if ( + frame.get("type") == "res" + and frame.get("id") == request_id + and frame.get("ok") is False + ): + return True + + return frame.get("event") in { + "agent.error", + "session.error", + "session.run.failed", + "sessions.run.failed", + "run.failed", + } + + +def _gateway_error_message(frame: dict[str, Any]) -> str: + error = frame.get("error") + if isinstance(error, str): + return f"OpenClaw gateway request failed: {error}" + if isinstance(error, dict): + message = error.get("message") or error.get("error") + if isinstance(message, str): + return f"OpenClaw gateway request failed: {message}" + + payload = frame.get("payload") + if isinstance(payload, dict): + message = payload.get("message") or payload.get("error") + if isinstance(message, str): + return f"OpenClaw gateway request failed: {message}" + + return f"OpenClaw gateway request failed: {frame!r}" + + +async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None: + first = await _receive_json(ws) + if first.get("event") != "connect.challenge": + raise RuntimeError(f"expected connect.challenge, received {first!r}") + + request_id = shortuuid("gwconnect_") + await _send_rpc( + ws, + method="connect", + params=build_connect_params(token=token), + request_id=request_id, + wait_response=False, + ) + response = await _wait_for_response(ws, request_id=request_id) + if not response.get("ok"): + raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}") + + +async def _send_rpc( + ws: aiohttp.ClientWebSocketResponse, + *, + method: str, + params: dict[str, Any], + request_id: str, + wait_response: bool = True, +) -> dict[str, Any] | None: + await ws.send_str( + json.dumps( + { + "type": "req", + "id": request_id, + "method": method, + "params": _drop_none(params), + } + ) + ) + if not wait_response: + return None + response = await _wait_for_response(ws, request_id=request_id) + if not response.get("ok", False): + raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}") + return response + + +async def _wait_for_response( + ws: aiohttp.ClientWebSocketResponse, *, request_id: str +) -> dict[str, Any]: + async for frame in _iter_gateway_frames(ws): + if frame.get("type") == "res" and frame.get("id") == request_id: + return frame + raise RuntimeError(f"OpenClaw gateway closed before response {request_id}") + + +async def _iter_gateway_frames( + ws: aiohttp.ClientWebSocketResponse, +) -> AsyncIterator[dict[str, Any]]: + async for message in ws: + if message.type == aiohttp.WSMsgType.TEXT: + yield json.loads(message.data) + elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): + return + elif message.type == aiohttp.WSMsgType.ERROR: + raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}") + + +async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]: + message = await ws.receive() + if message.type != aiohttp.WSMsgType.TEXT: + raise RuntimeError(f"expected gateway text frame, received {message.type!r}") + return json.loads(message.data) + + +def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any: + parts: list[dict[str, Any]] = [] + for item in content: + if isinstance(item, str): + if item: + parts.append({"type": "text", "text": item}) + elif isinstance(item, llm.ImageContent) and isinstance(item.image, str): + parts.append({"type": "image_url", "image_url": {"url": item.image}}) + + if not parts: + return None + if len(parts) == 1 and parts[0]["type"] == "text": + return parts[0]["text"] + return parts + + +def _content_to_text(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, list): + text_parts: list[str] = [] + for item in value: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + text_parts.append(text) + return "".join(text_parts) + return "" + + +def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: + current: Any = data + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _drop_none(value: Any) -> Any: + if isinstance(value, dict): + return {key: _drop_none(item) for key, item in value.items() if item is not None} + if isinstance(value, list): + return [_drop_none(item) for item in value] + return value diff --git a/start_agent_profiles.py b/start_agent_profiles.py new file mode 100644 index 0000000..c6a109b --- /dev/null +++ b/start_agent_profiles.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import argparse +import asyncio +import os +import signal +import sys +from dataclasses import dataclass +from pathlib import Path + +from dotenv import load_dotenv + +CUSTOM_ENV_PATH = Path(__file__).with_name(".env") +load_dotenv(dotenv_path=CUSTOM_ENV_PATH) + + +@dataclass(frozen=True) +class AgentProfile: + agent_name: str + llm_provider: str + input_mode: str + + +AGENT_PROFILES = { + "normal": AgentProfile( + agent_name="normal-agent", + llm_provider="openai-compatible", + input_mode="voice", + ), + "beaver": AgentProfile( + agent_name="beaver-agent", + llm_provider="beaver", + input_mode="voice", + ), + "vision-normal": AgentProfile( + agent_name="vision-normal-agent", + llm_provider="openai-compatible", + input_mode="vision_voice", + ), + "vision-beaver": AgentProfile( + agent_name="vision-beaver-agent", + llm_provider="beaver", + input_mode="vision_voice", + ), +} +DEFAULT_PROFILES = ("normal", "beaver") + + +def _env_bool(name: str, default: bool) -> bool: + value = os.getenv(name) + if value is None: + return default + + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + + return default + + +def _parse_profiles(value: str) -> list[str]: + if value.strip().lower() == "all": + return list(AGENT_PROFILES) + + profiles: list[str] = [] + for raw_profile in value.split(","): + profile = raw_profile.strip().lower().replace("_", "-") + if not profile: + continue + if profile not in AGENT_PROFILES: + valid = ", ".join([*AGENT_PROFILES, "all"]) + raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}") + profiles.append(profile) + + if not profiles: + raise ValueError("at least one profile is required") + return list(dict.fromkeys(profiles)) + + +def _parse_http_port_base(value: str | None) -> int: + if value is None or not value.strip(): + return 0 + + try: + port = int(value) + except ValueError as exc: + raise ValueError(f"invalid HTTP port base {value!r}") from exc + + if port < 0 or port > 65535: + raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535") + return port + + +def _profile_http_port(http_port_base: int, index: int) -> int: + if http_port_base == 0: + return 0 + + port = http_port_base + index + if port > 65535: + raise ValueError("HTTP port range exceeds 65535") + return port + + +def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]: + profile = AGENT_PROFILES[profile_name] + env = os.environ.copy() + env.update( + { + "CUSTOM_AGENT_PROFILE": profile_name, + "CUSTOM_AGENT_NAME": profile.agent_name, + "CUSTOM_LLM_PROVIDER": profile.llm_provider, + "CUSTOM_AGENT_INPUT_MODE": profile.input_mode, + "CUSTOM_AGENT_HTTP_PORT": str(http_port), + } + ) + return env + + +async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None: + while line := await stream.readline(): + text = line.decode("utf-8", errors="replace").rstrip() + print(f"[{prefix}] {text}", flush=True) + + +async def _start_profile( + profile_name: str, + *, + mode: str, + http_port: int, + reload: bool, +) -> asyncio.subprocess.Process: + profile = AGENT_PROFILES[profile_name] + script_path = Path(__file__).with_name("custom_agent.py") + args = [sys.executable, str(script_path), mode] + if mode == "dev" and not reload: + args.append("--no-reload") + + process = await asyncio.create_subprocess_exec( + *args, + env=_child_env(profile_name, http_port=http_port), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + print( + f"started {profile_name}: pid={process.pid} agent={profile.agent_name} " + f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}", + flush=True, + ) + if process.stdout is not None: + asyncio.create_task(_pipe_output(profile_name, process.stdout)) + if process.stderr is not None: + asyncio.create_task(_pipe_output(profile_name, process.stderr)) + return process + + +async def _terminate(processes: list[asyncio.subprocess.Process]) -> None: + for process in processes: + if process.returncode is None: + process.terminate() + + try: + await asyncio.wait_for( + asyncio.gather(*(process.wait() for process in processes)), + timeout=10.0, + ) + except asyncio.TimeoutError: + for process in processes: + if process.returncode is None: + process.kill() + await asyncio.gather(*(process.wait() for process in processes)) + + +async def _run( + profiles: list[str], + *, + mode: str, + http_port_base: int, + reload: bool, +) -> int: + processes = [ + await _start_profile( + profile, + mode=mode, + http_port=_profile_http_port(http_port_base, index), + reload=reload, + ) + for index, profile in enumerate(profiles) + ] + stop_event = asyncio.Event() + loop = asyncio.get_running_loop() + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, stop_event.set) + + wait_tasks = {asyncio.create_task(process.wait()): process for process in processes} + stop_task = asyncio.create_task(stop_event.wait()) + done, _ = await asyncio.wait( + [*wait_tasks, stop_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + exit_code = 0 + if stop_task not in done: + for task in done: + process = wait_tasks[task] + exit_code = task.result() or 0 + print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True) + + await _terminate(processes) + return exit_code + + +def main() -> None: + parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.") + parser.add_argument( + "mode", + nargs="?", + default="dev", + choices=("console", "dev", "start", "connect"), + help="custom_agent.py CLI mode to run for each profile", + ) + parser.add_argument( + "--profiles", + default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)), + help="comma-separated profiles to start, or 'all'", + ) + parser.add_argument( + "--http-port-base", + default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"), + help=( + "base HTTP health-check port for profile workers; " + "0 lets the OS assign free ports" + ), + ) + parser.add_argument( + "--reload", + action="store_true", + default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False), + help="enable auto-reload in dev mode", + ) + args = parser.parse_args() + + try: + profiles = _parse_profiles(args.profiles) + http_port_base = _parse_http_port_base(args.http_port_base) + except ValueError as exc: + parser.error(str(exc)) + + raise SystemExit( + asyncio.run( + _run( + profiles, + mode=args.mode, + http_port_base=http_port_base, + reload=args.reload, + ) + ) + ) + + +if __name__ == "__main__": + main() diff --git a/test_beaver_llm.py b/test_beaver_llm.py new file mode 100644 index 0000000..fe6f6f9 --- /dev/null +++ b/test_beaver_llm.py @@ -0,0 +1,169 @@ +import asyncio +import json + +import aiohttp +from aiohttp import web + +try: + from custom.beaver_llm import BeaverLLM, latest_user_text +except ModuleNotFoundError: + from beaver_llm import BeaverLLM, latest_user_text +from livekit.agents import ChatContext + + +def test_latest_user_text_uses_most_recent_user_message() -> None: + ctx = ChatContext.empty() + ctx.add_message(role="user", content="first") + ctx.add_message(role="assistant", content="ignored") + ctx.add_message(role="user", content=["second", "line"]) + + assert latest_user_text(ctx) == "second\nline" + + +async def test_beaver_llm_sends_latest_user_text_and_returns_reply( + unused_tcp_port: int, +) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + received.append(frame) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:livekit-room", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:livekit-room", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "beaver reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + beaver_llm = BeaverLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="livekit-room", + device_name="livekit-custom-agent", + ) + ctx = ChatContext.empty() + ctx.add_message(role="system", content="ignored instructions") + ctx.add_message(role="user", content="hello beaver") + + try: + collected = await beaver_llm.chat(chat_ctx=ctx).collect() + finally: + await beaver_llm.aclose() + await runner.cleanup() + + assert collected.text == "beaver reply" + assert received[0] == { + "type": "connect", + "peer_id": "livekit-room", + "device_name": "livekit-custom-agent", + "capabilities": ["text"], + } + assert received[1]["type"] == "message" + assert received[1]["message_id"].startswith("livekit-room-") + assert received[1]["message_id"].endswith("-000001") + assert received[1]["text"] == "hello beaver" + + +async def test_beaver_llm_waits_for_slow_reply_without_placeholder( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:livekit-room", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:livekit-room", + "accepted": True, + } + ) + await asyncio.sleep(0.05) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "beaver reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + beaver_llm = BeaverLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="livekit-room", + device_name="livekit-custom-agent", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="hello beaver") + + try: + chunks: list[str] = [] + async with beaver_llm.chat(chat_ctx=ctx) as stream: + async for chunk in stream: + if chunk.delta and chunk.delta.content: + chunks.append(chunk.delta.content) + finally: + await beaver_llm.aclose() + await runner.cleanup() + + assert chunks == ["beaver reply"] diff --git a/test_beaver_terminal_client.py b/test_beaver_terminal_client.py new file mode 100644 index 0000000..1f8a755 --- /dev/null +++ b/test_beaver_terminal_client.py @@ -0,0 +1,458 @@ +import asyncio +import json +import sys +from pathlib import Path + +import aiohttp +import pytest +from aiohttp import web + +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + raise SystemExit(pytest.main([__file__])) + +try: + from custom.beaver_terminal_client import ( + BeaverTerminalClient, + BeaverTerminalConnectionClosed, + BeaverTerminalError, + MessageIdGenerator, + build_connect_frame, + build_message_frame, + ) +except ModuleNotFoundError: + from beaver_terminal_client import ( + BeaverTerminalClient, + BeaverTerminalConnectionClosed, + BeaverTerminalError, + MessageIdGenerator, + build_connect_frame, + build_message_frame, + ) + + +def test_build_connect_frame_uses_stable_peer_id() -> None: + frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal") + + assert frame == { + "type": "connect", + "peer_id": "device-001", + "device_name": "desk-terminal", + "capabilities": ["text"], + } + + +def test_build_message_frame_uses_message_id_and_text() -> None: + frame = build_message_frame(message_id="device-001-000001", text="hello") + + assert frame == { + "type": "message", + "message_id": "device-001-000001", + "text": "hello", + } + + +def test_message_id_generator_uses_monotonic_peer_counter() -> None: + generator = MessageIdGenerator(peer_id="device-001", initial_counter=7) + + assert generator.next_id() == "device-001-000008" + assert generator.next_id() == "device-001-000009" + assert generator.counter == 9 + + +def test_message_id_generator_can_include_nonce() -> None: + generator = MessageIdGenerator(peer_id="device-001", nonce="run12345") + + assert generator.next_id() == "device-001-run12345-000001" + assert generator.next_id() == "device-001-run12345-000002" + + +async def test_client_connects_sends_text_and_returns_assistant_reply( + unused_tcp_port: int, +) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + received.append(frame) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "assistant reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert client.session_id == "terminal-dev:local:device-001" + assert reply == "assistant reply" + assert received == [ + { + "type": "connect", + "peer_id": "device-001", + "device_name": "desk-terminal", + "capabilities": ["text"], + }, + { + "type": "message", + "message_id": "device-001-000001", + "text": "hello", + }, + ] + + +async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": False, + "duplicate": True, + "pending": False, + "reply": "cached assistant reply", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert reply == "cached assistant reply" + + +async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json({"type": "error", "error": "text is required"}) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + with pytest.raises(BeaverTerminalError, match="text is required"): + await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + +async def test_client_cleans_up_owned_session_when_connect_fails( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.Response: + return web.Response(status=200, text="not a websocket") + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"): + await client.connect() + assert client._http_session is None + assert client._ws is None + finally: + await client.close() + await runner.cleanup() + + +async def test_client_treats_assistant_finish_reason_error_as_failed_turn( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "failed turn", + "finish_reason": "error", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + with pytest.raises(BeaverTerminalError, match="failed turn"): + await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + +async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "ping": + await ws.send_json({"type": "pong"}) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + assert await client.ping() + finally: + await client.close() + await runner.cleanup() + + +async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send( + unused_tcp_port: int, +) -> None: + connect_peer_ids: list[str] = [] + message_ids: list[str] = [] + connection_count = 0 + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + nonlocal connection_count + connection_count += 1 + current_connection = connection_count + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + connect_peer_ids.append(frame["peer_id"]) + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + message_ids.append(frame["message_id"]) + if current_connection == 1: + await ws.close() + continue + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-2", + "text": "reply after reconnect", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + await client.connect() + await asyncio.sleep(0.01) + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert reply == "reply after reconnect" + assert connect_peer_ids == ["device-001", "device-001"] + assert message_ids == ["device-001-000001", "device-001-000002"] diff --git a/test_hermes_gateway.py b/test_hermes_gateway.py new file mode 100644 index 0000000..739e202 --- /dev/null +++ b/test_hermes_gateway.py @@ -0,0 +1,262 @@ +import json + +import aiohttp +import pytest +from aiohttp import web + +from custom.hermes_gateway import ( + GatewaySessionState, + HermesGatewayLLM, + build_connect_params, + chat_context_to_gateway_messages, + extract_text_delta, + is_error_response, + is_terminal_event, +) +from livekit.agents import ChatContext, llm +from livekit.agents._exceptions import APIConnectionError + + +def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None: + ctx = ChatContext.empty() + ctx.add_message(role="system", content="system prompt") + ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")]) + + messages = chat_context_to_gateway_messages(ctx) + + assert messages == [ + {"role": "system", "content": "system prompt"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "look here"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + }, + ] + + +def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None: + assert ( + extract_text_delta( + {"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}} + ) + == "hi" + ) + assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there" + assert ( + extract_text_delta( + { + "type": "event", + "event": "session.message.delta", + "payload": {"message": {"content": [{"type": "text", "text": "!"}]}}, + } + ) + == "!" + ) + + +def test_per_room_session_state_reuses_stable_session_key() -> None: + state = GatewaySessionState(room_name="kitchen-room", agent_id="helper") + + assert state.session_key == "livekit:kitchen-room:helper" + state.session_key = "gateway-session-123" + assert state.session_key == "gateway-session-123" + + +def test_build_connect_params_uses_backend_operator_defaults() -> None: + params = build_connect_params(token="secret-token") + + assert params["client"] == { + "id": "gateway-client", + "version": "livekit-custom-agent", + "platform": "python", + "mode": "backend", + } + assert params["role"] == "operator" + assert params["scopes"] == ["operator.read", "operator.write"] + assert params["auth"] == {"token": "secret-token"} + assert "device" not in params + + +def test_gateway_response_helpers_match_only_current_send_request() -> None: + assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1") + assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1") + assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1") + assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1") + + +def test_hermes_llm_reports_provider_and_model() -> None: + state = GatewaySessionState(room_name="kitchen", agent_id="helper") + gateway_llm = HermesGatewayLLM( + url="ws://gateway.test/ws", + token="token", + state=state, + agent_id="helper", + model_name="hermes-agent", + ) + + assert gateway_llm.provider == "hermes-gateway" + assert gateway_llm.model == "hermes-agent" + + +def test_gateway_session_state_rejects_non_per_room_mode() -> None: + with pytest.raises(ValueError, match="per_room"): + GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn") + + +async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}}) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + payload = json.loads(message.data) + received.append(payload) + method = payload.get("method") + request_id = payload.get("id") + if method == "connect": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.create": + await ws.send_json( + { + "type": "res", + "id": request_id, + "ok": True, + "result": {"sessionKey": "livekit:kitchen:helper"}, + } + ) + elif method == "sessions.send": + await ws.send_json( + { + "type": "event", + "event": "agent", + "payload": {"delta": {"content": "你好"}}, + } + ) + await ws.send_json( + { + "type": "res", + "id": request_id, + "ok": True, + "result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}}, + } + ) + await ws.close() + + return ws + + app = web.Application() + app.router.add_get("/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + gateway_llm = HermesGatewayLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/ws", + token="secret-token", + state=GatewaySessionState(room_name="kitchen", agent_id="helper"), + agent_id="helper", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="杯子在哪里") + + try: + collected = await gateway_llm.chat(chat_ctx=ctx).collect() + finally: + await gateway_llm.aclose() + await runner.cleanup() + + assert collected.text == "你好" + assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"] + send_request = received[2] + assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper" + assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}] + + +def test_extract_text_delta_reads_final_message_content() -> None: + assert ( + extract_text_delta( + { + "type": "event", + "event": "session.message.completed", + "payload": { + "message": { + "content": [ + {"type": "text", "text": "完整回复"}, + ] + } + }, + } + ) + == "完整回复" + ) + + +def test_is_error_response_accepts_error_events() -> None: + assert is_error_response( + {"type": "event", "event": "agent.error", "payload": {"error": "boom"}}, + request_id="send-1", + ) + assert is_error_response( + {"type": "event", "event": "run.failed", "payload": {"message": "boom"}}, + request_id="send-1", + ) + + +async def test_llm_stream_maps_gateway_error_events_to_api_connection_error( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}}) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + payload = json.loads(message.data) + method = payload.get("method") + request_id = payload.get("id") + if method == "connect": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.create": + await ws.send_json({"type": "res", "id": request_id, "ok": True}) + elif method == "sessions.send": + await ws.send_json( + { + "type": "event", + "event": "run.failed", + "payload": {"message": "gateway exploded"}, + } + ) + await ws.close() + + return ws + + app = web.Application() + app.router.add_get("/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + gateway_llm = HermesGatewayLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/ws", + token=None, + state=GatewaySessionState(room_name="kitchen", agent_id="helper"), + agent_id="helper", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="hello") + + try: + with pytest.raises(APIConnectionError, match="gateway exploded"): + await gateway_llm.chat(chat_ctx=ctx).collect() + finally: + await gateway_llm.aclose() + await runner.cleanup() diff --git a/tts.py b/tts.py index a3cbc58..1afaa2b 100644 --- a/tts.py +++ b/tts.py @@ -7,7 +7,6 @@ import time import wave from collections.abc import Mapping from io import BytesIO -from typing import Optional import aiohttp @@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS): *, url: str, model_name: str = "voxcpmtts", - params: Optional[Mapping[str, object]] = None, - prompt_wav_path: Optional[str] = None, + params: Mapping[str, object] | None = None, + prompt_wav_path: str | None = None, prompt_wav_field: str = "prompt_wav", sample_rate: int = 16000, num_channels: int = 1, timeout: float = 60.0, - http_session: Optional[aiohttp.ClientSession] = None, + http_session: aiohttp.ClientSession | None = None, ) -> None: super().__init__( capabilities=tts.TTSCapabilities(streaming=False),