diff --git a/main/bridge_server.py b/main/bridge_server.py index fd0ada1..597f1be 100644 --- a/main/bridge_server.py +++ b/main/bridge_server.py @@ -1,6 +1,5 @@ import asyncio import base64 -import contextlib import json import os import re @@ -12,7 +11,7 @@ import traceback import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional +from typing import Any, Coroutine, Optional import httpx import opuslib @@ -123,6 +122,9 @@ class DeviceSession: last_interrupt_time: float = 0.0 last_uplink_audible_time: float = 0.0 agent_dispatch_task: Optional[asyncio.Task] = None + room_connect_task: Optional[asyncio.Task] = None + background_tasks: set[asyncio.Task[Any]] = field(default_factory=set) + forwarding_track_participants: dict[str, str] = field(default_factory=dict) closed: bool = False captured_frame_count: int = 0 first_capture_log_time: float = 0.0 @@ -209,11 +211,15 @@ class ESP32LiveKitBridge: session: DeviceSession, task: asyncio.Task[Any], ) -> None: + if session.room_connect_task is task: + session.room_connect_task = None if task.cancelled(): return try: task.result() except Exception as exc: + if session.closed: + return self._log_exception( f"LiveKit 房间连接后台任务失败: room={session.room_name}", exc, @@ -222,6 +228,48 @@ class ESP32LiveKitBridge: if websocket is not None: asyncio.create_task(websocket.close(code=1011, reason="livekit connect failed")) + def _create_session_task( + self, + session: DeviceSession, + coroutine: Coroutine[Any, Any, Any], + description: str, + ) -> Optional[asyncio.Task[Any]]: + if session.closed: + coroutine.close() + return None + + task = asyncio.create_task(coroutine) + session.background_tasks.add(task) + task.add_done_callback( + lambda done_task: self._handle_session_task_done( + session, + done_task, + description, + ) + ) + return task + + def _handle_session_task_done( + self, + session: DeviceSession, + task: asyncio.Task[Any], + description: str, + ) -> None: + session.background_tasks.discard(task) + if task.cancelled(): + return + try: + task.result() + except Exception as exc: + if not session.closed: + self._log_exception(f"{description} 失败: room={session.room_name}", exc) + + async def _disconnect_room_quietly(self, session: DeviceSession, reason: str) -> None: + try: + await session.room.disconnect() + except Exception as exc: + print(f"断开 LiveKit 房间失败: room={session.room_name} reason={reason} error={exc}") + async def _capture_mic_frame( self, session: DeviceSession, @@ -523,6 +571,8 @@ class ESP32LiveKitBridge: return False async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None: + if session.closed: + return payload = { "type": "interrupt", "topic": INTERRUPT_TOPIC, @@ -553,6 +603,8 @@ class ESP32LiveKitBridge: return path async def _publish_vision_frame(self, session: DeviceSession, message: dict[str, Any]) -> None: + if session.closed: + return image = message.get("image") if not isinstance(image, str) or not image: print("收到 vision frame,但 image 字段为空") @@ -603,6 +655,8 @@ class ESP32LiveKitBridge: print(f"发送 vision frame 失败,publish_data 签名不兼容: {last_error}") async def _publish_mcp_message(self, session: DeviceSession, message: dict[str, Any]) -> None: + if session.closed: + return payload = message.get("payload") if not isinstance(payload, dict): print(f"收到 ESP32 MCP 消息但缺少 payload: {message}") @@ -652,6 +706,8 @@ class ESP32LiveKitBridge: *, source_identity: str, ) -> None: + if session.closed: + return if session.websocket is None: print("跳过 MCP 请求,ESP32 尚未连接") return @@ -671,6 +727,8 @@ class ESP32LiveKitBridge: ) async def _send_tts_state(self, session: DeviceSession, state: str) -> None: + if session.closed: + return if session.websocket is None: print(f"跳过 tts {state},ESP32 尚未连接") return @@ -678,6 +736,8 @@ class ESP32LiveKitBridge: print(f"已发送 tts {state}: device={session.device_id}") async def _send_emotion(self, session: DeviceSession, emotion: str) -> None: + if session.closed: + return if session.websocket is None: print(f"跳过 emotion {emotion},ESP32 尚未连接") return @@ -703,6 +763,8 @@ class ESP32LiveKitBridge: await self._send_emotion(session, emotion) async def _send_tts_text(self, session: DeviceSession, text: str, final: bool) -> None: + if session.closed: + return if session.websocket is None: return raw_text = text @@ -733,12 +795,18 @@ class ESP32LiveKitBridge: if len(text) <= TTS_DISPLAY_SCROLL_WIDTH: self._cancel_tts_display_task(session) - asyncio.create_task(self._send_tts_text(session, text, final)) + self._create_session_task( + session, + self._send_tts_text(session, text, final), + "发送 TTS 字幕", + ) return if session.tts_display_task is None or session.tts_display_task.done(): - session.tts_display_task = asyncio.create_task( - self._scroll_tts_display_text(session, session.tts_stream_id) + session.tts_display_task = self._create_session_task( + session, + self._scroll_tts_display_text(session, session.tts_stream_id), + "滚动 TTS 字幕", ) async def _scroll_tts_display_text(self, session: DeviceSession, stream_id: int) -> None: @@ -857,7 +925,11 @@ class ESP32LiveKitBridge: f"[agent-state] room={session.room_name} identity={participant.identity} state={state}" ) if state == "thinking": - asyncio.create_task(self._start_thinking(session)) + self._create_session_task( + session, + self._start_thinking(session), + "处理 agent thinking 状态", + ) async def _stop_tts(self, session: DeviceSession) -> None: if not session.tts_active and not session.tts_thinking: @@ -898,14 +970,20 @@ class ESP32LiveKitBridge: session.tts_suppressed_until = now + TTS_INTERRUPT_SUPPRESS_SECONDS session.tts_waiting_for_user_audio_after_interrupt = True await self._force_stop_tts(session, reason) - asyncio.create_task(self._send_agent_interrupt(session, reason)) + self._create_session_task( + session, + self._send_agent_interrupt(session, reason), + "发送 agent interrupt", + ) def _reset_tts_idle_timer(self, session: DeviceSession) -> None: session.tts_last_audible_at = time.monotonic() if session.tts_idle_task is not None: session.tts_idle_task.cancel() - session.tts_idle_task = asyncio.create_task( - self._tts_idle_watchdog(session, session.tts_stream_id) + session.tts_idle_task = self._create_session_task( + session, + self._tts_idle_watchdog(session, session.tts_stream_id), + "TTS idle watchdog", ) async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None: @@ -959,6 +1037,8 @@ class ESP32LiveKitBridge: participant: rtc.RemoteParticipant, source: str, ) -> None: + if session.closed or session.websocket is None: + return if track.kind != rtc.TrackKind.KIND_AUDIO: return @@ -974,6 +1054,7 @@ class ESP32LiveKitBridge: if existing_task is not None and existing_task.done(): print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}") session.forwarding_tracks.pop(track_sid, None) + session.forwarding_track_participants.pop(track_sid, None) task = asyncio.create_task( self.forward_audio_to_esp32( @@ -983,17 +1064,37 @@ class ESP32LiveKitBridge: ) ) session.forwarding_tracks[track_sid] = task + session.forwarding_track_participants[track_sid] = participant.identity print( f"收到音频流: {participant.identity} sid={track_sid} " f"source={source} room={session.room_name}" ) + def _cancel_forwarding_tracks( + self, + session: DeviceSession, + participant_identity: Optional[str] = None, + ) -> list[asyncio.Task[Any]]: + cancelled: list[asyncio.Task[Any]] = [] + for track_sid, task in list(session.forwarding_tracks.items()): + track_participant = session.forwarding_track_participants.get(track_sid) + if participant_identity is not None and track_participant != participant_identity: + continue + session.forwarding_tracks.pop(track_sid, None) + session.forwarding_track_participants.pop(track_sid, None) + if not task.done(): + task.cancel() + cancelled.append(task) + return cancelled + def _scan_participant_audio_tracks( self, session: DeviceSession, participant: rtc.RemoteParticipant, source: str, ) -> None: + if session.closed: + return publications = getattr(participant, "track_publications", None) or {} for publication in publications.values(): track = getattr(publication, "track", None) @@ -1061,6 +1162,76 @@ class ESP32LiveKitBridge: return normalized[start:].strip() or normalized + async def _handle_room_connected(self, session: DeviceSession) -> None: + if session.closed: + return + print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}") + self._log_agent_participants(session, "connected") + for participant in list(session.room.remote_participants.values()): + if session.closed: + return + if self._is_agent_participant(participant, session.agent_name): + session.agent_ready.set() + self._scan_participant_audio_tracks(session, participant, "connected_scan") + self._handle_agent_state(session, participant) + + async def _handle_participant_connected( + self, + session: DeviceSession, + participant: rtc.RemoteParticipant, + ) -> None: + if session.closed: + return + role = "Agent" if self._is_agent_participant(participant, session.agent_name) else "Remote participant" + print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}") + self._log_agent_participants(session, "participant_connected") + if self._is_agent_participant(participant, session.agent_name): + session.agent_ready.set() + self._scan_participant_audio_tracks( + session, participant, "participant_connected_scan" + ) + self._handle_agent_state(session, participant) + + async def _handle_participant_attributes_changed( + self, + session: DeviceSession, + changed: list[str], + participant: rtc.Participant, + ) -> None: + if session.closed: + return + if AGENT_STATE_ATTRIBUTE not in changed: + return + if not isinstance(participant, rtc.RemoteParticipant): + return + if not self._is_agent_participant(participant, session.agent_name): + return + self._handle_agent_state(session, participant) + + async def _handle_track_subscribed( + self, + session: DeviceSession, + track: rtc.Track, + publication: rtc.TrackPublication, + participant: rtc.RemoteParticipant, + source: str, + ) -> None: + if session.closed: + return + self._maybe_forward_remote_audio(session, track, publication, participant, source) + + async def _handle_track_published( + self, + session: DeviceSession, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ) -> None: + if session.closed: + return + track = getattr(publication, "track", None) + if track is not None: + self._maybe_forward_remote_audio(session, track, publication, participant, "published") + def _register_room_handlers(self, session: DeviceSession) -> None: @session.room.on("connection_state_changed") def on_connection_state_changed(state: int) -> None: @@ -1069,44 +1240,32 @@ class ESP32LiveKitBridge: @session.room.on("connected") def on_connected() -> None: - print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}") - self._log_agent_participants(session, "connected") - for participant in session.room.remote_participants.values(): - if self._is_agent_participant(participant, session.agent_name): - session.agent_ready.set() - self._scan_participant_audio_tracks(session, participant, "connected_scan") - self._handle_agent_state(session, participant) + self._create_session_task( + session, + self._handle_room_connected(session), + "处理 LiveKit connected 事件", + ) @session.room.on("participant_connected") def on_participant_connected(participant: rtc.RemoteParticipant) -> None: - role = "Agent" if self._is_agent_participant(participant, session.agent_name) else "Remote participant" - print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}") - self._log_agent_participants(session, "participant_connected") - if self._is_agent_participant(participant, session.agent_name): - session.agent_ready.set() - self._scan_participant_audio_tracks( - session, participant, "participant_connected_scan" - ) - self._handle_agent_state(session, participant) + self._create_session_task( + session, + self._handle_participant_connected(session, participant), + "处理 LiveKit participant_connected 事件", + ) @session.room.on("participant_disconnected") def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None: print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}") - session.forwarding_tracks = { - track_sid: task - for track_sid, task in session.forwarding_tracks.items() - if not track_sid.endswith(f":{participant.identity}") - } + self._cancel_forwarding_tracks(session, participant.identity) @session.room.on("participant_attributes_changed") def on_participant_attributes_changed(changed: list[str], participant: rtc.Participant) -> None: - if AGENT_STATE_ATTRIBUTE not in changed: - return - if not isinstance(participant, rtc.RemoteParticipant): - return - if not self._is_agent_participant(participant, session.agent_name): - return - self._handle_agent_state(session, participant) + self._create_session_task( + session, + self._handle_participant_attributes_changed(session, changed, participant), + "处理 LiveKit participant_attributes_changed 事件", + ) @session.room.on("data_received") def on_data_received(data_packet: rtc.DataPacket) -> None: @@ -1133,12 +1292,14 @@ class ESP32LiveKitBridge: ): mcp_payload = payload.get("payload") if isinstance(mcp_payload, dict): - asyncio.create_task( + self._create_session_task( + session, self._forward_mcp_to_device( session, mcp_payload, source_identity=identity, - ) + ), + "转发 MCP 到 ESP32", ) else: print(f"收到 MCP 数据但缺少 payload: {payload}") @@ -1167,7 +1328,11 @@ class ESP32LiveKitBridge: ) if emotion and emotion != session.tts_emotion: session.tts_emotion = emotion - asyncio.create_task(self._send_emotion(session, emotion)) + self._create_session_task( + session, + self._send_emotion(session, emotion), + "发送 emotion", + ) display_text = self._current_tts_display_text(tts_text) print(f"[livekit-llm] display_text={display_text!r} final={segment.final}") if not display_text or display_text == session.tts_transcript_text: @@ -1177,14 +1342,19 @@ class ESP32LiveKitBridge: if not segment.final: continue display_text = segment.text - asyncio.create_task(self._start_thinking(session)) + self._create_session_task( + session, + self._start_thinking(session), + "发送 TTS thinking", + ) if session.websocket is not None: ws = session.websocket if is_agent: self._update_tts_display_text(session, display_text, segment.final) else: - asyncio.create_task( + self._create_session_task( + session, ws.send( json.dumps( { @@ -1193,7 +1363,8 @@ class ESP32LiveKitBridge: "final": segment.final, } ) - ) + ), + "发送 STT 到 ESP32", ) @session.room.on("track_subscribed") @@ -1202,7 +1373,11 @@ class ESP32LiveKitBridge: publication: rtc.TrackPublication, participant: rtc.RemoteParticipant, ) -> None: - self._maybe_forward_remote_audio(session, track, publication, participant, "event") + self._create_session_task( + session, + self._handle_track_subscribed(session, track, publication, participant, "event"), + "处理 LiveKit track_subscribed 事件", + ) @session.room.on("track_published") def on_track_published( @@ -1214,18 +1389,25 @@ class ESP32LiveKitBridge: # f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} " # f"track_sid={track_sid}" # ) - track = getattr(publication, "track", None) - if track is not None: - self._maybe_forward_remote_audio(session, track, publication, participant, "published") + self._create_session_task( + session, + self._handle_track_published(session, publication, participant), + "处理 LiveKit track_published 事件", + ) async def _connect_session_room(self, session: DeviceSession) -> None: + if session.closed: + return self._register_room_handlers(session) + connected = False # print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}") # print(f"[config] token_url={TOKEN_URL}") # print(f"[config] room={session.room_name} identity={session.identity}") # print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}") token = await fetch_token(session.room_name, session.identity, session.agent_name) + if session.closed: + return try: await session.room.connect( @@ -1233,7 +1415,10 @@ class ESP32LiveKitBridge: token, options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS), ) + connected = True except Exception as exc: + if session.closed: + return self._log_exception( f"连接 LiveKit 房间失败: room={session.room_name}", exc, @@ -1245,6 +1430,9 @@ class ESP32LiveKitBridge: ) raise + if session.closed: + await self._disconnect_room_quietly(session, "session_closed_after_connect") + return print(f"已连接到 LiveKit 房间: {session.room.name}") # print(f"[livekit] local_identity={session.room.local_participant.identity}") # print(f"[livekit] local_sid={session.room.local_participant.sid}") @@ -1252,6 +1440,9 @@ class ESP32LiveKitBridge: self._log_agent_participants(session, "after_connect") await self.ensure_agent_dispatched(session) + if session.closed: + await self._disconnect_room_quietly(session, "session_closed_after_dispatch") + return track = rtc.LocalAudioTrack.create_audio_track( f"esp32-mic-{session.device_id}", @@ -1259,6 +1450,9 @@ class ESP32LiveKitBridge: ) options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) publication = await session.room.local_participant.publish_track(track, options) + if session.closed: + await self._disconnect_room_quietly(session, "session_closed_after_publish") + return publication_sid = getattr(publication, "sid", None) track_sid = getattr(track, "sid", None) # print( @@ -1278,6 +1472,9 @@ class ESP32LiveKitBridge: return print(f"⚠️ agent 等待超时: room={session.room_name}") + if connected and session.closed: + await self._disconnect_room_quietly(session, "session_closed_after_agent_wait") + async def start(self) -> None: print(f"[config] websocket_port={WS_PORT}") print(f"[config] websocket_max_queue={WS_MAX_QUEUE} websocket_max_size={WS_MAX_SIZE}") @@ -1323,14 +1520,47 @@ class ESP32LiveKitBridge: session.tts_active = False session.tts_thinking = False session.tts_stream_id += 1 + cleanup_tasks: list[asyncio.Task[Any]] = [] + current_task = asyncio.current_task() + room_connect_pending = ( + session.room_connect_task is not None + and not session.room_connect_task.done() + ) + + if ( + not room_connect_pending + and session.agent_dispatch_task is not None + and not session.agent_dispatch_task.done() + ): + session.agent_dispatch_task.cancel() + if session.agent_dispatch_task is not current_task: + cleanup_tasks.append(session.agent_dispatch_task) + session.agent_dispatch_task = None + + cleanup_tasks.extend(self._cancel_forwarding_tracks(session)) + if session.tts_idle_task is not None: session.tts_idle_task.cancel() + if session.tts_idle_task is not current_task: + cleanup_tasks.append(session.tts_idle_task) session.tts_idle_task = None - self._cancel_tts_display_task(session) - try: - await session.room.disconnect() - except Exception as exc: - print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}") + if session.tts_display_task is not None: + session.tts_display_task.cancel() + if session.tts_display_task is not current_task: + cleanup_tasks.append(session.tts_display_task) + session.tts_display_task = None + + for task in list(session.background_tasks): + if task is current_task or task.done(): + continue + task.cancel() + cleanup_tasks.append(task) + + if cleanup_tasks: + await asyncio.gather(*cleanup_tasks, return_exceptions=True) + + if not room_connect_pending: + await self._disconnect_room_quietly(session, "session_close") async def forward_audio_to_esp32( self, @@ -1500,11 +1730,20 @@ class ESP32LiveKitBridge: except Exception as exc: print(f"音频流处理错误: {exc}") finally: + close_stream = getattr(audio_stream, "aclose", None) or getattr(audio_stream, "close", None) + if close_stream is not None: + try: + result = close_stream() + if result is not None and hasattr(result, "__await__"): + await result + except Exception as exc: + print(f"关闭 LiveKit 音频流失败: {exc}") print("🎧 TTS 音频结束") task = session.forwarding_tracks.get(track_sid) current_task = asyncio.current_task() if task is current_task: session.forwarding_tracks.pop(track_sid, None) + session.forwarding_track_participants.pop(track_sid, None) if stream_id == session.tts_stream_id and session.tts_idle_task is not None: session.tts_idle_task.cancel() session.tts_idle_task = None @@ -1566,9 +1805,14 @@ class ESP32LiveKitBridge: f"已发送 server hello: device={device_id} room={session.room_name} " f"audio={OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms" ) - asyncio.create_task(self._run_emotion_test_sequence(session)) + self._create_session_task( + session, + self._run_emotion_test_sequence(session), + "emotion 测试序列", + ) room_connect_task = asyncio.create_task(self._connect_session_room(session)) + session.room_connect_task = room_connect_task room_connect_task.add_done_callback( lambda task: self._track_room_connect_task(session, task) ) @@ -1656,10 +1900,6 @@ class ESP32LiveKitBridge: self._log_exception("WebSocket 其他错误", exc) finally: print(f"ESP32 断开连接: device={device_id} room={session.room_name}") - if room_connect_task is not None and not room_connect_task.done(): - room_connect_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await room_connect_task await self._close_session(session) self.device_sessions.pop(device_id, None)