From e706f1d4e204f1af50f301f1686e9aea69955722 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Mon, 11 May 2026 13:40:25 +0800 Subject: [PATCH] feat: support defferent models --- main/bridge_server.py | 445 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 389 insertions(+), 56 deletions(-) diff --git a/main/bridge_server.py b/main/bridge_server.py index 0ab4a44..ae827ff 100644 --- a/main/bridge_server.py +++ b/main/bridge_server.py @@ -5,40 +5,52 @@ import os import sys import httpx import json -import time import queue +import shutil import threading +import struct +import time +import wave +import opuslib + from typing import Any, Optional from livekit import rtc from livekit.rtc import AudioSource, AudioFrame from websockets.exceptions import ConnectionClosedError -import http.server -import multipart -from urllib.parse import parse_qs -# 配置信息 -# TOKEN_URL = "http://10.6.80.130:8000/v1/token" -# LIVEKIT_WS_URL = "ws://10.6.80.130:8000/" -# ROOM = "vera-room" -# IDENTITY = "vera-1" -# TOKEN_URL = "https://omnichat.bwgdi.com/v1/token" +try: + from livekit import api as livekit_api +except ImportError: + livekit_api = None + TOKEN_URL = "http://10.6.80.130:8000/getToken" -LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud" -# LIVEKIT_WS_URL = "wss://rtc.bwgdi.com/" +LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" ROOM = "test-livekit-room2" IDENTITY = "uv-livekit-hardcoded" +AGENT_NAME = "my-agent" import uuid # IDENTITY = f"uv-{uuid.uuid4().hex[:6]}" CONNECT_TIMEOUT_SECONDS = 10.0 WS_PORT = 8080 +AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True -SAMPLE_RATE = 16000 +INPUT_SAMPLE_RATE = 16000 +OUTPUT_SAMPLE_RATE = 24000 +INPUT_FRAME_DURATION_MS = 60 +INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000 +OUTPUT_FRAME_DURATION_MS = 20 +OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000 +TTS_IDLE_TIMEOUT_SECONDS = 0.8 +TTS_SILENCE_PEAK_THRESHOLD = 96 +TTS_PRE_ROLL_MS = 200 +SAVE_AGENT_TTS_WAV = True +AGENT_TTS_WAV_PREFIX = "agent_tts" async def fetch_token() -> str: async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: response = await client.get( TOKEN_URL, - params={"room": ROOM, "identity": IDENTITY, "agent_name": "my-agent"}, + params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME}, ) response.raise_for_status() @@ -57,12 +69,248 @@ class ESP32LiveKitBridge: self.room = rtc.Room() # 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit # 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000 - self.mic_source = AudioSource(sample_rate=SAMPLE_RATE, num_channels=1) + self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1) self.esp_ws = None # 保存 WebSocket 连接 self.audio_queue = queue.Queue() self.wav_writer_thread: Optional[threading.Thread] = None self.stop_event = threading.Event() self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed + self.protocol_version = 1 + self.tts_active = False + self.forwarding_tracks: set[str] = set() + self.tts_idle_task: Optional[asyncio.Task] = None + self.tts_stream_id = 0 + self.agent_dispatch_task: Optional[asyncio.Task] = None + + async def ensure_agent_dispatched(self) -> None: + if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT: + return + + for participant in self.room.remote_participants.values(): + if AGENT_NAME in participant.identity: + print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}") + return + + if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done(): + print("Agent dispatch 正在进行中,跳过重复请求") + return + + self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent()) + await self.agent_dispatch_task + + async def _dispatch_agent(self) -> None: + print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}") + + try: + if await self._dispatch_agent_with_sdk(): + return + + if await self._dispatch_agent_with_cli(): + return + + print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI") + except Exception as e: + print(f"Agent dispatch 失败: {e}") + + async def _dispatch_agent_with_sdk(self) -> bool: + if livekit_api is None: + return False + + api_key = os.getenv("LIVEKIT_API_KEY") + api_secret = os.getenv("LIVEKIT_API_SECRET") + if not api_key or not api_secret: + return False + + lkapi = livekit_api.LiveKitAPI( + url=LIVEKIT_WS_URL, + api_key=api_key, + api_secret=api_secret, + ) + try: + dispatch = await lkapi.agent_dispatch.create_dispatch( + livekit_api.CreateAgentDispatchRequest( + agent_name=AGENT_NAME, + room=ROOM, + metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}), + ) + ) + print(f"Agent dispatch 已创建: {dispatch}") + finally: + await lkapi.aclose() + + return True + + async def _dispatch_agent_with_cli(self) -> bool: + lk_path = shutil.which("lk") + if lk_path is None: + return False + + process = await asyncio.create_subprocess_exec( + lk_path, + "dispatch", + "create", + "--room", + ROOM, + "--agent-name", + AGENT_NAME, + "--metadata", + json.dumps({"source": "bridge_server", "identity": IDENTITY}), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + stdout_text = stdout.decode("utf-8", errors="replace").strip() + stderr_text = stderr.decode("utf-8", errors="replace").strip() + + if stdout_text: + print(f"lk dispatch 输出: {stdout_text}") + if stderr_text: + print(f"lk dispatch 错误输出: {stderr_text}") + + if process.returncode != 0: + print(f"lk dispatch create 失败,退出码: {process.returncode}") + return False + + print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}") + return True + + async def _send_tts_state(self, state: str) -> None: + if self.esp_ws is None: + print(f"跳过 tts {state},ESP32 尚未连接") + return + await self.esp_ws.send(json.dumps({"type": "tts", "state": state})) + print(f"已发送 tts {state}") + + async def _start_tts(self) -> None: + if self.tts_active: + print("跳过 tts start,当前已处于激活状态") + return + await self._send_tts_state("start") + if self.esp_ws is not None: + self.tts_active = True + + async def _stop_tts(self) -> None: + if not self.tts_active: + print("跳过 tts stop,当前未激活") + return + await self._send_tts_state("stop") + self.tts_active = False + + async def _abort_tts(self, reason: str = "client_abort") -> None: + print(f"收到打断请求,停止当前 TTS: reason={reason}") + self.tts_stream_id += 1 + if self.tts_idle_task is not None: + self.tts_idle_task.cancel() + self.tts_idle_task = None + await self._stop_tts() + + def _reset_tts_idle_timer(self) -> None: + if self.tts_idle_task is not None: + self.tts_idle_task.cancel() + self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id)) + + async def _tts_idle_watchdog(self, stream_id: int) -> None: + try: + await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS) + if stream_id != self.tts_stream_id: + return + print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态") + await self._stop_tts() + except asyncio.CancelledError: + pass + + def _has_audible_audio(self, pcm_data: bytes) -> bool: + if len(pcm_data) < 2: + return False + + sample_count = len(pcm_data) // 2 + peak = 0 + for i in range(sample_count): + sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True) + if sample < 0: + sample = -sample + if sample > peak: + peak = sample + if peak >= TTS_SILENCE_PEAK_THRESHOLD: + return True + return False + + def _maybe_forward_remote_audio( + self, + track: rtc.Track, + publication: Optional[rtc.TrackPublication], + participant: rtc.RemoteParticipant, + source: str, + ) -> None: + if track.kind != rtc.TrackKind.KIND_AUDIO: + return + + track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}" + if track_sid in self.forwarding_tracks: + print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}") + return + + self.forwarding_tracks.add(track_sid) + print(f"收到音频流: {participant.identity} sid={track_sid} source={source}") + asyncio.create_task( + self.forward_audio_to_esp32( + rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1) + ) + ) + + def _scan_existing_remote_audio_tracks(self) -> None: + participants = list(self.room.remote_participants.values()) + print(f"开始扫描远端音频轨,participants={len(participants)}") + for participant in participants: + publications = getattr(participant, "track_publications", {}) or {} + print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}") + for pub_sid, publication in publications.items(): + track = getattr(publication, "track", None) + kind = getattr(publication, "kind", None) + print( + f"检查 publication: participant={participant.identity} pub_sid={pub_sid} " + f"kind={kind} has_track={track is not None}" + ) + if track is not None: + self._maybe_forward_remote_audio(track, publication, participant, "scan") + + def _extract_opus_payload(self, message: bytes) -> bytes: + if self.protocol_version == 2: + if len(message) < 16: + raise ValueError(f"version 2 音频包过短: {len(message)} bytes") + version, msg_type, _reserved, _timestamp, payload_size = struct.unpack( + "!HHIII", message[:16] + ) + if version != 2: + raise ValueError(f"version 2 包头版本异常: {version}") + if msg_type != 0: + raise ValueError(f"version 2 音频包类型异常: {msg_type}") + if len(message) < 16 + payload_size: + raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}") + return message[16:16 + payload_size] + + if self.protocol_version == 3: + if len(message) < 4: + raise ValueError(f"version 3 音频包过短: {len(message)} bytes") + msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4]) + if msg_type != 0: + raise ValueError(f"version 3 音频包类型异常: {msg_type}") + if len(message) < 4 + payload_size: + raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}") + return message[4:4 + payload_size] + + return message + + def _wrap_opus_payload(self, payload: bytes) -> bytes: + if self.protocol_version == 2: + header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload)) + return header + payload + + if self.protocol_version == 3: + header = struct.pack("!BBH", 0, 0, len(payload)) + return header + payload + + return payload def _wav_writer_loop(self): import wave @@ -71,7 +319,7 @@ class ESP32LiveKitBridge: with wave.open("bridge_debug.wav", "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) # 16-bit - wav_file.setframerate(SAMPLE_RATE) + wav_file.setframerate(INPUT_SAMPLE_RATE) while not self.stop_event.is_set() or not self.audio_queue.empty(): try: # 使用 timeout 避免永久阻塞,以便检查 stop_event @@ -93,13 +341,23 @@ class ESP32LiveKitBridge: def on_connected(): print("✅ 成功连接到 LiveKit 房间") if self.room.remote_participants: + self._scan_existing_remote_audio_tracks() self.agent_ready.set() @self.room.on("participant_connected") def on_participant_connected(p: rtc.RemoteParticipant): print(f"👋 Agent ({p.identity}) 已加入房间") + self._scan_existing_remote_audio_tracks() self.agent_ready.set() + @self.room.on("participant_disconnected") + def on_participant_disconnected(p: rtc.RemoteParticipant): + print(f"👋 远端参与者离开房间: {p.identity}") + self.forwarding_tracks = { + track_sid for track_sid in self.forwarding_tracks + if not track_sid.endswith(f":{p.identity}") + } + @self.room.on("data_received") def on_data_received(data_packet: rtc.DataPacket): identity = data_packet.participant.identity if data_packet.participant else "未知" @@ -145,6 +403,7 @@ class ESP32LiveKitBridge: print(f"已连接到 LiveKit 房间: {self.room.name}") print(f"[livekit] local_identity={self.room.local_participant.identity}") print(f"[livekit] local_sid={self.room.local_participant.sid}") + print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}") # 2. 发布麦克风轨道 (ESP32 -> LiveKit) track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source) @@ -154,11 +413,10 @@ class ESP32LiveKitBridge: # 3. 监听房间内的音频 (LiveKit -> ESP32) @self.room.on("track_subscribed") def on_track_subscribed(track, publication, participant): - if track.kind == rtc.TrackKind.KIND_AUDIO: - print(f"收到音频流: {participant.identity}") - asyncio.create_task(self.forward_audio_to_esp32(rtc.AudioStream(track, sample_rate=SAMPLE_RATE, num_channels=1))) + self._maybe_forward_remote_audio(track, publication, participant, "event") print("等待 agent 加入...") + self._scan_existing_remote_audio_tracks() try: await asyncio.wait_for(self.agent_ready.wait(), timeout=10) print("✅ agent 已就绪") @@ -172,53 +430,121 @@ class ESP32LiveKitBridge: await self.room.disconnect() async def forward_audio_to_esp32(self, audio_stream): - """从 LiveKit 接收音频,通过 WebSocket 发回给 ESP32""" - import opuslib - import json - # 创建下行 Opus 编码器 - encoder = opuslib.Encoder(SAMPLE_RATE, 1, 'voip') - - # 1. 告知 ESP32 开始说话,切换 UI 到“说话中”并准备解码 - if self.esp_ws: - await self.esp_ws.send(json.dumps({"type": "tts", "state": "start"})) + """从 LiveKit 接收音频并发给 ESP32""" + + encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip') + pending_pcm = bytearray() + pre_roll_pcm = bytearray() + pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2 + tts_wav_file = None + tts_wav_path = None + stream_id = self.tts_stream_id + 1 + self.tts_stream_id = stream_id + + if SAVE_AGENT_TTS_WAV: + tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav" + tts_wav_file = wave.open(tts_wav_path, "wb") + tts_wav_file.setnchannels(1) + tts_wav_file.setsampwidth(2) + tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE) + print(f"开始保存 Agent TTS 音频: {tts_wav_path}") try: async for event in audio_stream: - if self.esp_ws: + if stream_id != self.tts_stream_id: + print("检测到更新的 TTS 流,停止旧流转发") + break + + frame = event.frame + pcm_data = frame.data.tobytes() + if tts_wav_file is not None: + tts_wav_file.writeframes(pcm_data) + + has_audible_audio = self._has_audible_audio(pcm_data) + current_frame_buffered = False + if not self.tts_active: + pre_roll_pcm.extend(pcm_data) + current_frame_buffered = True + if len(pre_roll_pcm) > pre_roll_max_bytes: + del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes] + + if not has_audible_audio: + continue + + await self._start_tts() + if not self.tts_active: + continue + + print("检测到可听 TTS 音频,切换到 Speaking 并开始转发") + pending_pcm.extend(pre_roll_pcm) + pre_roll_pcm.clear() + + if has_audible_audio: + self._reset_tts_idle_timer() + + # LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。 + # 这里聚合成固定 20ms PCM,再编码为 Opus 发给设备,减少卡顿。 + if not current_frame_buffered: + pending_pcm.extend(pcm_data) + frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM + + while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id: try: - # AudioStream 迭代产生的是 AudioFrameEvent,需要从中提取 frame - frame = event.frame - # 将 PCM 编码为 Opus 才能发给 ESP32 - pcm_data = frame.data.tobytes() - - # 使用当前帧的实际采样数进行编码 - opus_packet = encoder.encode(pcm_data, frame.samples_per_channel) - await self.esp_ws.send(opus_packet) + opus_packet = encoder.encode( + bytes(pending_pcm[:frame_bytes]), + OUTPUT_SAMPLES_PER_OPUS_FRAME, + ) + del pending_pcm[:frame_bytes] + # print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}") + await self.esp_ws.send(self._wrap_opus_payload(opus_packet)) except Exception as e: print(f"发送回 ESP32 失败: {e}") + break + + except Exception as e: + print(f"音频流处理错误: {e}") + finally: - # 2. 音频流结束,告知 ESP32 停止说话,切换回聆听或闲置状态 - if self.esp_ws: - await self.esp_ws.send(json.dumps({"type": "tts", "state": "stop"})) + print("🎧 TTS 音频结束") + if tts_wav_file is not None: + tts_wav_file.close() + print(f"Agent TTS 音频已保存: {tts_wav_path}") + if stream_id == self.tts_stream_id and self.tts_idle_task is not None: + self.tts_idle_task.cancel() + self.tts_idle_task = None + if stream_id == self.tts_stream_id: + await self._stop_tts() async def handle_websocket(self, websocket): """处理来自 ESP32 的 WebSocket 连接""" self.esp_ws = websocket + header_version = websocket.request.headers.get("Protocol-Version") + try: + self.protocol_version = int(header_version) if header_version else 1 + except ValueError: + self.protocol_version = 1 print("ESP32 已连接") + print(f"ESP32 协议版本: {self.protocol_version}") + self.tts_stream_id += 1 + self.tts_active = False + if self.tts_idle_task is not None: + self.tts_idle_task.cancel() + self.tts_idle_task = None opus_decoder = None try: + await self.ensure_agent_dispatched() + # 发送 hello 告诉 ESP32 握手成功 hello_msg = { "type": "hello", "transport": "websocket", "audio_params": { "format": "opus", # 明确要求 ESP32 发送 Opus - "sample_rate": SAMPLE_RATE, + "sample_rate": OUTPUT_SAMPLE_RATE, "channels": 1, - "frame_duration": 60 + "frame_duration": OUTPUT_FRAME_DURATION_MS } } - import json await websocket.send(json.dumps(hello_msg)) async for message in websocket: @@ -229,18 +555,14 @@ class ESP32LiveKitBridge: print(f"收到过短的字节消息 ({len(message)} bytes),跳过") continue - # ESP32 默认使用 websocket_protocol version=1 (见 websocket_protocol.cc) - # 这个版本下,没有 4 字节的 header,接收到的就是原生的 Opus 数据帧。 - # 直接丢给 opuslib 解码即可。 - audio_data = message - print(f"收到音频包长度: {len(message)}") + audio_data = self._extract_opus_payload(message) + # print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}") if audio_data: try: # Create Opus decoder if not exists if opus_decoder is None: - import opuslib - print(f"初始化 Opus 解码器: {SAMPLE_RATE}Hz, mono") - opus_decoder = opuslib.Decoder(SAMPLE_RATE, 1) + print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono") + opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1) # 启动音频保存线程 self.stop_event.clear() @@ -249,8 +571,8 @@ class ESP32LiveKitBridge: thread.start() # Decode Opus packet. - # Frame size for 60ms is SAMPLE_RATE * 0.06 - frame_size = int(SAMPLE_RATE * 0.06) + # ESP32 -> bridge 保持按设备上行配置 60ms 解码。 + frame_size = INPUT_SAMPLES_PER_OPUS_FRAME pcm_bytes = opus_decoder.decode(audio_data, frame_size) # 将音频数据放入队列由后台线程保存 @@ -260,20 +582,26 @@ class ESP32LiveKitBridge: if num_samples > 0: try: # Use the more robust frame creation logic from test_client_wav.py - frame = AudioFrame(pcm_bytes, SAMPLE_RATE, 1, num_samples) + frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples) await self.mic_source.capture_frame(frame) except TypeError: # Fallback for different SDK versions - frame = AudioFrame.create(sample_rate=SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples) + frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples) memoryview(frame.data).cast('B')[:] = pcm_bytes await self.mic_source.capture_frame(frame) except Exception as e: print(f"Opus audio decode error ({len(message)} bytes): {e}") elif isinstance(message, str): - import json try: data = json.loads(message) print(f"收到 ESP32 JSON 消息: {data}") + msg_type = data.get("type") + if msg_type == "abort": + reason = data.get("reason") + if reason is None: + await self._abort_tts("button_abort") + else: + print(f"忽略非按键打断: reason={reason}") except json.JSONDecodeError: print(f"收到未知的字符消息: {message}") except ConnectionClosedError as e: @@ -283,6 +611,11 @@ class ESP32LiveKitBridge: finally: print("ESP32 断开连接") self.esp_ws = None + self.tts_stream_id += 1 + self.tts_active = False + if self.tts_idle_task is not None: + self.tts_idle_task.cancel() + self.tts_idle_task = None if hasattr(self, "wav_writer_thread") and self.wav_writer_thread: self.stop_event.set() # 我们不一定需要 join,因为是 daemon=True