initial commit

This commit is contained in:
0Xiao0
2026-05-07 15:13:15 +08:00
commit ac81d4a9eb
7 changed files with 781 additions and 0 deletions

171
custom_agent.py Normal file
View File

@ -0,0 +1,171 @@
import logging
import os
import aiohttp
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
APIConnectOptions,
JobContext,
JobProcess,
LanguageCode,
MetricsCollectedEvent,
NOT_GIVEN,
NotGivenOr,
TurnHandlingOptions,
cli,
metrics,
room_io,
stt,
text_transforms,
utils,
)
from livekit.plugins import silero, openai
from livekit.plugins.turn_detector.multilingual import MultilingualModel
logger = logging.getLogger("custom-agent")
load_dotenv()
class SenseVoiceSTT(stt.STT):
def __init__(self, url: str):
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
self._url = url
@property
def model(self) -> str:
return "sensevoice"
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
data.add_field('model_name', 'sensevoice')
lang = language if language is not NOT_GIVEN else 'auto'
data.add_field('language', lang)
try:
async with session.post(self._url, data=data, timeout=30) as resp:
if resp.status != 200:
raise Exception(f"ASR server returned status {resp.status}")
result = await resp.json()
if not result.get("result"):
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
text = result["result"][0].get("clean_text", "")
logger.info(f"SenseVoice ASR Result: {text}")
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
)
except Exception as e:
logger.error(f"SenseVoice ASR error: {e}")
raise
class CustomAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
"Keep your responses concise and friendly."
"You are interacting with the user via a local ASR and LLM pipeline.",
)
async def on_enter(self) -> None:
self.session.generate_reply(instructions="greet the user and introduce yourself")
server = AgentServer()
def prewarm(proc: JobProcess) -> None:
# Load Silero VAD as requested
proc.userdata["vad"] = silero.VAD.load()
server.setup_fnc = prewarm
@server.rtc_session(agent_name="my-agent")
async def entrypoint(ctx: JobContext) -> None:
ctx.log_context_fields = {
"room": ctx.room.name,
}
# Configuration for custom local endpoints
# These can be set in your .env file
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox")
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav")
# Initialize SenseVoice STT and wrap with StreamAdapter
sensevoice_stt = SenseVoiceSTT(url=ASR_URL)
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"])
import httpx
from openai import AsyncClient as OpenAIAsyncClient
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key="sk-orez64WkG1NkfksB5j_hGA",
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
from tts_voxcpm import VoxCPMTTS
session: AgentSession = AgentSession(
# 1. Custom SenseVoice ASR (STT) with StreamAdapter
stt=stt_stream,
# 2. Minimax LLM - Using OpenAI plugin with local base_url
llm=openai.LLM(
model=MINIMAX_MODEL,
client=openai_client,
),
# 3. VoxCPM TTS - Custom implementation for blackbox API
tts=VoxCPMTTS(
url=VOXCPM_URL,
prompt_wav_path=PROMPT_WAV,
),
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
interruption={
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=True,
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
"filter_markdown",
],
)
@session.on("metrics_collected")
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
await session.start(
agent=CustomAgent(),
room=ctx.room,
)
if __name__ == "__main__":
cli.run_app(server)

188
test_agent.py Normal file
View File

