commit ac81d4a9eb4701f9171833d08a62ba6a121d04cd Author: 0Xiao0 <511201264@qq.com> Date: Thu May 7 15:13:15 2026 +0800 initial commit diff --git a/custom_agent.py b/custom_agent.py new file mode 100644 index 0000000..eef6927 --- /dev/null +++ b/custom_agent.py @@ -0,0 +1,171 @@ +import logging +import os +import aiohttp +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + Agent, + AgentServer, + AgentSession, + APIConnectOptions, + JobContext, + JobProcess, + LanguageCode, + MetricsCollectedEvent, + NOT_GIVEN, + NotGivenOr, + TurnHandlingOptions, + cli, + metrics, + room_io, + stt, + text_transforms, + utils, +) +from livekit.plugins import silero, openai +from livekit.plugins.turn_detector.multilingual import MultilingualModel + +logger = logging.getLogger("custom-agent") + +load_dotenv() + +class SenseVoiceSTT(stt.STT): + def __init__(self, url: str): + super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False)) + self._url = url + + @property + def model(self) -> str: + return "sensevoice" + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions, + ) -> stt.SpeechEvent: + audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes() + + async with aiohttp.ClientSession() as session: + data = aiohttp.FormData() + data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav') + data.add_field('model_name', 'sensevoice') + + lang = language if language is not NOT_GIVEN else 'auto' + data.add_field('language', lang) + + try: + async with session.post(self._url, data=data, timeout=30) as resp: + if resp.status != 200: + raise Exception(f"ASR server returned status {resp.status}") + + result = await resp.json() + if not result.get("result"): + return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT) + + text = result["result"][0].get("clean_text", "") + logger.info(f"SenseVoice ASR Result: {text}") + return stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))], + ) + except Exception as e: + logger.error(f"SenseVoice ASR error: {e}") + raise + +class CustomAgent(Agent): + def __init__(self) -> None: + super().__init__( + instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant." + "Keep your responses concise and friendly." + "You are interacting with the user via a local ASR and LLM pipeline.", + ) + + async def on_enter(self) -> None: + self.session.generate_reply(instructions="greet the user and introduce yourself") + +server = AgentServer() + +def prewarm(proc: JobProcess) -> None: + # Load Silero VAD as requested + proc.userdata["vad"] = silero.VAD.load() + +server.setup_fnc = prewarm + +@server.rtc_session(agent_name="my-agent") +async def entrypoint(ctx: JobContext) -> None: + ctx.log_context_fields = { + "room": ctx.room.name, + } + + # Configuration for custom local endpoints + # These can be set in your .env file + ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox") + + MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1") + MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max") + + VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox") + PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav") + + # Initialize SenseVoice STT and wrap with StreamAdapter + sensevoice_stt = SenseVoiceSTT(url=ASR_URL) + stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"]) + + import httpx + from openai import AsyncClient as OpenAIAsyncClient + + # Create a custom HTTP client that disables SSL verification + http_client = httpx.AsyncClient(verify=False) + + # Create the OpenAI AsyncClient with the custom HTTP client + openai_client = OpenAIAsyncClient( + api_key="sk-orez64WkG1NkfksB5j_hGA", + base_url=MINIMAX_BASE_URL, + http_client=http_client, + ) + + from tts_voxcpm import VoxCPMTTS + + session: AgentSession = AgentSession( + # 1. Custom SenseVoice ASR (STT) with StreamAdapter + stt=stt_stream, + # 2. Minimax LLM - Using OpenAI plugin with local base_url + llm=openai.LLM( + model=MINIMAX_MODEL, + client=openai_client, + ), + # 3. VoxCPM TTS - Custom implementation for blackbox API + tts=VoxCPMTTS( + url=VOXCPM_URL, + prompt_wav_path=PROMPT_WAV, + ), + # 4. Silero VAD + vad=ctx.proc.userdata["vad"], + turn_handling=TurnHandlingOptions( + turn_detection=MultilingualModel(), + interruption={ + "resume_false_interruption": True, + "false_interruption_timeout": 1.0, + }, + ), + preemptive_generation=True, + aec_warmup_duration=3.0, + tts_text_transforms=[ + "filter_emoji", + "filter_markdown", + ], + ) + + @session.on("metrics_collected") + def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: + metrics.log_metrics(ev.metrics) + + await session.start( + agent=CustomAgent(), + room=ctx.room, + ) + +if __name__ == "__main__": + cli.run_app(server) diff --git a/test_agent.py b/test_agent.py new file mode 100644 index 0000000..dc1d1cb --- /dev/null +++ b/test_agent.py @@ -0,0 +1,188 @@ +import asyncio +import requests +import logging +from pathlib import Path +import uuid +import wave +import numpy as np +from datetime import datetime +from livekit import rtc +from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("test-agent") + +TOKEN_URL = "http://localhost:8000/getToken" +WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" +ROOM_NAME = "test-room20" +WAV_FILE = "2food.wav" +TEST_TIMEOUT = 30 + +class TestState: + def __init__(self): + self.agent_connected = False + self.tts_received = False + self.tts_count = 0 + +test_state = TestState() + + +def get_token(agent_name="my-agent"): + try: + resp = requests.get( + TOKEN_URL, + params={ + "room": ROOM_NAME, + "identity": f"test-{uuid.uuid4().hex[:6]}", + "agent_name": agent_name, + }, + timeout=5 + ) + resp.raise_for_status() + return resp.json()["token"] + except Exception as e: + logger.error(f"❌ 获取token失败: {e}") + raise + + +async def publish_wav(room, wav_path): + wav_path = Path(wav_path) + if not wav_path.exists(): + logger.error(f"❌ WAV文件不存在: {wav_path}") + raise FileNotFoundError(f"文件不存在: {wav_path}") + + logger.info(f"📂 开始上传: {wav_path}") + + with wave.open(str(wav_path), "rb") as wf: + sample_rate = wf.getframerate() + num_channels = wf.getnchannels() + sample_width = wf.getsampwidth() + + logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit") + + source = AudioSource(sample_rate, num_channels) + track = LocalAudioTrack.create_audio_track("mic", source) + + await room.local_participant.publish_track(track) + logger.info("📡 已发布音轨") + + frame_duration = 0.02 + samples_per_frame = int(sample_rate * frame_duration) + + while True: + data = wf.readframes(samples_per_frame) + if not data: + break + + audio = np.frombuffer(data, dtype=np.int16) + if len(audio) == 0: + continue + + samples_per_channel = len(audio) // num_channels + + frame = AudioFrame( + data=data, + sample_rate=sample_rate, + num_channels=num_channels, + samples_per_channel=samples_per_channel, + ) + + await source.capture_frame(frame) + await asyncio.sleep(frame_duration) + + logger.info("✅ WAV推流完成") + + +async def test_agent(): + try: + logger.info("🔑 正在获取token...") + token = get_token() + logger.info("✅ Token获取成功") + + room = rtc.Room() + + @room.on("participant_connected") + def on_participant_connected(participant): + logger.info(f"✅ 参与者加入: {participant.identity}") + if "agent" in participant.identity.lower(): + test_state.agent_connected = True + logger.info("🎉 Agent已连接!") + + @room.on("participant_disconnected") + def on_participant_disconnected(participant): + logger.info(f"❌ 参与者离开: {participant.identity}") + + @room.on("track_subscribed") + def on_track_subscribed(track, publication, participant): + if track.kind == rtc.TrackKind.KIND_AUDIO: + test_state.tts_count += 1 + logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)") + test_state.tts_received = True + + logger.info(f"🔌 正在连接房间 {ROOM_NAME}...") + await room.connect(WS_URL, token) + logger.info("✅ 已连接到房间") + logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}") + + logger.info("⏳ 等待Agent连接...") + for i in range(10): + if test_state.agent_connected: + break + await asyncio.sleep(1) + + if not test_state.agent_connected: + logger.warning("⚠️ Agent未连接") + return False + + logger.info("🎙️ 正在上传测试音频...") + await publish_wav(room, WAV_FILE) + + logger.info("⏳ 等待Agent响应...") + for i in range(TEST_TIMEOUT): + if test_state.tts_received: + logger.info("✅ 收到Agent TTS响应!") + break + if i % 5 == 0: + logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)") + await asyncio.sleep(1) + + await asyncio.sleep(2) + + logger.info("\n" + "="*60) + logger.info("✅ 测试结果") + logger.info("="*60) + logger.info(f"Agent连接: {'✅' if test_state.agent_connected else '❌'}") + logger.info(f"收到TTS响应: {'✅' if test_state.tts_received else '❌'}") + logger.info(f"TTS音频次数: {test_state.tts_count} 次") + logger.info("="*60) + + await room.disconnect() + logger.info("✅ 已断开连接\n") + + return test_state.agent_connected and test_state.tts_received + + except Exception as e: + logger.error(f"❌ 测试失败: {e}", exc_info=True) + return False + + +async def main(): + logger.info("🚀 开始测试custom_agent...\n") + success = await test_agent() + + if success: + logger.info("✅ 测试成功!custom_agent 正常工作") + logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到,") + logger.info(" 或通过 agent-starter-react 这样的客户端交互查看") + return 0 + else: + logger.error("❌ 测试失败") + return 1 + + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + exit(exit_code) \ No newline at end of file diff --git a/test_asr.py b/test_asr.py new file mode 100644 index 0000000..64d41c1 --- /dev/null +++ b/test_asr.py @@ -0,0 +1,53 @@ +import asyncio +import logging +import wave +from custom_agent import SenseVoiceSTT +from livekit import rtc +from livekit.agents import utils + +# 设置日志级别以查看输出 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("test-asr") + +async def test(): + # 替换为你本地的一个音频文件路径 + audio_path = "/home/verachen/Music/voice/2food.wav" + + # 初始化 ASR + stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox") + + print(f"Testing ASR connectivity with file: {audio_path}") + + try: + # 读取音频文件 + with wave.open(audio_path, 'rb') as wf: + frames = wf.readframes(wf.getnframes()) + # 简单构造一个 AudioBuffer (假设是单声道 16kHz) + # 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes() + # 所以我们需要传递一个包含 AudioFrame 的 list + + # 这里我们模拟一个 Frame + frame = rtc.AudioFrame( + data=frames, + sample_rate=wf.getframerate(), + num_channels=wf.getnchannels(), + samples_per_channel=wf.getnframes() + ) + + # 调用 recognize + result = await stt.recognize(buffer=[frame]) + + if result.alternatives: + print(f"\n--- ASR Result ---") + print(f"Text: {result.alternatives[0].text}") + print(f"------------------\n") + else: + print("ASR returned no text.") + + except FileNotFoundError: + print(f"Error: Audio file not found at {audio_path}") + except Exception as e: + print(f"An error occurred: {e}") + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/test_livekit.py b/test_livekit.py new file mode 100644 index 0000000..4009dfa --- /dev/null +++ b/test_livekit.py @@ -0,0 +1,130 @@ +import asyncio +import requests +from livekit import rtc + +import wave +import numpy as np +from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack + +TOKEN_URL = "http://localhost:8000/getToken" +WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址 + +ROOM_NAME = "test-room20" +import uuid +IDENTITY = f"uv-{uuid.uuid4().hex[:6]}" +# IDENTITY = "test-user0" + + +def get_token(): + resp = requests.get( + TOKEN_URL, + params={ + "room": ROOM_NAME, + "identity": IDENTITY, + "agent_name": "my-agent", # 关键!!! + }, + ) + data = resp.json() + return data["token"] + + +async def main(): + token = get_token() + + room = rtc.Room() + + @room.on("participant_connected") + def on_participant_connected(participant): + print(f"✅ 有人加入房间: {participant.identity}") + + @room.on("participant_disconnected") + def on_participant_disconnected(participant): + print(f"❌ 有人离开房间: {participant.identity}") + + print("🔌 正在连接房间...") + await room.connect(WS_URL, token) + + print("✅ 已连接房间:", ROOM_NAME) + print("当前房间成员:") + for p in room.remote_participants.values(): + print(" -", p.identity) + + @room.on("data_received") + def on_data_received(data, participant, kind, topic): + try: + msg = data.decode() + print(f"📩 来自 {participant.identity}: {msg}") + except: + print("📩 收到二进制数据") + + @room.on("track_subscribed") + def on_track_subscribed(track, publication, participant): + print(f"🎧 订阅轨道: {participant.identity}") + + if track.kind == rtc.TrackKind.KIND_AUDIO: + print("👉 TTS 音频来了") + + # 等一下确保连接稳定 + await asyncio.sleep(1) + await room.local_participant.publish_data( + b"hello", + reliable=True, + topic="chat" + ) + # 上传 wav + await publish_wav(room, "2food.wav") + + await room.disconnect() + + +async def publish_wav(room, wav_path): + print("🎵 开始上传本地 wav:", wav_path) + + wf = wave.open(wav_path, "rb") + + sample_rate = wf.getframerate() + num_channels = wf.getnchannels() + sample_width = wf.getsampwidth() + + print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit") + + # 创建音频源 + source = AudioSource(sample_rate, num_channels) + + # 创建本地音轨 + track = LocalAudioTrack.create_audio_track("mic", source) + + # 发布轨道 + await room.local_participant.publish_track(track) + print("📡 已发布音轨") + + frame_duration = 0.02 # 20ms + samples_per_frame = int(sample_rate * frame_duration) + + while True: + data = wf.readframes(samples_per_frame) + if not data: + break + + # 用于计算长度 + audio = np.frombuffer(data, dtype=np.int16) + + if len(audio) == 0: + continue + + samples_per_channel = len(audio) // num_channels + + frame = AudioFrame( + data=data, # ✅ 关键:用 bytes + sample_rate=sample_rate, + num_channels=num_channels, + samples_per_channel=samples_per_channel, + ) + + await source.capture_frame(frame) + await asyncio.sleep(frame_duration) + print("✅ wav 推流结束") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/test_minimax.py b/test_minimax.py new file mode 100644 index 0000000..d64edff --- /dev/null +++ b/test_minimax.py @@ -0,0 +1,71 @@ +import asyncio +import os +import logging +from dotenv import load_dotenv +from livekit.agents.llm import ChatContext +from livekit.plugins import openai + +# Configure logging to see what's happening +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("test-minimax") + +async def test_minimax(): + print("Loading .env...") + load_dotenv() + + # Configuration from environment or defaults from custom_agent.py + MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1") + MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI") + # Using the hardcoded key from custom_agent.py as a fallback if not in .env + API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA") + + import httpx + from openai import AsyncClient as OpenAIAsyncClient + + print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}") + + # Create a custom HTTP client that disables SSL verification + http_client = httpx.AsyncClient(verify=False) + + # Create the OpenAI AsyncClient with the custom HTTP client + openai_client = OpenAIAsyncClient( + api_key=API_KEY, + base_url=MINIMAX_BASE_URL, + http_client=http_client, + ) + + llm = openai.LLM( + model=MINIMAX_MODEL, + client=openai_client, + ) + + print("Creating ChatContext...") + chat_ctx = ChatContext() + chat_ctx.add_message( + content="Hello! Can you introduce yourself? Please reply in Chinese.", + role="user", + ) + + print(f"\n--- Testing Streaming Chat ---") + print(f"Request: {chat_ctx.items[-1].content}") + print("Response: ", end="", flush=True) + + try: + print("\nCalling llm.chat()...") + stream = llm.chat(chat_ctx=chat_ctx) + print("Iterating over stream...") + async for chunk in stream: + if chunk.delta and chunk.delta.content: + print(chunk.delta.content, end="", flush=True) + print("\n--- Test Completed Successfully ---") + except Exception as e: + logger.error(f"\nTest failed with error: {e}") + +if __name__ == "__main__": + print("Starting...") + try: + asyncio.run(asyncio.wait_for(test_minimax(), timeout=30)) + except asyncio.TimeoutError: + print("\nTest timed out after 30 seconds.") + except Exception as e: + print(f"\nAn error occurred: {e}") diff --git a/test_voxcpm.py b/test_voxcpm.py new file mode 100644 index 0000000..f172bc4 --- /dev/null +++ b/test_voxcpm.py @@ -0,0 +1,50 @@ +import asyncio +import os +import logging +from tts_voxcpm import VoxCPMTTS +from livekit.agents import tts + +logging.basicConfig(level=logging.INFO) + +async def test_tts(): + # Use the URL from the user's curl command + url = "http://10.6.80.21:5002/tts-blackbox" + + # Check if we have a real wav file to test with + # In the earlier find_by_name, we found tests/change-sophie.wav + prompt_wav = "/home/verachen/Music/voice/2food.wav" + if not os.path.exists(prompt_wav): + prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl + + print(f"Testing VoxCPMTTS with URL: {url}") + print(f"Using prompt wav: {prompt_wav}") + + vox_tts = VoxCPMTTS( + url=url, + prompt_wav_path=prompt_wav + ) + + text = "你好,这是一段测试文本" + print(f"Synthesizing text: {text}") + + try: + stream = vox_tts.synthesize(text) + audio_frame = await stream.collect() + + print(f"Successfully synthesized audio!") + print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?") + # Actually AudioFrame has duration or samples + print(f"Samples: {len(audio_frame.data) // 2}") + + # Save to file for manual check if possible + with open("test_output.wav", "wb") as f: + # This won't be a valid WAV yet if it's just raw PCM, + # but if collect() returns combined frames, we can use to_wav_bytes() + f.write(audio_frame.to_wav_bytes()) + print("Saved output to test_output.wav") + + except Exception as e: + print(f"TTS test failed: {e}") + +if __name__ == "__main__": + asyncio.run(test_tts()) diff --git a/tts_voxcpm.py b/tts_voxcpm.py new file mode 100644 index 0000000..25291ac --- /dev/null +++ b/tts_voxcpm.py @@ -0,0 +1,118 @@ +import aiohttp +import logging +import os +from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS + +logger = logging.getLogger("voxcpm-tts") + +class VoxCPMTTS(tts.TTS): + def __init__( + self, + *, + url: str, + model_name: str = "voxcpmtts", + prompt_text: str = "澳门有乜嘢好食嘅", + prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav", + cfg_value: str = "2.0", + inference_timesteps: str = "10", + do_normalize: str = "true", + denoise: str = "true", + retry_badcase: str = "true", + retry_badcase_max_times: str = "3", + retry_badcase_ratio_threshold: str = "6.0", + sample_rate: int = 16000, + ): + super().__init__( + capabilities=tts.TTSCapabilities(streaming=False), + sample_rate=sample_rate, + num_channels=1, + ) + self._url = url + self._opts = { + "model_name": model_name, + "streaming": "false", + "prompt_text": prompt_text, + "cfg_value": str(cfg_value), + "inference_timesteps": str(inference_timesteps), + "do_normalize": str(do_normalize), + "denoise": str(denoise), + "retry_badcase": str(retry_badcase), + "retry_badcase_max_times": str(retry_badcase_max_times), + "retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold), + } + self._prompt_wav_path = prompt_wav_path + + @property + def model(self) -> str: + return self._opts["model_name"] + + def synthesize( + self, + text: str, + *, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> tts.ChunkedStream: + return VoxCPMStream( + self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options + ) + +class VoxCPMStream(tts.ChunkedStream): + def __init__( + self, + tts: VoxCPMTTS, + text: str, + url: str, + opts: dict, + prompt_wav_path: str, + conn_options: APIConnectOptions, + ): + super().__init__(tts=tts, input_text=text, conn_options=conn_options) + self._url = url + self._opts = opts + self._prompt_wav_path = prompt_wav_path + + async def _run(self, output_emitter: tts.AudioEmitter) -> None: + # Initialize emitter early to avoid "AudioEmitter isn't started" error on failure + output_emitter.initialize( + request_id="", + sample_rate=self._tts.sample_rate, + num_channels=self._tts.num_channels, + mime_type="audio/wav", + ) + + async with aiohttp.ClientSession() as session: + data = aiohttp.FormData() + data.add_field("text", self.input_text) + for k, v in self._opts.items(): + data.add_field(k, v) + + # Open the prompt wav file if it exists + f = None + if os.path.exists(self._prompt_wav_path): + f = open(self._prompt_wav_path, "rb") + data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav") + else: + logger.warning( + f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field" + ) + + try: + # Set a reasonable timeout for synthesis + async with session.post( + self._url, data=data, timeout=aiohttp.ClientTimeout(total=60) + ) as resp: + if resp.status != 200: + err_text = await resp.text() + logger.error(f"VoxCPM TTS error: {resp.status} {err_text}") + return + + # Read the entire audio data (since streaming=false) + audio_data = await resp.read() + + output_emitter.push(audio_data) + output_emitter.flush() + except Exception as e: + logger.error(f"VoxCPM TTS request failed: {e}") + finally: + if f: + f.close()