Files
xiaozhi-esp32/main/test_client_wav.py
2026-04-27 10:39:21 +08:00

195 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import sys
import wave
import time
from dotenv import load_dotenv
# Try to load credentials from .env.local
load_dotenv(".env.local")
from livekit import rtc
async def publish_audio_from_wav(room: rtc.Room, wav_path: str):
print(f"🎵 准备加载音频文件: {wav_path}")
if not os.path.exists(wav_path):
print(f"❌ 找不到文件 {wav_path}")
return
with wave.open(wav_path, "rb") as wf:
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
if sampwidth != 2:
print("❌ 错误:只支持 16-bit PCM (S16LE) 编码的 WAV 文件。")
return
print(f"📊 音频信息: 采样率 {sample_rate}Hz, 通道数 {num_channels}")
source = rtc.AudioSource(sample_rate, num_channels)
track = rtc.LocalAudioTrack.create_audio_track("agent_input_audio", source)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await room.local_participant.publish_track(track, options)
print("✅ 成功发布麦克风音频轨道,开始推流...")
# 每次读取并推送 20ms 的数据
chunk_duration_ms = 20
samples_per_chunk = int(sample_rate * (chunk_duration_ms / 1000.0))
bytes_per_chunk = samples_per_chunk * num_channels * sampwidth
while True:
data = wf.readframes(samples_per_chunk)
if not data:
break
# 根据已读取的数据计算出完整的采样数(最后一帧可能不足 20ms
frame_samples = len(data) // (num_channels * sampwidth)
# 使用 LiveKit SDK 封装 AudioFrame
try:
# LiveKit Python SDK version >= 0.15
audio_frame = rtc.AudioFrame(data, sample_rate, num_channels, frame_samples)
await source.capture_frame(audio_frame)
except TypeError:
# 若抛出 TypeError可能 SDK 版本有差异,尝试旧版本 API 或者直接 copy
frame = rtc.AudioFrame.create(sample_rate, num_channels, frame_samples)
frame.data[:] = data
await source.capture_frame(frame)
# 严格控制发送速率,避免瞬时把整个音频发过去而导致对面不识别(模拟真实发音)
await asyncio.sleep(chunk_duration_ms / 1000.0)
print("🎉 录音流推送完毕!等待 Agent 回复中...")
async def save_audio_stream(track: rtc.RemoteAudioTrack):
# 为避免旧文件冲突,加个时间戳
filename = f"agent_response_{int(time.time())}.wav"
print(f"🎙️ 正在将 Agent 的声音写入文件: {filename}")
stream = rtc.AudioStream(track)
wf = None
try:
async for event in stream:
# 接收到第一帧时初始化 WAV 格式
if wf is None:
wf = wave.open(filename, "wb")
wf.setnchannels(event.frame.num_channels)
wf.setsampwidth(2) # 16-bit
wf.setframerate(event.frame.sample_rate)
# 写入当前帧音频数据 (16-bit PCM)
wf.writeframes(bytes(event.frame.data))
except Exception as e:
print(f"音频流断开或出错: {e}")
finally:
if wf is not None:
wf.close()
print(f"💾 Agent 语音结果已成功保存到: {filename}")
async def main(room_url: str, token: str, wav_path: str):
room = rtc.Room()
agent_ready = asyncio.Event()
@room.on("connected")
def on_connected():
print("✅ 成功连接到 LiveKit 房间")
# 如果连接时房间里已经有 Agent远程参与者直接准备触发
if room.remote_participants:
agent_ready.set()
@room.on("participant_connected")
def on_participant_connected(participant: rtc.RemoteParticipant):
print(f"👋 Agent ({participant.identity}) 已加入房间")
agent_ready.set()
@room.on("data_received")
def on_data_received(data_packet: rtc.DataPacket):
identity = data_packet.participant.identity if data_packet.participant else "未知"
try:
print(f"📩 [数据接收 | {identity}]: {data_packet.data.decode('utf-8')}")
except Exception:
pass
@room.on("transcription_received")
def on_transcription_received(
segments: list[rtc.TranscriptionSegment],
participant: rtc.Participant,
track_pub: rtc.TrackPublication
):
identity = participant.identity if participant else "未知"
for segment in segments:
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
print(f"🗣️ [{status} | {identity}]: {segment.text}")
@room.on("track_subscribed")
def on_track_subscribed(
track: rtc.Track,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant
):
# 当 Agent 发出新的声音(音频轨道)时,我们订阅并保存
if track.kind == rtc.TrackKind.KIND_AUDIO:
asyncio.create_task(save_audio_stream(track))
print("⏳ 正在建立连接...")
await room.connect(room_url, token)
print("⏳ 等待 Agent 初始化并加入房间...")
await agent_ready.wait()
# 稍微延迟半秒钟,确保 Agent 侧面的准备(如模型加载等)一切就绪
await asyncio.sleep(0.5)
# 开始推送本地 wav 音频
asyncio.create_task(publish_audio_from_wav(room, wav_path))
try:
await asyncio.Event().wait()
except KeyboardInterrupt:
print("\n断开连接中...")
finally:
await room.disconnect()
if __name__ == "__main__":
if len(sys.argv) < 2:
print("❌ 用法: python test_client_wav.py <WAV文件路径> [LIVEKIT_URL] [LIVEKIT_TOKEN/API_KEY]")
print("说明:\n1. 必须提供 WAV 路径。")
print("2. 自动从 .env.local 读取 LIVEKIT_URL。并在没有提供 Token 时自动向 localhost:8000/getToken 请求。")
sys.exit(1)
wav_file = sys.argv[1]
url = os.getenv("LIVEKIT_URL")
token = None
if len(sys.argv) >= 4:
url = sys.argv[2]
token = sys.argv[3]
if not token:
import urllib.request
import json
import random
# 每次使用随机的测试房间,防止上一次没退出的 agent 堆积在同一个房间里导致多重回复
unique_room = f"test-room-{random.randint(1000, 9999)}"
print(f"🔄 正在通过本地服务获取 Token请求加入全新独立房间: {unique_room} ...")
try:
req = urllib.request.urlopen(f"http://localhost:8000/getToken?room={unique_room}&identity=python_tester&agent_name=my-agent")
res_body = req.read().decode('utf-8')
data = json.loads(res_body)
token = data.get("token")
print("✅ 成功获取了包含了 Agent dispatch 的临时 Token")
except Exception as e:
print(f"❌ 获取 Token 失败,错误信息: {e}")
print("若本地 token 服务未启动,请手动提供有效的测试 token")
sys.exit(1)
if not url or not token:
print("❌ 缺少 LiveKit URL 或 Token")
sys.exit(1)
asyncio.run(main(url, token, wav_file))