fix: fix voice interupt

This commit is contained in:
0Xiao0
2026-06-12 11:17:12 +08:00
parent 78b9138c17
commit 820dc44053
8 changed files with 537 additions and 48 deletions

View File

@ -20,6 +20,10 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
CUSTOM_ASR_MAX_SPEECH_DURATION=12
# Keep false if forced ASR turns should reply even while input audio continues.
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.

260
asr.py
View File

@ -1,11 +1,14 @@
import asyncio
import logging
from typing import Any, Optional, Union
from collections import deque
from collections.abc import AsyncIterable
from typing import Any
import aiohttp
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
@ -17,9 +20,15 @@ from livekit.agents import (
utils,
)
from livekit.agents.utils import is_given
from livekit.agents.vad import VAD, VADEventType
logger = logging.getLogger("blackbox-asr")
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
STT_CAPABILITIES = stt.STTCapabilities
class BlackboxSTT(stt.STT):
def __init__(
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
language: str | None = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
hotwords: str | None = None,
itn: bool | str | None = None,
chunk_mode: bool | str | None = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
def _form_value(value: bool | str) -> str:
if isinstance(value, bool):
return str(value).lower()
return value
class BoundedStreamAdapter(stt.STT):
def __init__(
self,
*,
stt: stt.STT,
vad: VAD,
max_speech_duration: float | None = 12.0,
pre_speech_duration: float = 0.5,
) -> None:
super().__init__(
capabilities=STT_CAPABILITIES(
streaming=True,
interim_results=False,
diarization=False,
)
)
self._vad = vad
self._stt = stt
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
self._stt.on("metrics_collected", self._on_metrics_collected)
@property
def wrapped_stt(self) -> stt.STT:
return self._stt
@property
def model(self) -> str:
return self._stt.model
@property
def provider(self) -> str:
return self._stt.provider
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
return _BoundedStreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
max_speech_duration=self._max_speech_duration,
pre_speech_duration=self._pre_speech_duration,
)
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
self.emit("metrics_collected", *args, **kwargs)
async def aclose(self) -> None:
self._stt.off("metrics_collected", self._on_metrics_collected)
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
def __init__(
self,
adapter: BoundedStreamAdapter,
*,
vad: VAD,
wrapped_stt: stt.STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
max_speech_duration: float | None,
pre_speech_duration: float,
) -> None:
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
async for _ in event_aiter:
pass
async def _run(self) -> None:
vad_stream = self._vad.stream()
lock = asyncio.Lock()
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
speech_active = False
segment_frames: list[rtc.AudioFrame] = []
segment_duration = 0.0
pre_roll_frames: deque[rtc.AudioFrame] = deque()
pre_roll_duration = 0.0
def _frame_duration(frame: rtc.AudioFrame) -> float:
return frame.samples_per_channel / frame.sample_rate
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
nonlocal pre_roll_duration
pre_roll_frames.append(frame)
pre_roll_duration += _frame_duration(frame)
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
if not frames:
return
if forced:
logger.info(
"Forcing ASR segment after %.2fs of continuous speech",
self._max_speech_duration,
)
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
await recognize_queue.put(frames)
async def _recognize_worker() -> None:
while True:
frames = await recognize_queue.get()
if frames is None:
return
merged_frames = utils.merge_frames(frames)
try:
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)
except Exception:
logger.exception("ASR segment recognition failed")
continue
if not t_event.alternatives or not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
async def _forward_input() -> None:
nonlocal segment_duration, segment_frames, speech_active
async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
vad_stream.flush()
continue
vad_stream.push_frame(input_frame)
forced_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active:
segment_frames.append(input_frame)
segment_duration += _frame_duration(input_frame)
if (
self._max_speech_duration is not None
and segment_duration >= self._max_speech_duration
):
forced_frames = segment_frames
segment_frames = []
segment_duration = 0.0
else:
_append_pre_roll(input_frame)
if forced_frames:
await _enqueue_segment(forced_frames, forced=True)
vad_stream.end_input()
final_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active and segment_frames:
final_frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
if final_frames:
await _enqueue_segment(final_frames)
async def _recognize_from_vad() -> None:
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
)
async with lock:
if not speech_active:
speech_active = True
segment_frames = list(pre_roll_frames)
segment_duration = sum(_frame_duration(f) for f in segment_frames)
pre_roll_frames.clear()
pre_roll_duration = 0.0
continue
if event.type != VADEventType.END_OF_SPEECH:
continue
async with lock:
frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
await _enqueue_segment(frames)
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
tasks = [
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
]
try:
await asyncio.gather(*tasks)
await recognize_queue.put(None)
await worker_task
finally:
await utils.aio.cancel_and_wait(*tasks, worker_task)
await vad_stream.aclose()

