188 lines
6.0 KiB
Python
188 lines
6.0 KiB
Python
import asyncio
|
||
import requests
|
||
import logging
|
||
from pathlib import Path
|
||
import uuid
|
||
import wave
|
||
import numpy as np
|
||
from datetime import datetime
|
||
from livekit import rtc
|
||
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger("test-agent")
|
||
|
||
TOKEN_URL = "http://localhost:8000/getToken"
|
||
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||
ROOM_NAME = "test-room20"
|
||
WAV_FILE = "2food.wav"
|
||
TEST_TIMEOUT = 30
|
||
|
||
class TestState:
|
||
def __init__(self):
|
||
self.agent_connected = False
|
||
self.tts_received = False
|
||
self.tts_count = 0
|
||
|
||
test_state = TestState()
|
||
|
||
|
||
def get_token(agent_name="my-agent"):
|
||
try:
|
||
resp = requests.get(
|
||
TOKEN_URL,
|
||
params={
|
||
"room": ROOM_NAME,
|
||
"identity": f"test-{uuid.uuid4().hex[:6]}",
|
||
"agent_name": agent_name,
|
||
},
|
||
timeout=5
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()["token"]
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取token失败: {e}")
|
||
raise
|
||
|
||
|
||
async def publish_wav(room, wav_path):
|
||
wav_path = Path(wav_path)
|
||
if not wav_path.exists():
|
||
logger.error(f"❌ WAV文件不存在: {wav_path}")
|
||
raise FileNotFoundError(f"文件不存在: {wav_path}")
|
||
|
||
logger.info(f"📂 开始上传: {wav_path}")
|
||
|
||
with wave.open(str(wav_path), "rb") as wf:
|
||
sample_rate = wf.getframerate()
|
||
num_channels = wf.getnchannels()
|
||
sample_width = wf.getsampwidth()
|
||
|
||
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||
|
||
source = AudioSource(sample_rate, num_channels)
|
||
track = LocalAudioTrack.create_audio_track("mic", source)
|
||
|
||
await room.local_participant.publish_track(track)
|
||
logger.info("📡 已发布音轨")
|
||
|
||
frame_duration = 0.02
|
||
samples_per_frame = int(sample_rate * frame_duration)
|
||
|
||
while True:
|
||
data = wf.readframes(samples_per_frame)
|
||
if not data:
|
||
break
|
||
|
||
audio = np.frombuffer(data, dtype=np.int16)
|
||
if len(audio) == 0:
|
||
continue
|
||
|
||
samples_per_channel = len(audio) // num_channels
|
||
|
||
frame = AudioFrame(
|
||
data=data,
|
||
sample_rate=sample_rate,
|
||
num_channels=num_channels,
|
||
samples_per_channel=samples_per_channel,
|
||
)
|
||
|
||
await source.capture_frame(frame)
|
||
await asyncio.sleep(frame_duration)
|
||
|
||
logger.info("✅ WAV推流完成")
|
||
|
||
|
||
async def test_agent():
|
||
try:
|
||
logger.info("🔑 正在获取token...")
|
||
token = get_token()
|
||
logger.info("✅ Token获取成功")
|
||
|
||
room = rtc.Room()
|
||
|
||
@room.on("participant_connected")
|
||
def on_participant_connected(participant):
|
||
logger.info(f"✅ 参与者加入: {participant.identity}")
|
||
if "agent" in participant.identity.lower():
|
||
test_state.agent_connected = True
|
||
logger.info("🎉 Agent已连接!")
|
||
|
||
@room.on("participant_disconnected")
|
||
def on_participant_disconnected(participant):
|
||
logger.info(f"❌ 参与者离开: {participant.identity}")
|
||
|
||
@room.on("track_subscribed")
|
||
def on_track_subscribed(track, publication, participant):
|
||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||
test_state.tts_count += 1
|
||
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
|
||
test_state.tts_received = True
|
||
|
||
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
|
||
await room.connect(WS_URL, token)
|
||
logger.info("✅ 已连接到房间")
|
||
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
|
||
|
||
logger.info("⏳ 等待Agent连接...")
|
||
for i in range(10):
|
||
if test_state.agent_connected:
|
||
break
|
||
await asyncio.sleep(1)
|
||
|
||
if not test_state.agent_connected:
|
||
logger.warning("⚠️ Agent未连接")
|
||
return False
|
||
|
||
logger.info("🎙️ 正在上传测试音频...")
|
||
await publish_wav(room, WAV_FILE)
|
||
|
||
logger.info("⏳ 等待Agent响应...")
|
||
for i in range(TEST_TIMEOUT):
|
||
if test_state.tts_received:
|
||
logger.info("✅ 收到Agent TTS响应!")
|
||
break
|
||
if i % 5 == 0:
|
||
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
|
||
await asyncio.sleep(1)
|
||
|
||
await asyncio.sleep(2)
|
||
|
||
logger.info("\n" + "="*60)
|
||
logger.info("✅ 测试结果")
|
||
logger.info("="*60)
|
||
logger.info(f"Agent连接: {'✅' if test_state.agent_connected else '❌'}")
|
||
logger.info(f"收到TTS响应: {'✅' if test_state.tts_received else '❌'}")
|
||
logger.info(f"TTS音频次数: {test_state.tts_count} 次")
|
||
logger.info("="*60)
|
||
|
||
await room.disconnect()
|
||
logger.info("✅ 已断开连接\n")
|
||
|
||
return test_state.agent_connected and test_state.tts_received
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 测试失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
async def main():
|
||
logger.info("🚀 开始测试custom_agent...\n")
|
||
success = await test_agent()
|
||
|
||
if success:
|
||
logger.info("✅ 测试成功!custom_agent 正常工作")
|
||
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到,")
|
||
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
|
||
return 0
|
||
else:
|
||
logger.error("❌ 测试失败")
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit_code = asyncio.run(main())
|
||
exit(exit_code) |