155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Any, Optional, Union
|
|
|
|
import aiohttp
|
|
|
|
from livekit import rtc
|
|
from livekit.agents import (
|
|
NOT_GIVEN,
|
|
APIConnectionError,
|
|
APIConnectOptions,
|
|
APIStatusError,
|
|
APITimeoutError,
|
|
LanguageCode,
|
|
NotGivenOr,
|
|
stt,
|
|
utils,
|
|
)
|
|
from livekit.agents.utils import is_given
|
|
|
|
logger = logging.getLogger("blackbox-asr")
|
|
|
|
|
|
class BlackboxSTT(stt.STT):
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
*,
|
|
model_name: str = "sensevoice",
|
|
language: Optional[str] = "auto",
|
|
output_language: str = "zh",
|
|
hotwords: Optional[str] = None,
|
|
itn: Optional[Union[bool, str]] = None,
|
|
chunk_mode: Optional[Union[bool, str]] = None,
|
|
timeout: float = 30.0,
|
|
http_session: Optional[aiohttp.ClientSession] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
capabilities=stt.STTCapabilities(
|
|
streaming=False,
|
|
interim_results=False,
|
|
diarization=False,
|
|
)
|
|
)
|
|
self._url = url
|
|
self._model_name = model_name
|
|
self._language = language
|
|
self._output_language = output_language
|
|
self._timeout = timeout
|
|
self._http_session = http_session
|
|
self._extra_fields: dict[str, str] = {}
|
|
|
|
if hotwords:
|
|
self._extra_fields["hotwords"] = hotwords
|
|
if itn is not None:
|
|
self._extra_fields["itn"] = _form_value(itn)
|
|
if chunk_mode is not None:
|
|
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model_name
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return "asr-blackbox"
|
|
|
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
|
if self._http_session is None:
|
|
self._http_session = utils.http_context.http_session()
|
|
return self._http_session
|
|
|
|
async def _recognize_impl(
|
|
self,
|
|
buffer: utils.AudioBuffer,
|
|
*,
|
|
language: NotGivenOr[str] = NOT_GIVEN,
|
|
conn_options: APIConnectOptions,
|
|
) -> stt.SpeechEvent:
|
|
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
|
|
|
form = aiohttp.FormData()
|
|
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
|
|
form.add_field("model_name", self._model_name)
|
|
|
|
resolved_language = language if is_given(language) else self._language
|
|
if resolved_language:
|
|
form.add_field("language", resolved_language)
|
|
for key, value in self._extra_fields.items():
|
|
form.add_field(key, value)
|
|
|
|
try:
|
|
async with self._ensure_session().post(
|
|
self._url,
|
|
data=form,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self._timeout,
|
|
sock_connect=conn_options.timeout,
|
|
),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
raise APIStatusError(
|
|
message=f"ASR blackbox error: {error_text}",
|
|
status_code=resp.status,
|
|
request_id=None,
|
|
body=error_text,
|
|
)
|
|
|
|
payload = await resp.json()
|
|
logger.info("ASR blackbox raw result: %s", payload)
|
|
text = _extract_asr_text(payload)
|
|
if not text:
|
|
raise APIConnectionError("ASR blackbox returned an empty transcript")
|
|
|
|
logger.info("ASR blackbox result: %s", text)
|
|
return stt.SpeechEvent(
|
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
alternatives=[
|
|
stt.SpeechData(
|
|
text=text,
|
|
language=LanguageCode(self._output_language),
|
|
)
|
|
],
|
|
)
|
|
except asyncio.TimeoutError as e:
|
|
raise APITimeoutError("ASR blackbox request timed out") from e
|
|
except aiohttp.ClientError as e:
|
|
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
|
|
|
|
|
|
def _extract_asr_text(payload: dict[str, Any]) -> str:
|
|
text = payload.get("text")
|
|
if isinstance(text, str):
|
|
return text.strip()
|
|
|
|
result = payload.get("result")
|
|
if isinstance(result, list) and result:
|
|
first = result[0]
|
|
if isinstance(first, dict):
|
|
for key in ("clean_text", "text", "raw_text"):
|
|
value = first.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
if isinstance(first, str):
|
|
return first.strip()
|
|
|
|
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
|
|
|
|
|
def _form_value(value: Union[bool, str]) -> str:
|
|
if isinstance(value, bool):
|
|
return str(value).lower()
|
|
return value
|