Files
livekit_agents/test_agent.py
2026-05-07 15:13:15 +08:00

188 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import requests
import logging
from pathlib import Path
import uuid
import wave
import numpy as np
from datetime import datetime
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("test-agent")
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM_NAME = "test-room20"
WAV_FILE = "2food.wav"
TEST_TIMEOUT = 30
class TestState:
def __init__(self):
self.agent_connected = False
self.tts_received = False
self.tts_count = 0
test_state = TestState()
def get_token(agent_name="my-agent"):
try:
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": f"test-{uuid.uuid4().hex[:6]}",
"agent_name": agent_name,
},
timeout=5
)
resp.raise_for_status()
return resp.json()["token"]
except Exception as e:
logger.error(f"❌ 获取token失败: {e}")
raise
async def publish_wav(room, wav_path):
wav_path = Path(wav_path)
if not wav_path.exists():
logger.error(f"❌ WAV文件不存在: {wav_path}")
raise FileNotFoundError(f"文件不存在: {wav_path}")
logger.info(f"📂 开始上传: {wav_path}")
with wave.open(str(wav_path), "rb") as wf:
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
source = AudioSource(sample_rate, num_channels)
track = LocalAudioTrack.create_audio_track("mic", source)
await room.local_participant.publish_track(track)
logger.info("📡 已发布音轨")
frame_duration = 0.02
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
logger.info("✅ WAV推流完成")
async def test_agent():
try:
logger.info("🔑 正在获取token...")
token = get_token()
logger.info("✅ Token获取成功")
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
logger.info(f"✅ 参与者加入: {participant.identity}")
if "agent" in participant.identity.lower():
test_state.agent_connected = True
logger.info("🎉 Agent已连接")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
logger.info(f"❌ 参与者离开: {participant.identity}")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
if track.kind == rtc.TrackKind.KIND_AUDIO:
test_state.tts_count += 1
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
test_state.tts_received = True
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
await room.connect(WS_URL, token)
logger.info("✅ 已连接到房间")
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
logger.info("⏳ 等待Agent连接...")
for i in range(10):
if test_state.agent_connected:
break
await asyncio.sleep(1)
if not test_state.agent_connected:
logger.warning("⚠️ Agent未连接")
return False
logger.info("🎙️ 正在上传测试音频...")
await publish_wav(room, WAV_FILE)
logger.info("⏳ 等待Agent响应...")
for i in range(TEST_TIMEOUT):
if test_state.tts_received:
logger.info("✅ 收到Agent TTS响应!")
break
if i % 5 == 0:
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
await asyncio.sleep(1)
await asyncio.sleep(2)
logger.info("\n" + "="*60)
logger.info("✅ 测试结果")
logger.info("="*60)
logger.info(f"Agent连接: {'' if test_state.agent_connected else ''}")
logger.info(f"收到TTS响应: {'' if test_state.tts_received else ''}")
logger.info(f"TTS音频次数: {test_state.tts_count}")
logger.info("="*60)
await room.disconnect()
logger.info("✅ 已断开连接\n")
return test_state.agent_connected and test_state.tts_received
except Exception as e:
logger.error(f"❌ 测试失败: {e}", exc_info=True)
return False
async def main():
logger.info("🚀 开始测试custom_agent...\n")
success = await test_agent()
if success:
logger.info("✅ 测试成功custom_agent 正常工作")
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到")
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
return 0
else:
logger.error("❌ 测试失败")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
exit(exit_code)