Files
livekit_agents/custom_agent.py
0Xiao0 746053fd58 fix
2026-05-13 15:35:04 +08:00

252 lines
8.4 KiB
Python

import logging
import os
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from asr import BlackboxSTT
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
JobContext,
JobProcess,
MetricsCollectedEvent,
RecordingOptions,
TurnHandlingOptions,
cli,
metrics,
room_io,
stt,
)
from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
from tts import BlackboxTTS
logger = logging.getLogger("custom-agent")
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
class CustomAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
"Keep your responses concise and friendly."
"You are interacting with the user via a local ASR and LLM pipeline.",
)
async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself")
pass
server = AgentServer()
def prewarm(proc: JobProcess) -> None:
# Load Silero VAD as requested
proc.userdata["vad"] = silero.VAD.load()
server.setup_fnc = prewarm
@server.rtc_session(agent_name=AGENT_NAME)
async def entrypoint(ctx: JobContext) -> None:
ctx.log_context_fields = {
"room": ctx.room.name,
}
# Configuration for custom local endpoints. These can be set in your .env file.
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
)
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
blackbox_stt = BlackboxSTT(
url=ASR_URL,
model_name=ASR_MODEL,
language=ASR_LANGUAGE,
output_language=ASR_OUTPUT_LANGUAGE,
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
import httpx
from openai import AsyncClient as OpenAIAsyncClient
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
llm=openai.LLM(
model=LLM_MODEL,
client=openai_client,
),
# 3. TTS blackbox
tts=BlackboxTTS(
url=TTS_URL,
model_name=TTS_MODEL,
params=_tts_params_from_env(TTS_MODEL),
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
interruption={
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=False,
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
"filter_markdown",
],
)
@session.on("metrics_collected")
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
await session.start(
agent=CustomAgent(),
room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(
sample_rate=OUTPUT_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
),
record=_recording_options_from_env(),
)
def _tts_params_from_env(model_name: str) -> dict[str, str]:
params: dict[str, str] = {}
model_name = model_name.lower()
if model_name == "voxcpmtts":
params.update(
{
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
"prompt_text": os.getenv(
"CUSTOM_TTS_PROMPT_TEXT",
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
),
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
"retry_badcase_ratio_threshold": os.getenv(
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
),
}
)
elif model_name == "melotts":
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
elif model_name == "cosyvoicetts":
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
elif model_name == "sovitstts":
params.update(
{
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
}
)
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
return params
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
if value:
params[key] = value
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if not value:
return default
try:
return int(value)
except ValueError:
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
return default
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
return default
def _recording_options_from_env() -> RecordingOptions:
return RecordingOptions(
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
)
if __name__ == "__main__":
cli.run_app(server)