feat: support defferent models
This commit is contained in:
@ -5,40 +5,52 @@ import os
|
||||
import sys
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import shutil
|
||||
import threading
|
||||
import struct
|
||||
import time
|
||||
import wave
|
||||
import opuslib
|
||||
|
||||
from typing import Any, Optional
|
||||
from livekit import rtc
|
||||
from livekit.rtc import AudioSource, AudioFrame
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
import http.server
|
||||
import multipart
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
# 配置信息
|
||||
# TOKEN_URL = "http://10.6.80.130:8000/v1/token"
|
||||
# LIVEKIT_WS_URL = "ws://10.6.80.130:8000/"
|
||||
# ROOM = "vera-room"
|
||||
# IDENTITY = "vera-1"
|
||||
# TOKEN_URL = "https://omnichat.bwgdi.com/v1/token"
|
||||
try:
|
||||
from livekit import api as livekit_api
|
||||
except ImportError:
|
||||
livekit_api = None
|
||||
|
||||
TOKEN_URL = "http://10.6.80.130:8000/getToken"
|
||||
LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud"
|
||||
# LIVEKIT_WS_URL = "wss://rtc.bwgdi.com/"
|
||||
LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||||
ROOM = "test-livekit-room2"
|
||||
IDENTITY = "uv-livekit-hardcoded"
|
||||
AGENT_NAME = "my-agent"
|
||||
import uuid
|
||||
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
||||
CONNECT_TIMEOUT_SECONDS = 10.0
|
||||
WS_PORT = 8080
|
||||
AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
INPUT_SAMPLE_RATE = 16000
|
||||
OUTPUT_SAMPLE_RATE = 24000
|
||||
INPUT_FRAME_DURATION_MS = 60
|
||||
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
|
||||
OUTPUT_FRAME_DURATION_MS = 20
|
||||
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
|
||||
TTS_IDLE_TIMEOUT_SECONDS = 0.8
|
||||
TTS_SILENCE_PEAK_THRESHOLD = 96
|
||||
TTS_PRE_ROLL_MS = 200
|
||||
SAVE_AGENT_TTS_WAV = True
|
||||
AGENT_TTS_WAV_PREFIX = "agent_tts"
|
||||
|
||||
async def fetch_token() -> str:
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
TOKEN_URL,
|
||||
params={"room": ROOM, "identity": IDENTITY, "agent_name": "my-agent"},
|
||||
params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -57,12 +69,248 @@ class ESP32LiveKitBridge:
|
||||
self.room = rtc.Room()
|
||||
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
|
||||
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
|
||||
self.mic_source = AudioSource(sample_rate=SAMPLE_RATE, num_channels=1)
|
||||
self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1)
|
||||
self.esp_ws = None # 保存 WebSocket 连接
|
||||
self.audio_queue = queue.Queue()
|
||||
self.wav_writer_thread: Optional[threading.Thread] = None
|
||||
self.stop_event = threading.Event()
|
||||
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
|
||||
self.protocol_version = 1
|
||||
self.tts_active = False
|
||||
self.forwarding_tracks: set[str] = set()
|
||||
self.tts_idle_task: Optional[asyncio.Task] = None
|
||||
self.tts_stream_id = 0
|
||||
self.agent_dispatch_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def ensure_agent_dispatched(self) -> None:
|
||||
if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT:
|
||||
return
|
||||
|
||||
for participant in self.room.remote_participants.values():
|
||||
if AGENT_NAME in participant.identity:
|
||||
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
|
||||
return
|
||||
|
||||
if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done():
|
||||
print("Agent dispatch 正在进行中,跳过重复请求")
|
||||
return
|
||||
|
||||
self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent())
|
||||
await self.agent_dispatch_task
|
||||
|
||||
async def _dispatch_agent(self) -> None:
|
||||
print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}")
|
||||
|
||||
try:
|
||||
if await self._dispatch_agent_with_sdk():
|
||||
return
|
||||
|
||||
if await self._dispatch_agent_with_cli():
|
||||
return
|
||||
|
||||
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
|
||||
except Exception as e:
|
||||
print(f"Agent dispatch 失败: {e}")
|
||||
|
||||
async def _dispatch_agent_with_sdk(self) -> bool:
|
||||
if livekit_api is None:
|
||||
return False
|
||||
|
||||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||||
if not api_key or not api_secret:
|
||||
return False
|
||||
|
||||
lkapi = livekit_api.LiveKitAPI(
|
||||
url=LIVEKIT_WS_URL,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
)
|
||||
try:
|
||||
dispatch = await lkapi.agent_dispatch.create_dispatch(
|
||||
livekit_api.CreateAgentDispatchRequest(
|
||||
agent_name=AGENT_NAME,
|
||||
room=ROOM,
|
||||
metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||||
)
|
||||
)
|
||||
print(f"Agent dispatch 已创建: {dispatch}")
|
||||
finally:
|
||||
await lkapi.aclose()
|
||||
|
||||
return True
|
||||
|
||||
async def _dispatch_agent_with_cli(self) -> bool:
|
||||
lk_path = shutil.which("lk")
|
||||
if lk_path is None:
|
||||
return False
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
lk_path,
|
||||
"dispatch",
|
||||
"create",
|
||||
"--room",
|
||||
ROOM,
|
||||
"--agent-name",
|
||||
AGENT_NAME,
|
||||
"--metadata",
|
||||
json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout_text = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_text = stderr.decode("utf-8", errors="replace").strip()
|
||||
|
||||
if stdout_text:
|
||||
print(f"lk dispatch 输出: {stdout_text}")
|
||||
if stderr_text:
|
||||
print(f"lk dispatch 错误输出: {stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
print(f"lk dispatch create 失败,退出码: {process.returncode}")
|
||||
return False
|
||||
|
||||
print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}")
|
||||
return True
|
||||
|
||||
async def _send_tts_state(self, state: str) -> None:
|
||||
if self.esp_ws is None:
|
||||
print(f"跳过 tts {state},ESP32 尚未连接")
|
||||
return
|
||||
await self.esp_ws.send(json.dumps({"type": "tts", "state": state}))
|
||||
print(f"已发送 tts {state}")
|
||||
|
||||
async def _start_tts(self) -> None:
|
||||
if self.tts_active:
|
||||
print("跳过 tts start,当前已处于激活状态")
|
||||
return
|
||||
await self._send_tts_state("start")
|
||||
if self.esp_ws is not None:
|
||||
self.tts_active = True
|
||||
|
||||
async def _stop_tts(self) -> None:
|
||||
if not self.tts_active:
|
||||
print("跳过 tts stop,当前未激活")
|
||||
return
|
||||
await self._send_tts_state("stop")
|
||||
self.tts_active = False
|
||||
|
||||
async def _abort_tts(self, reason: str = "client_abort") -> None:
|
||||
print(f"收到打断请求,停止当前 TTS: reason={reason}")
|
||||
self.tts_stream_id += 1
|
||||
if self.tts_idle_task is not None:
|
||||
self.tts_idle_task.cancel()
|
||||
self.tts_idle_task = None
|
||||
await self._stop_tts()
|
||||
|
||||
def _reset_tts_idle_timer(self) -> None:
|
||||
if self.tts_idle_task is not None:
|
||||
self.tts_idle_task.cancel()
|
||||
self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id))
|
||||
|
||||
async def _tts_idle_watchdog(self, stream_id: int) -> None:
|
||||
try:
|
||||
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
|
||||
if stream_id != self.tts_stream_id:
|
||||
return
|
||||
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态")
|
||||
await self._stop_tts()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _has_audible_audio(self, pcm_data: bytes) -> bool:
|
||||
if len(pcm_data) < 2:
|
||||
return False
|
||||
|
||||
sample_count = len(pcm_data) // 2
|
||||
peak = 0
|
||||
for i in range(sample_count):
|
||||
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
|
||||
if sample < 0:
|
||||
sample = -sample
|
||||
if sample > peak:
|
||||
peak = sample
|
||||
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _maybe_forward_remote_audio(
|
||||
self,
|
||||
track: rtc.Track,
|
||||
publication: Optional[rtc.TrackPublication],
|
||||
participant: rtc.RemoteParticipant,
|
||||
source: str,
|
||||
) -> None:
|
||||
if track.kind != rtc.TrackKind.KIND_AUDIO:
|
||||
return
|
||||
|
||||
track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}"
|
||||
if track_sid in self.forwarding_tracks:
|
||||
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
|
||||
return
|
||||
|
||||
self.forwarding_tracks.add(track_sid)
|
||||
print(f"收到音频流: {participant.identity} sid={track_sid} source={source}")
|
||||
asyncio.create_task(
|
||||
self.forward_audio_to_esp32(
|
||||
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1)
|
||||
)
|
||||
)
|
||||
|
||||
def _scan_existing_remote_audio_tracks(self) -> None:
|
||||
participants = list(self.room.remote_participants.values())
|
||||
print(f"开始扫描远端音频轨,participants={len(participants)}")
|
||||
for participant in participants:
|
||||
publications = getattr(participant, "track_publications", {}) or {}
|
||||
print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}")
|
||||
for pub_sid, publication in publications.items():
|
||||
track = getattr(publication, "track", None)
|
||||
kind = getattr(publication, "kind", None)
|
||||
print(
|
||||
f"检查 publication: participant={participant.identity} pub_sid={pub_sid} "
|
||||
f"kind={kind} has_track={track is not None}"
|
||||
)
|
||||
if track is not None:
|
||||
self._maybe_forward_remote_audio(track, publication, participant, "scan")
|
||||
|
||||
def _extract_opus_payload(self, message: bytes) -> bytes:
|
||||
if self.protocol_version == 2:
|
||||
if len(message) < 16:
|
||||
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
|
||||
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
|
||||
"!HHIII", message[:16]
|
||||
)
|
||||
if version != 2:
|
||||
raise ValueError(f"version 2 包头版本异常: {version}")
|
||||
if msg_type != 0:
|
||||
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
|
||||
if len(message) < 16 + payload_size:
|
||||
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
|
||||
return message[16:16 + payload_size]
|
||||
|
||||
if self.protocol_version == 3:
|
||||
if len(message) < 4:
|
||||
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
|
||||
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
|
||||
if msg_type != 0:
|
||||
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
|
||||
if len(message) < 4 + payload_size:
|
||||
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
|
||||
return message[4:4 + payload_size]
|
||||
|
||||
return message
|
||||
|
||||
def _wrap_opus_payload(self, payload: bytes) -> bytes:
|
||||
if self.protocol_version == 2:
|
||||
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
|
||||
return header + payload
|
||||
|
||||
if self.protocol_version == 3:
|
||||
header = struct.pack("!BBH", 0, 0, len(payload))
|
||||
return header + payload
|
||||
|
||||
return payload
|
||||
|
||||
def _wav_writer_loop(self):
|
||||
import wave
|
||||
@ -71,7 +319,7 @@ class ESP32LiveKitBridge:
|
||||
with wave.open("bridge_debug.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(SAMPLE_RATE)
|
||||
wav_file.setframerate(INPUT_SAMPLE_RATE)
|
||||
while not self.stop_event.is_set() or not self.audio_queue.empty():
|
||||
try:
|
||||
# 使用 timeout 避免永久阻塞,以便检查 stop_event
|
||||
@ -93,13 +341,23 @@ class ESP32LiveKitBridge:
|
||||
def on_connected():
|
||||
print("✅ 成功连接到 LiveKit 房间")
|
||||
if self.room.remote_participants:
|
||||
self._scan_existing_remote_audio_tracks()
|
||||
self.agent_ready.set()
|
||||
|
||||
@self.room.on("participant_connected")
|
||||
def on_participant_connected(p: rtc.RemoteParticipant):
|
||||
print(f"👋 Agent ({p.identity}) 已加入房间")
|
||||
self._scan_existing_remote_audio_tracks()
|
||||
self.agent_ready.set()
|
||||
|
||||
@self.room.on("participant_disconnected")
|
||||
def on_participant_disconnected(p: rtc.RemoteParticipant):
|
||||
print(f"👋 远端参与者离开房间: {p.identity}")
|
||||
self.forwarding_tracks = {
|
||||
track_sid for track_sid in self.forwarding_tracks
|
||||
if not track_sid.endswith(f":{p.identity}")
|
||||
}
|
||||
|
||||
@self.room.on("data_received")
|
||||
def on_data_received(data_packet: rtc.DataPacket):
|
||||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||||
@ -145,6 +403,7 @@ class ESP32LiveKitBridge:
|
||||
print(f"已连接到 LiveKit 房间: {self.room.name}")
|
||||
print(f"[livekit] local_identity={self.room.local_participant.identity}")
|
||||
print(f"[livekit] local_sid={self.room.local_participant.sid}")
|
||||
print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}")
|
||||
|
||||
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
|
||||
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
|
||||
@ -154,11 +413,10 @@ class ESP32LiveKitBridge:
|
||||
# 3. 监听房间内的音频 (LiveKit -> ESP32)
|
||||
@self.room.on("track_subscribed")
|
||||
def on_track_subscribed(track, publication, participant):
|
||||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||
print(f"收到音频流: {participant.identity}")
|
||||
asyncio.create_task(self.forward_audio_to_esp32(rtc.AudioStream(track, sample_rate=SAMPLE_RATE, num_channels=1)))
|
||||
self._maybe_forward_remote_audio(track, publication, participant, "event")
|
||||
|
||||
print("等待 agent 加入...")
|
||||
self._scan_existing_remote_audio_tracks()
|
||||
try:
|
||||
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
|
||||
print("✅ agent 已就绪")
|
||||
@ -172,53 +430,121 @@ class ESP32LiveKitBridge:
|
||||
await self.room.disconnect()
|
||||
|
||||
async def forward_audio_to_esp32(self, audio_stream):
|
||||
"""从 LiveKit 接收音频,通过 WebSocket 发回给 ESP32"""
|
||||
import opuslib
|
||||
import json
|
||||
# 创建下行 Opus 编码器
|
||||
encoder = opuslib.Encoder(SAMPLE_RATE, 1, 'voip')
|
||||
"""从 LiveKit 接收音频并发给 ESP32"""
|
||||
|
||||
# 1. 告知 ESP32 开始说话,切换 UI 到“说话中”并准备解码
|
||||
if self.esp_ws:
|
||||
await self.esp_ws.send(json.dumps({"type": "tts", "state": "start"}))
|
||||
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip')
|
||||
pending_pcm = bytearray()
|
||||
pre_roll_pcm = bytearray()
|
||||
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
|
||||
tts_wav_file = None
|
||||
tts_wav_path = None
|
||||
stream_id = self.tts_stream_id + 1
|
||||
self.tts_stream_id = stream_id
|
||||
|
||||
if SAVE_AGENT_TTS_WAV:
|
||||
tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav"
|
||||
tts_wav_file = wave.open(tts_wav_path, "wb")
|
||||
tts_wav_file.setnchannels(1)
|
||||
tts_wav_file.setsampwidth(2)
|
||||
tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE)
|
||||
print(f"开始保存 Agent TTS 音频: {tts_wav_path}")
|
||||
|
||||
try:
|
||||
async for event in audio_stream:
|
||||
if self.esp_ws:
|
||||
try:
|
||||
# AudioStream 迭代产生的是 AudioFrameEvent,需要从中提取 frame
|
||||
frame = event.frame
|
||||
# 将 PCM 编码为 Opus 才能发给 ESP32
|
||||
pcm_data = frame.data.tobytes()
|
||||
if stream_id != self.tts_stream_id:
|
||||
print("检测到更新的 TTS 流,停止旧流转发")
|
||||
break
|
||||
|
||||
# 使用当前帧的实际采样数进行编码
|
||||
opus_packet = encoder.encode(pcm_data, frame.samples_per_channel)
|
||||
await self.esp_ws.send(opus_packet)
|
||||
frame = event.frame
|
||||
pcm_data = frame.data.tobytes()
|
||||
if tts_wav_file is not None:
|
||||
tts_wav_file.writeframes(pcm_data)
|
||||
|
||||
has_audible_audio = self._has_audible_audio(pcm_data)
|
||||
current_frame_buffered = False
|
||||
if not self.tts_active:
|
||||
pre_roll_pcm.extend(pcm_data)
|
||||
current_frame_buffered = True
|
||||
if len(pre_roll_pcm) > pre_roll_max_bytes:
|
||||
del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes]
|
||||
|
||||
if not has_audible_audio:
|
||||
continue
|
||||
|
||||
await self._start_tts()
|
||||
if not self.tts_active:
|
||||
continue
|
||||
|
||||
print("检测到可听 TTS 音频,切换到 Speaking 并开始转发")
|
||||
pending_pcm.extend(pre_roll_pcm)
|
||||
pre_roll_pcm.clear()
|
||||
|
||||
if has_audible_audio:
|
||||
self._reset_tts_idle_timer()
|
||||
|
||||
# LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。
|
||||
# 这里聚合成固定 20ms PCM,再编码为 Opus 发给设备,减少卡顿。
|
||||
if not current_frame_buffered:
|
||||
pending_pcm.extend(pcm_data)
|
||||
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM
|
||||
|
||||
while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id:
|
||||
try:
|
||||
opus_packet = encoder.encode(
|
||||
bytes(pending_pcm[:frame_bytes]),
|
||||
OUTPUT_SAMPLES_PER_OPUS_FRAME,
|
||||
)
|
||||
del pending_pcm[:frame_bytes]
|
||||
# print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}")
|
||||
await self.esp_ws.send(self._wrap_opus_payload(opus_packet))
|
||||
except Exception as e:
|
||||
print(f"发送回 ESP32 失败: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"音频流处理错误: {e}")
|
||||
|
||||
finally:
|
||||
# 2. 音频流结束,告知 ESP32 停止说话,切换回聆听或闲置状态
|
||||
if self.esp_ws:
|
||||
await self.esp_ws.send(json.dumps({"type": "tts", "state": "stop"}))
|
||||
print("🎧 TTS 音频结束")
|
||||
if tts_wav_file is not None:
|
||||
tts_wav_file.close()
|
||||
print(f"Agent TTS 音频已保存: {tts_wav_path}")
|
||||
if stream_id == self.tts_stream_id and self.tts_idle_task is not None:
|
||||
self.tts_idle_task.cancel()
|
||||
self.tts_idle_task = None
|
||||
if stream_id == self.tts_stream_id:
|
||||
await self._stop_tts()
|
||||
|
||||
async def handle_websocket(self, websocket):
|
||||
"""处理来自 ESP32 的 WebSocket 连接"""
|
||||
self.esp_ws = websocket
|
||||
header_version = websocket.request.headers.get("Protocol-Version")
|
||||
try:
|
||||
self.protocol_version = int(header_version) if header_version else 1
|
||||
except ValueError:
|
||||
self.protocol_version = 1
|
||||
print("ESP32 已连接")
|
||||
print(f"ESP32 协议版本: {self.protocol_version}")
|
||||
self.tts_stream_id += 1
|
||||
self.tts_active = False
|
||||
if self.tts_idle_task is not None:
|
||||
self.tts_idle_task.cancel()
|
||||
self.tts_idle_task = None
|
||||
opus_decoder = None
|
||||
try:
|
||||
await self.ensure_agent_dispatched()
|
||||
|
||||
# 发送 hello 告诉 ESP32 握手成功
|
||||
hello_msg = {
|
||||
"type": "hello",
|
||||
"transport": "websocket",
|
||||
"audio_params": {
|
||||
"format": "opus", # 明确要求 ESP32 发送 Opus
|
||||
"sample_rate": SAMPLE_RATE,
|
||||
"sample_rate": OUTPUT_SAMPLE_RATE,
|
||||
"channels": 1,
|
||||
"frame_duration": 60
|
||||
"frame_duration": OUTPUT_FRAME_DURATION_MS
|
||||
}
|
||||
}
|
||||
import json
|
||||
await websocket.send(json.dumps(hello_msg))
|
||||
|
||||
async for message in websocket:
|
||||
@ -229,18 +555,14 @@ class ESP32LiveKitBridge:
|
||||
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
||||
continue
|
||||
|
||||
# ESP32 默认使用 websocket_protocol version=1 (见 websocket_protocol.cc)
|
||||
# 这个版本下,没有 4 字节的 header,接收到的就是原生的 Opus 数据帧。
|
||||
# 直接丢给 opuslib 解码即可。
|
||||
audio_data = message
|
||||
print(f"收到音频包长度: {len(message)}")
|
||||
audio_data = self._extract_opus_payload(message)
|
||||
# print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}")
|
||||
if audio_data:
|
||||
try:
|
||||
# Create Opus decoder if not exists
|
||||
if opus_decoder is None:
|
||||
import opuslib
|
||||
print(f"初始化 Opus 解码器: {SAMPLE_RATE}Hz, mono")
|
||||
opus_decoder = opuslib.Decoder(SAMPLE_RATE, 1)
|
||||
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
|
||||
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
|
||||
|
||||
# 启动音频保存线程
|
||||
self.stop_event.clear()
|
||||
@ -249,8 +571,8 @@ class ESP32LiveKitBridge:
|
||||
thread.start()
|
||||
|
||||
# Decode Opus packet.
|
||||
# Frame size for 60ms is SAMPLE_RATE * 0.06
|
||||
frame_size = int(SAMPLE_RATE * 0.06)
|
||||
# ESP32 -> bridge 保持按设备上行配置 60ms 解码。
|
||||
frame_size = INPUT_SAMPLES_PER_OPUS_FRAME
|
||||
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
|
||||
|
||||
# 将音频数据放入队列由后台线程保存
|
||||
@ -260,20 +582,26 @@ class ESP32LiveKitBridge:
|
||||
if num_samples > 0:
|
||||
try:
|
||||
# Use the more robust frame creation logic from test_client_wav.py
|
||||
frame = AudioFrame(pcm_bytes, SAMPLE_RATE, 1, num_samples)
|
||||
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
|
||||
await self.mic_source.capture_frame(frame)
|
||||
except TypeError:
|
||||
# Fallback for different SDK versions
|
||||
frame = AudioFrame.create(sample_rate=SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
|
||||
frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
|
||||
memoryview(frame.data).cast('B')[:] = pcm_bytes
|
||||
await self.mic_source.capture_frame(frame)
|
||||
except Exception as e:
|
||||
print(f"Opus audio decode error ({len(message)} bytes): {e}")
|
||||
elif isinstance(message, str):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(message)
|
||||
print(f"收到 ESP32 JSON 消息: {data}")
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "abort":
|
||||
reason = data.get("reason")
|
||||
if reason is None:
|
||||
await self._abort_tts("button_abort")
|
||||
else:
|
||||
print(f"忽略非按键打断: reason={reason}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"收到未知的字符消息: {message}")
|
||||
except ConnectionClosedError as e:
|
||||
@ -283,6 +611,11 @@ class ESP32LiveKitBridge:
|
||||
finally:
|
||||
print("ESP32 断开连接")
|
||||
self.esp_ws = None
|
||||
self.tts_stream_id += 1
|
||||
self.tts_active = False
|
||||
if self.tts_idle_task is not None:
|
||||
self.tts_idle_task.cancel()
|
||||
self.tts_idle_task = None
|
||||
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
|
||||
self.stop_event.set()
|
||||
# 我们不一定需要 join,因为是 daemon=True
|
||||
|
||||
Reference in New Issue
Block a user