643 lines
26 KiB
Python
643 lines
26 KiB
Python
|
||
import asyncio
|
||
import websockets
|
||
import os
|
||
import sys
|
||
import httpx
|
||
import json
|
||
import queue
|
||
import shutil
|
||
import threading
|
||
import struct
|
||
import time
|
||
import wave
|
||
import opuslib
|
||
|
||
from typing import Any, Optional
|
||
from livekit import rtc
|
||
from livekit.rtc import AudioSource, AudioFrame
|
||
from websockets.exceptions import ConnectionClosedError
|
||
|
||
try:
|
||
from livekit import api as livekit_api
|
||
except ImportError:
|
||
livekit_api = None
|
||
|
||
TOKEN_URL = "http://10.6.80.130:8000/getToken"
|
||
LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||
ROOM = "test-livekit-room2"
|
||
IDENTITY = "uv-livekit-hardcoded"
|
||
AGENT_NAME = "my-agent"
|
||
import uuid
|
||
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
||
CONNECT_TIMEOUT_SECONDS = 10.0
|
||
WS_PORT = 8080
|
||
AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True
|
||
|
||
INPUT_SAMPLE_RATE = 16000
|
||
OUTPUT_SAMPLE_RATE = 24000
|
||
INPUT_FRAME_DURATION_MS = 60
|
||
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
|
||
OUTPUT_FRAME_DURATION_MS = 20
|
||
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
|
||
TTS_IDLE_TIMEOUT_SECONDS = 0.8
|
||
TTS_SILENCE_PEAK_THRESHOLD = 96
|
||
TTS_PRE_ROLL_MS = 200
|
||
SAVE_AGENT_TTS_WAV = True
|
||
AGENT_TTS_WAV_PREFIX = "agent_tts"
|
||
|
||
async def fetch_token() -> str:
|
||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||
response = await client.get(
|
||
TOKEN_URL,
|
||
params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME},
|
||
)
|
||
response.raise_for_status()
|
||
|
||
payload: dict[str, Any] = response.json()
|
||
token = payload.get("token")
|
||
if not isinstance(token, str) or not token:
|
||
raise ValueError(f"token response missing token field: {payload}")
|
||
|
||
print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
|
||
print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
|
||
print(f"[token] jwt_prefix={token}")
|
||
return token
|
||
|
||
class ESP32LiveKitBridge:
|
||
def __init__(self):
|
||
self.room = rtc.Room()
|
||
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
|
||
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
|
||
self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1)
|
||
self.esp_ws = None # 保存 WebSocket 连接
|
||
self.audio_queue = queue.Queue()
|
||
self.wav_writer_thread: Optional[threading.Thread] = None
|
||
self.stop_event = threading.Event()
|
||
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
|
||
self.protocol_version = 1
|
||
self.tts_active = False
|
||
self.forwarding_tracks: set[str] = set()
|
||
self.tts_idle_task: Optional[asyncio.Task] = None
|
||
self.tts_stream_id = 0
|
||
self.agent_dispatch_task: Optional[asyncio.Task] = None
|
||
|
||
async def ensure_agent_dispatched(self) -> None:
|
||
if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT:
|
||
return
|
||
|
||
for participant in self.room.remote_participants.values():
|
||
if AGENT_NAME in participant.identity:
|
||
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
|
||
return
|
||
|
||
if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done():
|
||
print("Agent dispatch 正在进行中,跳过重复请求")
|
||
return
|
||
|
||
self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent())
|
||
await self.agent_dispatch_task
|
||
|
||
async def _dispatch_agent(self) -> None:
|
||
print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}")
|
||
|
||
try:
|
||
if await self._dispatch_agent_with_sdk():
|
||
return
|
||
|
||
if await self._dispatch_agent_with_cli():
|
||
return
|
||
|
||
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
|
||
except Exception as e:
|
||
print(f"Agent dispatch 失败: {e}")
|
||
|
||
async def _dispatch_agent_with_sdk(self) -> bool:
|
||
if livekit_api is None:
|
||
return False
|
||
|
||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||
if not api_key or not api_secret:
|
||
return False
|
||
|
||
lkapi = livekit_api.LiveKitAPI(
|
||
url=LIVEKIT_WS_URL,
|
||
api_key=api_key,
|
||
api_secret=api_secret,
|
||
)
|
||
try:
|
||
dispatch = await lkapi.agent_dispatch.create_dispatch(
|
||
livekit_api.CreateAgentDispatchRequest(
|
||
agent_name=AGENT_NAME,
|
||
room=ROOM,
|
||
metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||
)
|
||
)
|
||
print(f"Agent dispatch 已创建: {dispatch}")
|
||
finally:
|
||
await lkapi.aclose()
|
||
|
||
return True
|
||
|
||
async def _dispatch_agent_with_cli(self) -> bool:
|
||
lk_path = shutil.which("lk")
|
||
if lk_path is None:
|
||
return False
|
||
|
||
process = await asyncio.create_subprocess_exec(
|
||
lk_path,
|
||
"dispatch",
|
||
"create",
|
||
"--room",
|
||
ROOM,
|
||
"--agent-name",
|
||
AGENT_NAME,
|
||
"--metadata",
|
||
json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
stdout_text = stdout.decode("utf-8", errors="replace").strip()
|
||
stderr_text = stderr.decode("utf-8", errors="replace").strip()
|
||
|
||
if stdout_text:
|
||
print(f"lk dispatch 输出: {stdout_text}")
|
||
if stderr_text:
|
||
print(f"lk dispatch 错误输出: {stderr_text}")
|
||
|
||
if process.returncode != 0:
|
||
print(f"lk dispatch create 失败,退出码: {process.returncode}")
|
||
return False
|
||
|
||
print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}")
|
||
return True
|
||
|
||
async def _send_tts_state(self, state: str) -> None:
|
||
if self.esp_ws is None:
|
||
print(f"跳过 tts {state},ESP32 尚未连接")
|
||
return
|
||
await self.esp_ws.send(json.dumps({"type": "tts", "state": state}))
|
||
print(f"已发送 tts {state}")
|
||
|
||
async def _start_tts(self) -> None:
|
||
if self.tts_active:
|
||
print("跳过 tts start,当前已处于激活状态")
|
||
return
|
||
await self._send_tts_state("start")
|
||
if self.esp_ws is not None:
|
||
self.tts_active = True
|
||
|
||
async def _stop_tts(self) -> None:
|
||
if not self.tts_active:
|
||
print("跳过 tts stop,当前未激活")
|
||
return
|
||
await self._send_tts_state("stop")
|
||
self.tts_active = False
|
||
|
||
async def _abort_tts(self, reason: str = "client_abort") -> None:
|
||
print(f"收到打断请求,停止当前 TTS: reason={reason}")
|
||
self.tts_stream_id += 1
|
||
if self.tts_idle_task is not None:
|
||
self.tts_idle_task.cancel()
|
||
self.tts_idle_task = None
|
||
await self._stop_tts()
|
||
|
||
def _reset_tts_idle_timer(self) -> None:
|
||
if self.tts_idle_task is not None:
|
||
self.tts_idle_task.cancel()
|
||
self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id))
|
||
|
||
async def _tts_idle_watchdog(self, stream_id: int) -> None:
|
||
try:
|
||
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
|
||
if stream_id != self.tts_stream_id:
|
||
return
|
||
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态")
|
||
await self._stop_tts()
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
def _has_audible_audio(self, pcm_data: bytes) -> bool:
|
||
if len(pcm_data) < 2:
|
||
return False
|
||
|
||
sample_count = len(pcm_data) // 2
|
||
peak = 0
|
||
for i in range(sample_count):
|
||
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
|
||
if sample < 0:
|
||
sample = -sample
|
||
if sample > peak:
|
||
peak = sample
|
||
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
|
||
return True
|
||
return False
|
||
|
||
def _maybe_forward_remote_audio(
|
||
self,
|
||
track: rtc.Track,
|
||
publication: Optional[rtc.TrackPublication],
|
||
participant: rtc.RemoteParticipant,
|
||
source: str,
|
||
) -> None:
|
||
if track.kind != rtc.TrackKind.KIND_AUDIO:
|
||
return
|
||
|
||
track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}"
|
||
if track_sid in self.forwarding_tracks:
|
||
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
|
||
return
|
||
|
||
self.forwarding_tracks.add(track_sid)
|
||
print(f"收到音频流: {participant.identity} sid={track_sid} source={source}")
|
||
asyncio.create_task(
|
||
self.forward_audio_to_esp32(
|
||
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1)
|
||
)
|
||
)
|
||
|
||
def _scan_existing_remote_audio_tracks(self) -> None:
|
||
participants = list(self.room.remote_participants.values())
|
||
print(f"开始扫描远端音频轨,participants={len(participants)}")
|
||
for participant in participants:
|
||
publications = getattr(participant, "track_publications", {}) or {}
|
||
print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}")
|
||
for pub_sid, publication in publications.items():
|
||
track = getattr(publication, "track", None)
|
||
kind = getattr(publication, "kind", None)
|
||
print(
|
||
f"检查 publication: participant={participant.identity} pub_sid={pub_sid} "
|
||
f"kind={kind} has_track={track is not None}"
|
||
)
|
||
if track is not None:
|
||
self._maybe_forward_remote_audio(track, publication, participant, "scan")
|
||
|
||
def _extract_opus_payload(self, message: bytes) -> bytes:
|
||
if self.protocol_version == 2:
|
||
if len(message) < 16:
|
||
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
|
||
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
|
||
"!HHIII", message[:16]
|
||
)
|
||
if version != 2:
|
||
raise ValueError(f"version 2 包头版本异常: {version}")
|
||
if msg_type != 0:
|
||
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
|
||
if len(message) < 16 + payload_size:
|
||
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
|
||
return message[16:16 + payload_size]
|
||
|
||
if self.protocol_version == 3:
|
||
if len(message) < 4:
|
||
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
|
||
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
|
||
if msg_type != 0:
|
||
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
|
||
if len(message) < 4 + payload_size:
|
||
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
|
||
return message[4:4 + payload_size]
|
||
|
||
return message
|
||
|
||
def _wrap_opus_payload(self, payload: bytes) -> bytes:
|
||
if self.protocol_version == 2:
|
||
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
|
||
return header + payload
|
||
|
||
if self.protocol_version == 3:
|
||
header = struct.pack("!BBH", 0, 0, len(payload))
|
||
return header + payload
|
||
|
||
return payload
|
||
|
||
def _wav_writer_loop(self):
|
||
import wave
|
||
print("启动音频保存线程...")
|
||
try:
|
||
with wave.open("bridge_debug.wav", "wb") as wav_file:
|
||
wav_file.setnchannels(1)
|
||
wav_file.setsampwidth(2) # 16-bit
|
||
wav_file.setframerate(INPUT_SAMPLE_RATE)
|
||
while not self.stop_event.is_set() or not self.audio_queue.empty():
|
||
try:
|
||
# 使用 timeout 避免永久阻塞,以便检查 stop_event
|
||
pcm_bytes = self.audio_queue.get(timeout=0.5)
|
||
wav_file.writeframes(pcm_bytes)
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"音频保存线程错误: {e}")
|
||
finally:
|
||
print("音频保存线程退出")
|
||
|
||
async def start(self):
|
||
@self.room.on("connection_state_changed")
|
||
def on_connection_state_changed(state: int) -> None:
|
||
print(f"[livekit] state={rtc.ConnectionState.Name(state)}")
|
||
|
||
@self.room.on("connected")
|
||
def on_connected():
|
||
print("✅ 成功连接到 LiveKit 房间")
|
||
if self.room.remote_participants:
|
||
self._scan_existing_remote_audio_tracks()
|
||
self.agent_ready.set()
|
||
|
||
@self.room.on("participant_connected")
|
||
def on_participant_connected(p: rtc.RemoteParticipant):
|
||
print(f"👋 Agent ({p.identity}) 已加入房间")
|
||
self._scan_existing_remote_audio_tracks()
|
||
self.agent_ready.set()
|
||
|
||
@self.room.on("participant_disconnected")
|
||
def on_participant_disconnected(p: rtc.RemoteParticipant):
|
||
print(f"👋 远端参与者离开房间: {p.identity}")
|
||
self.forwarding_tracks = {
|
||
track_sid for track_sid in self.forwarding_tracks
|
||
if not track_sid.endswith(f":{p.identity}")
|
||
}
|
||
|
||
@self.room.on("data_received")
|
||
def on_data_received(data_packet: rtc.DataPacket):
|
||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||
try:
|
||
print(f"📩 [数据接收 | {identity}]: {data_packet.data.decode('utf-8')}")
|
||
except Exception:
|
||
pass
|
||
|
||
@self.room.on("transcription_received")
|
||
def on_transcription_received(
|
||
segments: list[rtc.TranscriptionSegment],
|
||
participant: rtc.Participant,
|
||
track_pub: rtc.TrackPublication
|
||
):
|
||
identity = participant.identity if participant else "未知"
|
||
for segment in segments:
|
||
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
|
||
print(f"🗣️ [{status} | {identity}]: {segment.text}")
|
||
# 将识别结果实时推送到 ESP32
|
||
if self.esp_ws is not None:
|
||
# 使用局部变量避免类型检查器报错,同时确保在创建任务时 ws 不是 None
|
||
ws = self.esp_ws
|
||
asyncio.create_task(ws.send(json.dumps({
|
||
"type": "stt",
|
||
"text": segment.text,
|
||
"final": segment.final
|
||
})))
|
||
|
||
# 1. 获取 Token 并连接 LiveKit
|
||
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
|
||
print(f"[config] token_url={TOKEN_URL}")
|
||
print(f"[config] room={ROOM} identity={IDENTITY}")
|
||
token = await fetch_token()
|
||
|
||
await asyncio.wait_for(
|
||
self.room.connect(
|
||
LIVEKIT_WS_URL,
|
||
token,
|
||
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
|
||
),
|
||
timeout=CONNECT_TIMEOUT_SECONDS + 2.0,
|
||
)
|
||
print(f"已连接到 LiveKit 房间: {self.room.name}")
|
||
print(f"[livekit] local_identity={self.room.local_participant.identity}")
|
||
print(f"[livekit] local_sid={self.room.local_participant.sid}")
|
||
print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}")
|
||
|
||
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
|
||
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
|
||
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
|
||
await self.room.local_participant.publish_track(track, options)
|
||
|
||
# 3. 监听房间内的音频 (LiveKit -> ESP32)
|
||
@self.room.on("track_subscribed")
|
||
def on_track_subscribed(track, publication, participant):
|
||
self._maybe_forward_remote_audio(track, publication, participant, "event")
|
||
|
||
print("等待 agent 加入...")
|
||
self._scan_existing_remote_audio_tracks()
|
||
try:
|
||
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
|
||
print("✅ agent 已就绪")
|
||
except asyncio.TimeoutError:
|
||
print("⚠️ agent 等待超时(后续可能收不到音频)")
|
||
|
||
async def close(self):
|
||
"""优雅关闭所有连接和资源"""
|
||
self.stop_event.set()
|
||
if self.room:
|
||
await self.room.disconnect()
|
||
|
||
async def forward_audio_to_esp32(self, audio_stream):
|
||
"""从 LiveKit 接收音频并发给 ESP32"""
|
||
|
||
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip')
|
||
pending_pcm = bytearray()
|
||
pre_roll_pcm = bytearray()
|
||
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
|
||
tts_wav_file = None
|
||
tts_wav_path = None
|
||
stream_id = self.tts_stream_id + 1
|
||
self.tts_stream_id = stream_id
|
||
|
||
if SAVE_AGENT_TTS_WAV:
|
||
tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav"
|
||
tts_wav_file = wave.open(tts_wav_path, "wb")
|
||
tts_wav_file.setnchannels(1)
|
||
tts_wav_file.setsampwidth(2)
|
||
tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE)
|
||
print(f"开始保存 Agent TTS 音频: {tts_wav_path}")
|
||
|
||
try:
|
||
async for event in audio_stream:
|
||
if stream_id != self.tts_stream_id:
|
||
print("检测到更新的 TTS 流,停止旧流转发")
|
||
break
|
||
|
||
frame = event.frame
|
||
pcm_data = frame.data.tobytes()
|
||
if tts_wav_file is not None:
|
||
tts_wav_file.writeframes(pcm_data)
|
||
|
||
has_audible_audio = self._has_audible_audio(pcm_data)
|
||
current_frame_buffered = False
|
||
if not self.tts_active:
|
||
pre_roll_pcm.extend(pcm_data)
|
||
current_frame_buffered = True
|
||
if len(pre_roll_pcm) > pre_roll_max_bytes:
|
||
del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes]
|
||
|
||
if not has_audible_audio:
|
||
continue
|
||
|
||
await self._start_tts()
|
||
if not self.tts_active:
|
||
continue
|
||
|
||
print("检测到可听 TTS 音频,切换到 Speaking 并开始转发")
|
||
pending_pcm.extend(pre_roll_pcm)
|
||
pre_roll_pcm.clear()
|
||
|
||
if has_audible_audio:
|
||
self._reset_tts_idle_timer()
|
||
|
||
# LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。
|
||
# 这里聚合成固定 20ms PCM,再编码为 Opus 发给设备,减少卡顿。
|
||
if not current_frame_buffered:
|
||
pending_pcm.extend(pcm_data)
|
||
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM
|
||
|
||
while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id:
|
||
try:
|
||
opus_packet = encoder.encode(
|
||
bytes(pending_pcm[:frame_bytes]),
|
||
OUTPUT_SAMPLES_PER_OPUS_FRAME,
|
||
)
|
||
del pending_pcm[:frame_bytes]
|
||
# print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}")
|
||
await self.esp_ws.send(self._wrap_opus_payload(opus_packet))
|
||
except Exception as e:
|
||
print(f"发送回 ESP32 失败: {e}")
|
||
break
|
||
|
||
except Exception as e:
|
||
print(f"音频流处理错误: {e}")
|
||
|
||
finally:
|
||
print("🎧 TTS 音频结束")
|
||
if tts_wav_file is not None:
|
||
tts_wav_file.close()
|
||
print(f"Agent TTS 音频已保存: {tts_wav_path}")
|
||
if stream_id == self.tts_stream_id and self.tts_idle_task is not None:
|
||
self.tts_idle_task.cancel()
|
||
self.tts_idle_task = None
|
||
if stream_id == self.tts_stream_id:
|
||
await self._stop_tts()
|
||
|
||
async def handle_websocket(self, websocket):
|
||
"""处理来自 ESP32 的 WebSocket 连接"""
|
||
self.esp_ws = websocket
|
||
header_version = websocket.request.headers.get("Protocol-Version")
|
||
try:
|
||
self.protocol_version = int(header_version) if header_version else 1
|
||
except ValueError:
|
||
self.protocol_version = 1
|
||
print("ESP32 已连接")
|
||
print(f"ESP32 协议版本: {self.protocol_version}")
|
||
self.tts_stream_id += 1
|
||
self.tts_active = False
|
||
if self.tts_idle_task is not None:
|
||
self.tts_idle_task.cancel()
|
||
self.tts_idle_task = None
|
||
opus_decoder = None
|
||
try:
|
||
await self.ensure_agent_dispatched()
|
||
|
||
# 发送 hello 告诉 ESP32 握手成功
|
||
hello_msg = {
|
||
"type": "hello",
|
||
"transport": "websocket",
|
||
"audio_params": {
|
||
"format": "opus", # 明确要求 ESP32 发送 Opus
|
||
"sample_rate": OUTPUT_SAMPLE_RATE,
|
||
"channels": 1,
|
||
"frame_duration": OUTPUT_FRAME_DURATION_MS
|
||
}
|
||
}
|
||
await websocket.send(json.dumps(hello_msg))
|
||
|
||
async for message in websocket:
|
||
# 接收 ESP32 的数据 -> 推送到 LiveKit
|
||
if isinstance(message, bytes):
|
||
# 判断如果消息长度极其短并且不是合理的音频流,可能是ping包等
|
||
if len(message) < 4:
|
||
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
||
continue
|
||
|
||
audio_data = self._extract_opus_payload(message)
|
||
# print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}")
|
||
if audio_data:
|
||
try:
|
||
# Create Opus decoder if not exists
|
||
if opus_decoder is None:
|
||
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
|
||
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
|
||
|
||
# 启动音频保存线程
|
||
self.stop_event.clear()
|
||
thread = threading.Thread(target=self._wav_writer_loop, daemon=True)
|
||
self.wav_writer_thread = thread
|
||
thread.start()
|
||
|
||
# Decode Opus packet.
|
||
# ESP32 -> bridge 保持按设备上行配置 60ms 解码。
|
||
frame_size = INPUT_SAMPLES_PER_OPUS_FRAME
|
||
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
|
||
|
||
# 将音频数据放入队列由后台线程保存
|
||
self.audio_queue.put(pcm_bytes)
|
||
|
||
num_samples = len(pcm_bytes) // 2
|
||
if num_samples > 0:
|
||
try:
|
||
# Use the more robust frame creation logic from test_client_wav.py
|
||
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
|
||
await self.mic_source.capture_frame(frame)
|
||
except TypeError:
|
||
# Fallback for different SDK versions
|
||
frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
|
||
memoryview(frame.data).cast('B')[:] = pcm_bytes
|
||
await self.mic_source.capture_frame(frame)
|
||
except Exception as e:
|
||
print(f"Opus audio decode error ({len(message)} bytes): {e}")
|
||
elif isinstance(message, str):
|
||
try:
|
||
data = json.loads(message)
|
||
print(f"收到 ESP32 JSON 消息: {data}")
|
||
msg_type = data.get("type")
|
||
if msg_type == "abort":
|
||
reason = data.get("reason")
|
||
if reason is None:
|
||
await self._abort_tts("button_abort")
|
||
else:
|
||
print(f"忽略非按键打断: reason={reason}")
|
||
except json.JSONDecodeError:
|
||
print(f"收到未知的字符消息: {message}")
|
||
except ConnectionClosedError as e:
|
||
print(f"ESP32 异常断开: {e}")
|
||
except Exception as e:
|
||
print(f"WebSocket 其他错误: {e}")
|
||
finally:
|
||
print("ESP32 断开连接")
|
||
self.esp_ws = None
|
||
self.tts_stream_id += 1
|
||
self.tts_active = False
|
||
if self.tts_idle_task is not None:
|
||
self.tts_idle_task.cancel()
|
||
self.tts_idle_task = None
|
||
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
|
||
self.stop_event.set()
|
||
# 我们不一定需要 join,因为是 daemon=True
|
||
# 但这里设置 stop_event 会让线程在完成队列后退出
|
||
|
||
|
||
async def main():
|
||
bridge = ESP32LiveKitBridge()
|
||
try:
|
||
await bridge.start()
|
||
|
||
# 启动 WebSocket 服务器
|
||
async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT):
|
||
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
|
||
await asyncio.Future() # 保持运行
|
||
finally:
|
||
await bridge.close()
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
asyncio.run(main())
|
||
except Exception as exc:
|
||
print(f"[error] {exc}", file=sys.stderr)
|
||
sys.exit(1)
|