@ -0,0 +1,188 @@
import asyncio
import requests
import logging
from pathlib import Path
import uuid
import wave
import numpy as np
from datetime import datetime
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("test-agent")
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM_NAME = "test-room20"
WAV_FILE = "2food.wav"
TEST_TIMEOUT = 30
class TestState:
def __init__(self):
self.agent_connected = False
self.tts_received = False
self.tts_count = 0
test_state = TestState()
def get_token(agent_name="my-agent"):
try:
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": f"test-{uuid.uuid4().hex[:6]}",
"agent_name": agent_name,
},
timeout=5
)
resp.raise_for_status()
return resp.json()["token"]
except Exception as e:
logger.error(f"❌ 获取token失败: {e}")
raise
async def publish_wav(room, wav_path):
wav_path = Path(wav_path)
if not wav_path.exists():
logger.error(f"❌ WAV文件不存在: {wav_path}")
raise FileNotFoundError(f"文件不存在: {wav_path}")
logger.info(f"📂 开始上传: {wav_path}")
with wave.open(str(wav_path), "rb") as wf:
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
source = AudioSource(sample_rate, num_channels)
track = LocalAudioTrack.create_audio_track("mic", source)
await room.local_participant.publish_track(track)
logger.info("📡 已发布音轨")
frame_duration = 0.02
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
logger.info("✅ WAV推流完成")
async def test_agent():
try:
logger.info("🔑 正在获取token...")
token = get_token()
logger.info("✅ Token获取成功")
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
logger.info(f"✅ 参与者加入: {participant.identity}")
if "agent" in participant.identity.lower():
test_state.agent_connected = True
logger.info("🎉 Agent已连接")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
logger.info(f"❌ 参与者离开: {participant.identity}")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
if track.kind == rtc.TrackKind.KIND_AUDIO:
test_state.tts_count += 1
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
test_state.tts_received = True
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
await room.connect(WS_URL, token)
logger.info("✅ 已连接到房间")
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
logger.info("⏳ 等待Agent连接...")
for i in range(10):
if test_state.agent_connected:
break
await asyncio.sleep(1)
if not test_state.agent_connected:
logger.warning("⚠️ Agent未连接")
return False
logger.info("🎙️ 正在上传测试音频...")
await publish_wav(room, WAV_FILE)
logger.info("⏳ 等待Agent响应...")
for i in range(TEST_TIMEOUT):
if test_state.tts_received:
logger.info("✅ 收到Agent TTS响应!")
break
if i % 5 == 0:
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
await asyncio.sleep(1)
await asyncio.sleep(2)
logger.info("\n" + "="*60)
logger.info("✅ 测试结果")
logger.info("="*60)
logger.info(f"Agent连接: {'' if test_state.agent_connected else ''}")
logger.info(f"收到TTS响应: {'' if test_state.tts_received else ''}")
logger.info(f"TTS音频次数: {test_state.tts_count}")
logger.info("="*60)
await room.disconnect()
logger.info("✅ 已断开连接\n")
return test_state.agent_connected and test_state.tts_received
except Exception as e:
logger.error(f"❌ 测试失败: {e}", exc_info=True)
return False
async def main():
logger.info("🚀 开始测试custom_agent...\n")
success = await test_agent()
if success:
logger.info("✅ 测试成功custom_agent 正常工作")
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到")
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
return 0
else:
logger.error("❌ 测试失败")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
exit(exit_code)

53
test_asr.py Normal file
View File

@ -0,0 +1,53 @@
import asyncio
import logging
import wave
from custom_agent import SenseVoiceSTT
from livekit import rtc
from livekit.agents import utils
# 设置日志级别以查看输出
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-asr")
async def test():
# 替换为你本地的一个音频文件路径
audio_path = "/home/verachen/Music/voice/2food.wav"
# 初始化 ASR
stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox")
print(f"Testing ASR connectivity with file: {audio_path}")
try:
# 读取音频文件
with wave.open(audio_path, 'rb') as wf:
frames = wf.readframes(wf.getnframes())
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
# 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
# 所以我们需要传递一个包含 AudioFrame 的 list
# 这里我们模拟一个 Frame
frame = rtc.AudioFrame(
data=frames,
sample_rate=wf.getframerate(),
num_channels=wf.getnchannels(),
samples_per_channel=wf.getnframes()
)
# 调用 recognize
result = await stt.recognize(buffer=[frame])
if result.alternatives:
print(f"\n--- ASR Result ---")
print(f"Text: {result.alternatives[0].text}")
print(f"------------------\n")
else:
print("ASR returned no text.")
except FileNotFoundError:
print(f"Error: Audio file not found at {audio_path}")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(test())

