195 lines
7.4 KiB
Python
195 lines
7.4 KiB
Python
import asyncio
|
||
import os
|
||
import sys
|
||
import wave
|
||
import time
|
||
from dotenv import load_dotenv
|
||
|
||
# Try to load credentials from .env.local
|
||
load_dotenv(".env.local")
|
||
|
||
from livekit import rtc
|
||
|
||
async def publish_audio_from_wav(room: rtc.Room, wav_path: str):
|
||
print(f"🎵 准备加载音频文件: {wav_path}")
|
||
if not os.path.exists(wav_path):
|
||
print(f"❌ 找不到文件 {wav_path}")
|
||
return
|
||
|
||
with wave.open(wav_path, "rb") as wf:
|
||
sample_rate = wf.getframerate()
|
||
num_channels = wf.getnchannels()
|
||
sampwidth = wf.getsampwidth()
|
||
|
||
if sampwidth != 2:
|
||
print("❌ 错误:只支持 16-bit PCM (S16LE) 编码的 WAV 文件。")
|
||
return
|
||
|
||
print(f"📊 音频信息: 采样率 {sample_rate}Hz, 通道数 {num_channels}")
|
||
|
||
source = rtc.AudioSource(sample_rate, num_channels)
|
||
track = rtc.LocalAudioTrack.create_audio_track("agent_input_audio", source)
|
||
options = rtc.TrackPublishOptions()
|
||
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
||
|
||
await room.local_participant.publish_track(track, options)
|
||
print("✅ 成功发布麦克风音频轨道,开始推流...")
|
||
|
||
# 每次读取并推送 20ms 的数据
|
||
chunk_duration_ms = 20
|
||
samples_per_chunk = int(sample_rate * (chunk_duration_ms / 1000.0))
|
||
bytes_per_chunk = samples_per_chunk * num_channels * sampwidth
|
||
|
||
while True:
|
||
data = wf.readframes(samples_per_chunk)
|
||
if not data:
|
||
break
|
||
|
||
# 根据已读取的数据计算出完整的采样数(最后一帧可能不足 20ms)
|
||
frame_samples = len(data) // (num_channels * sampwidth)
|
||
|
||
# 使用 LiveKit SDK 封装 AudioFrame
|
||
try:
|
||
# LiveKit Python SDK version >= 0.15
|
||
audio_frame = rtc.AudioFrame(data, sample_rate, num_channels, frame_samples)
|
||
await source.capture_frame(audio_frame)
|
||
except TypeError:
|
||
# 若抛出 TypeError,可能 SDK 版本有差异,尝试旧版本 API 或者直接 copy
|
||
frame = rtc.AudioFrame.create(sample_rate, num_channels, frame_samples)
|
||
frame.data[:] = data
|
||
await source.capture_frame(frame)
|
||
|
||
# 严格控制发送速率,避免瞬时把整个音频发过去而导致对面不识别(模拟真实发音)
|
||
await asyncio.sleep(chunk_duration_ms / 1000.0)
|
||
|
||
print("🎉 录音流推送完毕!等待 Agent 回复中...")
|
||
|
||
async def save_audio_stream(track: rtc.RemoteAudioTrack):
|
||
# 为避免旧文件冲突,加个时间戳
|
||
filename = f"agent_response_{int(time.time())}.wav"
|
||
print(f"🎙️ 正在将 Agent 的声音写入文件: {filename}")
|
||
stream = rtc.AudioStream(track)
|
||
|
||
wf = None
|
||
try:
|
||
async for event in stream:
|
||
# 接收到第一帧时初始化 WAV 格式
|
||
if wf is None:
|
||
wf = wave.open(filename, "wb")
|
||
wf.setnchannels(event.frame.num_channels)
|
||
wf.setsampwidth(2) # 16-bit
|
||
wf.setframerate(event.frame.sample_rate)
|
||
|
||
# 写入当前帧音频数据 (16-bit PCM)
|
||
wf.writeframes(bytes(event.frame.data))
|
||
except Exception as e:
|
||
print(f"音频流断开或出错: {e}")
|
||
finally:
|
||
if wf is not None:
|
||
wf.close()
|
||
print(f"💾 Agent 语音结果已成功保存到: {filename}")
|
||
|
||
|
||
async def main(room_url: str, token: str, wav_path: str):
|
||
room = rtc.Room()
|
||
agent_ready = asyncio.Event()
|
||
|
||
@room.on("connected")
|
||
def on_connected():
|
||
print("✅ 成功连接到 LiveKit 房间")
|
||
# 如果连接时房间里已经有 Agent(远程参与者),直接准备触发
|
||
if room.remote_participants:
|
||
agent_ready.set()
|
||
|
||
@room.on("participant_connected")
|
||
def on_participant_connected(participant: rtc.RemoteParticipant):
|
||
print(f"👋 Agent ({participant.identity}) 已加入房间")
|
||
agent_ready.set()
|
||
|
||
@room.on("data_received")
|
||
def on_data_received(data_packet: rtc.DataPacket):
|
||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||
try:
|
||
print(f"📩 [数据接收 | {identity}]: {data_packet.data.decode('utf-8')}")
|
||
except Exception:
|
||
pass
|
||
|
||
@room.on("transcription_received")
|
||
def on_transcription_received(
|
||
segments: list[rtc.TranscriptionSegment],
|
||
participant: rtc.Participant,
|
||
track_pub: rtc.TrackPublication
|
||
):
|
||
identity = participant.identity if participant else "未知"
|
||
for segment in segments:
|
||
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
|
||
print(f"🗣️ [{status} | {identity}]: {segment.text}")
|
||
|
||
@room.on("track_subscribed")
|
||
def on_track_subscribed(
|
||
track: rtc.Track,
|
||
publication: rtc.RemoteTrackPublication,
|
||
participant: rtc.RemoteParticipant
|
||
):
|
||
# 当 Agent 发出新的声音(音频轨道)时,我们订阅并保存
|
||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||
asyncio.create_task(save_audio_stream(track))
|
||
|
||
print("⏳ 正在建立连接...")
|
||
await room.connect(room_url, token)
|
||
|
||
print("⏳ 等待 Agent 初始化并加入房间...")
|
||
await agent_ready.wait()
|
||
# 稍微延迟半秒钟,确保 Agent 侧面的准备(如模型加载等)一切就绪
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 开始推送本地 wav 音频
|
||
asyncio.create_task(publish_audio_from_wav(room, wav_path))
|
||
|
||
try:
|
||
await asyncio.Event().wait()
|
||
except KeyboardInterrupt:
|
||
print("\n断开连接中...")
|
||
finally:
|
||
await room.disconnect()
|
||
|
||
if __name__ == "__main__":
|
||
if len(sys.argv) < 2:
|
||
print("❌ 用法: python test_client_wav.py <WAV文件路径> [LIVEKIT_URL] [LIVEKIT_TOKEN/API_KEY]")
|
||
print("说明:\n1. 必须提供 WAV 路径。")
|
||
print("2. 自动从 .env.local 读取 LIVEKIT_URL。并在没有提供 Token 时自动向 localhost:8000/getToken 请求。")
|
||
sys.exit(1)
|
||
|
||
wav_file = sys.argv[1]
|
||
|
||
url = os.getenv("LIVEKIT_URL")
|
||
|
||
token = None
|
||
if len(sys.argv) >= 4:
|
||
url = sys.argv[2]
|
||
token = sys.argv[3]
|
||
|
||
if not token:
|
||
import urllib.request
|
||
import json
|
||
import random
|
||
# 每次使用随机的测试房间,防止上一次没退出的 agent 堆积在同一个房间里导致多重回复
|
||
unique_room = f"test-room-{random.randint(1000, 9999)}"
|
||
print(f"🔄 正在通过本地服务获取 Token,请求加入全新独立房间: {unique_room} ...")
|
||
try:
|
||
req = urllib.request.urlopen(f"http://localhost:8000/getToken?room={unique_room}&identity=python_tester&agent_name=my-agent")
|
||
res_body = req.read().decode('utf-8')
|
||
data = json.loads(res_body)
|
||
token = data.get("token")
|
||
print("✅ 成功获取了包含了 Agent dispatch 的临时 Token!")
|
||
except Exception as e:
|
||
print(f"❌ 获取 Token 失败,错误信息: {e}")
|
||
print("若本地 token 服务未启动,请手动提供有效的测试 token")
|
||
sys.exit(1)
|
||
|
||
if not url or not token:
|
||
print("❌ 缺少 LiveKit URL 或 Token")
|
||
sys.exit(1)
|
||
|
||
asyncio.run(main(url, token, wav_file))
|