fix: fix voice interupt
This commit is contained in:
260
asr.py
260
asr.py
@ -1,11 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from livekit import rtc
|
||||
from livekit.agents import (
|
||||
DEFAULT_API_CONNECT_OPTIONS,
|
||||
NOT_GIVEN,
|
||||
APIConnectionError,
|
||||
APIConnectOptions,
|
||||
@ -17,9 +20,15 @@ from livekit.agents import (
|
||||
utils,
|
||||
)
|
||||
from livekit.agents.utils import is_given
|
||||
from livekit.agents.vad import VAD, VADEventType
|
||||
|
||||
logger = logging.getLogger("blackbox-asr")
|
||||
|
||||
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
|
||||
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
|
||||
)
|
||||
STT_CAPABILITIES = stt.STTCapabilities
|
||||
|
||||
|
||||
class BlackboxSTT(stt.STT):
|
||||
def __init__(
|
||||
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
|
||||
url: str,
|
||||
*,
|
||||
model_name: str = "sensevoice",
|
||||
language: Optional[str] = "auto",
|
||||
language: str | None = "auto",
|
||||
output_language: str = "zh",
|
||||
hotwords: Optional[str] = None,
|
||||
itn: Optional[Union[bool, str]] = None,
|
||||
chunk_mode: Optional[Union[bool, str]] = None,
|
||||
hotwords: str | None = None,
|
||||
itn: bool | str | None = None,
|
||||
chunk_mode: bool | str | None = None,
|
||||
timeout: float = 30.0,
|
||||
http_session: Optional[aiohttp.ClientSession] = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=stt.STTCapabilities(
|
||||
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
|
||||
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||
|
||||
|
||||
def _form_value(value: Union[bool, str]) -> str:
|
||||
def _form_value(value: bool | str) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
return value
|
||||
|
||||
|
||||
class BoundedStreamAdapter(stt.STT):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stt: stt.STT,
|
||||
vad: VAD,
|
||||
max_speech_duration: float | None = 12.0,
|
||||
pre_speech_duration: float = 0.5,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=STT_CAPABILITIES(
|
||||
streaming=True,
|
||||
interim_results=False,
|
||||
diarization=False,
|
||||
)
|
||||
)
|
||||
self._vad = vad
|
||||
self._stt = stt
|
||||
self._max_speech_duration = max_speech_duration
|
||||
self._pre_speech_duration = pre_speech_duration
|
||||
self._stt.on("metrics_collected", self._on_metrics_collected)
|
||||
|
||||
@property
|
||||
def wrapped_stt(self) -> stt.STT:
|
||||
return self._stt
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._stt.model
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self._stt.provider
|
||||
|
||||
async def _recognize_impl(
|
||||
self,
|
||||
buffer: utils.AudioBuffer,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> stt.SpeechEvent:
|
||||
return await self._stt.recognize(
|
||||
buffer=buffer, language=language, conn_options=conn_options
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> stt.RecognizeStream:
|
||||
return _BoundedStreamAdapterWrapper(
|
||||
self,
|
||||
vad=self._vad,
|
||||
wrapped_stt=self._stt,
|
||||
language=language,
|
||||
conn_options=conn_options,
|
||||
max_speech_duration=self._max_speech_duration,
|
||||
pre_speech_duration=self._pre_speech_duration,
|
||||
)
|
||||
|
||||
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.emit("metrics_collected", *args, **kwargs)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._stt.off("metrics_collected", self._on_metrics_collected)
|
||||
|
||||
|
||||
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: BoundedStreamAdapter,
|
||||
*,
|
||||
vad: VAD,
|
||||
wrapped_stt: stt.STT,
|
||||
language: NotGivenOr[str],
|
||||
conn_options: APIConnectOptions,
|
||||
max_speech_duration: float | None,
|
||||
pre_speech_duration: float,
|
||||
) -> None:
|
||||
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
|
||||
self._vad = vad
|
||||
self._wrapped_stt = wrapped_stt
|
||||
self._wrapped_stt_conn_options = conn_options
|
||||
self._language = language
|
||||
self._max_speech_duration = max_speech_duration
|
||||
self._pre_speech_duration = pre_speech_duration
|
||||
|
||||
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
|
||||
async for _ in event_aiter:
|
||||
pass
|
||||
|
||||
async def _run(self) -> None:
|
||||
vad_stream = self._vad.stream()
|
||||
lock = asyncio.Lock()
|
||||
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
|
||||
|
||||
speech_active = False
|
||||
segment_frames: list[rtc.AudioFrame] = []
|
||||
segment_duration = 0.0
|
||||
pre_roll_frames: deque[rtc.AudioFrame] = deque()
|
||||
pre_roll_duration = 0.0
|
||||
|
||||
def _frame_duration(frame: rtc.AudioFrame) -> float:
|
||||
return frame.samples_per_channel / frame.sample_rate
|
||||
|
||||
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
|
||||
nonlocal pre_roll_duration
|
||||
pre_roll_frames.append(frame)
|
||||
pre_roll_duration += _frame_duration(frame)
|
||||
|
||||
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
|
||||
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
|
||||
|
||||
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
|
||||
if not frames:
|
||||
return
|
||||
|
||||
if forced:
|
||||
logger.info(
|
||||
"Forcing ASR segment after %.2fs of continuous speech",
|
||||
self._max_speech_duration,
|
||||
)
|
||||
|
||||
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
||||
await recognize_queue.put(frames)
|
||||
|
||||
async def _recognize_worker() -> None:
|
||||
while True:
|
||||
frames = await recognize_queue.get()
|
||||
if frames is None:
|
||||
return
|
||||
|
||||
merged_frames = utils.merge_frames(frames)
|
||||
try:
|
||||
t_event = await self._wrapped_stt.recognize(
|
||||
buffer=merged_frames,
|
||||
language=self._language,
|
||||
conn_options=self._wrapped_stt_conn_options,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("ASR segment recognition failed")
|
||||
continue
|
||||
|
||||
if not t_event.alternatives or not t_event.alternatives[0].text:
|
||||
continue
|
||||
|
||||
self._event_ch.send_nowait(
|
||||
stt.SpeechEvent(
|
||||
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||
alternatives=[t_event.alternatives[0]],
|
||||
)
|
||||
)
|
||||
|
||||
async def _forward_input() -> None:
|
||||
nonlocal segment_duration, segment_frames, speech_active
|
||||
|
||||
async for input_frame in self._input_ch:
|
||||
if isinstance(input_frame, self._FlushSentinel):
|
||||
vad_stream.flush()
|
||||
continue
|
||||
|
||||
vad_stream.push_frame(input_frame)
|
||||
forced_frames: list[rtc.AudioFrame] = []
|
||||
|
||||
async with lock:
|
||||
if speech_active:
|
||||
segment_frames.append(input_frame)
|
||||
segment_duration += _frame_duration(input_frame)
|
||||
if (
|
||||
self._max_speech_duration is not None
|
||||
and segment_duration >= self._max_speech_duration
|
||||
):
|
||||
forced_frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
else:
|
||||
_append_pre_roll(input_frame)
|
||||
|
||||
if forced_frames:
|
||||
await _enqueue_segment(forced_frames, forced=True)
|
||||
|
||||
vad_stream.end_input()
|
||||
|
||||
final_frames: list[rtc.AudioFrame] = []
|
||||
async with lock:
|
||||
if speech_active and segment_frames:
|
||||
final_frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
speech_active = False
|
||||
|
||||
if final_frames:
|
||||
await _enqueue_segment(final_frames)
|
||||
|
||||
async def _recognize_from_vad() -> None:
|
||||
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
|
||||
|
||||
async for event in vad_stream:
|
||||
if event.type == VADEventType.START_OF_SPEECH:
|
||||
self._event_ch.send_nowait(
|
||||
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
|
||||
)
|
||||
async with lock:
|
||||
if not speech_active:
|
||||
speech_active = True
|
||||
segment_frames = list(pre_roll_frames)
|
||||
segment_duration = sum(_frame_duration(f) for f in segment_frames)
|
||||
pre_roll_frames.clear()
|
||||
pre_roll_duration = 0.0
|
||||
continue
|
||||
|
||||
if event.type != VADEventType.END_OF_SPEECH:
|
||||
continue
|
||||
|
||||
async with lock:
|
||||
frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
speech_active = False
|
||||
|
||||
await _enqueue_segment(frames)
|
||||
|
||||
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
|
||||
tasks = [
|
||||
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
|
||||
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await recognize_queue.put(None)
|
||||
await worker_task
|
||||
finally:
|
||||
await utils.aio.cancel_and_wait(*tasks, worker_task)
|
||||
await vad_stream.aclose()
|
||||
|
||||
Reference in New Issue
Block a user