feat: supported different models
This commit is contained in:
154
asr.py
Normal file
154
asr.py
Normal file
@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from livekit import rtc
|
||||
from livekit.agents import (
|
||||
NOT_GIVEN,
|
||||
APIConnectionError,
|
||||
APIConnectOptions,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
LanguageCode,
|
||||
NotGivenOr,
|
||||
stt,
|
||||
utils,
|
||||
)
|
||||
from livekit.agents.utils import is_given
|
||||
|
||||
logger = logging.getLogger("blackbox-asr")
|
||||
|
||||
|
||||
class BlackboxSTT(stt.STT):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
model_name: str = "sensevoice",
|
||||
language: Optional[str] = "auto",
|
||||
output_language: str = "zh",
|
||||
hotwords: Optional[str] = None,
|
||||
itn: Optional[Union[bool, str]] = None,
|
||||
chunk_mode: Optional[Union[bool, str]] = None,
|
||||
timeout: float = 30.0,
|
||||
http_session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=stt.STTCapabilities(
|
||||
streaming=False,
|
||||
interim_results=False,
|
||||
diarization=False,
|
||||
)
|
||||
)
|
||||
self._url = url
|
||||
self._model_name = model_name
|
||||
self._language = language
|
||||
self._output_language = output_language
|
||||
self._timeout = timeout
|
||||
self._http_session = http_session
|
||||
self._extra_fields: dict[str, str] = {}
|
||||
|
||||
if hotwords:
|
||||
self._extra_fields["hotwords"] = hotwords
|
||||
if itn is not None:
|
||||
self._extra_fields["itn"] = _form_value(itn)
|
||||
if chunk_mode is not None:
|
||||
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return "asr-blackbox"
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = utils.http_context.http_session()
|
||||
return self._http_session
|
||||
|
||||
async def _recognize_impl(
|
||||
self,
|
||||
buffer: utils.AudioBuffer,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions,
|
||||
) -> stt.SpeechEvent:
|
||||
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
||||
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
|
||||
form.add_field("model_name", self._model_name)
|
||||
|
||||
resolved_language = language if is_given(language) else self._language
|
||||
if resolved_language:
|
||||
form.add_field("language", resolved_language)
|
||||
for key, value in self._extra_fields.items():
|
||||
form.add_field(key, value)
|
||||
|
||||
try:
|
||||
async with self._ensure_session().post(
|
||||
self._url,
|
||||
data=form,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self._timeout,
|
||||
sock_connect=conn_options.timeout,
|
||||
),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise APIStatusError(
|
||||
message=f"ASR blackbox error: {error_text}",
|
||||
status_code=resp.status,
|
||||
request_id=None,
|
||||
body=error_text,
|
||||
)
|
||||
|
||||
payload = await resp.json()
|
||||
logger.info("ASR blackbox raw result: %s", payload)
|
||||
text = _extract_asr_text(payload)
|
||||
if not text:
|
||||
raise APIConnectionError("ASR blackbox returned an empty transcript")
|
||||
|
||||
logger.info("ASR blackbox result: %s", text)
|
||||
return stt.SpeechEvent(
|
||||
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||
alternatives=[
|
||||
stt.SpeechData(
|
||||
text=text,
|
||||
language=LanguageCode(self._output_language),
|
||||
)
|
||||
],
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise APITimeoutError("ASR blackbox request timed out") from e
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
|
||||
|
||||
|
||||
def _extract_asr_text(payload: dict[str, Any]) -> str:
|
||||
text = payload.get("text")
|
||||
if isinstance(text, str):
|
||||
return text.strip()
|
||||
|
||||
result = payload.get("result")
|
||||
if isinstance(result, list) and result:
|
||||
first = result[0]
|
||||
if isinstance(first, dict):
|
||||
for key in ("clean_text", "text", "raw_text"):
|
||||
value = first.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(first, str):
|
||||
return first.strip()
|
||||
|
||||
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||
|
||||
|
||||
def _form_value(value: Union[bool, str]) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
return value
|
||||
Reference in New Issue
Block a user