feat: supported different models
This commit is contained in:
225
custom_agent.py
225
custom_agent.py
@ -1,78 +1,35 @@
|
||||
import logging
|
||||
import os
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from livekit import rtc
|
||||
|
||||
from asr import BlackboxSTT
|
||||
from livekit.agents import (
|
||||
Agent,
|
||||
AgentServer,
|
||||
AgentSession,
|
||||
APIConnectOptions,
|
||||
JobContext,
|
||||
JobProcess,
|
||||
LanguageCode,
|
||||
MetricsCollectedEvent,
|
||||
NOT_GIVEN,
|
||||
NotGivenOr,
|
||||
RecordingOptions,
|
||||
TurnHandlingOptions,
|
||||
cli,
|
||||
metrics,
|
||||
room_io,
|
||||
stt,
|
||||
text_transforms,
|
||||
utils,
|
||||
)
|
||||
from livekit.plugins import silero, openai
|
||||
from livekit.plugins import openai, silero
|
||||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||
from tts import BlackboxTTS
|
||||
|
||||
logger = logging.getLogger("custom-agent")
|
||||
|
||||
load_dotenv()
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||
|
||||
class SenseVoiceSTT(stt.STT):
|
||||
def __init__(self, url: str):
|
||||
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
|
||||
self._url = url
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return "sensevoice"
|
||||
|
||||
async def _recognize_impl(
|
||||
self,
|
||||
buffer: utils.AudioBuffer,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions,
|
||||
) -> stt.SpeechEvent:
|
||||
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
|
||||
data.add_field('model_name', 'sensevoice')
|
||||
|
||||
lang = language if language is not NOT_GIVEN else 'auto'
|
||||
data.add_field('language', lang)
|
||||
|
||||
try:
|
||||
async with session.post(self._url, data=data, timeout=30) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"ASR server returned status {resp.status}")
|
||||
|
||||
result = await resp.json()
|
||||
if not result.get("result"):
|
||||
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
|
||||
|
||||
text = result["result"][0].get("clean_text", "")
|
||||
logger.info(f"SenseVoice ASR Result: {text}")
|
||||
return stt.SpeechEvent(
|
||||
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SenseVoice ASR error: {e}")
|
||||
raise
|
||||
|
||||
class CustomAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
@ -83,63 +40,86 @@ class CustomAgent(Agent):
|
||||
)
|
||||
|
||||
async def on_enter(self) -> None:
|
||||
self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||
pass
|
||||
|
||||
server = AgentServer()
|
||||
|
||||
|
||||
def prewarm(proc: JobProcess) -> None:
|
||||
# Load Silero VAD as requested
|
||||
proc.userdata["vad"] = silero.VAD.load()
|
||||
|
||||
|
||||
server.setup_fnc = prewarm
|
||||
|
||||
@server.rtc_session(agent_name="my-agent")
|
||||
|
||||
@server.rtc_session(agent_name=AGENT_NAME)
|
||||
async def entrypoint(ctx: JobContext) -> None:
|
||||
ctx.log_context_fields = {
|
||||
"room": ctx.room.name,
|
||||
}
|
||||
|
||||
# Configuration for custom local endpoints
|
||||
# These can be set in your .env file
|
||||
# Configuration for custom local endpoints. These can be set in your .env file.
|
||||
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
||||
|
||||
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||
|
||||
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
||||
|
||||
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox")
|
||||
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav")
|
||||
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||
if not MINIMAX_API_KEY:
|
||||
raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
|
||||
# Initialize SenseVoice STT and wrap with StreamAdapter
|
||||
sensevoice_stt = SenseVoiceSTT(url=ASR_URL)
|
||||
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"])
|
||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
||||
)
|
||||
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||
|
||||
blackbox_stt = BlackboxSTT(
|
||||
url=ASR_URL,
|
||||
model_name=ASR_MODEL,
|
||||
language=ASR_LANGUAGE,
|
||||
output_language=ASR_OUTPUT_LANGUAGE,
|
||||
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
|
||||
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||||
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||||
)
|
||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||
|
||||
import httpx
|
||||
from openai import AsyncClient as OpenAIAsyncClient
|
||||
|
||||
# Create a custom HTTP client that disables SSL verification
|
||||
http_client = httpx.AsyncClient(verify=False)
|
||||
|
||||
|
||||
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||
openai_client = OpenAIAsyncClient(
|
||||
api_key="sk-orez64WkG1NkfksB5j_hGA",
|
||||
api_key=MINIMAX_API_KEY,
|
||||
base_url=MINIMAX_BASE_URL,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
from tts_voxcpm import VoxCPMTTS
|
||||
|
||||
session: AgentSession = AgentSession(
|
||||
# 1. Custom SenseVoice ASR (STT) with StreamAdapter
|
||||
# 1. Custom ASR blackbox with StreamAdapter
|
||||
stt=stt_stream,
|
||||
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
||||
llm=openai.LLM(
|
||||
model=MINIMAX_MODEL,
|
||||
client=openai_client,
|
||||
),
|
||||
# 3. VoxCPM TTS - Custom implementation for blackbox API
|
||||
tts=VoxCPMTTS(
|
||||
url=VOXCPM_URL,
|
||||
prompt_wav_path=PROMPT_WAV,
|
||||
# 3. TTS blackbox
|
||||
tts=BlackboxTTS(
|
||||
url=TTS_URL,
|
||||
model_name=TTS_MODEL,
|
||||
params=_tts_params_from_env(TTS_MODEL),
|
||||
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
|
||||
sample_rate=TTS_SAMPLE_RATE,
|
||||
num_channels=TTS_NUM_CHANNELS,
|
||||
),
|
||||
# 4. Silero VAD
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
@ -150,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
"false_interruption_timeout": 1.0,
|
||||
},
|
||||
),
|
||||
preemptive_generation=True,
|
||||
preemptive_generation=False,
|
||||
aec_warmup_duration=3.0,
|
||||
tts_text_transforms=[
|
||||
"filter_emoji",
|
||||
@ -165,7 +145,102 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
await session.start(
|
||||
agent=CustomAgent(),
|
||||
room=ctx.room,
|
||||
room_options=room_io.RoomOptions(
|
||||
audio_output=room_io.AudioOutputOptions(
|
||||
sample_rate=OUTPUT_SAMPLE_RATE,
|
||||
num_channels=TTS_NUM_CHANNELS,
|
||||
),
|
||||
),
|
||||
record=_recording_options_from_env(),
|
||||
)
|
||||
|
||||
|
||||
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||
params: dict[str, str] = {}
|
||||
model_name = model_name.lower()
|
||||
|
||||
if model_name == "voxcpmtts":
|
||||
params.update(
|
||||
{
|
||||
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||
"prompt_text": os.getenv(
|
||||
"CUSTOM_TTS_PROMPT_TEXT",
|
||||
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
|
||||
),
|
||||
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
|
||||
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
|
||||
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
|
||||
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
|
||||
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
|
||||
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
|
||||
"retry_badcase_ratio_threshold": os.getenv(
|
||||
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
|
||||
),
|
||||
}
|
||||
)
|
||||
elif model_name == "melotts":
|
||||
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
|
||||
elif model_name == "cosyvoicetts":
|
||||
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
||||
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
||||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
||||
elif model_name == "sovitstts":
|
||||
params.update(
|
||||
{
|
||||
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
|
||||
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
|
||||
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
|
||||
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
|
||||
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
|
||||
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||
}
|
||||
)
|
||||
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
||||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
|
||||
if value:
|
||||
params[key] = value
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
value = os.getenv(name)
|
||||
if not value:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
|
||||
return default
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
|
||||
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
|
||||
return default
|
||||
|
||||
|
||||
def _recording_options_from_env() -> RecordingOptions:
|
||||
return RecordingOptions(
|
||||
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
|
||||
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
|
||||
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
|
||||
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli.run_app(server)
|
||||
|
||||
Reference in New Issue
Block a user