Files
livekit_agents/asr.py
2026-06-12 11:17:12 +08:00

401 lines
14 KiB
Python

import asyncio
import logging
from collections import deque
from collections.abc import AsyncIterable
from typing import Any
import aiohttp
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
LanguageCode,
NotGivenOr,
stt,
utils,
)
from livekit.agents.utils import is_given
from livekit.agents.vad import VAD, VADEventType
logger = logging.getLogger("blackbox-asr")
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
STT_CAPABILITIES = stt.STTCapabilities
class BlackboxSTT(stt.STT):
def __init__(
self,
url: str,
*,
model_name: str = "sensevoice",
language: str | None = "auto",
output_language: str = "zh",
hotwords: str | None = None,
itn: bool | str | None = None,
chunk_mode: bool | str | None = None,
timeout: float = 30.0,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
streaming=False,
interim_results=False,
diarization=False,
)
)
self._url = url
self._model_name = model_name
self._language = language
self._output_language = output_language
self._timeout = timeout
self._http_session = http_session
self._extra_fields: dict[str, str] = {}
if hotwords:
self._extra_fields["hotwords"] = hotwords
if itn is not None:
self._extra_fields["itn"] = _form_value(itn)
if chunk_mode is not None:
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "asr-blackbox"
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
form = aiohttp.FormData()
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
form.add_field("model_name", self._model_name)
resolved_language = language if is_given(language) else self._language
if resolved_language:
form.add_field("language", resolved_language)
for key, value in self._extra_fields.items():
form.add_field(key, value)
try:
async with self._ensure_session().post(
self._url,
data=form,
timeout=aiohttp.ClientTimeout(
total=self._timeout,
sock_connect=conn_options.timeout,
),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"ASR blackbox error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
payload = await resp.json()
logger.info("ASR blackbox raw result: %s", payload)
text = _extract_asr_text(payload)
if not text:
raise APIConnectionError("ASR blackbox returned an empty transcript")
logger.info("ASR blackbox result: %s", text)
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
text=text,
language=LanguageCode(self._output_language),
)
],
)
except asyncio.TimeoutError as e:
raise APITimeoutError("ASR blackbox request timed out") from e
except aiohttp.ClientError as e:
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
def _extract_asr_text(payload: dict[str, Any]) -> str:
text = payload.get("text")
if isinstance(text, str):
return text.strip()
result = payload.get("result")
if isinstance(result, list) and result:
first = result[0]
if isinstance(first, dict):
for key in ("clean_text", "text", "raw_text"):
value = first.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(first, str):
return first.strip()
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: bool | str) -> str:
if isinstance(value, bool):
return str(value).lower()
return value
class BoundedStreamAdapter(stt.STT):
def __init__(
self,
*,
stt: stt.STT,
vad: VAD,
max_speech_duration: float | None = 12.0,
pre_speech_duration: float = 0.5,
) -> None:
super().__init__(
capabilities=STT_CAPABILITIES(
streaming=True,
interim_results=False,
diarization=False,
)
)
self._vad = vad
self._stt = stt
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
self._stt.on("metrics_collected", self._on_metrics_collected)
@property
def wrapped_stt(self) -> stt.STT:
return self._stt
@property
def model(self) -> str:
return self._stt.model
@property
def provider(self) -> str:
return self._stt.provider
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
return _BoundedStreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
max_speech_duration=self._max_speech_duration,
pre_speech_duration=self._pre_speech_duration,
)
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
self.emit("metrics_collected", *args, **kwargs)
async def aclose(self) -> None:
self._stt.off("metrics_collected", self._on_metrics_collected)
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
def __init__(
self,
adapter: BoundedStreamAdapter,
*,
vad: VAD,
wrapped_stt: stt.STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
max_speech_duration: float | None,
pre_speech_duration: float,
) -> None:
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
async for _ in event_aiter:
pass
async def _run(self) -> None:
vad_stream = self._vad.stream()
lock = asyncio.Lock()
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
speech_active = False
segment_frames: list[rtc.AudioFrame] = []
segment_duration = 0.0
pre_roll_frames: deque[rtc.AudioFrame] = deque()
pre_roll_duration = 0.0
def _frame_duration(frame: rtc.AudioFrame) -> float:
return frame.samples_per_channel / frame.sample_rate
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
nonlocal pre_roll_duration
pre_roll_frames.append(frame)
pre_roll_duration += _frame_duration(frame)
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
if not frames:
return
if forced:
logger.info(
"Forcing ASR segment after %.2fs of continuous speech",
self._max_speech_duration,
)
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
await recognize_queue.put(frames)
async def _recognize_worker() -> None:
while True:
frames = await recognize_queue.get()
if frames is None:
return
merged_frames = utils.merge_frames(frames)
try:
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)
except Exception:
logger.exception("ASR segment recognition failed")
continue
if not t_event.alternatives or not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
async def _forward_input() -> None:
nonlocal segment_duration, segment_frames, speech_active
async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
vad_stream.flush()
continue
vad_stream.push_frame(input_frame)
forced_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active:
segment_frames.append(input_frame)
segment_duration += _frame_duration(input_frame)
if (
self._max_speech_duration is not None
and segment_duration >= self._max_speech_duration
):
forced_frames = segment_frames
segment_frames = []
segment_duration = 0.0
else:
_append_pre_roll(input_frame)
if forced_frames:
await _enqueue_segment(forced_frames, forced=True)
vad_stream.end_input()
final_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active and segment_frames:
final_frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
if final_frames:
await _enqueue_segment(final_frames)
async def _recognize_from_vad() -> None:
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
)
async with lock:
if not speech_active:
speech_active = True
segment_frames = list(pre_roll_frames)
segment_duration = sum(_frame_duration(f) for f in segment_frames)
pre_roll_frames.clear()
pre_roll_duration = 0.0
continue
if event.type != VADEventType.END_OF_SPEECH:
continue
async with lock:
frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
await _enqueue_segment(frames)
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
tasks = [
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
]
try:
await asyncio.gather(*tasks)
await recognize_queue.put(None)
await worker_task
finally:
await utils.aio.cancel_and_wait(*tasks, worker_task)
await vad_stream.aclose()