diff --git a/main/Kconfig.projbuild b/main/Kconfig.projbuild index ffaf7c0..c136b4c 100644 --- a/main/Kconfig.projbuild +++ b/main/Kconfig.projbuild @@ -6,6 +6,34 @@ config OTA_URL help The application will access this URL to check for new firmwares and server address. +config USE_DIRECT_WEBSOCKET + bool "Use direct WebSocket without OTA" + default n + help + Skip the OTA server check and use the WebSocket settings below directly. + +config WEBSOCKET_URL + string "Default WebSocket URL" + depends on USE_DIRECT_WEBSOCKET + default "ws://10.6.80.130:8080" + help + The WebSocket server URL used when direct WebSocket mode is enabled. + +config WEBSOCKET_TOKEN + string "Default WebSocket token" + depends on USE_DIRECT_WEBSOCKET + default "" + help + Optional Authorization token for the direct WebSocket server. + +config WEBSOCKET_PROTOCOL_VERSION + int "Default WebSocket protocol version" + depends on USE_DIRECT_WEBSOCKET + range 1 3 + default 1 + help + Protocol-Version header and hello version used by the WebSocket protocol. + choice prompt "Flash Assets" default FLASH_DEFAULT_ASSETS if !USE_EMOTE_MESSAGE_STYLE diff --git a/main/application.cc b/main/application.cc index d54ac6b..23da40f 100644 --- a/main/application.cc +++ b/main/application.cc @@ -302,11 +302,15 @@ void Application::HandleActivationDoneEvent() { SystemInfo::PrintHeapStats(); SetDeviceState(kDeviceStateIdle); - has_server_time_ = ota_->HasServerTime(); + if (ota_ != nullptr) { + has_server_time_ = ota_->HasServerTime(); + } auto display = Board::GetInstance().GetDisplay(); - std::string message = std::string(Lang::Strings::VERSION) + ota_->GetCurrentVersion(); - display->ShowNotification(message.c_str()); + if (ota_ != nullptr) { + std::string message = std::string(Lang::Strings::VERSION) + ota_->GetCurrentVersion(); + display->ShowNotification(message.c_str()); + } display->SetChatMessage("system", ""); // Release OTA object after activation is complete @@ -321,6 +325,10 @@ void Application::HandleActivationDoneEvent() { } void Application::ActivationTask() { +#if CONFIG_USE_DIRECT_WEBSOCKET + CheckAssetsVersion(); + InitializeProtocol(); +#else // Create OTA object for activation process ota_ = std::make_unique(); @@ -332,6 +340,7 @@ void Application::ActivationTask() { // Initialize the protocol InitializeProtocol(); +#endif // Signal completion to main loop xEventGroupSetBits(event_group_, MAIN_EVENT_ACTIVATION_DONE); @@ -477,6 +486,9 @@ void Application::InitializeProtocol() { display->SetStatus(Lang::Strings::LOADING_PROTOCOL); +#if CONFIG_USE_DIRECT_WEBSOCKET + protocol_ = std::make_unique(); +#else if (ota_->HasMqttConfig()) { protocol_ = std::make_unique(); } else if (ota_->HasWebsocketConfig()) { @@ -485,6 +497,7 @@ void Application::InitializeProtocol() { ESP_LOGW(TAG, "No protocol specified in the OTA config, using MQTT"); protocol_ = std::make_unique(); } +#endif protocol_->OnConnected([this]() { DismissAlert(); @@ -1128,4 +1141,3 @@ void Application::ResetProtocol() { protocol_.reset(); }); } - diff --git a/main/boards/waveshare/esp32-s3-touch-lcd-4.3c/sdkconfig.4_3c b/main/boards/waveshare/esp32-s3-touch-lcd-4.3c/sdkconfig.4_3c index 2f73cb6..2129bc5 100755 --- a/main/boards/waveshare/esp32-s3-touch-lcd-4.3c/sdkconfig.4_3c +++ b/main/boards/waveshare/esp32-s3-touch-lcd-4.3c/sdkconfig.4_3c @@ -598,6 +598,10 @@ CONFIG_PARTITION_TABLE_MD5=y # Xiaozhi Assistant # CONFIG_OTA_URL="https://api.tenclass.net/xiaozhi/ota/" +CONFIG_USE_DIRECT_WEBSOCKET=y +CONFIG_WEBSOCKET_URL="ws://10.6.80.130:8080" +CONFIG_WEBSOCKET_TOKEN="" +CONFIG_WEBSOCKET_PROTOCOL_VERSION=1 # CONFIG_FLASH_NONE_ASSETS is not set CONFIG_FLASH_DEFAULT_ASSETS=y # CONFIG_FLASH_CUSTOM_ASSETS is not set diff --git a/main/bridge_server.py b/main/bridge_server.py new file mode 100644 index 0000000..d189b1d --- /dev/null +++ b/main/bridge_server.py @@ -0,0 +1,919 @@ +import asyncio +import json +import os +import shutil +import struct +import sys +import time +import traceback +import uuid +from dataclasses import dataclass, field +from typing import Any, Optional + +import httpx +import opuslib +import websockets +from livekit import rtc +from livekit.rtc import AudioFrame, AudioSource +from websockets.exceptions import ConnectionClosedError + +try: + from livekit import api as livekit_api +except ImportError: + livekit_api = None + +TOKEN_URL = "http://10.6.80.130:8000/getToken" +LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud" +ROOM_PREFIX = "test-livekit" +IDENTITY_PREFIX = "uv-livekit" +AGENT_NAME = "my-agent" +CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0")) +AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0")) +WS_PORT = 8080 +AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower() + +INPUT_SAMPLE_RATE = 16000 +OUTPUT_SAMPLE_RATE = 24000 +INPUT_FRAME_DURATION_MS = 60 +INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000 +OUTPUT_FRAME_DURATION_MS = 20 +OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000 +TTS_IDLE_TIMEOUT_SECONDS = 0.8 +TTS_SILENCE_PEAK_THRESHOLD = 96 +TTS_PRE_ROLL_MS = 200 +TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = 3 +TTS_INTERRUPT_SILENCE_FRAMES = 3 +INTERRUPT_TOPIC = "lk.interrupt" + + +@dataclass +class DeviceSession: + device_id: str + websocket: Any + protocol_version: int + room_name: str + identity: str + room: rtc.Room + mic_source: AudioSource + agent_ready: asyncio.Event + forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict) + tts_active: bool = False + tts_idle_task: Optional[asyncio.Task] = None + tts_stream_id: int = 0 + tts_transcript_text: str = "" + agent_dispatch_task: Optional[asyncio.Task] = None + closed: bool = False + captured_frame_count: int = 0 + first_capture_log_time: float = 0.0 + + +async def fetch_token(room_name: str, identity: str) -> str: + params = {"room": room_name, "identity": identity} + if AGENT_DISPATCH_MODE == "token": + params["agent_name"] = AGENT_NAME + + async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: + response = await client.get(TOKEN_URL, params=params) + response.raise_for_status() + + payload: dict[str, Any] = response.json() + token = payload.get("token") + if not isinstance(token, str) or not token: + raise ValueError(f"token response missing token field: {payload}") + + print(f"[token] room={payload.get('room')} identity={payload.get('identity')}") + print(f"[token] jwt_prefix={token[:16]}... len={len(token)}") + return token + + +class ESP32LiveKitBridge: + def __init__(self) -> None: + self.device_sessions: dict[str, DeviceSession] = {} + + def _log_exception(self, prefix: str, exc: BaseException) -> None: + print(f"{prefix}: type={type(exc).__name__} detail={exc!r}") + formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + if formatted_tb.strip(): + print(formatted_tb.rstrip()) + + def _build_device_id(self, websocket: Any) -> str: + headers = websocket.request.headers + requested_id = headers.get("X-Device-Id") or headers.get("Device-Id") + if requested_id: + return requested_id + return f"ws-{id(websocket):x}" + + def _build_session_names(self, device_id: str) -> tuple[str, str]: + session_tag = uuid.uuid4().hex[:8] + room_name = f"{ROOM_PREFIX}-{session_tag}" + identity = f"{IDENTITY_PREFIX}-{session_tag}" + print(f"[session] device={device_id} room={room_name} identity={identity}") + return room_name, identity + + def _is_agent_participant(self, participant: rtc.RemoteParticipant) -> bool: + identity = getattr(participant, "identity", "") or "" + return identity.startswith("agent-") or AGENT_NAME in identity + + def _get_agent_identities(self, session: DeviceSession) -> list[str]: + return [ + participant.identity + for participant in session.room.remote_participants.values() + if self._is_agent_participant(participant) + ] + + def _log_agent_participants(self, session: DeviceSession, source: str) -> None: + agent_identities = self._get_agent_identities(session) + print( + f"[agent-check] source={source} room={session.room_name} " + f"agent_count={len(agent_identities)} agents={agent_identities}" + ) + + async def _has_existing_dispatch(self, session: DeviceSession) -> bool: + if livekit_api is None: + return False + + api_key = os.getenv("LIVEKIT_API_KEY") + api_secret = os.getenv("LIVEKIT_API_SECRET") + if not api_key or not api_secret: + return False + + request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None) + if request_cls is None: + return False + + lkapi = livekit_api.LiveKitAPI( + url=LIVEKIT_WS_URL, + api_key=api_key, + api_secret=api_secret, + ) + try: + response = await lkapi.agent_dispatch.list_dispatch( + request_cls(room=session.room_name) + ) + dispatches = ( + getattr(response, "agent_dispatches", None) + or getattr(response, "dispatches", None) + or getattr(response, "items", None) + or [] + ) + for dispatch in dispatches: + dispatch_agent_name = getattr(dispatch, "agent_name", None) + dispatch_room = getattr(dispatch, "room", None) + if dispatch_room == session.room_name and ( + dispatch_agent_name == AGENT_NAME or dispatch_agent_name is None + ): + print( + f"检测到已有 dispatch: room={session.room_name} " + f"agent={dispatch_agent_name}" + ) + return True + except Exception as exc: + print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}") + finally: + await lkapi.aclose() + + return False + + async def ensure_agent_dispatched(self, session: DeviceSession) -> None: + if AGENT_DISPATCH_MODE != "bridge": + print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}") + return + + if await self._has_existing_dispatch(session): + return + + for participant in session.room.remote_participants.values(): + if self._is_agent_participant(participant): + print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}") + return + + if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done(): + print("Agent dispatch 正在进行中,跳过重复请求") + return + + session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session)) + await session.agent_dispatch_task + + async def _dispatch_agent(self, session: DeviceSession) -> None: + print(f"准备 dispatch agent: room={session.room_name}, agent={AGENT_NAME}") + + try: + if await self._dispatch_agent_with_sdk(session): + return + + if await self._dispatch_agent_with_cli(session): + return + + print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI") + except Exception as exc: + print(f"Agent dispatch 失败: {exc}") + + async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool: + if livekit_api is None: + return False + + api_key = os.getenv("LIVEKIT_API_KEY") + api_secret = os.getenv("LIVEKIT_API_SECRET") + if not api_key or not api_secret: + return False + + lkapi = livekit_api.LiveKitAPI( + url=LIVEKIT_WS_URL, + api_key=api_key, + api_secret=api_secret, + ) + try: + dispatch = await lkapi.agent_dispatch.create_dispatch( + livekit_api.CreateAgentDispatchRequest( + agent_name=AGENT_NAME, + room=session.room_name, + metadata=json.dumps( + { + "source": "bridge_server", + "identity": session.identity, + "device_id": session.device_id, + } + ), + ) + ) + print(f"Agent dispatch 已创建: {dispatch}") + finally: + await lkapi.aclose() + + return True + + async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool: + lk_path = shutil.which("lk") + if lk_path is None: + return False + + process = await asyncio.create_subprocess_exec( + lk_path, + "dispatch", + "create", + "--room", + session.room_name, + "--agent-name", + AGENT_NAME, + "--metadata", + json.dumps( + { + "source": "bridge_server", + "identity": session.identity, + "device_id": session.device_id, + } + ), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + stdout_text = stdout.decode("utf-8", errors="replace").strip() + stderr_text = stderr.decode("utf-8", errors="replace").strip() + + if stdout_text: + print(f"lk dispatch 输出: {stdout_text}") + if stderr_text: + print(f"lk dispatch 错误输出: {stderr_text}") + + if process.returncode != 0: + print(f"lk dispatch create 失败,退出码: {process.returncode}") + return False + + print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={AGENT_NAME}") + return True + + async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool: + participant = getattr(session.room, "local_participant", None) + if participant is None: + print("跳过发送 agent 控制事件,local participant 尚未就绪") + return False + + data = json.dumps(payload).encode("utf-8") + agent_identities = self._get_agent_identities(session) + kwargs: dict[str, Any] = {} + if agent_identities: + kwargs["destination_identities"] = agent_identities + + last_error: Optional[Exception] = None + for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs): + try: + await participant.publish_data(data, **attempt) + print( + f"已发送 agent 控制事件: topic={attempt.get('topic', '')} " + f"targets={agent_identities or 'broadcast'} payload={payload}" + ) + return True + except TypeError as exc: + last_error = exc + except Exception as exc: + print(f"发送 agent 控制事件失败: {exc}") + return False + + if last_error is not None: + print(f"发送 agent 控制事件失败,publish_data 签名不兼容: {last_error}") + return False + + async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None: + payload = { + "type": "interrupt", + "topic": INTERRUPT_TOPIC, + "reason": reason, + "room": session.room_name, + "identity": session.identity, + "device_id": session.device_id, + } + ok = await self._publish_agent_event(session, payload) + if not ok: + print("警告: bridge 已停止 TTS,但 agent 侧 interrupt 未确认送出") + + async def _send_tts_state(self, session: DeviceSession, state: str) -> None: + if session.websocket is None: + print(f"跳过 tts {state},ESP32 尚未连接") + return + await session.websocket.send(json.dumps({"type": "tts", "state": state})) + print(f"已发送 tts {state}: device={session.device_id}") + + async def _start_tts(self, session: DeviceSession) -> None: + if session.tts_active: + print("跳过 tts start,当前已处于激活状态") + return + session.tts_transcript_text = "" + await self._send_tts_state(session, "start") + session.tts_active = True + + async def _stop_tts(self, session: DeviceSession) -> None: + if not session.tts_active: + print("跳过 tts stop,当前未激活") + return + await self._send_tts_state(session, "stop") + session.tts_active = False + session.tts_transcript_text = "" + + async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None: + print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}") + session.tts_stream_id += 1 + if session.tts_idle_task is not None: + session.tts_idle_task.cancel() + session.tts_idle_task = None + await self._send_agent_interrupt(session, reason) + await self._stop_tts(session) + + def _reset_tts_idle_timer(self, session: DeviceSession) -> None: + if session.tts_idle_task is not None: + session.tts_idle_task.cancel() + session.tts_idle_task = asyncio.create_task( + self._tts_idle_watchdog(session, session.tts_stream_id) + ) + + async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None: + try: + await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS) + if stream_id != session.tts_stream_id: + return + print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态") + await self._stop_tts(session) + except asyncio.CancelledError: + pass + + def _has_audible_audio(self, pcm_data: bytes) -> bool: + if len(pcm_data) < 2: + return False + + sample_count = len(pcm_data) // 2 + peak = 0 + for i in range(sample_count): + sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True) + if sample < 0: + sample = -sample + if sample > peak: + peak = sample + if peak >= TTS_SILENCE_PEAK_THRESHOLD: + return True + return False + + def _maybe_forward_remote_audio( + self, + session: DeviceSession, + track: rtc.Track, + publication: Optional[rtc.TrackPublication], + participant: rtc.RemoteParticipant, + source: str, + ) -> None: + if track.kind != rtc.TrackKind.KIND_AUDIO: + return + + track_sid = ( + getattr(track, "sid", None) + or getattr(publication, "sid", None) + or f"audio:{participant.identity}" + ) + existing_task = session.forwarding_tracks.get(track_sid) + if existing_task is not None and not existing_task.done(): + print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}") + return + if existing_task is not None and existing_task.done(): + print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}") + session.forwarding_tracks.pop(track_sid, None) + + task = asyncio.create_task( + self.forward_audio_to_esp32( + session, + track_sid, + rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1), + ) + ) + session.forwarding_tracks[track_sid] = task + print( + f"收到音频流: {participant.identity} sid={track_sid} " + f"source={source} room={session.room_name}" + ) + + def _scan_participant_audio_tracks( + self, + session: DeviceSession, + participant: rtc.RemoteParticipant, + source: str, + ) -> None: + publications = getattr(participant, "track_publications", None) or {} + for publication in publications.values(): + track = getattr(publication, "track", None) + if track is None: + continue + self._maybe_forward_remote_audio(session, track, publication, participant, source) + + def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes: + if session.protocol_version == 2: + if len(message) < 16: + raise ValueError(f"version 2 音频包过短: {len(message)} bytes") + version, msg_type, _reserved, _timestamp, payload_size = struct.unpack( + "!HHIII", message[:16] + ) + if version != 2: + raise ValueError(f"version 2 包头版本异常: {version}") + if msg_type != 0: + raise ValueError(f"version 2 音频包类型异常: {msg_type}") + if len(message) < 16 + payload_size: + raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}") + return message[16:16 + payload_size] + + if session.protocol_version == 3: + if len(message) < 4: + raise ValueError(f"version 3 音频包过短: {len(message)} bytes") + msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4]) + if msg_type != 0: + raise ValueError(f"version 3 音频包类型异常: {msg_type}") + if len(message) < 4 + payload_size: + raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}") + return message[4:4 + payload_size] + + return message + + def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes: + if session.protocol_version == 2: + header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload)) + return header + payload + + if session.protocol_version == 3: + header = struct.pack("!BBH", 0, 0, len(payload)) + return header + payload + + return payload + + def _register_room_handlers(self, session: DeviceSession) -> None: + @session.room.on("connection_state_changed") + def on_connection_state_changed(state: int) -> None: + print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}") + + @session.room.on("connected") + def on_connected() -> None: + print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}") + self._log_agent_participants(session, "connected") + for participant in session.room.remote_participants.values(): + if self._is_agent_participant(participant): + session.agent_ready.set() + self._scan_participant_audio_tracks(session, participant, "connected_scan") + + @session.room.on("participant_connected") + def on_participant_connected(participant: rtc.RemoteParticipant) -> None: + role = "Agent" if self._is_agent_participant(participant) else "Remote participant" + print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}") + self._log_agent_participants(session, "participant_connected") + if self._is_agent_participant(participant): + session.agent_ready.set() + self._scan_participant_audio_tracks( + session, participant, "participant_connected_scan" + ) + + @session.room.on("participant_disconnected") + def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None: + print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}") + session.forwarding_tracks = { + track_sid: task + for track_sid, task in session.forwarding_tracks.items() + if not track_sid.endswith(f":{participant.identity}") + } + + @session.room.on("data_received") + def on_data_received(data_packet: rtc.DataPacket) -> None: + identity = data_packet.participant.identity if data_packet.participant else "未知" + try: + print( + f"📩 [数据接收 | room={session.room_name} | {identity}]: " + f"{data_packet.data.decode('utf-8')}" + ) + except Exception: + pass + + @session.room.on("transcription_received") + def on_transcription_received( + segments: list[rtc.TranscriptionSegment], + participant: rtc.Participant, + track_pub: rtc.TrackPublication, + ) -> None: + identity = participant.identity if participant else "未知" + is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(participant) + for segment in segments: + status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果" + print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}") + if session.websocket is not None: + ws = session.websocket + if is_agent: + session.tts_transcript_text = segment.text + asyncio.create_task( + ws.send( + json.dumps( + { + "type": "tts", + "state": "sentence_start", + "text": session.tts_transcript_text, + "final": segment.final, + } + ) + ) + ) + else: + asyncio.create_task( + ws.send( + json.dumps( + { + "type": "stt", + "text": segment.text, + "final": segment.final, + } + ) + ) + ) + + @session.room.on("track_subscribed") + def on_track_subscribed( + track: rtc.Track, + publication: rtc.TrackPublication, + participant: rtc.RemoteParticipant, + ) -> None: + self._maybe_forward_remote_audio(session, track, publication, participant, "event") + + @session.room.on("track_published") + def on_track_published( + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ) -> None: + track_sid = getattr(publication, "sid", None) + print( + f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} " + f"track_sid={track_sid}" + ) + track = getattr(publication, "track", None) + if track is not None: + self._maybe_forward_remote_audio(session, track, publication, participant, "published") + + async def _connect_session_room(self, session: DeviceSession) -> None: + self._register_room_handlers(session) + + print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}") + print(f"[config] token_url={TOKEN_URL}") + print(f"[config] room={session.room_name} identity={session.identity}") + print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}") + token = await fetch_token(session.room_name, session.identity) + + try: + await session.room.connect( + LIVEKIT_WS_URL, + token, + options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS), + ) + except Exception as exc: + self._log_exception( + f"连接 LiveKit 房间失败: room={session.room_name}", + exc, + ) + print( + "提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;" + "请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、" + "*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。" + ) + raise + + print(f"已连接到 LiveKit 房间: {session.room.name}") + print(f"[livekit] local_identity={session.room.local_participant.identity}") + print(f"[livekit] local_sid={session.room.local_participant.sid}") + print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}") + self._log_agent_participants(session, "after_connect") + + await self.ensure_agent_dispatched(session) + + track = rtc.LocalAudioTrack.create_audio_track( + f"esp32-mic-{session.device_id}", + session.mic_source, + ) + options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) + publication = await session.room.local_participant.publish_track(track, options) + publication_sid = getattr(publication, "sid", None) + track_sid = getattr(track, "sid", None) + print( + f"已发布 ESP32 mic track: room={session.room_name} " + f"track_sid={track_sid} publication_sid={publication_sid}" + ) + self._log_agent_participants(session, "after_publish_mic") + + print(f"等待 agent 加入: room={session.room_name}") + try: + await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS) + print(f"✅ agent 已就绪: room={session.room_name}") + except asyncio.TimeoutError: + print(f"⚠️ agent 等待超时: room={session.room_name}") + + async def start(self) -> None: + print(f"[config] websocket_port={WS_PORT}") + print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}") + print(f"[config] token_url={TOKEN_URL}") + print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}") + + async def close(self) -> None: + for session in list(self.device_sessions.values()): + await self._close_session(session) + + async def _close_session(self, session: DeviceSession) -> None: + if session.closed: + return + session.closed = True + session.websocket = None + session.tts_active = False + session.tts_stream_id += 1 + if session.tts_idle_task is not None: + session.tts_idle_task.cancel() + session.tts_idle_task = None + try: + await session.room.disconnect() + except Exception as exc: + print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}") + + async def forward_audio_to_esp32( + self, + session: DeviceSession, + track_sid: str, + audio_stream: rtc.AudioStream, + ) -> None: + encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip") + pending_pcm = bytearray() + pre_roll_pcm = bytearray() + pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2 + audible_frame_streak = 0 + silence_frame_streak = 0 + waiting_for_post_interrupt_silence = False + stream_id = session.tts_stream_id + print( + f"启动 TTS 转发: device={session.device_id} room={session.room_name} " + f"track_sid={track_sid} stream_id={stream_id}" + ) + + try: + async for event in audio_stream: + if stream_id != session.tts_stream_id: + print("检测到 TTS 被中断,停止当前播放并等待静音分隔") + pending_pcm.clear() + pre_roll_pcm.clear() + audible_frame_streak = 0 + silence_frame_streak = 0 + waiting_for_post_interrupt_silence = True + stream_id = session.tts_stream_id + if session.tts_active: + await self._stop_tts(session) + continue + + frame = event.frame + pcm_data = frame.data.tobytes() + has_audible_audio = self._has_audible_audio(pcm_data) + + if waiting_for_post_interrupt_silence: + if has_audible_audio: + silence_frame_streak = 0 + continue + + silence_frame_streak += 1 + if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES: + continue + + print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播") + waiting_for_post_interrupt_silence = False + silence_frame_streak = 0 + stream_id = session.tts_stream_id + continue + + current_frame_buffered = False + if not session.tts_active: + pre_roll_pcm.extend(pcm_data) + current_frame_buffered = True + if len(pre_roll_pcm) > pre_roll_max_bytes: + del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes] + + if not has_audible_audio: + audible_frame_streak = 0 + continue + + audible_frame_streak += 1 + if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES: + continue + + await self._start_tts(session) + if not session.tts_active: + continue + + print( + "检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 " + f"(frames={audible_frame_streak})" + ) + pending_pcm.extend(pre_roll_pcm) + pre_roll_pcm.clear() + audible_frame_streak = 0 + + if has_audible_audio: + self._reset_tts_idle_timer(session) + + if not current_frame_buffered: + pending_pcm.extend(pcm_data) + frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 + + while ( + len(pending_pcm) >= frame_bytes + and stream_id == session.tts_stream_id + and session.websocket is not None + ): + try: + opus_packet = encoder.encode( + bytes(pending_pcm[:frame_bytes]), + OUTPUT_SAMPLES_PER_OPUS_FRAME, + ) + del pending_pcm[:frame_bytes] + await session.websocket.send(self._wrap_opus_payload(session, opus_packet)) + except Exception as exc: + print(f"发送回 ESP32 失败: {exc}") + break + + except Exception as exc: + print(f"音频流处理错误: {exc}") + finally: + print("🎧 TTS 音频结束") + task = session.forwarding_tracks.get(track_sid) + current_task = asyncio.current_task() + if task is current_task: + session.forwarding_tracks.pop(track_sid, None) + if stream_id == session.tts_stream_id and session.tts_idle_task is not None: + session.tts_idle_task.cancel() + session.tts_idle_task = None + if stream_id == session.tts_stream_id: + await self._stop_tts(session) + + async def handle_websocket(self, websocket: Any) -> None: + header_version = websocket.request.headers.get("Protocol-Version") + try: + protocol_version = int(header_version) if header_version else 1 + except ValueError: + protocol_version = 1 + + device_id = self._build_device_id(websocket) + existing_session = self.device_sessions.get(device_id) + if existing_session is not None: + print(f"检测到重复 device_id,关闭旧 session: device={device_id}") + await self._close_session(existing_session) + self.device_sessions.pop(device_id, None) + + room_name, identity = self._build_session_names(device_id) + session = DeviceSession( + device_id=device_id, + websocket=websocket, + protocol_version=protocol_version, + room_name=room_name, + identity=identity, + room=rtc.Room(), + mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1), + agent_ready=asyncio.Event(), + ) + self.device_sessions[device_id] = session + + print(f"ESP32 已连接: device={device_id}") + print(f"ESP32 协议版本: {session.protocol_version}") + session.tts_stream_id += 1 + opus_decoder = None + + try: + await self._connect_session_room(session) + + hello_msg = { + "type": "hello", + "transport": "websocket", + "session": { + "room": session.room_name, + "identity": session.identity, + }, + "audio_params": { + "format": "opus", + "sample_rate": OUTPUT_SAMPLE_RATE, + "channels": 1, + "frame_duration": OUTPUT_FRAME_DURATION_MS, + }, + } + await websocket.send(json.dumps(hello_msg)) + + async for message in websocket: + if isinstance(message, bytes): + if len(message) < 4: + print(f"收到过短的字节消息 ({len(message)} bytes),跳过") + continue + + audio_data = self._extract_opus_payload(session, message) + if not audio_data: + continue + + try: + if opus_decoder is None: + print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono") + opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1) + + pcm_bytes = opus_decoder.decode(audio_data, INPUT_SAMPLES_PER_OPUS_FRAME) + + num_samples = len(pcm_bytes) // 2 + if num_samples > 0: + session.captured_frame_count += 1 + now = time.monotonic() + if ( + session.captured_frame_count <= 5 + or now - session.first_capture_log_time >= 5.0 + ): + session.first_capture_log_time = now + print( + f"[uplink] capture_frame count={session.captured_frame_count} " + f"bytes={len(pcm_bytes)} samples={num_samples} " + f"room={session.room_name}" + ) + try: + frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples) + await session.mic_source.capture_frame(frame) + except TypeError: + frame = AudioFrame.create( + sample_rate=INPUT_SAMPLE_RATE, + num_channels=1, + samples_per_channel=num_samples, + ) + memoryview(frame.data).cast("B")[:] = pcm_bytes + await session.mic_source.capture_frame(frame) + except Exception as exc: + print(f"Opus audio decode error ({len(message)} bytes): {exc}") + elif isinstance(message, str): + try: + data = json.loads(message) + print(f"收到 ESP32 JSON 消息: {data}") + msg_type = data.get("type") + if msg_type == "abort": + reason = data.get("reason") + abort_reason = reason if isinstance(reason, str) and reason else "button_abort" + print(f"处理 ESP32 打断请求: reason={abort_reason}") + await self._abort_tts(session, abort_reason) + except json.JSONDecodeError: + print(f"收到未知的字符消息: {message}") + except ConnectionClosedError as exc: + print(f"ESP32 异常断开: {exc}") + except Exception as exc: + self._log_exception("WebSocket 其他错误", exc) + finally: + print(f"ESP32 断开连接: device={device_id} room={session.room_name}") + await self._close_session(session) + self.device_sessions.pop(device_id, None) + + +async def main() -> None: + bridge = ESP32LiveKitBridge() + try: + await bridge.start() + async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT): + print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...") + await asyncio.Future() + finally: + await bridge.close() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as exc: + print(f"[error] {exc}", file=sys.stderr) + sys.exit(1) diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 4b74a5e..cd775d1 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -85,10 +85,21 @@ bool WebsocketProtocol::OpenAudioChannel() { std::string url = settings.GetString("url"); std::string token = settings.GetString("token"); int version = settings.GetInt("version"); +#if CONFIG_USE_DIRECT_WEBSOCKET + url = CONFIG_WEBSOCKET_URL; + token = CONFIG_WEBSOCKET_TOKEN; + version = CONFIG_WEBSOCKET_PROTOCOL_VERSION; +#endif if (version != 0) { version_ = version; } + if (url.empty()) { + ESP_LOGE(TAG, "Websocket URL is not set"); + SetError(Lang::Strings::SERVER_NOT_CONNECTED); + return false; + } + error_occurred_ = false; auto network = Board::GetInstance().GetNetwork();