Files
xiaozhi-esp32_bak/main/bridge_server.py
2026-05-11 13:40:25 +08:00

643 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import websockets
import os
import sys
import httpx
import json
import queue
import shutil
import threading
import struct
import time
import wave
import opuslib
from typing import Any, Optional
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame
from websockets.exceptions import ConnectionClosedError
try:
from livekit import api as livekit_api
except ImportError:
livekit_api = None
TOKEN_URL = "http://10.6.80.130:8000/getToken"
LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM = "test-livekit-room2"
IDENTITY = "uv-livekit-hardcoded"
AGENT_NAME = "my-agent"
import uuid
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
CONNECT_TIMEOUT_SECONDS = 10.0
WS_PORT = 8080
AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True
INPUT_SAMPLE_RATE = 16000
OUTPUT_SAMPLE_RATE = 24000
INPUT_FRAME_DURATION_MS = 60
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
OUTPUT_FRAME_DURATION_MS = 20
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
TTS_IDLE_TIMEOUT_SECONDS = 0.8
TTS_SILENCE_PEAK_THRESHOLD = 96
TTS_PRE_ROLL_MS = 200
SAVE_AGENT_TTS_WAV = True
AGENT_TTS_WAV_PREFIX = "agent_tts"
async def fetch_token() -> str:
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
response = await client.get(
TOKEN_URL,
params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME},
)
response.raise_for_status()
payload: dict[str, Any] = response.json()
token = payload.get("token")
if not isinstance(token, str) or not token:
raise ValueError(f"token response missing token field: {payload}")
print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
print(f"[token] jwt_prefix={token}")
return token
class ESP32LiveKitBridge:
def __init__(self):
self.room = rtc.Room()
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1)
self.esp_ws = None # 保存 WebSocket 连接
self.audio_queue = queue.Queue()
self.wav_writer_thread: Optional[threading.Thread] = None
self.stop_event = threading.Event()
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
self.protocol_version = 1
self.tts_active = False
self.forwarding_tracks: set[str] = set()
self.tts_idle_task: Optional[asyncio.Task] = None
self.tts_stream_id = 0
self.agent_dispatch_task: Optional[asyncio.Task] = None
async def ensure_agent_dispatched(self) -> None:
if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT:
return
for participant in self.room.remote_participants.values():
if AGENT_NAME in participant.identity:
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
return
if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done():
print("Agent dispatch 正在进行中,跳过重复请求")
return
self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent())
await self.agent_dispatch_task
async def _dispatch_agent(self) -> None:
print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}")
try:
if await self._dispatch_agent_with_sdk():
return
if await self._dispatch_agent_with_cli():
return
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
except Exception as e:
print(f"Agent dispatch 失败: {e}")
async def _dispatch_agent_with_sdk(self) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
dispatch = await lkapi.agent_dispatch.create_dispatch(
livekit_api.CreateAgentDispatchRequest(
agent_name=AGENT_NAME,
room=ROOM,
metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}),
)
)
print(f"Agent dispatch 已创建: {dispatch}")
finally:
await lkapi.aclose()
return True
async def _dispatch_agent_with_cli(self) -> bool:
lk_path = shutil.which("lk")
if lk_path is None:
return False
process = await asyncio.create_subprocess_exec(
lk_path,
"dispatch",
"create",
"--room",
ROOM,
"--agent-name",
AGENT_NAME,
"--metadata",
json.dumps({"source": "bridge_server", "identity": IDENTITY}),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode("utf-8", errors="replace").strip()
stderr_text = stderr.decode("utf-8", errors="replace").strip()
if stdout_text:
print(f"lk dispatch 输出: {stdout_text}")
if stderr_text:
print(f"lk dispatch 错误输出: {stderr_text}")
if process.returncode != 0:
print(f"lk dispatch create 失败,退出码: {process.returncode}")
return False
print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}")
return True
async def _send_tts_state(self, state: str) -> None:
if self.esp_ws is None:
print(f"跳过 tts {state}ESP32 尚未连接")
return
await self.esp_ws.send(json.dumps({"type": "tts", "state": state}))
print(f"已发送 tts {state}")
async def _start_tts(self) -> None:
if self.tts_active:
print("跳过 tts start当前已处于激活状态")
return
await self._send_tts_state("start")
if self.esp_ws is not None:
self.tts_active = True
async def _stop_tts(self) -> None:
if not self.tts_active:
print("跳过 tts stop当前未激活")
return
await self._send_tts_state("stop")
self.tts_active = False
async def _abort_tts(self, reason: str = "client_abort") -> None:
print(f"收到打断请求,停止当前 TTS: reason={reason}")
self.tts_stream_id += 1
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
await self._stop_tts()
def _reset_tts_idle_timer(self) -> None:
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id))
async def _tts_idle_watchdog(self, stream_id: int) -> None:
try:
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
if stream_id != self.tts_stream_id:
return
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s切回聆听状态")
await self._stop_tts()
except asyncio.CancelledError:
pass
def _has_audible_audio(self, pcm_data: bytes) -> bool:
if len(pcm_data) < 2:
return False
sample_count = len(pcm_data) // 2
peak = 0
for i in range(sample_count):
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > peak:
peak = sample
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
return True
return False
def _maybe_forward_remote_audio(
self,
track: rtc.Track,
publication: Optional[rtc.TrackPublication],
participant: rtc.RemoteParticipant,
source: str,
) -> None:
if track.kind != rtc.TrackKind.KIND_AUDIO:
return
track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}"
if track_sid in self.forwarding_tracks:
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
return
self.forwarding_tracks.add(track_sid)
print(f"收到音频流: {participant.identity} sid={track_sid} source={source}")
asyncio.create_task(
self.forward_audio_to_esp32(
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1)
)
)
def _scan_existing_remote_audio_tracks(self) -> None:
participants = list(self.room.remote_participants.values())
print(f"开始扫描远端音频轨participants={len(participants)}")
for participant in participants:
publications = getattr(participant, "track_publications", {}) or {}
print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}")
for pub_sid, publication in publications.items():
track = getattr(publication, "track", None)
kind = getattr(publication, "kind", None)
print(
f"检查 publication: participant={participant.identity} pub_sid={pub_sid} "
f"kind={kind} has_track={track is not None}"
)
if track is not None:
self._maybe_forward_remote_audio(track, publication, participant, "scan")
def _extract_opus_payload(self, message: bytes) -> bytes:
if self.protocol_version == 2:
if len(message) < 16:
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
"!HHIII", message[:16]
)
if version != 2:
raise ValueError(f"version 2 包头版本异常: {version}")
if msg_type != 0:
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
if len(message) < 16 + payload_size:
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
return message[16:16 + payload_size]
if self.protocol_version == 3:
if len(message) < 4:
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
if msg_type != 0:
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
if len(message) < 4 + payload_size:
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
return message[4:4 + payload_size]
return message
def _wrap_opus_payload(self, payload: bytes) -> bytes:
if self.protocol_version == 2:
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
return header + payload
if self.protocol_version == 3:
header = struct.pack("!BBH", 0, 0, len(payload))
return header + payload
return payload
def _wav_writer_loop(self):
import wave
print("启动音频保存线程...")
try:
with wave.open("bridge_debug.wav", "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(INPUT_SAMPLE_RATE)
while not self.stop_event.is_set() or not self.audio_queue.empty():
try:
# 使用 timeout 避免永久阻塞,以便检查 stop_event
pcm_bytes = self.audio_queue.get(timeout=0.5)
wav_file.writeframes(pcm_bytes)
except queue.Empty:
continue
except Exception as e:
print(f"音频保存线程错误: {e}")
finally:
print("音频保存线程退出")
async def start(self):
@self.room.on("connection_state_changed")
def on_connection_state_changed(state: int) -> None:
print(f"[livekit] state={rtc.ConnectionState.Name(state)}")
@self.room.on("connected")
def on_connected():
print("✅ 成功连接到 LiveKit 房间")
if self.room.remote_participants:
self._scan_existing_remote_audio_tracks()
self.agent_ready.set()
@self.room.on("participant_connected")
def on_participant_connected(p: rtc.RemoteParticipant):
print(f"👋 Agent ({p.identity}) 已加入房间")
self._scan_existing_remote_audio_tracks()
self.agent_ready.set()
@self.room.on("participant_disconnected")
def on_participant_disconnected(p: rtc.RemoteParticipant):
print(f"👋 远端参与者离开房间: {p.identity}")
self.forwarding_tracks = {
track_sid for track_sid in self.forwarding_tracks
if not track_sid.endswith(f":{p.identity}")
}
@self.room.on("data_received")
def on_data_received(data_packet: rtc.DataPacket):
identity = data_packet.participant.identity if data_packet.participant else "未知"
try:
print(f"📩 [数据接收 | {identity}]: {data_packet.data.decode('utf-8')}")
except Exception:
pass
@self.room.on("transcription_received")
def on_transcription_received(
segments: list[rtc.TranscriptionSegment],
participant: rtc.Participant,
track_pub: rtc.TrackPublication
):
identity = participant.identity if participant else "未知"
for segment in segments:
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
print(f"🗣️ [{status} | {identity}]: {segment.text}")
# 将识别结果实时推送到 ESP32
if self.esp_ws is not None:
# 使用局部变量避免类型检查器报错,同时确保在创建任务时 ws 不是 None
ws = self.esp_ws
asyncio.create_task(ws.send(json.dumps({
"type": "stt",
"text": segment.text,
"final": segment.final
})))
# 1. 获取 Token 并连接 LiveKit
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
print(f"[config] token_url={TOKEN_URL}")
print(f"[config] room={ROOM} identity={IDENTITY}")
token = await fetch_token()
await asyncio.wait_for(
self.room.connect(
LIVEKIT_WS_URL,
token,
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
),
timeout=CONNECT_TIMEOUT_SECONDS + 2.0,
)
print(f"已连接到 LiveKit 房间: {self.room.name}")
print(f"[livekit] local_identity={self.room.local_participant.identity}")
print(f"[livekit] local_sid={self.room.local_participant.sid}")
print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}")
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
await self.room.local_participant.publish_track(track, options)
# 3. 监听房间内的音频 (LiveKit -> ESP32)
@self.room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
self._maybe_forward_remote_audio(track, publication, participant, "event")
print("等待 agent 加入...")
self._scan_existing_remote_audio_tracks()
try:
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
print("✅ agent 已就绪")
except asyncio.TimeoutError:
print("⚠️ agent 等待超时(后续可能收不到音频)")
async def close(self):
"""优雅关闭所有连接和资源"""
self.stop_event.set()
if self.room:
await self.room.disconnect()
async def forward_audio_to_esp32(self, audio_stream):
"""从 LiveKit 接收音频并发给 ESP32"""
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip')
pending_pcm = bytearray()
pre_roll_pcm = bytearray()
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
tts_wav_file = None
tts_wav_path = None
stream_id = self.tts_stream_id + 1
self.tts_stream_id = stream_id
if SAVE_AGENT_TTS_WAV:
tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav"
tts_wav_file = wave.open(tts_wav_path, "wb")
tts_wav_file.setnchannels(1)
tts_wav_file.setsampwidth(2)
tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE)
print(f"开始保存 Agent TTS 音频: {tts_wav_path}")
try:
async for event in audio_stream:
if stream_id != self.tts_stream_id:
print("检测到更新的 TTS 流,停止旧流转发")
break
frame = event.frame
pcm_data = frame.data.tobytes()
if tts_wav_file is not None:
tts_wav_file.writeframes(pcm_data)
has_audible_audio = self._has_audible_audio(pcm_data)
current_frame_buffered = False
if not self.tts_active:
pre_roll_pcm.extend(pcm_data)
current_frame_buffered = True
if len(pre_roll_pcm) > pre_roll_max_bytes:
del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes]
if not has_audible_audio:
continue
await self._start_tts()
if not self.tts_active:
continue
print("检测到可听 TTS 音频,切换到 Speaking 并开始转发")
pending_pcm.extend(pre_roll_pcm)
pre_roll_pcm.clear()
if has_audible_audio:
self._reset_tts_idle_timer()
# LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。
# 这里聚合成固定 20ms PCM再编码为 Opus 发给设备,减少卡顿。
if not current_frame_buffered:
pending_pcm.extend(pcm_data)
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM
while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id:
try:
opus_packet = encoder.encode(
bytes(pending_pcm[:frame_bytes]),
OUTPUT_SAMPLES_PER_OPUS_FRAME,
)
del pending_pcm[:frame_bytes]
# print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}")
await self.esp_ws.send(self._wrap_opus_payload(opus_packet))
except Exception as e:
print(f"发送回 ESP32 失败: {e}")
break
except Exception as e:
print(f"音频流处理错误: {e}")
finally:
print("🎧 TTS 音频结束")
if tts_wav_file is not None:
tts_wav_file.close()
print(f"Agent TTS 音频已保存: {tts_wav_path}")
if stream_id == self.tts_stream_id and self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
if stream_id == self.tts_stream_id:
await self._stop_tts()
async def handle_websocket(self, websocket):
"""处理来自 ESP32 的 WebSocket 连接"""
self.esp_ws = websocket
header_version = websocket.request.headers.get("Protocol-Version")
try:
self.protocol_version = int(header_version) if header_version else 1
except ValueError:
self.protocol_version = 1
print("ESP32 已连接")
print(f"ESP32 协议版本: {self.protocol_version}")
self.tts_stream_id += 1
self.tts_active = False
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
opus_decoder = None
try:
await self.ensure_agent_dispatched()
# 发送 hello 告诉 ESP32 握手成功
hello_msg = {
"type": "hello",
"transport": "websocket",
"audio_params": {
"format": "opus", # 明确要求 ESP32 发送 Opus
"sample_rate": OUTPUT_SAMPLE_RATE,
"channels": 1,
"frame_duration": OUTPUT_FRAME_DURATION_MS
}
}
await websocket.send(json.dumps(hello_msg))
async for message in websocket:
# 接收 ESP32 的数据 -> 推送到 LiveKit
if isinstance(message, bytes):
# 判断如果消息长度极其短并且不是合理的音频流可能是ping包等
if len(message) < 4:
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
continue
audio_data = self._extract_opus_payload(message)
# print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}")
if audio_data:
try:
# Create Opus decoder if not exists
if opus_decoder is None:
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
# 启动音频保存线程
self.stop_event.clear()
thread = threading.Thread(target=self._wav_writer_loop, daemon=True)
self.wav_writer_thread = thread
thread.start()
# Decode Opus packet.
# ESP32 -> bridge 保持按设备上行配置 60ms 解码。
frame_size = INPUT_SAMPLES_PER_OPUS_FRAME
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
# 将音频数据放入队列由后台线程保存
self.audio_queue.put(pcm_bytes)
num_samples = len(pcm_bytes) // 2
if num_samples > 0:
try:
# Use the more robust frame creation logic from test_client_wav.py
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
await self.mic_source.capture_frame(frame)
except TypeError:
# Fallback for different SDK versions
frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
memoryview(frame.data).cast('B')[:] = pcm_bytes
await self.mic_source.capture_frame(frame)
except Exception as e:
print(f"Opus audio decode error ({len(message)} bytes): {e}")
elif isinstance(message, str):
try:
data = json.loads(message)
print(f"收到 ESP32 JSON 消息: {data}")
msg_type = data.get("type")
if msg_type == "abort":
reason = data.get("reason")
if reason is None:
await self._abort_tts("button_abort")
else:
print(f"忽略非按键打断: reason={reason}")
except json.JSONDecodeError:
print(f"收到未知的字符消息: {message}")
except ConnectionClosedError as e:
print(f"ESP32 异常断开: {e}")
except Exception as e:
print(f"WebSocket 其他错误: {e}")
finally:
print("ESP32 断开连接")
self.esp_ws = None
self.tts_stream_id += 1
self.tts_active = False
if self.tts_idle_task is not None:
self.tts_idle_task.cancel()
self.tts_idle_task = None
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
self.stop_event.set()
# 我们不一定需要 join因为是 daemon=True
# 但这里设置 stop_event 会让线程在完成队列后退出
async def main():
bridge = ESP32LiveKitBridge()
try:
await bridge.start()
# 启动 WebSocket 服务器
async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT):
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
await asyncio.Future() # 保持运行
finally:
await bridge.close()
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as exc:
print(f"[error] {exc}", file=sys.stderr)
sys.exit(1)