initial commit
This commit is contained in:
188
test_agent.py
Normal file
188
test_agent.py
Normal file
@ -0,0 +1,188 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import wave
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from livekit import rtc
|
||||
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("test-agent")
|
||||
|
||||
TOKEN_URL = "http://localhost:8000/getToken"
|
||||
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||||
ROOM_NAME = "test-room20"
|
||||
WAV_FILE = "2food.wav"
|
||||
TEST_TIMEOUT = 30
|
||||
|
||||
class TestState:
|
||||
def __init__(self):
|
||||
self.agent_connected = False
|
||||
self.tts_received = False
|
||||
self.tts_count = 0
|
||||
|
||||
test_state = TestState()
|
||||
|
||||
|
||||
def get_token(agent_name="my-agent"):
|
||||
try:
|
||||
resp = requests.get(
|
||||
TOKEN_URL,
|
||||
params={
|
||||
"room": ROOM_NAME,
|
||||
"identity": f"test-{uuid.uuid4().hex[:6]}",
|
||||
"agent_name": agent_name,
|
||||
},
|
||||
timeout=5
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["token"]
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取token失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def publish_wav(room, wav_path):
|
||||
wav_path = Path(wav_path)
|
||||
if not wav_path.exists():
|
||||
logger.error(f"❌ WAV文件不存在: {wav_path}")
|
||||
raise FileNotFoundError(f"文件不存在: {wav_path}")
|
||||
|
||||
logger.info(f"📂 开始上传: {wav_path}")
|
||||
|
||||
with wave.open(str(wav_path), "rb") as wf:
|
||||
sample_rate = wf.getframerate()
|
||||
num_channels = wf.getnchannels()
|
||||
sample_width = wf.getsampwidth()
|
||||
|
||||
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||||
|
||||
source = AudioSource(sample_rate, num_channels)
|
||||
track = LocalAudioTrack.create_audio_track("mic", source)
|
||||
|
||||
await room.local_participant.publish_track(track)
|
||||
logger.info("📡 已发布音轨")
|
||||
|
||||
frame_duration = 0.02
|
||||
samples_per_frame = int(sample_rate * frame_duration)
|
||||
|
||||
while True:
|
||||
data = wf.readframes(samples_per_frame)
|
||||
if not data:
|
||||
break
|
||||
|
||||
audio = np.frombuffer(data, dtype=np.int16)
|
||||
if len(audio) == 0:
|
||||
continue
|
||||
|
||||
samples_per_channel = len(audio) // num_channels
|
||||
|
||||
frame = AudioFrame(
|
||||
data=data,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
samples_per_channel=samples_per_channel,
|
||||
)
|
||||
|
||||
await source.capture_frame(frame)
|
||||
await asyncio.sleep(frame_duration)
|
||||
|
||||
logger.info("✅ WAV推流完成")
|
||||
|
||||
|
||||
async def test_agent():
|
||||
try:
|
||||
logger.info("🔑 正在获取token...")
|
||||
token = get_token()
|
||||
logger.info("✅ Token获取成功")
|
||||
|
||||
room = rtc.Room()
|
||||
|
||||
@room.on("participant_connected")
|
||||
def on_participant_connected(participant):
|
||||
logger.info(f"✅ 参与者加入: {participant.identity}")
|
||||
if "agent" in participant.identity.lower():
|
||||
test_state.agent_connected = True
|
||||
logger.info("🎉 Agent已连接!")
|
||||
|
||||
@room.on("participant_disconnected")
|
||||
def on_participant_disconnected(participant):
|
||||
logger.info(f"❌ 参与者离开: {participant.identity}")
|
||||
|
||||
@room.on("track_subscribed")
|
||||
def on_track_subscribed(track, publication, participant):
|
||||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||
test_state.tts_count += 1
|
||||
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
|
||||
test_state.tts_received = True
|
||||
|
||||
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
|
||||
await room.connect(WS_URL, token)
|
||||
logger.info("✅ 已连接到房间")
|
||||
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
|
||||
|
||||
logger.info("⏳ 等待Agent连接...")
|
||||
for i in range(10):
|
||||
if test_state.agent_connected:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not test_state.agent_connected:
|
||||
logger.warning("⚠️ Agent未连接")
|
||||
return False
|
||||
|
||||
logger.info("🎙️ 正在上传测试音频...")
|
||||
await publish_wav(room, WAV_FILE)
|
||||
|
||||
logger.info("⏳ 等待Agent响应...")
|
||||
for i in range(TEST_TIMEOUT):
|
||||
if test_state.tts_received:
|
||||
logger.info("✅ 收到Agent TTS响应!")
|
||||
break
|
||||
if i % 5 == 0:
|
||||
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("✅ 测试结果")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Agent连接: {'✅' if test_state.agent_connected else '❌'}")
|
||||
logger.info(f"收到TTS响应: {'✅' if test_state.tts_received else '❌'}")
|
||||
logger.info(f"TTS音频次数: {test_state.tts_count} 次")
|
||||
logger.info("="*60)
|
||||
|
||||
await room.disconnect()
|
||||
logger.info("✅ 已断开连接\n")
|
||||
|
||||
return test_state.agent_connected and test_state.tts_received
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 测试失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("🚀 开始测试custom_agent...\n")
|
||||
success = await test_agent()
|
||||
|
||||
if success:
|
||||
logger.info("✅ 测试成功!custom_agent 正常工作")
|
||||
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到,")
|
||||
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
|
||||
return 0
|
||||
else:
|
||||
logger.error("❌ 测试失败")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
exit(exit_code)
|
||||
Reference in New Issue
Block a user