Compare commits
18 Commits
6ec16bf68e
...
beaver
| Author | SHA1 | Date | |
|---|---|---|---|
| 820dc44053 | |||
| 78b9138c17 | |||
| 52e6d3cd9c | |||
| 0a50f25dfa | |||
| 34cf1b9736 | |||
| af261d3b63 | |||
| 879c73bfee | |||
| b1cad592e2 | |||
| e7529dc47b | |||
| 7efd9eba98 | |||
| e097323176 | |||
| 2064db15dc | |||
| f272053a95 | |||
| fba51a5257 | |||
| b18c5b40da | |||
| 89011fed81 | |||
| 3a2f5c4252 | |||
| 746053fd58 |
91
.env.example
Normal file
91
.env.example
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# LiveKit connection
|
||||||
|
LIVEKIT_URL=ws://localhost:7880
|
||||||
|
LIVEKIT_API_KEY=
|
||||||
|
LIVEKIT_API_SECRET=
|
||||||
|
|
||||||
|
CUSTOM_AGENT_PROFILE=normal
|
||||||
|
# CUSTOM_AGENT_NAME=normal-agent
|
||||||
|
CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver
|
||||||
|
|
||||||
|
# Beaver terminal text WebSocket
|
||||||
|
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
|
||||||
|
TERMINAL_PEER_ID=device-001
|
||||||
|
TERMINAL_DEVICE_NAME=desk-terminal
|
||||||
|
|
||||||
|
# ASR blackbox
|
||||||
|
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
|
||||||
|
CUSTOM_ASR_MODEL=qwen
|
||||||
|
CUSTOM_ASR_LANGUAGE=Chinese
|
||||||
|
CUSTOM_ASR_OUTPUT_LANGUAGE=zh
|
||||||
|
CUSTOM_ASR_HOTWORDS=
|
||||||
|
CUSTOM_ASR_ITN=
|
||||||
|
CUSTOM_ASR_CHUNK_MODE=
|
||||||
|
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
|
||||||
|
CUSTOM_ASR_MAX_SPEECH_DURATION=12
|
||||||
|
# Keep false if forced ASR turns should reply even while input audio continues.
|
||||||
|
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
|
||||||
|
|
||||||
|
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
|
||||||
|
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
|
||||||
|
# CUSTOM_LLM_PROVIDER=beaver
|
||||||
|
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
|
||||||
|
|
||||||
|
# OpenAI-compatible LLM
|
||||||
|
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
|
||||||
|
# CUSTOM_LLM_MODEL=Qwen3.6-35B
|
||||||
|
# CUSTOM_LLM_API_KEY=sk-
|
||||||
|
# CUSTOM_LLM_VERIFY_SSL=false
|
||||||
|
|
||||||
|
CUSTOM_LLM_BASE_URL=http:/localhost/v1
|
||||||
|
CUSTOM_LLM_MODEL=Mistral-Medium-3.5-128B
|
||||||
|
CUSTOM_LLM_API_KEY=sk-
|
||||||
|
CUSTOM_LLM_VERIFY_SSL=false
|
||||||
|
CUSTOM_SAVE_MODEL_IMAGES=true
|
||||||
|
|
||||||
|
# CUSTOM_TEXT_LLM_MODEL=
|
||||||
|
# CUSTOM_VISION_LLM_MODEL=
|
||||||
|
|
||||||
|
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
|
||||||
|
# CUSTOM_LLM_MODEL=deepseek-v4-flash
|
||||||
|
# CUSTOM_LLM_API_KEY=sk-
|
||||||
|
# CUSTOM_LLM_VERIFY_SSL=false
|
||||||
|
|
||||||
|
|
||||||
|
# TTS blackbox
|
||||||
|
CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
|
||||||
|
CUSTOM_TTS_MODEL=voxcpmtts
|
||||||
|
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
|
||||||
|
CUSTOM_TTS_STREAMING=true
|
||||||
|
# CUSTOM_TTS_PROMPT_TEXT=澳门有乜嘢好食嘅
|
||||||
|
|
||||||
|
# VoxCPM TTS parameters
|
||||||
|
VOXCPM_CFG_VALUE=2.0
|
||||||
|
VOXCPM_INFERENCE_TIMESTEPS=10
|
||||||
|
VOXCPM_DO_NORMALIZE=true
|
||||||
|
VOXCPM_DENOISE=true
|
||||||
|
VOXCPM_RETRY_BADCASE=true
|
||||||
|
VOXCPM_RETRY_BADCASE_MAX_TIMES=3
|
||||||
|
VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD=6.0
|
||||||
|
|
||||||
|
# MeloTTS parameters
|
||||||
|
CUSTOM_TTS_SPEED=1.0
|
||||||
|
|
||||||
|
# CosyVoice parameters
|
||||||
|
CUSTOM_TTS_SPK_ID=
|
||||||
|
CUSTOM_TTS_MODE=
|
||||||
|
CUSTOM_TTS_INSTRUCT_TEXT=
|
||||||
|
|
||||||
|
# GPT-SoVITS parameters
|
||||||
|
CUSTOM_TTS_TEXT_LANG=zh
|
||||||
|
CUSTOM_TTS_PROMPT_LANG=zh
|
||||||
|
CUSTOM_TTS_TEXT_SPLIT_METHOD=cut0
|
||||||
|
CUSTOM_TTS_BATCH_SIZE=1
|
||||||
|
CUSTOM_TTS_MEDIA_TYPE=wav
|
||||||
|
CUSTOM_TTS_REF_AUDIO_PATH=
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
|
||||||
|
CUSTOM_MEMORY_TIMEOUT=2
|
||||||
|
CUSTOM_MEMORY_MAX_CHARS=2000
|
||||||
|
CUSTOM_MEMORY_API_KEY=
|
||||||
|
CUSTOM_PREEMPTIVE_GENERATION=false
|
||||||
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
__pycache__/
|
||||||
|
.env
|
||||||
|
model_images/
|
||||||
260
asr.py
260
asr.py
@ -1,11 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional, Union
|
from collections import deque
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from livekit import rtc
|
from livekit import rtc
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
NOT_GIVEN,
|
NOT_GIVEN,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIConnectOptions,
|
APIConnectOptions,
|
||||||
@ -17,9 +20,15 @@ from livekit.agents import (
|
|||||||
utils,
|
utils,
|
||||||
)
|
)
|
||||||
from livekit.agents.utils import is_given
|
from livekit.agents.utils import is_given
|
||||||
|
from livekit.agents.vad import VAD, VADEventType
|
||||||
|
|
||||||
logger = logging.getLogger("blackbox-asr")
|
logger = logging.getLogger("blackbox-asr")
|
||||||
|
|
||||||
|
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
|
||||||
|
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
|
||||||
|
)
|
||||||
|
STT_CAPABILITIES = stt.STTCapabilities
|
||||||
|
|
||||||
|
|
||||||
class BlackboxSTT(stt.STT):
|
class BlackboxSTT(stt.STT):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
|
|||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
model_name: str = "sensevoice",
|
model_name: str = "sensevoice",
|
||||||
language: Optional[str] = "auto",
|
language: str | None = "auto",
|
||||||
output_language: str = "zh",
|
output_language: str = "zh",
|
||||||
hotwords: Optional[str] = None,
|
hotwords: str | None = None,
|
||||||
itn: Optional[Union[bool, str]] = None,
|
itn: bool | str | None = None,
|
||||||
chunk_mode: Optional[Union[bool, str]] = None,
|
chunk_mode: bool | str | None = None,
|
||||||
timeout: float = 30.0,
|
timeout: float = 30.0,
|
||||||
http_session: Optional[aiohttp.ClientSession] = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
capabilities=stt.STTCapabilities(
|
capabilities=stt.STTCapabilities(
|
||||||
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
|
|||||||
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||||
|
|
||||||
|
|
||||||
def _form_value(value: Union[bool, str]) -> str:
|
def _form_value(value: bool | str) -> str:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return str(value).lower()
|
return str(value).lower()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BoundedStreamAdapter(stt.STT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
stt: stt.STT,
|
||||||
|
vad: VAD,
|
||||||
|
max_speech_duration: float | None = 12.0,
|
||||||
|
pre_speech_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
capabilities=STT_CAPABILITIES(
|
||||||
|
streaming=True,
|
||||||
|
interim_results=False,
|
||||||
|
diarization=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._vad = vad
|
||||||
|
self._stt = stt
|
||||||
|
self._max_speech_duration = max_speech_duration
|
||||||
|
self._pre_speech_duration = pre_speech_duration
|
||||||
|
self._stt.on("metrics_collected", self._on_metrics_collected)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wrapped_stt(self) -> stt.STT:
|
||||||
|
return self._stt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._stt.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return self._stt.provider
|
||||||
|
|
||||||
|
async def _recognize_impl(
|
||||||
|
self,
|
||||||
|
buffer: utils.AudioBuffer,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> stt.SpeechEvent:
|
||||||
|
return await self._stt.recognize(
|
||||||
|
buffer=buffer, language=language, conn_options=conn_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> stt.RecognizeStream:
|
||||||
|
return _BoundedStreamAdapterWrapper(
|
||||||
|
self,
|
||||||
|
vad=self._vad,
|
||||||
|
wrapped_stt=self._stt,
|
||||||
|
language=language,
|
||||||
|
conn_options=conn_options,
|
||||||
|
max_speech_duration=self._max_speech_duration,
|
||||||
|
pre_speech_duration=self._pre_speech_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
self.emit("metrics_collected", *args, **kwargs)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
self._stt.off("metrics_collected", self._on_metrics_collected)
|
||||||
|
|
||||||
|
|
||||||
|
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
adapter: BoundedStreamAdapter,
|
||||||
|
*,
|
||||||
|
vad: VAD,
|
||||||
|
wrapped_stt: stt.STT,
|
||||||
|
language: NotGivenOr[str],
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
max_speech_duration: float | None,
|
||||||
|
pre_speech_duration: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
|
||||||
|
self._vad = vad
|
||||||
|
self._wrapped_stt = wrapped_stt
|
||||||
|
self._wrapped_stt_conn_options = conn_options
|
||||||
|
self._language = language
|
||||||
|
self._max_speech_duration = max_speech_duration
|
||||||
|
self._pre_speech_duration = pre_speech_duration
|
||||||
|
|
||||||
|
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
|
||||||
|
async for _ in event_aiter:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
vad_stream = self._vad.stream()
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
|
||||||
|
|
||||||
|
speech_active = False
|
||||||
|
segment_frames: list[rtc.AudioFrame] = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
pre_roll_frames: deque[rtc.AudioFrame] = deque()
|
||||||
|
pre_roll_duration = 0.0
|
||||||
|
|
||||||
|
def _frame_duration(frame: rtc.AudioFrame) -> float:
|
||||||
|
return frame.samples_per_channel / frame.sample_rate
|
||||||
|
|
||||||
|
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
|
||||||
|
nonlocal pre_roll_duration
|
||||||
|
pre_roll_frames.append(frame)
|
||||||
|
pre_roll_duration += _frame_duration(frame)
|
||||||
|
|
||||||
|
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
|
||||||
|
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
|
||||||
|
|
||||||
|
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
|
||||||
|
if not frames:
|
||||||
|
return
|
||||||
|
|
||||||
|
if forced:
|
||||||
|
logger.info(
|
||||||
|
"Forcing ASR segment after %.2fs of continuous speech",
|
||||||
|
self._max_speech_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
||||||
|
await recognize_queue.put(frames)
|
||||||
|
|
||||||
|
async def _recognize_worker() -> None:
|
||||||
|
while True:
|
||||||
|
frames = await recognize_queue.get()
|
||||||
|
if frames is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
merged_frames = utils.merge_frames(frames)
|
||||||
|
try:
|
||||||
|
t_event = await self._wrapped_stt.recognize(
|
||||||
|
buffer=merged_frames,
|
||||||
|
language=self._language,
|
||||||
|
conn_options=self._wrapped_stt_conn_options,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ASR segment recognition failed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not t_event.alternatives or not t_event.alternatives[0].text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
stt.SpeechEvent(
|
||||||
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||||
|
alternatives=[t_event.alternatives[0]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _forward_input() -> None:
|
||||||
|
nonlocal segment_duration, segment_frames, speech_active
|
||||||
|
|
||||||
|
async for input_frame in self._input_ch:
|
||||||
|
if isinstance(input_frame, self._FlushSentinel):
|
||||||
|
vad_stream.flush()
|
||||||
|
continue
|
||||||
|
|
||||||
|
vad_stream.push_frame(input_frame)
|
||||||
|
forced_frames: list[rtc.AudioFrame] = []
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
if speech_active:
|
||||||
|
segment_frames.append(input_frame)
|
||||||
|
segment_duration += _frame_duration(input_frame)
|
||||||
|
if (
|
||||||
|
self._max_speech_duration is not None
|
||||||
|
and segment_duration >= self._max_speech_duration
|
||||||
|
):
|
||||||
|
forced_frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
else:
|
||||||
|
_append_pre_roll(input_frame)
|
||||||
|
|
||||||
|
if forced_frames:
|
||||||
|
await _enqueue_segment(forced_frames, forced=True)
|
||||||
|
|
||||||
|
vad_stream.end_input()
|
||||||
|
|
||||||
|
final_frames: list[rtc.AudioFrame] = []
|
||||||
|
async with lock:
|
||||||
|
if speech_active and segment_frames:
|
||||||
|
final_frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
speech_active = False
|
||||||
|
|
||||||
|
if final_frames:
|
||||||
|
await _enqueue_segment(final_frames)
|
||||||
|
|
||||||
|
async def _recognize_from_vad() -> None:
|
||||||
|
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
|
||||||
|
|
||||||
|
async for event in vad_stream:
|
||||||
|
if event.type == VADEventType.START_OF_SPEECH:
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
if not speech_active:
|
||||||
|
speech_active = True
|
||||||
|
segment_frames = list(pre_roll_frames)
|
||||||
|
segment_duration = sum(_frame_duration(f) for f in segment_frames)
|
||||||
|
pre_roll_frames.clear()
|
||||||
|
pre_roll_duration = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.type != VADEventType.END_OF_SPEECH:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
speech_active = False
|
||||||
|
|
||||||
|
await _enqueue_segment(frames)
|
||||||
|
|
||||||
|
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
|
||||||
|
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
await recognize_queue.put(None)
|
||||||
|
await worker_task
|
||||||
|
finally:
|
||||||
|
await utils.aio.cancel_and_wait(*tasks, worker_task)
|
||||||
|
await vad_stream.aclose()
|
||||||
|
|||||||
128
beaver_llm.py
Normal file
128
beaver_llm.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
from beaver_terminal_client import BeaverTerminalClient
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from custom.beaver_terminal_client import BeaverTerminalClient
|
||||||
|
from livekit.agents import llm
|
||||||
|
from livekit.agents.types import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
NOT_GIVEN,
|
||||||
|
APIConnectOptions,
|
||||||
|
NotGivenOr,
|
||||||
|
)
|
||||||
|
from livekit.agents.utils import shortuuid
|
||||||
|
|
||||||
|
logger = logging.getLogger("beaver-llm")
|
||||||
|
|
||||||
|
|
||||||
|
def latest_user_text(chat_ctx: llm.ChatContext) -> str:
|
||||||
|
for message in reversed(chat_ctx.messages()):
|
||||||
|
if message.role != "user":
|
||||||
|
continue
|
||||||
|
return _content_to_text(message.content)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_text(content: Sequence[llm.ChatContent]) -> str:
|
||||||
|
text_parts = [item for item in content if isinstance(item, str)]
|
||||||
|
return "\n".join(text_parts)
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverLLM(llm.LLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
peer_id: str,
|
||||||
|
device_name: str,
|
||||||
|
model_name: str = "beaver-terminal",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name)
|
||||||
|
self._model_name = model_name
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return "beaver"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_id(self) -> str | None:
|
||||||
|
return self._client.session_id
|
||||||
|
|
||||||
|
async def connect(self, *, warmup_text: str | None = None) -> str | None:
|
||||||
|
warmup_reply: str | None = None
|
||||||
|
async with self._lock:
|
||||||
|
await self._client.connect()
|
||||||
|
if warmup_text and warmup_text.strip():
|
||||||
|
warmup_reply = await self._client.send_text(warmup_text.strip())
|
||||||
|
|
||||||
|
if warmup_reply is None:
|
||||||
|
logger.info("Beaver handshake completed session_id=%s", self.session_id)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Beaver handshake warmup completed session_id=%s reply_len=%s",
|
||||||
|
self.session_id,
|
||||||
|
len(warmup_reply),
|
||||||
|
)
|
||||||
|
return warmup_reply
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool] | None = None,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
||||||
|
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
||||||
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
||||||
|
) -> llm.LLMStream:
|
||||||
|
return BeaverLLMStream(
|
||||||
|
self,
|
||||||
|
chat_ctx=chat_ctx,
|
||||||
|
tools=tools or [],
|
||||||
|
conn_options=conn_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._client.close()
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverLLMStream(llm.LLMStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beaver_llm: BeaverLLM,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool],
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||||
|
self._beaver_llm = beaver_llm
|
||||||
|
self._request_id = shortuuid("beaver_")
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
user_text = latest_user_text(self.chat_ctx)
|
||||||
|
async with self._beaver_llm._lock:
|
||||||
|
reply = await self._beaver_llm._client.send_text(user_text)
|
||||||
|
|
||||||
|
if reply:
|
||||||
|
self._send_text_chunk(reply)
|
||||||
|
|
||||||
|
def _send_text_chunk(self, text: str) -> None:
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
llm.ChatChunk(
|
||||||
|
id=self._request_id,
|
||||||
|
delta=llm.ChoiceDelta(role="assistant", content=text),
|
||||||
|
)
|
||||||
|
)
|
||||||
238
beaver_terminal_client.py
Normal file
238
beaver_terminal_client.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
logger = logging.getLogger("beaver-terminal-client")
|
||||||
|
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
|
||||||
|
DEFAULT_TERMINAL_PEER_ID = "device-001"
|
||||||
|
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
|
||||||
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageIdGenerator:
|
||||||
|
peer_id: str
|
||||||
|
nonce: str | None = None
|
||||||
|
initial_counter: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.counter = self.initial_counter
|
||||||
|
|
||||||
|
def next_id(self) -> str:
|
||||||
|
self.counter += 1
|
||||||
|
if self.nonce:
|
||||||
|
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
|
||||||
|
return f"{self.peer_id}-{self.counter:06d}"
|
||||||
|
|
||||||
|
|
||||||
|
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": peer_id,
|
||||||
|
"device_name": device_name,
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "message",
|
||||||
|
"message_id": message_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
peer_id: str,
|
||||||
|
device_name: str,
|
||||||
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
|
message_ids: MessageIdGenerator | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._url = url
|
||||||
|
self._peer_id = peer_id
|
||||||
|
self._device_name = device_name
|
||||||
|
self._owned_session = http_session is None
|
||||||
|
self._http_session = http_session
|
||||||
|
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||||
|
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
|
||||||
|
self.session_id: str | None = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
session = self._ensure_http_session()
|
||||||
|
try:
|
||||||
|
self._ws = await session.ws_connect(self._url)
|
||||||
|
await self._send_json(
|
||||||
|
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||||
|
)
|
||||||
|
frame = await self._receive_json()
|
||||||
|
if frame.get("type") != "connected":
|
||||||
|
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||||
|
session_id = frame.get("session_id")
|
||||||
|
self.session_id = session_id if isinstance(session_id, str) else None
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
|
||||||
|
await self._cleanup_failed_connection()
|
||||||
|
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
|
||||||
|
except Exception:
|
||||||
|
await self._cleanup_failed_connection()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> str:
|
||||||
|
for attempt in range(2):
|
||||||
|
if not self._websocket_is_open():
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
message_id = self._message_ids.next_id()
|
||||||
|
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_json(message_frame)
|
||||||
|
return await self._wait_for_reply(message_id)
|
||||||
|
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||||
|
if attempt == 1:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
"Beaver websocket closed before assistant reply"
|
||||||
|
) from exc
|
||||||
|
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
raise BeaverTerminalError("unreachable Beaver send state")
|
||||||
|
|
||||||
|
async def _wait_for_reply(self, message_id: str) -> str:
|
||||||
|
while True:
|
||||||
|
frame = await self._receive_json()
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||||
|
reply = frame.get("reply")
|
||||||
|
if isinstance(reply, str):
|
||||||
|
return reply
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
frame_type == "message"
|
||||||
|
and frame.get("role") == "assistant"
|
||||||
|
and frame.get("message_id") == message_id
|
||||||
|
):
|
||||||
|
text = frame.get("text")
|
||||||
|
if frame.get("finish_reason") == "error":
|
||||||
|
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
|
||||||
|
return text if isinstance(text, str) else ""
|
||||||
|
|
||||||
|
if frame_type == "error":
|
||||||
|
error = frame.get("error")
|
||||||
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
await self._send_json({"type": "ping"})
|
||||||
|
while True:
|
||||||
|
frame = await self._receive_json()
|
||||||
|
if frame.get("type") == "pong":
|
||||||
|
return True
|
||||||
|
if frame.get("type") == "error":
|
||||||
|
error = frame.get("error")
|
||||||
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
if self._owned_session and self._http_session is not None:
|
||||||
|
await self._http_session.close()
|
||||||
|
self._http_session = None
|
||||||
|
|
||||||
|
async def _close_websocket(self) -> None:
|
||||||
|
if self._ws is not None:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
async def _cleanup_failed_connection(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
self.session_id = None
|
||||||
|
if self._owned_session and self._http_session is not None:
|
||||||
|
await self._http_session.close()
|
||||||
|
self._http_session = None
|
||||||
|
|
||||||
|
def _websocket_is_open(self) -> bool:
|
||||||
|
return self._ws is not None and not self._ws.closed
|
||||||
|
|
||||||
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = aiohttp.ClientSession()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def _send_json(self, frame: dict[str, Any]) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||||
|
await self._ws.send_json(frame)
|
||||||
|
|
||||||
|
async def _receive_json(self) -> dict[str, Any]:
|
||||||
|
if self._ws is None:
|
||||||
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||||
|
|
||||||
|
message = await self._ws.receive()
|
||||||
|
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||||
|
if message.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||||
|
)
|
||||||
|
if message.type != aiohttp.WSMsgType.TEXT:
|
||||||
|
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||||
|
data = message.json()
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def client_from_env() -> BeaverTerminalClient:
|
||||||
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
|
return BeaverTerminalClient(
|
||||||
|
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
|
||||||
|
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
|
||||||
|
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_console() -> None:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
client = client_from_env()
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
logger.info("Connected to Beaver session_id=%s", client.session_id)
|
||||||
|
while True:
|
||||||
|
text = await asyncio.to_thread(input, "> ")
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
if text in {"quit", "exit"}:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
reply = await client.send_text(text)
|
||||||
|
except BeaverTerminalError as exc:
|
||||||
|
logger.error("Beaver turn failed: %s", exc)
|
||||||
|
continue
|
||||||
|
print(reply)
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(run_console())
|
||||||
1133
custom_agent.py
1133
custom_agent.py
File diff suppressed because it is too large
Load Diff
391
hermes_gateway.py
Normal file
391
hermes_gateway.py
Normal file
@ -0,0 +1,391 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from livekit.agents import llm
|
||||||
|
from livekit.agents._exceptions import APIConnectionError
|
||||||
|
from livekit.agents.types import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
NOT_GIVEN,
|
||||||
|
APIConnectOptions,
|
||||||
|
NotGivenOr,
|
||||||
|
)
|
||||||
|
from livekit.agents.utils import shortuuid
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GatewaySessionState:
|
||||||
|
room_name: str
|
||||||
|
agent_id: str | None = None
|
||||||
|
session_key: str | None = None
|
||||||
|
session_mode: str = "per_room"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.session_mode != "per_room":
|
||||||
|
raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room")
|
||||||
|
if self.session_key is None:
|
||||||
|
suffix = self.agent_id or "default"
|
||||||
|
self.session_key = f"livekit:{self.room_name}:{suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
class HermesGatewayLLM(llm.LLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
token: str | None,
|
||||||
|
state: GatewaySessionState,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
model_name: str = "hermes-agent",
|
||||||
|
request_timeout: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._url = url
|
||||||
|
self._token = token
|
||||||
|
self._state = state
|
||||||
|
self._agent_id = agent_id
|
||||||
|
self._model_name = model_name
|
||||||
|
self._request_timeout = request_timeout
|
||||||
|
self._http_session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return "hermes-gateway"
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool] | None = None,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
||||||
|
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
||||||
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
||||||
|
) -> llm.LLMStream:
|
||||||
|
return HermesGatewayLLMStream(
|
||||||
|
self,
|
||||||
|
chat_ctx=chat_ctx,
|
||||||
|
tools=tools or [],
|
||||||
|
conn_options=conn_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self._request_timeout)
|
||||||
|
self._http_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self._http_session is not None:
|
||||||
|
await self._http_session.close()
|
||||||
|
self._http_session = None
|
||||||
|
|
||||||
|
|
||||||
|
class HermesGatewayLLMStream(llm.LLMStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm: HermesGatewayLLM,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool],
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||||
|
self._llm = llm
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
request_id = shortuuid("gwreq_")
|
||||||
|
async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws:
|
||||||
|
await _connect(ws, token=self._llm._token)
|
||||||
|
await _send_rpc(
|
||||||
|
ws,
|
||||||
|
method="sessions.create",
|
||||||
|
params={
|
||||||
|
"key": self._llm._state.session_key,
|
||||||
|
"sessionKey": self._llm._state.session_key,
|
||||||
|
"agentId": self._llm._agent_id,
|
||||||
|
"metadata": {
|
||||||
|
"source": "livekit",
|
||||||
|
"room": self._llm._state.room_name,
|
||||||
|
},
|
||||||
|
"idempotencyKey": self._llm._state.session_key,
|
||||||
|
},
|
||||||
|
request_id=shortuuid("gwcreate_"),
|
||||||
|
)
|
||||||
|
await _send_rpc(
|
||||||
|
ws,
|
||||||
|
method="sessions.send",
|
||||||
|
params={
|
||||||
|
"key": self._llm._state.session_key,
|
||||||
|
"sessionKey": self._llm._state.session_key,
|
||||||
|
"agentId": self._llm._agent_id,
|
||||||
|
"messages": chat_context_to_gateway_messages(self.chat_ctx),
|
||||||
|
"stream": True,
|
||||||
|
"idempotencyKey": request_id,
|
||||||
|
},
|
||||||
|
request_id=request_id,
|
||||||
|
wait_response=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
streamed_text = ""
|
||||||
|
async for frame in _iter_gateway_frames(ws):
|
||||||
|
if is_error_response(frame, request_id=request_id):
|
||||||
|
raise APIConnectionError(_gateway_error_message(frame), retryable=False)
|
||||||
|
text = extract_text_delta(frame)
|
||||||
|
if text:
|
||||||
|
if text.startswith(streamed_text):
|
||||||
|
text = text[len(streamed_text) :]
|
||||||
|
if text:
|
||||||
|
streamed_text += text
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
llm.ChatChunk(
|
||||||
|
id=request_id,
|
||||||
|
delta=llm.ChoiceDelta(role="assistant", content=text),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_terminal_event(frame, request_id=request_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def build_connect_params(*, token: str | None) -> dict[str, Any]:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"minProtocol": 3,
|
||||||
|
"maxProtocol": 4,
|
||||||
|
"client": {
|
||||||
|
"id": "gateway-client",
|
||||||
|
"version": "livekit-custom-agent",
|
||||||
|
"platform": "python",
|
||||||
|
"mode": "backend",
|
||||||
|
},
|
||||||
|
"role": "operator",
|
||||||
|
"scopes": ["operator.read", "operator.write"],
|
||||||
|
"caps": [],
|
||||||
|
"commands": [],
|
||||||
|
"permissions": {},
|
||||||
|
"locale": "zh-CN",
|
||||||
|
"userAgent": "livekit-custom-agent",
|
||||||
|
}
|
||||||
|
if token:
|
||||||
|
params["auth"] = {"token": token}
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]:
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
for message in chat_ctx.messages():
|
||||||
|
content = _message_content_to_gateway_content(message.content)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
messages.append({"role": message.role, "content": content})
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_delta(frame: dict[str, Any]) -> str:
|
||||||
|
payload = frame.get("payload")
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
payload = frame
|
||||||
|
|
||||||
|
for path in (
|
||||||
|
("delta", "content"),
|
||||||
|
("delta", "text"),
|
||||||
|
("message", "delta", "content"),
|
||||||
|
("message", "delta", "text"),
|
||||||
|
("message", "content"),
|
||||||
|
("content",),
|
||||||
|
("text",),
|
||||||
|
):
|
||||||
|
value = _get_nested(payload, path)
|
||||||
|
text = _content_to_text(value)
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool:
|
||||||
|
if frame.get("type") == "res" and frame.get("id") == request_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
event = frame.get("event")
|
||||||
|
if event in {
|
||||||
|
"agent.done",
|
||||||
|
"agent.completed",
|
||||||
|
"agent.error",
|
||||||
|
"session.message.completed",
|
||||||
|
"session.run.completed",
|
||||||
|
"sessions.run.completed",
|
||||||
|
"run.completed",
|
||||||
|
"run.failed",
|
||||||
|
}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
payload = frame.get("payload")
|
||||||
|
if isinstance(payload, dict) and payload.get("done") is True:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool:
|
||||||
|
if (
|
||||||
|
frame.get("type") == "res"
|
||||||
|
and frame.get("id") == request_id
|
||||||
|
and frame.get("ok") is False
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return frame.get("event") in {
|
||||||
|
"agent.error",
|
||||||
|
"session.error",
|
||||||
|
"session.run.failed",
|
||||||
|
"sessions.run.failed",
|
||||||
|
"run.failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _gateway_error_message(frame: dict[str, Any]) -> str:
|
||||||
|
error = frame.get("error")
|
||||||
|
if isinstance(error, str):
|
||||||
|
return f"OpenClaw gateway request failed: {error}"
|
||||||
|
if isinstance(error, dict):
|
||||||
|
message = error.get("message") or error.get("error")
|
||||||
|
if isinstance(message, str):
|
||||||
|
return f"OpenClaw gateway request failed: {message}"
|
||||||
|
|
||||||
|
payload = frame.get("payload")
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
message = payload.get("message") or payload.get("error")
|
||||||
|
if isinstance(message, str):
|
||||||
|
return f"OpenClaw gateway request failed: {message}"
|
||||||
|
|
||||||
|
return f"OpenClaw gateway request failed: {frame!r}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None:
|
||||||
|
first = await _receive_json(ws)
|
||||||
|
if first.get("event") != "connect.challenge":
|
||||||
|
raise RuntimeError(f"expected connect.challenge, received {first!r}")
|
||||||
|
|
||||||
|
request_id = shortuuid("gwconnect_")
|
||||||
|
await _send_rpc(
|
||||||
|
ws,
|
||||||
|
method="connect",
|
||||||
|
params=build_connect_params(token=token),
|
||||||
|
request_id=request_id,
|
||||||
|
wait_response=False,
|
||||||
|
)
|
||||||
|
response = await _wait_for_response(ws, request_id=request_id)
|
||||||
|
if not response.get("ok"):
|
||||||
|
raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_rpc(
|
||||||
|
ws: aiohttp.ClientWebSocketResponse,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
request_id: str,
|
||||||
|
wait_response: bool = True,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
await ws.send_str(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "req",
|
||||||
|
"id": request_id,
|
||||||
|
"method": method,
|
||||||
|
"params": _drop_none(params),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not wait_response:
|
||||||
|
return None
|
||||||
|
response = await _wait_for_response(ws, request_id=request_id)
|
||||||
|
if not response.get("ok", False):
|
||||||
|
raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_response(
|
||||||
|
ws: aiohttp.ClientWebSocketResponse, *, request_id: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
async for frame in _iter_gateway_frames(ws):
|
||||||
|
if frame.get("type") == "res" and frame.get("id") == request_id:
|
||||||
|
return frame
|
||||||
|
raise RuntimeError(f"OpenClaw gateway closed before response {request_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _iter_gateway_frames(
|
||||||
|
ws: aiohttp.ClientWebSocketResponse,
|
||||||
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
|
async for message in ws:
|
||||||
|
if message.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
yield json.loads(message.data)
|
||||||
|
elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
|
||||||
|
return
|
||||||
|
elif message.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]:
|
||||||
|
message = await ws.receive()
|
||||||
|
if message.type != aiohttp.WSMsgType.TEXT:
|
||||||
|
raise RuntimeError(f"expected gateway text frame, received {message.type!r}")
|
||||||
|
return json.loads(message.data)
|
||||||
|
|
||||||
|
|
||||||
|
def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any:
|
||||||
|
parts: list[dict[str, Any]] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
if item:
|
||||||
|
parts.append({"type": "text", "text": item})
|
||||||
|
elif isinstance(item, llm.ImageContent) and isinstance(item.image, str):
|
||||||
|
parts.append({"type": "image_url", "image_url": {"url": item.image}})
|
||||||
|
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
if len(parts) == 1 and parts[0]["type"] == "text":
|
||||||
|
return parts[0]["text"]
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_text(value: Any) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
text_parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
text_parts.append(text)
|
||||||
|
return "".join(text_parts)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any:
|
||||||
|
current: Any = data
|
||||||
|
for key in path:
|
||||||
|
if not isinstance(current, dict):
|
||||||
|
return None
|
||||||
|
current = current.get(key)
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_none(value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {key: _drop_none(item) for key, item in value.items() if item is not None}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_drop_none(item) for item in value]
|
||||||
|
return value
|
||||||
292
memory.py
Normal file
292
memory.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
|
||||||
|
|
||||||
|
logger = logging.getLogger("memory-recall")
|
||||||
|
|
||||||
|
_LOCATION_STOPWORDS = {
|
||||||
|
"哪里",
|
||||||
|
"在哪",
|
||||||
|
"在哪里",
|
||||||
|
"哪儿",
|
||||||
|
"位置",
|
||||||
|
"什么地方",
|
||||||
|
"帮我找",
|
||||||
|
"帮我寻找",
|
||||||
|
"找一下",
|
||||||
|
"找",
|
||||||
|
"请问",
|
||||||
|
"请",
|
||||||
|
"吗",
|
||||||
|
"呢",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRecallClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
api_key: str | None = None,
|
||||||
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._url = url
|
||||||
|
self._timeout = timeout
|
||||||
|
self._max_chars = max_chars
|
||||||
|
self._api_key = api_key
|
||||||
|
self._http_session = http_session
|
||||||
|
self._cached_payload: Any | None = None
|
||||||
|
|
||||||
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = utils.http_context.http_session()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def recall(self, query: str) -> str:
|
||||||
|
query = query.strip()
|
||||||
|
if not query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self._api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._ensure_session().get(
|
||||||
|
self._url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
raise APIStatusError(
|
||||||
|
message=f"Memory recall error: {error_text}",
|
||||||
|
status_code=resp.status,
|
||||||
|
request_id=None,
|
||||||
|
body=error_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await resp.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
data = await resp.text()
|
||||||
|
|
||||||
|
self._cached_payload = data
|
||||||
|
return self._format_memory(data, query)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
|
||||||
|
)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.warning("Memory recall connection error: %s, using cached room graph", e)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
|
||||||
|
logger.warning("Memory recall failed: %s, using cached room graph", e)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
|
||||||
|
def _format_memory(self, data: Any, query: str) -> str:
|
||||||
|
memory = _format_room_graph_memory(data, query)
|
||||||
|
if len(memory) > self._max_chars:
|
||||||
|
memory = memory[: self._max_chars].rstrip()
|
||||||
|
return memory
|
||||||
|
|
||||||
|
def _format_cached_memory(self, query: str) -> str:
|
||||||
|
if self._cached_payload is None:
|
||||||
|
return ""
|
||||||
|
return self._format_memory(self._cached_payload, query)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_room_graph_memory(payload: Any, query: str) -> str:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
logger.warning("Unsupported room graph response: %s", payload)
|
||||||
|
return ""
|
||||||
|
objects = payload.get("objects", [])
|
||||||
|
relations = payload.get("relations", [])
|
||||||
|
summary = payload.get("summary", "")
|
||||||
|
|
||||||
|
if not objects and not relations and not summary:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
query_terms = _query_terms(query)
|
||||||
|
relevant_objects, relevant_relations = _relevant_room_graph(
|
||||||
|
objects=objects,
|
||||||
|
relations=relations,
|
||||||
|
query_terms=query_terms,
|
||||||
|
)
|
||||||
|
|
||||||
|
objects_text = json.dumps(
|
||||||
|
relevant_objects or _compact_items(objects, limit=12),
|
||||||
|
ensure_ascii=False,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
relations_text = json.dumps(
|
||||||
|
relevant_relations or _compact_items(relations, limit=24),
|
||||||
|
ensure_ascii=False,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你是一个物品定位助手。
|
||||||
|
|
||||||
|
目标物品:{query}
|
||||||
|
相关物品:{objects_text}
|
||||||
|
相关空间关系:{relations_text}
|
||||||
|
房间概览:{summary}
|
||||||
|
|
||||||
|
回答要求:
|
||||||
|
1. 只说明它和其他物品的位置关系。
|
||||||
|
2. 不要编造不存在的关系。
|
||||||
|
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
|
||||||
|
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
|
||||||
|
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
|
||||||
|
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
|
||||||
|
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
|
||||||
|
query_terms,
|
||||||
|
len(relevant_objects),
|
||||||
|
len(objects) if isinstance(objects, list) else 0,
|
||||||
|
len(relevant_relations),
|
||||||
|
len(relations) if isinstance(relations, list) else 0,
|
||||||
|
len(prompt),
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _query_terms(query: str) -> list[str]:
|
||||||
|
normalized = re.sub(r"[\s??。!,、,.!]", "", query)
|
||||||
|
for word in _LOCATION_STOPWORDS:
|
||||||
|
normalized = normalized.replace(word, "")
|
||||||
|
|
||||||
|
terms = [normalized] if normalized else []
|
||||||
|
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
|
||||||
|
if token not in _LOCATION_STOPWORDS and token not in terms:
|
||||||
|
terms.append(token)
|
||||||
|
return terms[:4]
|
||||||
|
|
||||||
|
|
||||||
|
def _relevant_room_graph(
|
||||||
|
*,
|
||||||
|
objects: Any,
|
||||||
|
relations: Any,
|
||||||
|
query_terms: list[str],
|
||||||
|
) -> tuple[list[Any], list[Any]]:
|
||||||
|
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
matched_ids: set[str] = set()
|
||||||
|
matched_objects: list[Any] = []
|
||||||
|
object_by_id: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
obj_id = _object_id(obj)
|
||||||
|
if obj_id:
|
||||||
|
object_by_id[obj_id] = obj
|
||||||
|
|
||||||
|
obj_text = _compact_text(obj)
|
||||||
|
if any(term and term in obj_text for term in query_terms):
|
||||||
|
matched_objects.append(obj)
|
||||||
|
if obj_id:
|
||||||
|
matched_ids.add(obj_id)
|
||||||
|
|
||||||
|
relevant_relations: list[Any] = []
|
||||||
|
related_ids: set[str] = set(matched_ids)
|
||||||
|
for relation in relations:
|
||||||
|
relation_text = _compact_text(relation)
|
||||||
|
relation_ids = _ids_in_value(relation)
|
||||||
|
if (
|
||||||
|
any(term and term in relation_text for term in query_terms)
|
||||||
|
or bool(matched_ids.intersection(relation_ids))
|
||||||
|
):
|
||||||
|
relevant_relations.append(relation)
|
||||||
|
related_ids.update(relation_ids)
|
||||||
|
|
||||||
|
relevant_objects = list(matched_objects)
|
||||||
|
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
|
||||||
|
for obj_id in related_ids:
|
||||||
|
obj = object_by_id.get(obj_id)
|
||||||
|
key = _object_key(obj)
|
||||||
|
if obj is not None and key not in seen_object_keys:
|
||||||
|
relevant_objects.append(obj)
|
||||||
|
seen_object_keys.add(key)
|
||||||
|
|
||||||
|
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_items(items: Any, *, limit: int) -> list[Any]:
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
return [_compact_item(item) for item in items[:limit]]
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_item(item: Any) -> Any:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
return item
|
||||||
|
|
||||||
|
preferred_keys = (
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"label",
|
||||||
|
"class",
|
||||||
|
"category",
|
||||||
|
"type",
|
||||||
|
"text",
|
||||||
|
"source",
|
||||||
|
"target",
|
||||||
|
"subject",
|
||||||
|
"object",
|
||||||
|
"relation",
|
||||||
|
"predicate",
|
||||||
|
"description",
|
||||||
|
)
|
||||||
|
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
|
||||||
|
return compact or item
|
||||||
|
|
||||||
|
|
||||||
|
def _object_id(obj: Any) -> str | None:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return None
|
||||||
|
for key in ("id", "object_id", "uuid", "name", "label"):
|
||||||
|
value = obj.get(key)
|
||||||
|
if isinstance(value, (str, int)):
|
||||||
|
return str(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _object_key(obj: Any) -> str:
|
||||||
|
return _object_id(obj) or _compact_text(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _ids_in_value(value: Any) -> set[str]:
|
||||||
|
ids: set[str] = set()
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for key, item in value.items():
|
||||||
|
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
|
||||||
|
if isinstance(item, (str, int)):
|
||||||
|
ids.add(str(item))
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
obj_id = _object_id(item)
|
||||||
|
if obj_id:
|
||||||
|
ids.add(obj_id)
|
||||||
|
ids.update(_ids_in_value(item))
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
ids.update(_ids_in_value(item))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_text(value: Any) -> str:
|
||||||
|
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
|
||||||
264
start_agent_profiles.py
Normal file
264
start_agent_profiles.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentProfile:
|
||||||
|
agent_name: str
|
||||||
|
llm_provider: str
|
||||||
|
input_mode: str
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_PROFILES = {
|
||||||
|
"normal": AgentProfile(
|
||||||
|
agent_name="normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode="voice",
|
||||||
|
),
|
||||||
|
"beaver": AgentProfile(
|
||||||
|
agent_name="beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode="voice",
|
||||||
|
),
|
||||||
|
"vision-normal": AgentProfile(
|
||||||
|
agent_name="vision-normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode="vision_voice",
|
||||||
|
),
|
||||||
|
"vision-beaver": AgentProfile(
|
||||||
|
agent_name="vision-beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode="vision_voice",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
DEFAULT_PROFILES = ("normal", "beaver")
|
||||||
|
|
||||||
|
|
||||||
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if normalized in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_profiles(value: str) -> list[str]:
|
||||||
|
if value.strip().lower() == "all":
|
||||||
|
return list(AGENT_PROFILES)
|
||||||
|
|
||||||
|
profiles: list[str] = []
|
||||||
|
for raw_profile in value.split(","):
|
||||||
|
profile = raw_profile.strip().lower().replace("_", "-")
|
||||||
|
if not profile:
|
||||||
|
continue
|
||||||
|
if profile not in AGENT_PROFILES:
|
||||||
|
valid = ", ".join([*AGENT_PROFILES, "all"])
|
||||||
|
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
|
||||||
|
profiles.append(profile)
|
||||||
|
|
||||||
|
if not profiles:
|
||||||
|
raise ValueError("at least one profile is required")
|
||||||
|
return list(dict.fromkeys(profiles))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_http_port_base(value: str | None) -> int:
|
||||||
|
if value is None or not value.strip():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
port = int(value)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"invalid HTTP port base {value!r}") from exc
|
||||||
|
|
||||||
|
if port < 0 or port > 65535:
|
||||||
|
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def _profile_http_port(http_port_base: int, index: int) -> int:
|
||||||
|
if http_port_base == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
port = http_port_base + index
|
||||||
|
if port > 65535:
|
||||||
|
raise ValueError("HTTP port range exceeds 65535")
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
|
||||||
|
profile = AGENT_PROFILES[profile_name]
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.update(
|
||||||
|
{
|
||||||
|
"CUSTOM_AGENT_PROFILE": profile_name,
|
||||||
|
"CUSTOM_AGENT_NAME": profile.agent_name,
|
||||||
|
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
|
||||||
|
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
|
||||||
|
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
|
||||||
|
while line := await stream.readline():
|
||||||
|
text = line.decode("utf-8", errors="replace").rstrip()
|
||||||
|
print(f"[{prefix}] {text}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _start_profile(
|
||||||
|
profile_name: str,
|
||||||
|
*,
|
||||||
|
mode: str,
|
||||||
|
http_port: int,
|
||||||
|
reload: bool,
|
||||||
|
) -> asyncio.subprocess.Process:
|
||||||
|
profile = AGENT_PROFILES[profile_name]
|
||||||
|
script_path = Path(__file__).with_name("custom_agent.py")
|
||||||
|
args = [sys.executable, str(script_path), mode]
|
||||||
|
if mode == "dev" and not reload:
|
||||||
|
args.append("--no-reload")
|
||||||
|
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*args,
|
||||||
|
env=_child_env(profile_name, http_port=http_port),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
|
||||||
|
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if process.stdout is not None:
|
||||||
|
asyncio.create_task(_pipe_output(profile_name, process.stdout))
|
||||||
|
if process.stderr is not None:
|
||||||
|
asyncio.create_task(_pipe_output(profile_name, process.stderr))
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
|
||||||
|
for process in processes:
|
||||||
|
if process.returncode is None:
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(*(process.wait() for process in processes)),
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
for process in processes:
|
||||||
|
if process.returncode is None:
|
||||||
|
process.kill()
|
||||||
|
await asyncio.gather(*(process.wait() for process in processes))
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
profiles: list[str],
|
||||||
|
*,
|
||||||
|
mode: str,
|
||||||
|
http_port_base: int,
|
||||||
|
reload: bool,
|
||||||
|
) -> int:
|
||||||
|
processes = [
|
||||||
|
await _start_profile(
|
||||||
|
profile,
|
||||||
|
mode=mode,
|
||||||
|
http_port=_profile_http_port(http_port_base, index),
|
||||||
|
reload=reload,
|
||||||
|
)
|
||||||
|
for index, profile in enumerate(profiles)
|
||||||
|
]
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(sig, stop_event.set)
|
||||||
|
|
||||||
|
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
|
||||||
|
stop_task = asyncio.create_task(stop_event.wait())
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
[*wait_tasks, stop_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_code = 0
|
||||||
|
if stop_task not in done:
|
||||||
|
for task in done:
|
||||||
|
process = wait_tasks[task]
|
||||||
|
exit_code = task.result() or 0
|
||||||
|
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
|
||||||
|
|
||||||
|
await _terminate(processes)
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
|
||||||
|
parser.add_argument(
|
||||||
|
"mode",
|
||||||
|
nargs="?",
|
||||||
|
default="dev",
|
||||||
|
choices=("console", "dev", "start", "connect"),
|
||||||
|
help="custom_agent.py CLI mode to run for each profile",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profiles",
|
||||||
|
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
|
||||||
|
help="comma-separated profiles to start, or 'all'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--http-port-base",
|
||||||
|
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
|
||||||
|
help=(
|
||||||
|
"base HTTP health-check port for profile workers; "
|
||||||
|
"0 lets the OS assign free ports"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reload",
|
||||||
|
action="store_true",
|
||||||
|
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
|
||||||
|
help="enable auto-reload in dev mode",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
profiles = _parse_profiles(args.profiles)
|
||||||
|
http_port_base = _parse_http_port_base(args.http_port_base)
|
||||||
|
except ValueError as exc:
|
||||||
|
parser.error(str(exc))
|
||||||
|
|
||||||
|
raise SystemExit(
|
||||||
|
asyncio.run(
|
||||||
|
_run(
|
||||||
|
profiles,
|
||||||
|
mode=args.mode,
|
||||||
|
http_port_base=http_port_base,
|
||||||
|
reload=args.reload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
169
test_beaver_llm.py
Normal file
169
test_beaver_llm.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
try:
|
||||||
|
from custom.beaver_llm import BeaverLLM, latest_user_text
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from beaver_llm import BeaverLLM, latest_user_text
|
||||||
|
from livekit.agents import ChatContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_latest_user_text_uses_most_recent_user_message() -> None:
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="first")
|
||||||
|
ctx.add_message(role="assistant", content="ignored")
|
||||||
|
ctx.add_message(role="user", content=["second", "line"])
|
||||||
|
|
||||||
|
assert latest_user_text(ctx) == "second\nline"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="system", content="ignored instructions")
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert collected.text == "beaver reply"
|
||||||
|
assert received[0] == {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "livekit-room",
|
||||||
|
"device_name": "livekit-custom-agent",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
assert received[1]["type"] == "message"
|
||||||
|
assert received[1]["message_id"].startswith("livekit-room-")
|
||||||
|
assert received[1]["message_id"].endswith("-000001")
|
||||||
|
assert received[1]["text"] == "hello beaver"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
async with beaver_llm.chat(chat_ctx=ctx) as stream:
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.delta and chunk.delta.content:
|
||||||
|
chunks.append(chunk.delta.content)
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert chunks == ["beaver reply"]
|
||||||
458
test_beaver_terminal_client.py
Normal file
458
test_beaver_terminal_client.py
Normal file
@ -0,0 +1,458 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from custom.beaver_terminal_client import (
|
||||||
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalConnectionClosed,
|
||||||
|
BeaverTerminalError,
|
||||||
|
MessageIdGenerator,
|
||||||
|
build_connect_frame,
|
||||||
|
build_message_frame,
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from beaver_terminal_client import (
|
||||||
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalConnectionClosed,
|
||||||
|
BeaverTerminalError,
|
||||||
|
MessageIdGenerator,
|
||||||
|
build_connect_frame,
|
||||||
|
build_message_frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
||||||
|
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
|
||||||
|
|
||||||
|
assert frame == {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "device-001",
|
||||||
|
"device_name": "desk-terminal",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_message_frame_uses_message_id_and_text() -> None:
|
||||||
|
frame = build_message_frame(message_id="device-001-000001", text="hello")
|
||||||
|
|
||||||
|
assert frame == {
|
||||||
|
"type": "message",
|
||||||
|
"message_id": "device-001-000001",
|
||||||
|
"text": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
|
||||||
|
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
|
||||||
|
|
||||||
|
assert generator.next_id() == "device-001-000008"
|
||||||
|
assert generator.next_id() == "device-001-000009"
|
||||||
|
assert generator.counter == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id_generator_can_include_nonce() -> None:
|
||||||
|
generator = MessageIdGenerator(peer_id="device-001", nonce="run12345")
|
||||||
|
|
||||||
|
assert generator.next_id() == "device-001-run12345-000001"
|
||||||
|
assert generator.next_id() == "device-001-run12345-000002"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_connects_sends_text_and_returns_assistant_reply(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "assistant reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert client.session_id == "terminal-dev:local:device-001"
|
||||||
|
assert reply == "assistant reply"
|
||||||
|
assert received == [
|
||||||
|
{
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "device-001",
|
||||||
|
"device_name": "desk-terminal",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"message_id": "device-001-000001",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": False,
|
||||||
|
"duplicate": True,
|
||||||
|
"pending": False,
|
||||||
|
"reply": "cached assistant reply",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert reply == "cached assistant reply"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json({"type": "error", "error": "text is required"})
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
with pytest.raises(BeaverTerminalError, match="text is required"):
|
||||||
|
await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_cleans_up_owned_session_when_connect_fails(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.Response:
|
||||||
|
return web.Response(status=200, text="not a websocket")
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
|
||||||
|
await client.connect()
|
||||||
|
assert client._http_session is None
|
||||||
|
assert client._ws is None
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "failed turn",
|
||||||
|
"finish_reason": "error",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
with pytest.raises(BeaverTerminalError, match="failed turn"):
|
||||||
|
await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "ping":
|
||||||
|
await ws.send_json({"type": "pong"})
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
assert await client.ping()
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
connect_peer_ids: list[str] = []
|
||||||
|
message_ids: list[str] = []
|
||||||
|
connection_count = 0
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
nonlocal connection_count
|
||||||
|
connection_count += 1
|
||||||
|
current_connection = connection_count
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
connect_peer_ids.append(frame["peer_id"])
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
message_ids.append(frame["message_id"])
|
||||||
|
if current_connection == 1:
|
||||||
|
await ws.close()
|
||||||
|
continue
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-2",
|
||||||
|
"text": "reply after reconnect",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert reply == "reply after reconnect"
|
||||||
|
assert connect_peer_ids == ["device-001", "device-001"]
|
||||||
|
assert message_ids == ["device-001-000001", "device-001-000002"]
|
||||||
262
test_hermes_gateway.py
Normal file
262
test_hermes_gateway.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from custom.hermes_gateway import (
|
||||||
|
GatewaySessionState,
|
||||||
|
HermesGatewayLLM,
|
||||||
|
build_connect_params,
|
||||||
|
chat_context_to_gateway_messages,
|
||||||
|
extract_text_delta,
|
||||||
|
is_error_response,
|
||||||
|
is_terminal_event,
|
||||||
|
)
|
||||||
|
from livekit.agents import ChatContext, llm
|
||||||
|
from livekit.agents._exceptions import APIConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None:
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="system", content="system prompt")
|
||||||
|
ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")])
|
||||||
|
|
||||||
|
messages = chat_context_to_gateway_messages(ctx)
|
||||||
|
|
||||||
|
assert messages == [
|
||||||
|
{"role": "system", "content": "system prompt"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "look here"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None:
|
||||||
|
assert (
|
||||||
|
extract_text_delta(
|
||||||
|
{"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}}
|
||||||
|
)
|
||||||
|
== "hi"
|
||||||
|
)
|
||||||
|
assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there"
|
||||||
|
assert (
|
||||||
|
extract_text_delta(
|
||||||
|
{
|
||||||
|
"type": "event",
|
||||||
|
"event": "session.message.delta",
|
||||||
|
"payload": {"message": {"content": [{"type": "text", "text": "!"}]}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
== "!"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_per_room_session_state_reuses_stable_session_key() -> None:
|
||||||
|
state = GatewaySessionState(room_name="kitchen-room", agent_id="helper")
|
||||||
|
|
||||||
|
assert state.session_key == "livekit:kitchen-room:helper"
|
||||||
|
state.session_key = "gateway-session-123"
|
||||||
|
assert state.session_key == "gateway-session-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_connect_params_uses_backend_operator_defaults() -> None:
|
||||||
|
params = build_connect_params(token="secret-token")
|
||||||
|
|
||||||
|
assert params["client"] == {
|
||||||
|
"id": "gateway-client",
|
||||||
|
"version": "livekit-custom-agent",
|
||||||
|
"platform": "python",
|
||||||
|
"mode": "backend",
|
||||||
|
}
|
||||||
|
assert params["role"] == "operator"
|
||||||
|
assert params["scopes"] == ["operator.read", "operator.write"]
|
||||||
|
assert params["auth"] == {"token": "secret-token"}
|
||||||
|
assert "device" not in params
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_response_helpers_match_only_current_send_request() -> None:
|
||||||
|
assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1")
|
||||||
|
assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1")
|
||||||
|
assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1")
|
||||||
|
assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_llm_reports_provider_and_model() -> None:
|
||||||
|
state = GatewaySessionState(room_name="kitchen", agent_id="helper")
|
||||||
|
gateway_llm = HermesGatewayLLM(
|
||||||
|
url="ws://gateway.test/ws",
|
||||||
|
token="token",
|
||||||
|
state=state,
|
||||||
|
agent_id="helper",
|
||||||
|
model_name="hermes-agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert gateway_llm.provider == "hermes-gateway"
|
||||||
|
assert gateway_llm.model == "hermes-agent"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_session_state_rejects_non_per_room_mode() -> None:
|
||||||
|
with pytest.raises(ValueError, match="per_room"):
|
||||||
|
GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
payload = json.loads(message.data)
|
||||||
|
received.append(payload)
|
||||||
|
method = payload.get("method")
|
||||||
|
request_id = payload.get("id")
|
||||||
|
if method == "connect":
|
||||||
|
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||||
|
elif method == "sessions.create":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "res",
|
||||||
|
"id": request_id,
|
||||||
|
"ok": True,
|
||||||
|
"result": {"sessionKey": "livekit:kitchen:helper"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif method == "sessions.send":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "event",
|
||||||
|
"event": "agent",
|
||||||
|
"payload": {"delta": {"content": "你好"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "res",
|
||||||
|
"id": request_id,
|
||||||
|
"ok": True,
|
||||||
|
"result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
gateway_llm = HermesGatewayLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||||
|
token="secret-token",
|
||||||
|
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||||
|
agent_id="helper",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="杯子在哪里")
|
||||||
|
|
||||||
|
try:
|
||||||
|
collected = await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await gateway_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert collected.text == "你好"
|
||||||
|
assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"]
|
||||||
|
send_request = received[2]
|
||||||
|
assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper"
|
||||||
|
assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_delta_reads_final_message_content() -> None:
|
||||||
|
assert (
|
||||||
|
extract_text_delta(
|
||||||
|
{
|
||||||
|
"type": "event",
|
||||||
|
"event": "session.message.completed",
|
||||||
|
"payload": {
|
||||||
|
"message": {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "完整回复"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
== "完整回复"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_error_response_accepts_error_events() -> None:
|
||||||
|
assert is_error_response(
|
||||||
|
{"type": "event", "event": "agent.error", "payload": {"error": "boom"}},
|
||||||
|
request_id="send-1",
|
||||||
|
)
|
||||||
|
assert is_error_response(
|
||||||
|
{"type": "event", "event": "run.failed", "payload": {"message": "boom"}},
|
||||||
|
request_id="send-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_stream_maps_gateway_error_events_to_api_connection_error(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
payload = json.loads(message.data)
|
||||||
|
method = payload.get("method")
|
||||||
|
request_id = payload.get("id")
|
||||||
|
if method == "connect":
|
||||||
|
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||||
|
elif method == "sessions.create":
|
||||||
|
await ws.send_json({"type": "res", "id": request_id, "ok": True})
|
||||||
|
elif method == "sessions.send":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "event",
|
||||||
|
"event": "run.failed",
|
||||||
|
"payload": {"message": "gateway exploded"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
gateway_llm = HermesGatewayLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
|
||||||
|
token=None,
|
||||||
|
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
|
||||||
|
agent_id="helper",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="hello")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(APIConnectionError, match="gateway exploded"):
|
||||||
|
await gateway_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await gateway_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
31
tts.py
31
tts.py
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import wave
|
import wave
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@ -29,13 +29,13 @@ class BlackboxTTS(tts.TTS):
|
|||||||
*,
|
*,
|
||||||
url: str,
|
url: str,
|
||||||
model_name: str = "voxcpmtts",
|
model_name: str = "voxcpmtts",
|
||||||
params: Optional[Mapping[str, object]] = None,
|
params: Mapping[str, object] | None = None,
|
||||||
prompt_wav_path: Optional[str] = None,
|
prompt_wav_path: str | None = None,
|
||||||
prompt_wav_field: str = "prompt_wav",
|
prompt_wav_field: str = "prompt_wav",
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
http_session: Optional[aiohttp.ClientSession] = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
capabilities=tts.TTSCapabilities(streaming=False),
|
capabilities=tts.TTSCapabilities(streaming=False),
|
||||||
@ -88,6 +88,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
self._tts: BlackboxTTS = tts
|
self._tts: BlackboxTTS = tts
|
||||||
|
|
||||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||||
|
started_at = time.perf_counter()
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
form.add_field("text", self.input_text)
|
form.add_field("text", self.input_text)
|
||||||
form.add_field("model_name", self._tts._model_name)
|
form.add_field("model_name", self._tts._model_name)
|
||||||
@ -131,6 +132,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
content_type = resp.headers.get("Content-Type", "audio/wav")
|
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||||
logged_wav_format = False
|
logged_wav_format = False
|
||||||
wav_header_probe = bytearray()
|
wav_header_probe = bytearray()
|
||||||
|
first_audio_at: float | None = None
|
||||||
|
chunk_count = 0
|
||||||
|
total_bytes = 0
|
||||||
output_emitter.initialize(
|
output_emitter.initialize(
|
||||||
request_id=utils.shortuuid(),
|
request_id=utils.shortuuid(),
|
||||||
sample_rate=self._tts.sample_rate,
|
sample_rate=self._tts.sample_rate,
|
||||||
@ -140,6 +144,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
|
|
||||||
async for data, _ in resp.content.iter_chunks():
|
async for data, _ in resp.content.iter_chunks():
|
||||||
if data:
|
if data:
|
||||||
|
chunk_count += 1
|
||||||
|
total_bytes += len(data)
|
||||||
|
if first_audio_at is None:
|
||||||
|
first_audio_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
|
||||||
|
first_audio_at - started_at,
|
||||||
|
len(self.input_text),
|
||||||
|
len(data),
|
||||||
|
)
|
||||||
if not logged_wav_format:
|
if not logged_wav_format:
|
||||||
wav_header_probe.extend(data)
|
wav_header_probe.extend(data)
|
||||||
logged_wav_format = _log_wav_format(
|
logged_wav_format = _log_wav_format(
|
||||||
@ -156,6 +170,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
logged_wav_format = True
|
logged_wav_format = True
|
||||||
output_emitter.push(data)
|
output_emitter.push(data)
|
||||||
output_emitter.flush()
|
output_emitter.flush()
|
||||||
|
finished_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
|
||||||
|
finished_at - started_at,
|
||||||
|
(first_audio_at - started_at) if first_audio_at else -1.0,
|
||||||
|
chunk_count,
|
||||||
|
total_bytes,
|
||||||
|
len(self.input_text),
|
||||||
|
)
|
||||||
except asyncio.TimeoutError as e:
|
except asyncio.TimeoutError as e:
|
||||||
raise APITimeoutError("TTS blackbox request timed out") from e
|
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user