feat: support defferent models

This commit is contained in:
0Xiao0
2026-05-11 13:40:25 +08:00
parent 10125c9cd4
commit e706f1d4e2

View File

@ -5,40 +5,52 @@ import os
import sys
import httpx
import json
import time
import queue
import shutil
import threading
import struct
import time
import wave
import opuslib
from typing import Any, Optional
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame
from websockets.exceptions import ConnectionClosedError
import http.server
import multipart
from urllib.parse import parse_qs
# 配置信息
# TOKEN_URL = "http://10.6.80.130:8000/v1/token"
# LIVEKIT_WS_URL = "ws://10.6.80.130:8000/"
# ROOM = "vera-room"
# IDENTITY = "vera-1"
# TOKEN_URL = "https://omnichat.bwgdi.com/v1/token"
try:
from livekit import api as livekit_api
except ImportError:
livekit_api = None
TOKEN_URL = "http://10.6.80.130:8000/getToken"
LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud"
# LIVEKIT_WS_URL = "wss://rtc.bwgdi.com/"
LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM = "test-livekit-room2"
IDENTITY = "uv-livekit-hardcoded"
AGENT_NAME = "my-agent"
import uuid
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
CONNECT_TIMEOUT_SECONDS = 10.0
WS_PORT = 8080
AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True
SAMPLE_RATE = 16000
INPUT_SAMPLE_RATE = 16000
OUTPUT_SAMPLE_RATE = 24000
INPUT_FRAME_DURATION_MS = 60
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
OUTPUT_FRAME_DURATION_MS = 20
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
TTS_IDLE_TIMEOUT_SECONDS = 0.8
TTS_SILENCE_PEAK_THRESHOLD = 96
TTS_PRE_ROLL_MS = 200
SAVE_AGENT_TTS_WAV = True
AGENT_TTS_WAV_PREFIX = "agent_tts"
async def fetch_token() -> str:
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
response = await client.get(
TOKEN_URL,
params={"room": ROOM, "identity": IDENTITY, "agent_name": "my-agent"},
params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME},
)
response.raise_for_status()
@ -57,12 +69,248 @@ class ESP32LiveKitBridge:
self.room = rtc.Room()
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
self.mic_source = AudioSource(sample_rate=SAMPLE_RATE, num_channels=1)
self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1)
self.esp_ws = None # 保存 WebSocket 连接
self.audio_queue = queue.Queue()
self.wav_writer_thread: Optional[threading.Thread] = None
self.stop_event = threading.Event()
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
self.protocol_version = 1
self.tts_active = False
self.forwarding_tracks: set[str] = set()
self.tts_idle_task: Optional[asyncio.Task] = None
self.tts_stream_id = 0
self.agent_dispatch_task: Optional[asyncio.Task] = None
async def ensure_agent_dispatched(self) -> None:
if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT:
return
for participant in self.room.remote_participants.values():
if AGENT_NAME in participant.identity:
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
return
if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done():
print("Agent dispatch 正在进行中,跳过重复请求")
return
self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent())
await self.agent_dispatch_task
async def _dispatch_agent(self) -> None:
print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}")
try:
if await self._dispatch_agent_with_sdk():
return
if await self._dispatch_agent_with_cli():
return
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
except Exception as e:
print(f"Agent dispatch 失败: {e}")
async def _dispatch_agent_with_sdk(self) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
dispatch = await lkapi.agent_dispatch.create_dispatch(
livekit_api.CreateAgentDispatchRequest(
agent_name=AGENT_NAME,
room=ROOM,
metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}),
)
)
print(f"Agent dispatch 已创建: {dispatch}")
finally:
await lkapi.aclose()
return True
async def _dispatch_agent_with_cli(self) -> bool:
lk_path = shutil.which("lk")
if lk_path is None:
return False
process = await asyncio.create_subprocess_exec(
lk_path,
"dispatch",
"create",
"--room",
ROOM,
"--agent-name",
AGENT_NAME,
"--metadata",
json.dumps({"source": "bridge_server", "identity": IDENTITY}),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode("utf-8", errors="replace").strip()
stderr_text = stderr.decode("utf-8", errors="replace").strip()
if stdout_text:
print(f"lk dispatch 输出: {stdout_text}")
if stderr_text:
print(f"lk dispatch 错误输出: {stderr_text}")
if process.returncode != 0:
print(f"lk dispatch create 失败,退出码: {process.returncode}")
return False
print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}")
return True
async def _send_tts_state(self, state: str) -> None:
if self.esp_ws is None:
print(f"跳过 tts {state}ESP32 尚未连接")
return
await self.esp_ws.send(json.dumps({"type": "tts", "state": state}))
print(f"已发送 tts {state}")
async def _start_tts(self) -> None:
if self.tts_active:
print("跳过 tts start当前已处于激活状态")
return
await self._send_tts_state("start")
if self.esp_ws is not None:
self.tts_active = True
async def _stop_tts(self) -> None:
if not self.tts_active:
print("跳过 tts stop当前未激活")
return
await self._send_tts_state("stop")
self.tts_active = False
async def _abort_tts(self, reason: str = "client_abort") -> None:
print(f"收到打断请求,停止当前 TTS: reason={reason}")
self.tts_stream_id += 1
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
await self._stop_tts()
def _reset_tts_idle_timer(self) -> None:
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id))
async def _tts_idle_watchdog(self, stream_id: int) -> None:
try:
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
if stream_id != self.tts_stream_id:
return
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s切回聆听状态")
await self._stop_tts()
except asyncio.CancelledError:
pass
def _has_audible_audio(self, pcm_data: bytes) -> bool:
if len(pcm_data) < 2:
return False
sample_count = len(pcm_data) // 2
peak = 0
for i in range(sample_count):
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > peak:
peak = sample
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
return True
return False
def _maybe_forward_remote_audio(
self,
track: rtc.Track,
publication: Optional[rtc.TrackPublication],
participant: rtc.RemoteParticipant,
source: str,
) -> None:
if track.kind != rtc.TrackKind.KIND_AUDIO:
return
track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}"
if track_sid in self.forwarding_tracks:
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
return
self.forwarding_tracks.add(track_sid)
print(f"收到音频流: {participant.identity} sid={track_sid} source={source}")
asyncio.create_task(
self.forward_audio_to_esp32(
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1)
)
)
def _scan_existing_remote_audio_tracks(self) -> None:
participants = list(self.room.remote_participants.values())
print(f"开始扫描远端音频轨participants={len(participants)}")
for participant in participants:
publications = getattr(participant, "track_publications", {}) or {}
print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}")
for pub_sid, publication in publications.items():
track = getattr(publication, "track", None)
kind = getattr(publication, "kind", None)
print(
f"检查 publication: participant={participant.identity} pub_sid={pub_sid} "
f"kind={kind} has_track={track is not None}"
)
if track is not None:
self._maybe_forward_remote_audio(track, publication, participant, "scan")
def _extract_opus_payload(self, message: bytes) -> bytes:
if self.protocol_version == 2:
if len(message) < 16:
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
"!HHIII", message[:16]
)
if version != 2:
raise ValueError(f"version 2 包头版本异常: {version}")
if msg_type != 0:
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
if len(message) < 16 + payload_size:
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
return message[16:16 + payload_size]
if self.protocol_version == 3:
if len(message) < 4:
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
if msg_type != 0:
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
if len(message) < 4 + payload_size:
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
return message[4:4 + payload_size]
return message
def _wrap_opus_payload(self, payload: bytes) -> bytes:
if self.protocol_version == 2:
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
return header + payload
if self.protocol_version == 3:
header = struct.pack("!BBH", 0, 0, len(payload))
return header + payload
return payload
def _wav_writer_loop(self):
import wave
@ -71,7 +319,7 @@ class ESP32LiveKitBridge:
with wave.open("bridge_debug.wav", "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(SAMPLE_RATE)
wav_file.setframerate(INPUT_SAMPLE_RATE)
while not self.stop_event.is_set() or not self.audio_queue.empty():
try:
# 使用 timeout 避免永久阻塞,以便检查 stop_event
@ -93,13 +341,23 @@ class ESP32LiveKitBridge:
def on_connected():
print("✅ 成功连接到 LiveKit 房间")
if self.room.remote_participants:
self._scan_existing_remote_audio_tracks()
self.agent_ready.set()
@self.room.on("participant_connected")
def on_participant_connected(p: rtc.RemoteParticipant):
print(f"👋 Agent ({p.identity}) 已加入房间")
self._scan_existing_remote_audio_tracks()
self.agent_ready.set()
@self.room.on("participant_disconnected")
def on_participant_disconnected(p: rtc.RemoteParticipant):
print(f"👋 远端参与者离开房间: {p.identity}")
self.forwarding_tracks = {
track_sid for track_sid in self.forwarding_tracks
if not track_sid.endswith(f":{p.identity}")
}
@self.room.on("data_received")
def on_data_received(data_packet: rtc.DataPacket):
identity = data_packet.participant.identity if data_packet.participant else "未知"
@ -145,6 +403,7 @@ class ESP32LiveKitBridge:
print(f"已连接到 LiveKit 房间: {self.room.name}")
print(f"[livekit] local_identity={self.room.local_participant.identity}")
print(f"[livekit] local_sid={self.room.local_participant.sid}")
print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}")
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
@ -154,11 +413,10 @@ class ESP32LiveKitBridge:
# 3. 监听房间内的音频 (LiveKit -> ESP32)
@self.room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
if track.kind == rtc.TrackKind.KIND_AUDIO:
print(f"收到音频流: {participant.identity}")
asyncio.create_task(self.forward_audio_to_esp32(rtc.AudioStream(track, sample_rate=SAMPLE_RATE, num_channels=1)))
self._maybe_forward_remote_audio(track, publication, participant, "event")
print("等待 agent 加入...")
self._scan_existing_remote_audio_tracks()
try:
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
print("✅ agent 已就绪")
@ -172,53 +430,121 @@ class ESP32LiveKitBridge:
await self.room.disconnect()
async def forward_audio_to_esp32(self, audio_stream):
"""从 LiveKit 接收音频,通过 WebSocket 发回给 ESP32"""
import opuslib
import json
# 创建下行 Opus 编码器
encoder = opuslib.Encoder(SAMPLE_RATE, 1, 'voip')
"""从 LiveKit 接收音频并发给 ESP32"""
# 1. 告知 ESP32 开始说话,切换 UI 到“说话中”并准备解码
if self.esp_ws:
await self.esp_ws.send(json.dumps({"type": "tts", "state": "start"}))
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip')
pending_pcm = bytearray()
pre_roll_pcm = bytearray()
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
tts_wav_file = None
tts_wav_path = None
stream_id = self.tts_stream_id + 1
self.tts_stream_id = stream_id
if SAVE_AGENT_TTS_WAV:
tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav"
tts_wav_file = wave.open(tts_wav_path, "wb")
tts_wav_file.setnchannels(1)
tts_wav_file.setsampwidth(2)
tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE)
print(f"开始保存 Agent TTS 音频: {tts_wav_path}")
try:
async for event in audio_stream:
if self.esp_ws:
try:
# AudioStream 迭代产生的是 AudioFrameEvent需要从中提取 frame
frame = event.frame
# 将 PCM 编码为 Opus 才能发给 ESP32
pcm_data = frame.data.tobytes()
if stream_id != self.tts_stream_id:
print("检测到更新的 TTS 流,停止旧流转发")
break
# 使用当前帧的实际采样数进行编码
opus_packet = encoder.encode(pcm_data, frame.samples_per_channel)
await self.esp_ws.send(opus_packet)
frame = event.frame
pcm_data = frame.data.tobytes()
if tts_wav_file is not None:
tts_wav_file.writeframes(pcm_data)
has_audible_audio = self._has_audible_audio(pcm_data)
current_frame_buffered = False
if not self.tts_active:
pre_roll_pcm.extend(pcm_data)
current_frame_buffered = True
if len(pre_roll_pcm) > pre_roll_max_bytes:
del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes]
if not has_audible_audio:
continue
await self._start_tts()
if not self.tts_active:
continue
print("检测到可听 TTS 音频,切换到 Speaking 并开始转发")
pending_pcm.extend(pre_roll_pcm)
pre_roll_pcm.clear()
if has_audible_audio:
self._reset_tts_idle_timer()
# LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。
# 这里聚合成固定 20ms PCM再编码为 Opus 发给设备,减少卡顿。
if not current_frame_buffered:
pending_pcm.extend(pcm_data)
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM
while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id:
try:
opus_packet = encoder.encode(
bytes(pending_pcm[:frame_bytes]),
OUTPUT_SAMPLES_PER_OPUS_FRAME,
)
del pending_pcm[:frame_bytes]
# print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}")
await self.esp_ws.send(self._wrap_opus_payload(opus_packet))
except Exception as e:
print(f"发送回 ESP32 失败: {e}")
break
except Exception as e:
print(f"音频流处理错误: {e}")
finally:
# 2. 音频结束,告知 ESP32 停止说话,切换回聆听或闲置状态
if self.esp_ws:
await self.esp_ws.send(json.dumps({"type": "tts", "state": "stop"}))
print("🎧 TTS 音频结束")
if tts_wav_file is not None:
tts_wav_file.close()
print(f"Agent TTS 音频已保存: {tts_wav_path}")
if stream_id == self.tts_stream_id and self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
if stream_id == self.tts_stream_id:
await self._stop_tts()
async def handle_websocket(self, websocket):
"""处理来自 ESP32 的 WebSocket 连接"""
self.esp_ws = websocket
header_version = websocket.request.headers.get("Protocol-Version")
try:
self.protocol_version = int(header_version) if header_version else 1
except ValueError:
self.protocol_version = 1
print("ESP32 已连接")
print(f"ESP32 协议版本: {self.protocol_version}")
self.tts_stream_id += 1
self.tts_active = False
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
opus_decoder = None
try:
await self.ensure_agent_dispatched()
# 发送 hello 告诉 ESP32 握手成功
hello_msg = {
"type": "hello",
"transport": "websocket",
"audio_params": {
"format": "opus", # 明确要求 ESP32 发送 Opus
"sample_rate": SAMPLE_RATE,
"sample_rate": OUTPUT_SAMPLE_RATE,
"channels": 1,
"frame_duration": 60
"frame_duration": OUTPUT_FRAME_DURATION_MS
}
}
import json
await websocket.send(json.dumps(hello_msg))
async for message in websocket:
@ -229,18 +555,14 @@ class ESP32LiveKitBridge:
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
continue
# ESP32 默认使用 websocket_protocol version=1 (见 websocket_protocol.cc)
# 这个版本下,没有 4 字节的 header接收到的就是原生的 Opus 数据帧。
# 直接丢给 opuslib 解码即可。
audio_data = message
print(f"收到音频包长度: {len(message)}")
audio_data = self._extract_opus_payload(message)
# print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}")
if audio_data:
try:
# Create Opus decoder if not exists
if opus_decoder is None:
import opuslib
print(f"初始化 Opus 解码器: {SAMPLE_RATE}Hz, mono")
opus_decoder = opuslib.Decoder(SAMPLE_RATE, 1)
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
# 启动音频保存线程
self.stop_event.clear()
@ -249,8 +571,8 @@ class ESP32LiveKitBridge:
thread.start()
# Decode Opus packet.
# Frame size for 60ms is SAMPLE_RATE * 0.06
frame_size = int(SAMPLE_RATE * 0.06)
# ESP32 -> bridge 保持按设备上行配置 60ms 解码。
frame_size = INPUT_SAMPLES_PER_OPUS_FRAME
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
# 将音频数据放入队列由后台线程保存
@ -260,20 +582,26 @@ class ESP32LiveKitBridge:
if num_samples > 0:
try:
# Use the more robust frame creation logic from test_client_wav.py
frame = AudioFrame(pcm_bytes, SAMPLE_RATE, 1, num_samples)
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
await self.mic_source.capture_frame(frame)
except TypeError:
# Fallback for different SDK versions
frame = AudioFrame.create(sample_rate=SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
memoryview(frame.data).cast('B')[:] = pcm_bytes
await self.mic_source.capture_frame(frame)
except Exception as e:
print(f"Opus audio decode error ({len(message)} bytes): {e}")
elif isinstance(message, str):
import json
try:
data = json.loads(message)
print(f"收到 ESP32 JSON 消息: {data}")
msg_type = data.get("type")
if msg_type == "abort":
reason = data.get("reason")
if reason is None:
await self._abort_tts("button_abort")
else:
print(f"忽略非按键打断: reason={reason}")
except json.JSONDecodeError:
print(f"收到未知的字符消息: {message}")
except ConnectionClosedError as e:
@ -283,6 +611,11 @@ class ESP32LiveKitBridge:
finally:
print("ESP32 断开连接")
self.esp_ws = None
self.tts_stream_id += 1
self.tts_active = False
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
self.stop_event.set()
# 我们不一定需要 join因为是 daemon=True