View File

@ -109,6 +109,7 @@ class BeaverLLMStream(llm.LLMStream):
) -> None:
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._beaver_llm = beaver_llm
self._request_id = shortuuid("beaver_")
async def _run(self) -> None:
user_text = latest_user_text(self.chat_ctx)
@ -116,9 +117,12 @@ class BeaverLLMStream(llm.LLMStream):
reply = await self._beaver_llm._client.send_text(user_text)
if reply:
self._send_text_chunk(reply)
def _send_text_chunk(self, text: str) -> None:
self._event_ch.send_nowait(
llm.ChatChunk(
id=shortuuid("beaver_"),
delta=llm.ChoiceDelta(role="assistant", content=reply),
id=self._request_id,
delta=llm.ChoiceDelta(role="assistant", content=text),
)
)

View File

@ -28,6 +28,7 @@ class BeaverTerminalConnectionClosed(BeaverTerminalError):
@dataclass
class MessageIdGenerator:
peer_id: str
nonce: str | None = None
initial_counter: int = 0
def __post_init__(self) -> None:
@ -35,6 +36,8 @@ class MessageIdGenerator:
def next_id(self) -> str:
self.counter += 1
if self.nonce:
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
@ -77,6 +80,7 @@ class BeaverTerminalClient:
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
try:
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
@ -86,6 +90,12 @@ class BeaverTerminalClient:
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
await self._cleanup_failed_connection()
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
except Exception:
await self._cleanup_failed_connection()
raise
async def send_text(self, text: str) -> str:
for attempt in range(2):
@ -153,6 +163,13 @@ class BeaverTerminalClient:
await self._ws.close()
self._ws = None
async def _cleanup_failed_connection(self) -> None:
await self._close_websocket()
self.session_id = None
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import json
import logging
@ -9,12 +10,13 @@ from dataclasses import dataclass
from pathlib import Path
from beaver_llm import BeaverLLM
from beaver_terminal_client import BeaverTerminalError
from dotenv import load_dotenv
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from asr import BlackboxSTT, BoundedStreamAdapter
from livekit.agents import (
Agent,
AgentServer,
@ -32,7 +34,6 @@ from livekit.agents import (
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
@ -61,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
@ -70,6 +71,7 @@ GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_AGENT_PROFILE = "normal"
@ -163,14 +165,27 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"confident",
"confused",
"cool",
"crying",
"delicious",
"embarrassed",
"funny",
"happy",
"kissy",
"laughing",
"loving",
"neutral",
"relaxed",
"sad",
"shocked",
"silly",
"sleepy",
"surprised",
"fearful",
"calm",
"concerned",
"thinking",
"winking",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
@ -380,6 +395,7 @@ class CustomAgent(Agent):
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
@ -635,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
def _with_vision_as_latest_user_message(
chat_ctx: ChatContext, vision_frame: VisionFrame
) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
@ -742,6 +760,7 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
@ -749,7 +768,9 @@ async def entrypoint(ctx: JobContext) -> None:
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
INPUT_MODE = _normalize_input_mode(
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
)
if LLM_PROVIDER not in {
"openai",
"openai-compatible",
@ -782,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
beaver_warmup_text: str | None = None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -792,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
stt_stream = BoundedStreamAdapter(
stt=blackbox_stt,
vad=ctx.proc.userdata["vad"],
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
)
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
allow_interruptions = _env_bool(
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
ASR_MAX_SPEECH_DURATION <= 0,
)
if LLM_PROVIDER == "beaver":
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
if not beaver_url:
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
raise RuntimeError(
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
)
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
beaver_peer_id = (
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
)
beaver_device_name = (
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
or "livekit-custom-agent"
@ -811,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
)
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
beaver_url,
beaver_peer_id,
beaver_device_name,
ctx.room.name,
base_llm.session_id,
bool(beaver_warmup_text and beaver_warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
@ -914,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
turn_detection=turn_detection,
interruption={
"enabled": allow_interruptions,
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
@ -946,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
def _interrupt_done(fut: asyncio.Future[None]) -> None:
try:
fut.result()
except Exception:
logger.exception("Bridge interrupt failed")
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
reason = None if payload is None else payload.get("reason")
logger.info(
"Received bridge interrupt: topic=%s reason=%s",
packet_topic or "<default>",
reason if isinstance(reason, str) and reason else "<unspecified>",
)
try:
interrupt_fut = session.interrupt(force=True)
except RuntimeError:
logger.exception("Bridge interrupt received before AgentSession was running")
return
interrupt_fut.add_done_callback(_interrupt_done)
if packet_topic == INTERRUPT_TOPIC:
payload: dict[str, object] | None = None
try:
decoded = json.loads(data_packet.data.decode("utf-8"))
if isinstance(decoded, dict):
payload = decoded
except Exception:
logger.exception("Failed to decode interrupt payload")
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
logger.exception("Failed to decode data payload")
return
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
@ -1013,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
),
record=_recording_options_from_env(),
)
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
_start_beaver_background_warmup(
ctx=ctx,
beaver_llm=base_llm,
warmup_text=beaver_warmup_text,
)
def _start_beaver_background_warmup(
*,
ctx: JobContext,
beaver_llm: BeaverLLM,
warmup_text: str | None,
) -> None:
async def _warmup() -> None:
try:
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
except BeaverTerminalError:
logger.warning(
"Beaver background handshake failed; will retry on first user turn",
exc_info=True,
)
return
except Exception:
logger.exception("Unexpected Beaver background handshake failure")
return
logger.info(
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
ctx.room.name,
beaver_llm.session_id,
bool(warmup_text and warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
async def _cancel_warmup() -> None:
if warmup_task.done():
return
warmup_task.cancel()
try:
await warmup_task
except asyncio.CancelledError:
pass
ctx.add_shutdown_callback(_cancel_warmup)
def _tts_params_from_env(model_name: str) -> dict[str, str]:

View File

@ -1,3 +1,4 @@
import asyncio
import json
import aiohttp
@ -96,3 +97,73 @@ async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await asyncio.sleep(0.05)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello beaver")
try:
chunks: list[str] = []
async with beaver_llm.chat(chat_ctx=ctx) as stream:
async for chunk in stream:
if chunk.delta and chunk.delta.content:
chunks.append(chunk.delta.content)
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert chunks == ["beaver reply"]

View File

@ -14,6 +14,7 @@ if __name__ == "__main__":
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
@ -22,6 +23,7 @@ try:
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
@ -244,6 +246,36 @@ async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
await runner.cleanup()
async def test_client_cleans_up_owned_session_when_connect_fails(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.Response:
return web.Response(status=200, text="not a websocket")
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
await client.connect()
assert client._http_session is None
assert client._ws is None
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:

7
tts.py
View File

@ -7,7 +7,6 @@ import time
import wave
from collections.abc import Mapping
from io import BytesIO
from typing import Optional
import aiohttp
@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS):
*,
url: str,
model_name: str = "voxcpmtts",
params: Optional[Mapping[str, object]] = None,
prompt_wav_path: Optional[str] = None,
params: Mapping[str, object] | None = None,
prompt_wav_path: str | None = None,
prompt_wav_field: str = "prompt_wav",
sample_rate: int = 16000,
num_channels: int = 1,
timeout: float = 60.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),