130
test_livekit.py Normal file
View File

@ -0,0 +1,130 @@
import asyncio
import requests
from livekit import rtc
import wave
import numpy as np
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址
ROOM_NAME = "test-room20"
import uuid
IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
# IDENTITY = "test-user0"
def get_token():
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": IDENTITY,
"agent_name": "my-agent", # 关键!!!
},
)
data = resp.json()
return data["token"]
async def main():
token = get_token()
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
print(f"✅ 有人加入房间: {participant.identity}")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
print(f"❌ 有人离开房间: {participant.identity}")
print("🔌 正在连接房间...")
await room.connect(WS_URL, token)
print("✅ 已连接房间:", ROOM_NAME)
print("当前房间成员:")
for p in room.remote_participants.values():
print(" -", p.identity)
@room.on("data_received")
def on_data_received(data, participant, kind, topic):
try:
msg = data.decode()
print(f"📩 来自 {participant.identity}: {msg}")
except:
print("📩 收到二进制数据")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
print(f"🎧 订阅轨道: {participant.identity}")
if track.kind == rtc.TrackKind.KIND_AUDIO:
print("👉 TTS 音频来了")
# 等一下确保连接稳定
await asyncio.sleep(1)
await room.local_participant.publish_data(
b"hello",
reliable=True,
topic="chat"
)
# 上传 wav
await publish_wav(room, "2food.wav")
await room.disconnect()
async def publish_wav(room, wav_path):
print("🎵 开始上传本地 wav:", wav_path)
wf = wave.open(wav_path, "rb")
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
# 创建音频源
source = AudioSource(sample_rate, num_channels)
# 创建本地音轨
track = LocalAudioTrack.create_audio_track("mic", source)
# 发布轨道
await room.local_participant.publish_track(track)
print("📡 已发布音轨")
frame_duration = 0.02 # 20ms
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
# 用于计算长度
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data, # ✅ 关键:用 bytes
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
print("✅ wav 推流结束")
if __name__ == "__main__":
asyncio.run(main())

71
test_minimax.py Normal file
View File

@ -0,0 +1,71 @@
import asyncio
import os
import logging
from dotenv import load_dotenv
from livekit.agents.llm import ChatContext
from livekit.plugins import openai
# Configure logging to see what's happening
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-minimax")
async def test_minimax():
print("Loading .env...")
load_dotenv()
# Configuration from environment or defaults from custom_agent.py
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI")
# Using the hardcoded key from custom_agent.py as a fallback if not in .env
API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA")
import httpx
from openai import AsyncClient as OpenAIAsyncClient
print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}")
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key=API_KEY,
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
llm = openai.LLM(
model=MINIMAX_MODEL,
client=openai_client,
)
print("Creating ChatContext...")
chat_ctx = ChatContext()
chat_ctx.add_message(
content="Hello! Can you introduce yourself? Please reply in Chinese.",
role="user",
)
print(f"\n--- Testing Streaming Chat ---")
print(f"Request: {chat_ctx.items[-1].content}")
print("Response: ", end="", flush=True)
try:
print("\nCalling llm.chat()...")
stream = llm.chat(chat_ctx=chat_ctx)
print("Iterating over stream...")
async for chunk in stream:
if chunk.delta and chunk.delta.content:
print(chunk.delta.content, end="", flush=True)
print("\n--- Test Completed Successfully ---")
except Exception as e:
logger.error(f"\nTest failed with error: {e}")
if __name__ == "__main__":
print("Starting...")
try:
asyncio.run(asyncio.wait_for(test_minimax(), timeout=30))
except asyncio.TimeoutError:
print("\nTest timed out after 30 seconds.")
except Exception as e:
print(f"\nAn error occurred: {e}")

50
test_voxcpm.py Normal file
View File

