fix: fix voice interupt

This commit is contained in:
0Xiao0
2026-06-12 11:17:12 +08:00
parent 78b9138c17
commit 820dc44053
8 changed files with 537 additions and 48 deletions

260
asr.py
View File

@ -1,11 +1,14 @@
import asyncio
import logging
from typing import Any, Optional, Union
from collections import deque
from collections.abc import AsyncIterable
from typing import Any
import aiohttp
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
@ -17,9 +20,15 @@ from livekit.agents import (
utils,
)
from livekit.agents.utils import is_given
from livekit.agents.vad import VAD, VADEventType
logger = logging.getLogger("blackbox-asr")
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
STT_CAPABILITIES = stt.STTCapabilities
class BlackboxSTT(stt.STT):
def __init__(
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
language: str | None = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
hotwords: str | None = None,
itn: bool | str | None = None,
chunk_mode: bool | str | None = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
def _form_value(value: bool | str) -> str:
if isinstance(value, bool):
return str(value).lower()
return value
class BoundedStreamAdapter(stt.STT):
def __init__(
self,
*,
stt: stt.STT,
vad: VAD,
max_speech_duration: float | None = 12.0,
pre_speech_duration: float = 0.5,
) -> None:
super().__init__(
capabilities=STT_CAPABILITIES(
streaming=True,
interim_results=False,
diarization=False,
)
)
self._vad = vad
self._stt = stt
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
self._stt.on("metrics_collected", self._on_metrics_collected)
@property
def wrapped_stt(self) -> stt.STT:
return self._stt
@property
def model(self) -> str:
return self._stt.model
@property
def provider(self) -> str:
return self._stt.provider
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
return _BoundedStreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
max_speech_duration=self._max_speech_duration,
pre_speech_duration=self._pre_speech_duration,
)
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
self.emit("metrics_collected", *args, **kwargs)
async def aclose(self) -> None:
self._stt.off("metrics_collected", self._on_metrics_collected)
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
def __init__(
self,
adapter: BoundedStreamAdapter,
*,
vad: VAD,
wrapped_stt: stt.STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
max_speech_duration: float | None,
pre_speech_duration: float,
) -> None:
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
async for _ in event_aiter:
pass
async def _run(self) -> None:
vad_stream = self._vad.stream()
lock = asyncio.Lock()
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
speech_active = False
segment_frames: list[rtc.AudioFrame] = []
segment_duration = 0.0
pre_roll_frames: deque[rtc.AudioFrame] = deque()
pre_roll_duration = 0.0
def _frame_duration(frame: rtc.AudioFrame) -> float:
return frame.samples_per_channel / frame.sample_rate
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
nonlocal pre_roll_duration
pre_roll_frames.append(frame)
pre_roll_duration += _frame_duration(frame)
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
if not frames:
return
if forced:
logger.info(
"Forcing ASR segment after %.2fs of continuous speech",
self._max_speech_duration,
)
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
await recognize_queue.put(frames)
async def _recognize_worker() -> None:
while True:
frames = await recognize_queue.get()
if frames is None:
return
merged_frames = utils.merge_frames(frames)
try:
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)
except Exception:
logger.exception("ASR segment recognition failed")
continue
if not t_event.alternatives or not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
async def _forward_input() -> None:
nonlocal segment_duration, segment_frames, speech_active
async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
vad_stream.flush()
continue
vad_stream.push_frame(input_frame)
forced_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active:
segment_frames.append(input_frame)
segment_duration += _frame_duration(input_frame)
if (
self._max_speech_duration is not None
and segment_duration >= self._max_speech_duration
):
forced_frames = segment_frames
segment_frames = []
segment_duration = 0.0
else:
_append_pre_roll(input_frame)
if forced_frames:
await _enqueue_segment(forced_frames, forced=True)
vad_stream.end_input()
final_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active and segment_frames:
final_frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
if final_frames:
await _enqueue_segment(final_frames)
async def _recognize_from_vad() -> None:
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
)
async with lock:
if not speech_active:
speech_active = True
segment_frames = list(pre_roll_frames)
segment_duration = sum(_frame_duration(f) for f in segment_frames)
pre_roll_frames.clear()
pre_roll_duration = 0.0
continue
if event.type != VADEventType.END_OF_SPEECH:
continue
async with lock:
frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
await _enqueue_segment(frames)
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
tasks = [
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
]
try:
await asyncio.gather(*tasks)
await recognize_queue.put(None)
await worker_task
finally:
await utils.aio.cancel_and_wait(*tasks, worker_task)
await vad_stream.aclose()