172 lines
5.5 KiB
Python
172 lines
5.5 KiB
Python
import logging
|
|
import os
|
|
import aiohttp
|
|
from dotenv import load_dotenv
|
|
from livekit import rtc
|
|
from livekit.agents import (
|
|
Agent,
|
|
AgentServer,
|
|
AgentSession,
|
|
APIConnectOptions,
|
|
JobContext,
|
|
JobProcess,
|
|
LanguageCode,
|
|
MetricsCollectedEvent,
|
|
NOT_GIVEN,
|
|
NotGivenOr,
|
|
TurnHandlingOptions,
|
|
cli,
|
|
metrics,
|
|
room_io,
|
|
stt,
|
|
text_transforms,
|
|
utils,
|
|
)
|
|
from livekit.plugins import silero, openai
|
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
|
|
|
logger = logging.getLogger("custom-agent")
|
|
|
|
load_dotenv()
|
|
|
|
class SenseVoiceSTT(stt.STT):
|
|
def __init__(self, url: str):
|
|
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
|
|
self._url = url
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return "sensevoice"
|
|
|
|
async def _recognize_impl(
|
|
self,
|
|
buffer: utils.AudioBuffer,
|
|
*,
|
|
language: NotGivenOr[str] = NOT_GIVEN,
|
|
conn_options: APIConnectOptions,
|
|
) -> stt.SpeechEvent:
|
|
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
data = aiohttp.FormData()
|
|
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
|
|
data.add_field('model_name', 'sensevoice')
|
|
|
|
lang = language if language is not NOT_GIVEN else 'auto'
|
|
data.add_field('language', lang)
|
|
|
|
try:
|
|
async with session.post(self._url, data=data, timeout=30) as resp:
|
|
if resp.status != 200:
|
|
raise Exception(f"ASR server returned status {resp.status}")
|
|
|
|
result = await resp.json()
|
|
if not result.get("result"):
|
|
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
|
|
|
|
text = result["result"][0].get("clean_text", "")
|
|
logger.info(f"SenseVoice ASR Result: {text}")
|
|
return stt.SpeechEvent(
|
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"SenseVoice ASR error: {e}")
|
|
raise
|
|
|
|
class CustomAgent(Agent):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
|
|
"Keep your responses concise and friendly."
|
|
"You are interacting with the user via a local ASR and LLM pipeline.",
|
|
)
|
|
|
|
async def on_enter(self) -> None:
|
|
self.session.generate_reply(instructions="greet the user and introduce yourself")
|
|
|
|
server = AgentServer()
|
|
|
|
def prewarm(proc: JobProcess) -> None:
|
|
# Load Silero VAD as requested
|
|
proc.userdata["vad"] = silero.VAD.load()
|
|
|
|
server.setup_fnc = prewarm
|
|
|
|
@server.rtc_session(agent_name="my-agent")
|
|
async def entrypoint(ctx: JobContext) -> None:
|
|
ctx.log_context_fields = {
|
|
"room": ctx.room.name,
|
|
}
|
|
|
|
# Configuration for custom local endpoints
|
|
# These can be set in your .env file
|
|
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
|
|
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
|
|
|
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox")
|
|
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav")
|
|
|
|
# Initialize SenseVoice STT and wrap with StreamAdapter
|
|
sensevoice_stt = SenseVoiceSTT(url=ASR_URL)
|
|
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"])
|
|
|
|
import httpx
|
|
from openai import AsyncClient as OpenAIAsyncClient
|
|
|
|
# Create a custom HTTP client that disables SSL verification
|
|
http_client = httpx.AsyncClient(verify=False)
|
|
|
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
|
openai_client = OpenAIAsyncClient(
|
|
api_key="sk-orez64WkG1NkfksB5j_hGA",
|
|
base_url=MINIMAX_BASE_URL,
|
|
http_client=http_client,
|
|
)
|
|
|
|
from tts_voxcpm import VoxCPMTTS
|
|
|
|
session: AgentSession = AgentSession(
|
|
# 1. Custom SenseVoice ASR (STT) with StreamAdapter
|
|
stt=stt_stream,
|
|
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
|
llm=openai.LLM(
|
|
model=MINIMAX_MODEL,
|
|
client=openai_client,
|
|
),
|
|
# 3. VoxCPM TTS - Custom implementation for blackbox API
|
|
tts=VoxCPMTTS(
|
|
url=VOXCPM_URL,
|
|
prompt_wav_path=PROMPT_WAV,
|
|
),
|
|
# 4. Silero VAD
|
|
vad=ctx.proc.userdata["vad"],
|
|
turn_handling=TurnHandlingOptions(
|
|
turn_detection=MultilingualModel(),
|
|
interruption={
|
|
"resume_false_interruption": True,
|
|
"false_interruption_timeout": 1.0,
|
|
},
|
|
),
|
|
preemptive_generation=True,
|
|
aec_warmup_duration=3.0,
|
|
tts_text_transforms=[
|
|
"filter_emoji",
|
|
"filter_markdown",
|
|
],
|
|
)
|
|
|
|
@session.on("metrics_collected")
|
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
|
metrics.log_metrics(ev.metrics)
|
|
|
|
await session.start(
|
|
agent=CustomAgent(),
|
|
room=ctx.room,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
cli.run_app(server)
|