401 lines
14 KiB
Python
401 lines
14 KiB
Python
import asyncio
|
|
import logging
|
|
from collections import deque
|
|
from collections.abc import AsyncIterable
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
|
|
from livekit import rtc
|
|
from livekit.agents import (
|
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
NOT_GIVEN,
|
|
APIConnectionError,
|
|
APIConnectOptions,
|
|
APIStatusError,
|
|
APITimeoutError,
|
|
LanguageCode,
|
|
NotGivenOr,
|
|
stt,
|
|
utils,
|
|
)
|
|
from livekit.agents.utils import is_given
|
|
from livekit.agents.vad import VAD, VADEventType
|
|
|
|
logger = logging.getLogger("blackbox-asr")
|
|
|
|
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
|
|
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
|
|
)
|
|
STT_CAPABILITIES = stt.STTCapabilities
|
|
|
|
|
|
class BlackboxSTT(stt.STT):
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
*,
|
|
model_name: str = "sensevoice",
|
|
language: str | None = "auto",
|
|
output_language: str = "zh",
|
|
hotwords: str | None = None,
|
|
itn: bool | str | None = None,
|
|
chunk_mode: bool | str | None = None,
|
|
timeout: float = 30.0,
|
|
http_session: aiohttp.ClientSession | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
capabilities=stt.STTCapabilities(
|
|
streaming=False,
|
|
interim_results=False,
|
|
diarization=False,
|
|
)
|
|
)
|
|
self._url = url
|
|
self._model_name = model_name
|
|
self._language = language
|
|
self._output_language = output_language
|
|
self._timeout = timeout
|
|
self._http_session = http_session
|
|
self._extra_fields: dict[str, str] = {}
|
|
|
|
if hotwords:
|
|
self._extra_fields["hotwords"] = hotwords
|
|
if itn is not None:
|
|
self._extra_fields["itn"] = _form_value(itn)
|
|
if chunk_mode is not None:
|
|
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model_name
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return "asr-blackbox"
|
|
|
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
|
if self._http_session is None:
|
|
self._http_session = utils.http_context.http_session()
|
|
return self._http_session
|
|
|
|
async def _recognize_impl(
|
|
self,
|
|
buffer: utils.AudioBuffer,
|
|
*,
|
|
language: NotGivenOr[str] = NOT_GIVEN,
|
|
conn_options: APIConnectOptions,
|
|
) -> stt.SpeechEvent:
|
|
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
|
|
|
form = aiohttp.FormData()
|
|
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
|
|
form.add_field("model_name", self._model_name)
|
|
|
|
resolved_language = language if is_given(language) else self._language
|
|
if resolved_language:
|
|
form.add_field("language", resolved_language)
|
|
for key, value in self._extra_fields.items():
|
|
form.add_field(key, value)
|
|
|
|
try:
|
|
async with self._ensure_session().post(
|
|
self._url,
|
|
data=form,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self._timeout,
|
|
sock_connect=conn_options.timeout,
|
|
),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
raise APIStatusError(
|
|
message=f"ASR blackbox error: {error_text}",
|
|
status_code=resp.status,
|
|
request_id=None,
|
|
body=error_text,
|
|
)
|
|
|
|
payload = await resp.json()
|
|
logger.info("ASR blackbox raw result: %s", payload)
|
|
text = _extract_asr_text(payload)
|
|
if not text:
|
|
raise APIConnectionError("ASR blackbox returned an empty transcript")
|
|
|
|
logger.info("ASR blackbox result: %s", text)
|
|
return stt.SpeechEvent(
|
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
alternatives=[
|
|
stt.SpeechData(
|
|
text=text,
|
|
language=LanguageCode(self._output_language),
|
|
)
|
|
],
|
|
)
|
|
except asyncio.TimeoutError as e:
|
|
raise APITimeoutError("ASR blackbox request timed out") from e
|
|
except aiohttp.ClientError as e:
|
|
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
|
|
|
|
|
|
def _extract_asr_text(payload: dict[str, Any]) -> str:
|
|
text = payload.get("text")
|
|
if isinstance(text, str):
|
|
return text.strip()
|
|
|
|
result = payload.get("result")
|
|
if isinstance(result, list) and result:
|
|
first = result[0]
|
|
if isinstance(first, dict):
|
|
for key in ("clean_text", "text", "raw_text"):
|
|
value = first.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
if isinstance(first, str):
|
|
return first.strip()
|
|
|
|
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
|
|
|
|
|
def _form_value(value: bool | str) -> str:
|
|
if isinstance(value, bool):
|
|
return str(value).lower()
|
|
return value
|
|
|
|
|
|
class BoundedStreamAdapter(stt.STT):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
stt: stt.STT,
|
|
vad: VAD,
|
|
max_speech_duration: float | None = 12.0,
|
|
pre_speech_duration: float = 0.5,
|
|
) -> None:
|
|
super().__init__(
|
|
capabilities=STT_CAPABILITIES(
|
|
streaming=True,
|
|
interim_results=False,
|
|
diarization=False,
|
|
)
|
|
)
|
|
self._vad = vad
|
|
self._stt = stt
|
|
self._max_speech_duration = max_speech_duration
|
|
self._pre_speech_duration = pre_speech_duration
|
|
self._stt.on("metrics_collected", self._on_metrics_collected)
|
|
|
|
@property
|
|
def wrapped_stt(self) -> stt.STT:
|
|
return self._stt
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._stt.model
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return self._stt.provider
|
|
|
|
async def _recognize_impl(
|
|
self,
|
|
buffer: utils.AudioBuffer,
|
|
*,
|
|
language: NotGivenOr[str] = NOT_GIVEN,
|
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
) -> stt.SpeechEvent:
|
|
return await self._stt.recognize(
|
|
buffer=buffer, language=language, conn_options=conn_options
|
|
)
|
|
|
|
def stream(
|
|
self,
|
|
*,
|
|
language: NotGivenOr[str] = NOT_GIVEN,
|
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
) -> stt.RecognizeStream:
|
|
return _BoundedStreamAdapterWrapper(
|
|
self,
|
|
vad=self._vad,
|
|
wrapped_stt=self._stt,
|
|
language=language,
|
|
conn_options=conn_options,
|
|
max_speech_duration=self._max_speech_duration,
|
|
pre_speech_duration=self._pre_speech_duration,
|
|
)
|
|
|
|
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
|
|
self.emit("metrics_collected", *args, **kwargs)
|
|
|
|
async def aclose(self) -> None:
|
|
self._stt.off("metrics_collected", self._on_metrics_collected)
|
|
|
|
|
|
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
|
|
def __init__(
|
|
self,
|
|
adapter: BoundedStreamAdapter,
|
|
*,
|
|
vad: VAD,
|
|
wrapped_stt: stt.STT,
|
|
language: NotGivenOr[str],
|
|
conn_options: APIConnectOptions,
|
|
max_speech_duration: float | None,
|
|
pre_speech_duration: float,
|
|
) -> None:
|
|
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
|
|
self._vad = vad
|
|
self._wrapped_stt = wrapped_stt
|
|
self._wrapped_stt_conn_options = conn_options
|
|
self._language = language
|
|
self._max_speech_duration = max_speech_duration
|
|
self._pre_speech_duration = pre_speech_duration
|
|
|
|
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
|
|
async for _ in event_aiter:
|
|
pass
|
|
|
|
async def _run(self) -> None:
|
|
vad_stream = self._vad.stream()
|
|
lock = asyncio.Lock()
|
|
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
|
|
|
|
speech_active = False
|
|
segment_frames: list[rtc.AudioFrame] = []
|
|
segment_duration = 0.0
|
|
pre_roll_frames: deque[rtc.AudioFrame] = deque()
|
|
pre_roll_duration = 0.0
|
|
|
|
def _frame_duration(frame: rtc.AudioFrame) -> float:
|
|
return frame.samples_per_channel / frame.sample_rate
|
|
|
|
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
|
|
nonlocal pre_roll_duration
|
|
pre_roll_frames.append(frame)
|
|
pre_roll_duration += _frame_duration(frame)
|
|
|
|
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
|
|
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
|
|
|
|
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
|
|
if not frames:
|
|
return
|
|
|
|
if forced:
|
|
logger.info(
|
|
"Forcing ASR segment after %.2fs of continuous speech",
|
|
self._max_speech_duration,
|
|
)
|
|
|
|
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
|
await recognize_queue.put(frames)
|
|
|
|
async def _recognize_worker() -> None:
|
|
while True:
|
|
frames = await recognize_queue.get()
|
|
if frames is None:
|
|
return
|
|
|
|
merged_frames = utils.merge_frames(frames)
|
|
try:
|
|
t_event = await self._wrapped_stt.recognize(
|
|
buffer=merged_frames,
|
|
language=self._language,
|
|
conn_options=self._wrapped_stt_conn_options,
|
|
)
|
|
except Exception:
|
|
logger.exception("ASR segment recognition failed")
|
|
continue
|
|
|
|
if not t_event.alternatives or not t_event.alternatives[0].text:
|
|
continue
|
|
|
|
self._event_ch.send_nowait(
|
|
stt.SpeechEvent(
|
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
alternatives=[t_event.alternatives[0]],
|
|
)
|
|
)
|
|
|
|
async def _forward_input() -> None:
|
|
nonlocal segment_duration, segment_frames, speech_active
|
|
|
|
async for input_frame in self._input_ch:
|
|
if isinstance(input_frame, self._FlushSentinel):
|
|
vad_stream.flush()
|
|
continue
|
|
|
|
vad_stream.push_frame(input_frame)
|
|
forced_frames: list[rtc.AudioFrame] = []
|
|
|
|
async with lock:
|
|
if speech_active:
|
|
segment_frames.append(input_frame)
|
|
segment_duration += _frame_duration(input_frame)
|
|
if (
|
|
self._max_speech_duration is not None
|
|
and segment_duration >= self._max_speech_duration
|
|
):
|
|
forced_frames = segment_frames
|
|
segment_frames = []
|
|
segment_duration = 0.0
|
|
else:
|
|
_append_pre_roll(input_frame)
|
|
|
|
if forced_frames:
|
|
await _enqueue_segment(forced_frames, forced=True)
|
|
|
|
vad_stream.end_input()
|
|
|
|
final_frames: list[rtc.AudioFrame] = []
|
|
async with lock:
|
|
if speech_active and segment_frames:
|
|
final_frames = segment_frames
|
|
segment_frames = []
|
|
segment_duration = 0.0
|
|
speech_active = False
|
|
|
|
if final_frames:
|
|
await _enqueue_segment(final_frames)
|
|
|
|
async def _recognize_from_vad() -> None:
|
|
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
|
|
|
|
async for event in vad_stream:
|
|
if event.type == VADEventType.START_OF_SPEECH:
|
|
self._event_ch.send_nowait(
|
|
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
|
|
)
|
|
async with lock:
|
|
if not speech_active:
|
|
speech_active = True
|
|
segment_frames = list(pre_roll_frames)
|
|
segment_duration = sum(_frame_duration(f) for f in segment_frames)
|
|
pre_roll_frames.clear()
|
|
pre_roll_duration = 0.0
|
|
continue
|
|
|
|
if event.type != VADEventType.END_OF_SPEECH:
|
|
continue
|
|
|
|
async with lock:
|
|
frames = segment_frames
|
|
segment_frames = []
|
|
segment_duration = 0.0
|
|
speech_active = False
|
|
|
|
await _enqueue_segment(frames)
|
|
|
|
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
|
|
tasks = [
|
|
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
|
|
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
|
|
]
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
await recognize_queue.put(None)
|
|
await worker_task
|
|
finally:
|
|
await utils.aio.cancel_and_wait(*tasks, worker_task)
|
|
await vad_stream.aclose()
|