import asyncio import json import os import shutil import struct import sys import time import traceback import uuid from dataclasses import dataclass, field from typing import Any, Optional import httpx import opuslib import websockets from livekit import rtc from livekit.rtc import AudioFrame, AudioSource from websockets.exceptions import ConnectionClosedError try: from livekit import api as livekit_api except ImportError: livekit_api = None TOKEN_URL = "http://172.19.0.240:8000/getToken" LIVEKIT_WS_URL = "ws://172.19.0.240:7880" ROOM_PREFIX = "test-livekit" IDENTITY_PREFIX = "uv-livekit" AGENT_NAME = "my-agent" CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0")) AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0")) WS_PORT = 8080 AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower() INPUT_SAMPLE_RATE = 16000 OUTPUT_SAMPLE_RATE = 24000 INPUT_FRAME_DURATION_MS = 20 INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000 INPUT_MAX_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * 60 // 1000 OUTPUT_FRAME_DURATION_MS = 20 OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000 TTS_IDLE_TIMEOUT_SECONDS = 0.25 TTS_SILENCE_PEAK_THRESHOLD = 96 TTS_PRE_ROLL_MS = 80 TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = 1 TTS_INTERRUPT_SILENCE_FRAMES = 3 INTERRUPT_TOPIC = "lk.interrupt" TTS_DISPLAY_SENTENCE_BREAKS = "。!?!?;;" TTS_DISPLAY_SCROLL_WIDTH = int(os.getenv("TTS_DISPLAY_SCROLL_WIDTH", "18")) TTS_DISPLAY_SCROLL_INTERVAL_SECONDS = float(os.getenv("TTS_DISPLAY_SCROLL_INTERVAL_SECONDS", "0.18")) TTS_DISPLAY_SCROLL_GAP = " " @dataclass class DeviceSession: device_id: str websocket: Any protocol_version: int room_name: str identity: str room: rtc.Room mic_source: AudioSource agent_ready: asyncio.Event forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict) tts_active: bool = False tts_idle_task: Optional[asyncio.Task] = None tts_display_task: Optional[asyncio.Task] = None tts_stream_id: int = 0 tts_transcript_text: str = "" tts_display_text: str = "" tts_display_final: bool = False agent_dispatch_task: Optional[asyncio.Task] = None closed: bool = False captured_frame_count: int = 0 first_capture_log_time: float = 0.0 async def fetch_token(room_name: str, identity: str) -> str: params = {"room": room_name, "identity": identity} if AGENT_DISPATCH_MODE == "token": params["agent_name"] = AGENT_NAME async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: response = await client.get(TOKEN_URL, params=params) response.raise_for_status() payload: dict[str, Any] = response.json() token = payload.get("token") if not isinstance(token, str) or not token: raise ValueError(f"token response missing token field: {payload}") # print(f"[token] room={payload.get('room')} identity={payload.get('identity')}") # print(f"[token] jwt_prefix={token[:16]}... len={len(token)}") return token class ESP32LiveKitBridge: def __init__(self) -> None: self.device_sessions: dict[str, DeviceSession] = {} def _log_exception(self, prefix: str, exc: BaseException) -> None: print(f"{prefix}: type={type(exc).__name__} detail={exc!r}") formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) if formatted_tb.strip(): print(formatted_tb.rstrip()) def _build_device_id(self, websocket: Any) -> str: headers = websocket.request.headers requested_id = headers.get("X-Device-Id") or headers.get("Device-Id") if requested_id: return requested_id return f"ws-{id(websocket):x}" def _build_session_names(self, device_id: str) -> tuple[str, str]: session_tag = uuid.uuid4().hex[:8] room_name = f"{ROOM_PREFIX}-{session_tag}" identity = f"{IDENTITY_PREFIX}-{session_tag}" print(f"[session] device={device_id} room={room_name} identity={identity}") return room_name, identity def _is_agent_participant(self, participant: rtc.RemoteParticipant) -> bool: identity = getattr(participant, "identity", "") or "" return identity.startswith("agent-") or AGENT_NAME in identity def _get_agent_identities(self, session: DeviceSession) -> list[str]: return [ participant.identity for participant in session.room.remote_participants.values() if self._is_agent_participant(participant) ] def _log_agent_participants(self, session: DeviceSession, source: str) -> None: agent_identities = self._get_agent_identities(session) # print( # f"[agent-check] source={source} room={session.room_name} " # f"agent_count={len(agent_identities)} agents={agent_identities}" # ) async def _has_existing_dispatch(self, session: DeviceSession) -> bool: if livekit_api is None: return False api_key = os.getenv("LIVEKIT_API_KEY") api_secret = os.getenv("LIVEKIT_API_SECRET") if not api_key or not api_secret: return False request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None) if request_cls is None: return False lkapi = livekit_api.LiveKitAPI( url=LIVEKIT_WS_URL, api_key=api_key, api_secret=api_secret, ) try: response = await lkapi.agent_dispatch.list_dispatch( request_cls(room=session.room_name) ) dispatches = ( getattr(response, "agent_dispatches", None) or getattr(response, "dispatches", None) or getattr(response, "items", None) or [] ) for dispatch in dispatches: dispatch_agent_name = getattr(dispatch, "agent_name", None) dispatch_room = getattr(dispatch, "room", None) if dispatch_room == session.room_name and ( dispatch_agent_name == AGENT_NAME or dispatch_agent_name is None ): print( f"检测到已有 dispatch: room={session.room_name} " f"agent={dispatch_agent_name}" ) return True except Exception as exc: print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}") finally: await lkapi.aclose() return False async def ensure_agent_dispatched(self, session: DeviceSession) -> None: if AGENT_DISPATCH_MODE != "bridge": # print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}") return if await self._has_existing_dispatch(session): return for participant in session.room.remote_participants.values(): if self._is_agent_participant(participant): # print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}") return if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done(): # print("Agent dispatch 正在进行中,跳过重复请求") return session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session)) await session.agent_dispatch_task async def _dispatch_agent(self, session: DeviceSession) -> None: print(f"准备 dispatch agent: room={session.room_name}, agent={AGENT_NAME}") try: if await self._dispatch_agent_with_sdk(session): return if await self._dispatch_agent_with_cli(session): return print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI") except Exception as exc: print(f"Agent dispatch 失败: {exc}") async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool: if livekit_api is None: return False api_key = os.getenv("LIVEKIT_API_KEY") api_secret = os.getenv("LIVEKIT_API_SECRET") if not api_key or not api_secret: return False lkapi = livekit_api.LiveKitAPI( url=LIVEKIT_WS_URL, api_key=api_key, api_secret=api_secret, ) try: dispatch = await lkapi.agent_dispatch.create_dispatch( livekit_api.CreateAgentDispatchRequest( agent_name=AGENT_NAME, room=session.room_name, metadata=json.dumps( { "source": "bridge_server", "identity": session.identity, "device_id": session.device_id, } ), ) ) print(f"Agent dispatch 已创建: {dispatch}") finally: await lkapi.aclose() return True async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool: lk_path = shutil.which("lk") if lk_path is None: return False process = await asyncio.create_subprocess_exec( lk_path, "dispatch", "create", "--room", session.room_name, "--agent-name", AGENT_NAME, "--metadata", json.dumps( { "source": "bridge_server", "identity": session.identity, "device_id": session.device_id, } ), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() stdout_text = stdout.decode("utf-8", errors="replace").strip() stderr_text = stderr.decode("utf-8", errors="replace").strip() if stdout_text: print(f"lk dispatch 输出: {stdout_text}") if stderr_text: print(f"lk dispatch 错误输出: {stderr_text}") if process.returncode != 0: print(f"lk dispatch create 失败,退出码: {process.returncode}") return False print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={AGENT_NAME}") return True async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool: participant = getattr(session.room, "local_participant", None) if participant is None: print("跳过发送 agent 控制事件,local participant 尚未就绪") return False data = json.dumps(payload).encode("utf-8") agent_identities = self._get_agent_identities(session) kwargs: dict[str, Any] = {} if agent_identities: kwargs["destination_identities"] = agent_identities last_error: Optional[Exception] = None for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs): try: await participant.publish_data(data, **attempt) print( f"已发送 agent 控制事件: topic={attempt.get('topic', '')} " f"targets={agent_identities or 'broadcast'} payload={payload}" ) return True except TypeError as exc: last_error = exc except Exception as exc: print(f"发送 agent 控制事件失败: {exc}") return False if last_error is not None: print(f"发送 agent 控制事件失败,publish_data 签名不兼容: {last_error}") return False async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None: payload = { "type": "interrupt", "topic": INTERRUPT_TOPIC, "reason": reason, "room": session.room_name, "identity": session.identity, "device_id": session.device_id, } ok = await self._publish_agent_event(session, payload) if not ok: print("警告: bridge 已停止 TTS,但 agent 侧 interrupt 未确认送出") async def _send_tts_state(self, session: DeviceSession, state: str) -> None: if session.websocket is None: print(f"跳过 tts {state},ESP32 尚未连接") return await session.websocket.send(json.dumps({"type": "tts", "state": state})) print(f"已发送 tts {state}: device={session.device_id}") async def _send_tts_text(self, session: DeviceSession, text: str, final: bool) -> None: if session.websocket is None: return await session.websocket.send( json.dumps( { "type": "tts", "state": "sentence_start", "text": text, "final": final, } ) ) def _cancel_tts_display_task(self, session: DeviceSession) -> None: if session.tts_display_task is not None: session.tts_display_task.cancel() session.tts_display_task = None def _update_tts_display_text(self, session: DeviceSession, text: str, final: bool) -> None: session.tts_display_text = text session.tts_display_final = final if len(text) <= TTS_DISPLAY_SCROLL_WIDTH: self._cancel_tts_display_task(session) asyncio.create_task(self._send_tts_text(session, text, final)) return if session.tts_display_task is None or session.tts_display_task.done(): session.tts_display_task = asyncio.create_task( self._scroll_tts_display_text(session, session.tts_stream_id) ) async def _scroll_tts_display_text(self, session: DeviceSession, stream_id: int) -> None: offset = 0 last_sent = "" try: while stream_id == session.tts_stream_id and session.websocket is not None: text = session.tts_display_text if not text: return if len(text) <= TTS_DISPLAY_SCROLL_WIDTH: if text != last_sent: await self._send_tts_text(session, text, session.tts_display_final) return scroll_text = text + TTS_DISPLAY_SCROLL_GAP if offset >= len(scroll_text): offset = 0 looped_text = scroll_text + scroll_text[:TTS_DISPLAY_SCROLL_WIDTH] visible_text = looped_text[offset:offset + TTS_DISPLAY_SCROLL_WIDTH].rstrip() if visible_text and visible_text != last_sent: await self._send_tts_text( session, visible_text, session.tts_display_final, ) last_sent = visible_text offset += 1 await asyncio.sleep(TTS_DISPLAY_SCROLL_INTERVAL_SECONDS) except asyncio.CancelledError: pass except Exception as exc: print(f"TTS 字幕滚动失败: {exc}") async def _start_tts(self, session: DeviceSession) -> None: if session.tts_active: print("跳过 tts start,当前已处于激活状态") return if not session.tts_display_text: session.tts_transcript_text = "" session.tts_display_final = False self._cancel_tts_display_task(session) await self._send_tts_state(session, "start") session.tts_active = True async def _stop_tts(self, session: DeviceSession) -> None: if not session.tts_active: print("跳过 tts stop,当前未激活") return self._cancel_tts_display_task(session) await self._send_tts_state(session, "stop") session.tts_active = False session.tts_transcript_text = "" session.tts_display_text = "" session.tts_display_final = False async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None: print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}") session.tts_stream_id += 1 if session.tts_idle_task is not None: session.tts_idle_task.cancel() session.tts_idle_task = None self._cancel_tts_display_task(session) await self._send_agent_interrupt(session, reason) await self._stop_tts(session) def _reset_tts_idle_timer(self, session: DeviceSession) -> None: if session.tts_idle_task is not None: session.tts_idle_task.cancel() session.tts_idle_task = asyncio.create_task( self._tts_idle_watchdog(session, session.tts_stream_id) ) async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None: try: await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS) if stream_id != session.tts_stream_id: return print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态") await self._stop_tts(session) except asyncio.CancelledError: pass def _has_audible_audio(self, pcm_data: bytes) -> bool: if len(pcm_data) < 2: return False sample_count = len(pcm_data) // 2 peak = 0 for i in range(sample_count): sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True) if sample < 0: sample = -sample if sample > peak: peak = sample if peak >= TTS_SILENCE_PEAK_THRESHOLD: return True return False def _maybe_forward_remote_audio( self, session: DeviceSession, track: rtc.Track, publication: Optional[rtc.TrackPublication], participant: rtc.RemoteParticipant, source: str, ) -> None: if track.kind != rtc.TrackKind.KIND_AUDIO: return track_sid = ( getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}" ) existing_task = session.forwarding_tracks.get(track_sid) if existing_task is not None and not existing_task.done(): print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}") return if existing_task is not None and existing_task.done(): print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}") session.forwarding_tracks.pop(track_sid, None) task = asyncio.create_task( self.forward_audio_to_esp32( session, track_sid, rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1), ) ) session.forwarding_tracks[track_sid] = task print( f"收到音频流: {participant.identity} sid={track_sid} " f"source={source} room={session.room_name}" ) def _scan_participant_audio_tracks( self, session: DeviceSession, participant: rtc.RemoteParticipant, source: str, ) -> None: publications = getattr(participant, "track_publications", None) or {} for publication in publications.values(): track = getattr(publication, "track", None) if track is None: continue self._maybe_forward_remote_audio(session, track, publication, participant, source) def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes: if session.protocol_version == 2: if len(message) < 16: raise ValueError(f"version 2 音频包过短: {len(message)} bytes") version, msg_type, _reserved, _timestamp, payload_size = struct.unpack( "!HHIII", message[:16] ) if version != 2: raise ValueError(f"version 2 包头版本异常: {version}") if msg_type != 0: raise ValueError(f"version 2 音频包类型异常: {msg_type}") if len(message) < 16 + payload_size: raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}") return message[16:16 + payload_size] if session.protocol_version == 3: if len(message) < 4: raise ValueError(f"version 3 音频包过短: {len(message)} bytes") msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4]) if msg_type != 0: raise ValueError(f"version 3 音频包类型异常: {msg_type}") if len(message) < 4 + payload_size: raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}") return message[4:4 + payload_size] return message def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes: if session.protocol_version == 2: header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload)) return header + payload if session.protocol_version == 3: header = struct.pack("!BBH", 0, 0, len(payload)) return header + payload return payload def _current_tts_display_text(self, text: str) -> str: normalized = " ".join(text.split()) if not normalized: return "" last_break = -1 previous_break = -1 for index, char in enumerate(normalized): if char in TTS_DISPLAY_SENTENCE_BREAKS: previous_break = last_break last_break = index if last_break == -1: return normalized if last_break == len(normalized) - 1: start = previous_break + 1 else: start = last_break + 1 return normalized[start:].strip() or normalized def _register_room_handlers(self, session: DeviceSession) -> None: @session.room.on("connection_state_changed") def on_connection_state_changed(state: int) -> None: # print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}") pass @session.room.on("connected") def on_connected() -> None: print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}") self._log_agent_participants(session, "connected") for participant in session.room.remote_participants.values(): if self._is_agent_participant(participant): session.agent_ready.set() self._scan_participant_audio_tracks(session, participant, "connected_scan") @session.room.on("participant_connected") def on_participant_connected(participant: rtc.RemoteParticipant) -> None: role = "Agent" if self._is_agent_participant(participant) else "Remote participant" print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}") self._log_agent_participants(session, "participant_connected") if self._is_agent_participant(participant): session.agent_ready.set() self._scan_participant_audio_tracks( session, participant, "participant_connected_scan" ) @session.room.on("participant_disconnected") def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None: print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}") session.forwarding_tracks = { track_sid: task for track_sid, task in session.forwarding_tracks.items() if not track_sid.endswith(f":{participant.identity}") } @session.room.on("data_received") def on_data_received(data_packet: rtc.DataPacket) -> None: identity = data_packet.participant.identity if data_packet.participant else "未知" try: print( f"📩 [数据接收 | room={session.room_name} | {identity}]: " f"{data_packet.data.decode('utf-8')}" ) except Exception: pass @session.room.on("transcription_received") def on_transcription_received( segments: list[rtc.TranscriptionSegment], participant: rtc.Participant, track_pub: rtc.TrackPublication, ) -> None: identity = participant.identity if participant else "未知" is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(participant) for segment in segments: status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果" print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}") if is_agent: display_text = self._current_tts_display_text(segment.text) if not display_text or display_text == session.tts_transcript_text: continue session.tts_transcript_text = display_text else: if not segment.final: continue display_text = segment.text if session.websocket is not None: ws = session.websocket if is_agent: self._update_tts_display_text(session, display_text, segment.final) else: asyncio.create_task( ws.send( json.dumps( { "type": "stt", "text": segment.text, "final": segment.final, } ) ) ) @session.room.on("track_subscribed") def on_track_subscribed( track: rtc.Track, publication: rtc.TrackPublication, participant: rtc.RemoteParticipant, ) -> None: self._maybe_forward_remote_audio(session, track, publication, participant, "event") @session.room.on("track_published") def on_track_published( publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant, ) -> None: track_sid = getattr(publication, "sid", None) # print( # f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} " # f"track_sid={track_sid}" # ) track = getattr(publication, "track", None) if track is not None: self._maybe_forward_remote_audio(session, track, publication, participant, "published") async def _connect_session_room(self, session: DeviceSession) -> None: self._register_room_handlers(session) # print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}") # print(f"[config] token_url={TOKEN_URL}") # print(f"[config] room={session.room_name} identity={session.identity}") # print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}") token = await fetch_token(session.room_name, session.identity) try: await session.room.connect( LIVEKIT_WS_URL, token, options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS), ) except Exception as exc: self._log_exception( f"连接 LiveKit 房间失败: room={session.room_name}", exc, ) print( "提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;" "请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、" "*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。" ) raise print(f"已连接到 LiveKit 房间: {session.room.name}") # print(f"[livekit] local_identity={session.room.local_participant.identity}") # print(f"[livekit] local_sid={session.room.local_participant.sid}") # print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}") self._log_agent_participants(session, "after_connect") await self.ensure_agent_dispatched(session) track = rtc.LocalAudioTrack.create_audio_track( f"esp32-mic-{session.device_id}", session.mic_source, ) options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) publication = await session.room.local_participant.publish_track(track, options) publication_sid = getattr(publication, "sid", None) track_sid = getattr(track, "sid", None) # print( # f"已发布 ESP32 mic track: room={session.room_name} " # f"track_sid={track_sid} publication_sid={publication_sid}" # ) self._log_agent_participants(session, "after_publish_mic") # print(f"等待 agent 加入: room={session.room_name}") try: await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS) # print(f"✅ agent 已就绪: room={session.room_name}") except asyncio.TimeoutError: print(f"⚠️ agent 等待超时: room={session.room_name}") async def start(self) -> None: print(f"[config] websocket_port={WS_PORT}") print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}") print(f"[config] token_url={TOKEN_URL}") print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}") async def close(self) -> None: for session in list(self.device_sessions.values()): await self._close_session(session) async def _close_session(self, session: DeviceSession) -> None: if session.closed: return session.closed = True session.websocket = None session.tts_active = False session.tts_stream_id += 1 if session.tts_idle_task is not None: session.tts_idle_task.cancel() session.tts_idle_task = None self._cancel_tts_display_task(session) try: await session.room.disconnect() except Exception as exc: print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}") async def forward_audio_to_esp32( self, session: DeviceSession, track_sid: str, audio_stream: rtc.AudioStream, ) -> None: encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip") pending_pcm = bytearray() pre_roll_pcm = bytearray() pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2 audible_frame_streak = 0 silence_frame_streak = 0 waiting_for_post_interrupt_silence = False stream_id = session.tts_stream_id print( f"启动 TTS 转发: device={session.device_id} room={session.room_name} " f"track_sid={track_sid} stream_id={stream_id}" ) try: async for event in audio_stream: if stream_id != session.tts_stream_id: print("检测到 TTS 被中断,停止当前播放并等待静音分隔") pending_pcm.clear() pre_roll_pcm.clear() audible_frame_streak = 0 silence_frame_streak = 0 waiting_for_post_interrupt_silence = True stream_id = session.tts_stream_id if session.tts_active: await self._stop_tts(session) continue frame = event.frame pcm_data = frame.data.tobytes() has_audible_audio = self._has_audible_audio(pcm_data) if waiting_for_post_interrupt_silence: if has_audible_audio: silence_frame_streak = 0 continue silence_frame_streak += 1 if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES: continue print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播") waiting_for_post_interrupt_silence = False silence_frame_streak = 0 stream_id = session.tts_stream_id continue current_frame_buffered = False if not session.tts_active: pre_roll_pcm.extend(pcm_data) current_frame_buffered = True if len(pre_roll_pcm) > pre_roll_max_bytes: del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes] if not has_audible_audio: audible_frame_streak = 0 continue audible_frame_streak += 1 if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES: continue await self._start_tts(session) if not session.tts_active: continue print( "检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 " f"(frames={audible_frame_streak})" ) pending_pcm.extend(pre_roll_pcm) pre_roll_pcm.clear() audible_frame_streak = 0 if has_audible_audio: self._reset_tts_idle_timer(session) if not current_frame_buffered: pending_pcm.extend(pcm_data) frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 while ( len(pending_pcm) >= frame_bytes and stream_id == session.tts_stream_id and session.websocket is not None ): try: opus_packet = encoder.encode( bytes(pending_pcm[:frame_bytes]), OUTPUT_SAMPLES_PER_OPUS_FRAME, ) del pending_pcm[:frame_bytes] await session.websocket.send(self._wrap_opus_payload(session, opus_packet)) except Exception as exc: print(f"发送回 ESP32 失败: {exc}") break except Exception as exc: print(f"音频流处理错误: {exc}") finally: print("🎧 TTS 音频结束") task = session.forwarding_tracks.get(track_sid) current_task = asyncio.current_task() if task is current_task: session.forwarding_tracks.pop(track_sid, None) if stream_id == session.tts_stream_id and session.tts_idle_task is not None: session.tts_idle_task.cancel() session.tts_idle_task = None if stream_id == session.tts_stream_id: await self._stop_tts(session) async def handle_websocket(self, websocket: Any) -> None: header_version = websocket.request.headers.get("Protocol-Version") try: protocol_version = int(header_version) if header_version else 1 except ValueError: protocol_version = 1 device_id = self._build_device_id(websocket) existing_session = self.device_sessions.get(device_id) if existing_session is not None: print(f"检测到重复 device_id,关闭旧 session: device={device_id}") await self._close_session(existing_session) self.device_sessions.pop(device_id, None) room_name, identity = self._build_session_names(device_id) session = DeviceSession( device_id=device_id, websocket=websocket, protocol_version=protocol_version, room_name=room_name, identity=identity, room=rtc.Room(), mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1), agent_ready=asyncio.Event(), ) self.device_sessions[device_id] = session # print(f"ESP32 已连接: device={device_id}") # print(f"ESP32 协议版本: {session.protocol_version}") session.tts_stream_id += 1 opus_decoder = None try: await self._connect_session_room(session) hello_msg = { "type": "hello", "transport": "websocket", "session": { "room": session.room_name, "identity": session.identity, }, "audio_params": { "format": "opus", "sample_rate": OUTPUT_SAMPLE_RATE, "channels": 1, "frame_duration": OUTPUT_FRAME_DURATION_MS, }, } await websocket.send(json.dumps(hello_msg)) async for message in websocket: if isinstance(message, bytes): if len(message) < 4: print(f"收到过短的字节消息 ({len(message)} bytes),跳过") continue audio_data = self._extract_opus_payload(session, message) if not audio_data: continue try: if opus_decoder is None: # print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono") opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1) pcm_bytes = opus_decoder.decode(audio_data, INPUT_MAX_SAMPLES_PER_OPUS_FRAME) num_samples = len(pcm_bytes) // 2 if num_samples > 0: session.captured_frame_count += 1 now = time.monotonic() if ( session.captured_frame_count <= 5 or now - session.first_capture_log_time >= 5.0 ): session.first_capture_log_time = now # print( # f"[uplink] capture_frame count={session.captured_frame_count} " # f"bytes={len(pcm_bytes)} samples={num_samples} " # f"room={session.room_name}" # ) try: frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples) await session.mic_source.capture_frame(frame) except TypeError: frame = AudioFrame.create( sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples, ) memoryview(frame.data).cast("B")[:] = pcm_bytes await session.mic_source.capture_frame(frame) except Exception as exc: print(f"Opus audio decode error ({len(message)} bytes): {exc}") elif isinstance(message, str): try: data = json.loads(message) # print(f"收到 ESP32 JSON 消息: {data}") msg_type = data.get("type") if msg_type == "abort": reason = data.get("reason") abort_reason = reason if isinstance(reason, str) and reason else "button_abort" print(f"处理 ESP32 打断请求: reason={abort_reason}") await self._abort_tts(session, abort_reason) except json.JSONDecodeError: print(f"收到未知的字符消息: {message}") except ConnectionClosedError as exc: print(f"ESP32 异常断开: {exc}") except Exception as exc: self._log_exception("WebSocket 其他错误", exc) finally: print(f"ESP32 断开连接: device={device_id} room={session.room_name}") await self._close_session(session) self.device_sessions.pop(device_id, None) async def main() -> None: bridge = ESP32LiveKitBridge() try: await bridge.start() async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT): print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...") await asyncio.Future() finally: await bridge.close() if __name__ == "__main__": try: asyncio.run(main()) except Exception as exc: print(f"[error] {exc}", file=sys.stderr) sys.exit(1)