Files
livekit_agents/asr.py
2026-05-11 11:22:01 +08:00

155 lines
4.9 KiB
Python

import asyncio
import logging
from typing import Any, Optional, Union
import aiohttp
from livekit import rtc
from livekit.agents import (
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
LanguageCode,
NotGivenOr,
stt,
utils,
)
from livekit.agents.utils import is_given
logger = logging.getLogger("blackbox-asr")
class BlackboxSTT(stt.STT):
def __init__(
self,
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
streaming=False,
interim_results=False,
diarization=False,
)
)
self._url = url
self._model_name = model_name
self._language = language
self._output_language = output_language
self._timeout = timeout
self._http_session = http_session
self._extra_fields: dict[str, str] = {}
if hotwords:
self._extra_fields["hotwords"] = hotwords
if itn is not None:
self._extra_fields["itn"] = _form_value(itn)
if chunk_mode is not None:
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "asr-blackbox"
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
form = aiohttp.FormData()
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
form.add_field("model_name", self._model_name)
resolved_language = language if is_given(language) else self._language
if resolved_language:
form.add_field("language", resolved_language)
for key, value in self._extra_fields.items():
form.add_field(key, value)
try:
async with self._ensure_session().post(
self._url,
data=form,
timeout=aiohttp.ClientTimeout(
total=self._timeout,
sock_connect=conn_options.timeout,
),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"ASR blackbox error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
payload = await resp.json()
logger.info("ASR blackbox raw result: %s", payload)
text = _extract_asr_text(payload)
if not text:
raise APIConnectionError("ASR blackbox returned an empty transcript")
logger.info("ASR blackbox result: %s", text)
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
text=text,
language=LanguageCode(self._output_language),
)
],
)
except asyncio.TimeoutError as e:
raise APITimeoutError("ASR blackbox request timed out") from e
except aiohttp.ClientError as e:
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
def _extract_asr_text(payload: dict[str, Any]) -> str:
text = payload.get("text")
if isinstance(text, str):
return text.strip()
result = payload.get("result")
if isinstance(result, list) and result:
first = result[0]
if isinstance(first, dict):
for key in ("clean_text", "text", "raw_text"):
value = first.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(first, str):
return first.strip()
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
if isinstance(value, bool):
return str(value).lower()
return value