@ -0,0 +1,50 @@
import asyncio
import os
import logging
from tts_voxcpm import VoxCPMTTS
from livekit.agents import tts
logging.basicConfig(level=logging.INFO)
async def test_tts():
# Use the URL from the user's curl command
url = "http://10.6.80.21:5002/tts-blackbox"
# Check if we have a real wav file to test with
# In the earlier find_by_name, we found tests/change-sophie.wav
prompt_wav = "/home/verachen/Music/voice/2food.wav"
if not os.path.exists(prompt_wav):
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
print(f"Testing VoxCPMTTS with URL: {url}")
print(f"Using prompt wav: {prompt_wav}")
vox_tts = VoxCPMTTS(
url=url,
prompt_wav_path=prompt_wav
)
text = "你好,这是一段测试文本"
print(f"Synthesizing text: {text}")
try:
stream = vox_tts.synthesize(text)
audio_frame = await stream.collect()
print(f"Successfully synthesized audio!")
print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?")
# Actually AudioFrame has duration or samples
print(f"Samples: {len(audio_frame.data) // 2}")
# Save to file for manual check if possible
with open("test_output.wav", "wb") as f:
# This won't be a valid WAV yet if it's just raw PCM,
# but if collect() returns combined frames, we can use to_wav_bytes()
f.write(audio_frame.to_wav_bytes())
print("Saved output to test_output.wav")
except Exception as e:
print(f"TTS test failed: {e}")
if __name__ == "__main__":
asyncio.run(test_tts())

118
tts_voxcpm.py Normal file
View File

@ -0,0 +1,118 @@
import aiohttp
import logging
import os
from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
logger = logging.getLogger("voxcpm-tts")
class VoxCPMTTS(tts.TTS):
def __init__(
self,
*,
url: str,
model_name: str = "voxcpmtts",
prompt_text: str = "澳门有乜嘢好食嘅",
prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav",
cfg_value: str = "2.0",
inference_timesteps: str = "10",
do_normalize: str = "true",
denoise: str = "true",
retry_badcase: str = "true",
retry_badcase_max_times: str = "3",
retry_badcase_ratio_threshold: str = "6.0",
sample_rate: int = 16000,
):
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),
sample_rate=sample_rate,
num_channels=1,
)
self._url = url
self._opts = {
"model_name": model_name,
"streaming": "false",
"prompt_text": prompt_text,
"cfg_value": str(cfg_value),
"inference_timesteps": str(inference_timesteps),
"do_normalize": str(do_normalize),
"denoise": str(denoise),
"retry_badcase": str(retry_badcase),
"retry_badcase_max_times": str(retry_badcase_max_times),
"retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold),
}
self._prompt_wav_path = prompt_wav_path
@property
def model(self) -> str:
return self._opts["model_name"]
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> tts.ChunkedStream:
return VoxCPMStream(
self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options
)
class VoxCPMStream(tts.ChunkedStream):
def __init__(
self,
tts: VoxCPMTTS,
text: str,
url: str,
opts: dict,
prompt_wav_path: str,
conn_options: APIConnectOptions,
):
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
self._url = url
self._opts = opts
self._prompt_wav_path = prompt_wav_path
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
# Initialize emitter early to avoid "AudioEmitter isn't started" error on failure
output_emitter.initialize(
request_id="",
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
mime_type="audio/wav",
)
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field("text", self.input_text)
for k, v in self._opts.items():
data.add_field(k, v)
# Open the prompt wav file if it exists
f = None
if os.path.exists(self._prompt_wav_path):
f = open(self._prompt_wav_path, "rb")
data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav")
else:
logger.warning(
f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field"
)
try:
# Set a reasonable timeout for synthesis
async with session.post(
self._url, data=data, timeout=aiohttp.ClientTimeout(total=60)
) as resp:
if resp.status != 200:
err_text = await resp.text()
logger.error(f"VoxCPM TTS error: {resp.status} {err_text}")
return
# Read the entire audio data (since streaming=false)
audio_data = await resp.read()
output_emitter.push(audio_data)
output_emitter.flush()
except Exception as e:
logger.error(f"VoxCPM TTS request failed: {e}")
finally:
if f:
f.close()