Files
xiaozhi-esp32/main/bridge_server.py
2026-05-25 17:21:11 +08:00

1124 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import json
import os
import shutil
import struct
import sys
import time
import traceback
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
import httpx
import opuslib
import websockets
from livekit import rtc
from livekit.rtc import AudioFrame, AudioSource
from websockets.exceptions import ConnectionClosedError
try:
from livekit import api as livekit_api
except ImportError:
livekit_api = None
TOKEN_URL = "http://172.19.0.240:8000/getToken"
LIVEKIT_WS_URL = "ws://172.19.0.240:7880"
ROOM_PREFIX = "test-livekit"
IDENTITY_PREFIX = "uv-livekit"
AGENT_NAME = "my-agent"
CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0"))
AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0"))
WS_PORT = 8080
AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower()
PROJECT_ROOT = Path(__file__).resolve().parent.parent
VISION_FRAME_SAVE_DIR = Path(os.getenv("VISION_FRAME_SAVE_DIR", str(PROJECT_ROOT / "vision_frames")))
INPUT_SAMPLE_RATE = 16000
OUTPUT_SAMPLE_RATE = 24000
INPUT_FRAME_DURATION_MS = 20
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
INPUT_MAX_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * 60 // 1000
OUTPUT_FRAME_DURATION_MS = 20
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
TTS_IDLE_TIMEOUT_SECONDS = 0.25
TTS_SILENCE_PEAK_THRESHOLD = 96
TTS_PRE_ROLL_MS = 80
TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = 1
TTS_INTERRUPT_SILENCE_FRAMES = 3
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
TTS_DISPLAY_SENTENCE_BREAKS = "。!?!?;"
TTS_DISPLAY_SCROLL_WIDTH = int(os.getenv("TTS_DISPLAY_SCROLL_WIDTH", "18"))
TTS_DISPLAY_SCROLL_INTERVAL_SECONDS = float(os.getenv("TTS_DISPLAY_SCROLL_INTERVAL_SECONDS", "0.18"))
TTS_DISPLAY_SCROLL_GAP = " "
TTS_INTERRUPT_SUPPRESS_SECONDS = float(os.getenv("TTS_INTERRUPT_SUPPRESS_SECONDS", "0.8"))
@dataclass
class DeviceSession:
device_id: str
websocket: Any
protocol_version: int
room_name: str
identity: str
room: rtc.Room
mic_source: AudioSource
agent_ready: asyncio.Event
forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict)
tts_active: bool = False
tts_idle_task: Optional[asyncio.Task] = None
tts_display_task: Optional[asyncio.Task] = None
tts_stream_id: int = 0
tts_transcript_text: str = ""
tts_display_text: str = ""
tts_display_final: bool = False
tts_suppressed_until: float = 0.0
agent_dispatch_task: Optional[asyncio.Task] = None
closed: bool = False
captured_frame_count: int = 0
first_capture_log_time: float = 0.0
async def fetch_token(room_name: str, identity: str) -> str:
params = {"room": room_name, "identity": identity}
if AGENT_DISPATCH_MODE == "token":
params["agent_name"] = AGENT_NAME
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
response = await client.get(TOKEN_URL, params=params)
response.raise_for_status()
payload: dict[str, Any] = response.json()
token = payload.get("token")
if not isinstance(token, str) or not token:
raise ValueError(f"token response missing token field: {payload}")
# print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
# print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
return token
class ESP32LiveKitBridge:
def __init__(self) -> None:
self.device_sessions: dict[str, DeviceSession] = {}
def _log_exception(self, prefix: str, exc: BaseException) -> None:
print(f"{prefix}: type={type(exc).__name__} detail={exc!r}")
formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
if formatted_tb.strip():
print(formatted_tb.rstrip())
def _build_device_id(self, websocket: Any) -> str:
headers = websocket.request.headers
requested_id = headers.get("X-Device-Id") or headers.get("Device-Id")
if requested_id:
return requested_id
return f"ws-{id(websocket):x}"
def _build_session_names(self, device_id: str) -> tuple[str, str]:
session_tag = uuid.uuid4().hex[:8]
room_name = f"{ROOM_PREFIX}-{session_tag}"
identity = f"{IDENTITY_PREFIX}-{session_tag}"
print(f"[session] device={device_id} room={room_name} identity={identity}")
return room_name, identity
def _is_agent_participant(self, participant: rtc.RemoteParticipant) -> bool:
identity = getattr(participant, "identity", "") or ""
return identity.startswith("agent-") or AGENT_NAME in identity
def _get_agent_identities(self, session: DeviceSession) -> list[str]:
return [
participant.identity
for participant in session.room.remote_participants.values()
if self._is_agent_participant(participant)
]
def _log_agent_participants(self, session: DeviceSession, source: str) -> None:
agent_identities = self._get_agent_identities(session)
# print(
# f"[agent-check] source={source} room={session.room_name} "
# f"agent_count={len(agent_identities)} agents={agent_identities}"
# )
async def _has_existing_dispatch(self, session: DeviceSession) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None)
if request_cls is None:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
response = await lkapi.agent_dispatch.list_dispatch(
request_cls(room=session.room_name)
)
dispatches = (
getattr(response, "agent_dispatches", None)
or getattr(response, "dispatches", None)
or getattr(response, "items", None)
or []
)
for dispatch in dispatches:
dispatch_agent_name = getattr(dispatch, "agent_name", None)
dispatch_room = getattr(dispatch, "room", None)
if dispatch_room == session.room_name and (
dispatch_agent_name == AGENT_NAME or dispatch_agent_name is None
):
print(
f"检测到已有 dispatch: room={session.room_name} "
f"agent={dispatch_agent_name}"
)
return True
except Exception as exc:
print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}")
finally:
await lkapi.aclose()
return False
async def ensure_agent_dispatched(self, session: DeviceSession) -> None:
if AGENT_DISPATCH_MODE != "bridge":
# print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}")
return
if await self._has_existing_dispatch(session):
return
for participant in session.room.remote_participants.values():
if self._is_agent_participant(participant):
# print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
return
if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done():
# print("Agent dispatch 正在进行中,跳过重复请求")
return
session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session))
await session.agent_dispatch_task
async def _dispatch_agent(self, session: DeviceSession) -> None:
print(f"准备 dispatch agent: room={session.room_name}, agent={AGENT_NAME}")
try:
if await self._dispatch_agent_with_sdk(session):
return
if await self._dispatch_agent_with_cli(session):
return
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
except Exception as exc:
print(f"Agent dispatch 失败: {exc}")
async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
dispatch = await lkapi.agent_dispatch.create_dispatch(
livekit_api.CreateAgentDispatchRequest(
agent_name=AGENT_NAME,
room=session.room_name,
metadata=json.dumps(
{
"source": "bridge_server",
"identity": session.identity,
"device_id": session.device_id,
}
),
)
)
print(f"Agent dispatch 已创建: {dispatch}")
finally:
await lkapi.aclose()
return True
async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool:
lk_path = shutil.which("lk")
if lk_path is None:
return False
process = await asyncio.create_subprocess_exec(
lk_path,
"dispatch",
"create",
"--room",
session.room_name,
"--agent-name",
AGENT_NAME,
"--metadata",
json.dumps(
{
"source": "bridge_server",
"identity": session.identity,
"device_id": session.device_id,
}
),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode("utf-8", errors="replace").strip()
stderr_text = stderr.decode("utf-8", errors="replace").strip()
if stdout_text:
print(f"lk dispatch 输出: {stdout_text}")
if stderr_text:
print(f"lk dispatch 错误输出: {stderr_text}")
if process.returncode != 0:
print(f"lk dispatch create 失败,退出码: {process.returncode}")
return False
print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={AGENT_NAME}")
return True
async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool:
participant = getattr(session.room, "local_participant", None)
if participant is None:
print("跳过发送 agent 控制事件local participant 尚未就绪")
return False
data = json.dumps(payload).encode("utf-8")
agent_identities = self._get_agent_identities(session)
kwargs: dict[str, Any] = {}
if agent_identities:
kwargs["destination_identities"] = agent_identities
last_error: Optional[Exception] = None
for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs):
try:
await participant.publish_data(data, **attempt)
print(
f"已发送 agent 控制事件: topic={attempt.get('topic', '<default>')} "
f"targets={agent_identities or 'broadcast'} payload={payload}"
)
return True
except TypeError as exc:
last_error = exc
except Exception as exc:
print(f"发送 agent 控制事件失败: {exc}")
return False
if last_error is not None:
print(f"发送 agent 控制事件失败publish_data 签名不兼容: {last_error}")
return False
async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None:
payload = {
"type": "interrupt",
"topic": INTERRUPT_TOPIC,
"reason": reason,
"room": session.room_name,
"identity": session.identity,
"device_id": session.device_id,
}
ok = await self._publish_agent_event(session, payload)
if not ok:
print("警告: bridge 已停止 TTS但 agent 侧 interrupt 未确认送出")
def _save_vision_frame(self, session: DeviceSession, image: str) -> Optional[Path]:
try:
image_bytes = base64.b64decode(image, validate=True)
except Exception as exc:
print(f"vision frame base64 解码失败: {exc}")
return None
safe_device_id = "".join(
char if char.isalnum() or char in ("-", "_") else "_"
for char in session.device_id
)
timestamp_ms = int(time.time() * 1000)
VISION_FRAME_SAVE_DIR.mkdir(parents=True, exist_ok=True)
path = VISION_FRAME_SAVE_DIR / f"{timestamp_ms}_{safe_device_id}.jpg"
path.write_bytes(image_bytes)
return path
async def _publish_vision_frame(self, session: DeviceSession, message: dict[str, Any]) -> None:
image = message.get("image")
if not isinstance(image, str) or not image:
print("收到 vision frame但 image 字段为空")
return
saved_path = self._save_vision_frame(session, image)
if saved_path is None:
return
print(f"已保存 vision frame: {saved_path}")
participant = getattr(session.room, "local_participant", None)
if participant is None:
print("跳过发送 vision framelocal participant 尚未就绪")
return
payload = {
"type": "vision_frame",
"topic": VISION_FRAME_TOPIC,
"room": session.room_name,
"identity": session.identity,
"device_id": session.device_id,
"mime_type": message.get("mime_type", "image/jpeg"),
"image": image,
"saved_path": str(saved_path),
}
data = json.dumps(payload).encode("utf-8")
agent_identities = self._get_agent_identities(session)
kwargs: dict[str, Any] = {}
if agent_identities:
kwargs["destination_identities"] = agent_identities
last_error: Optional[Exception] = None
for attempt in ({"topic": VISION_FRAME_TOPIC, **kwargs}, kwargs):
try:
await participant.publish_data(data, **attempt)
print(
f"已发送 vision frame: bytes={len(data)} "
f"targets={agent_identities or 'broadcast'}"
)
return
except TypeError as exc:
last_error = exc
except Exception as exc:
print(f"发送 vision frame 失败: {exc}")
return
if last_error is not None:
print(f"发送 vision frame 失败publish_data 签名不兼容: {last_error}")
async def _send_tts_state(self, session: DeviceSession, state: str) -> None:
if session.websocket is None:
print(f"跳过 tts {state}ESP32 尚未连接")
return
await session.websocket.send(json.dumps({"type": "tts", "state": state}))
print(f"已发送 tts {state}: device={session.device_id}")
async def _send_tts_text(self, session: DeviceSession, text: str, final: bool) -> None:
if session.websocket is None:
return
await session.websocket.send(
json.dumps(
{
"type": "tts",
"state": "sentence_start",
"text": text,
"final": final,
}
)
)
def _cancel_tts_display_task(self, session: DeviceSession) -> None:
if session.tts_display_task is not None:
session.tts_display_task.cancel()
session.tts_display_task = None
def _update_tts_display_text(self, session: DeviceSession, text: str, final: bool) -> None:
session.tts_display_text = text
session.tts_display_final = final
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
self._cancel_tts_display_task(session)
asyncio.create_task(self._send_tts_text(session, text, final))
return
if session.tts_display_task is None or session.tts_display_task.done():
session.tts_display_task = asyncio.create_task(
self._scroll_tts_display_text(session, session.tts_stream_id)
)
async def _scroll_tts_display_text(self, session: DeviceSession, stream_id: int) -> None:
offset = 0
last_sent = ""
try:
while stream_id == session.tts_stream_id and session.websocket is not None:
text = session.tts_display_text
if not text:
return
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
if text != last_sent:
await self._send_tts_text(session, text, session.tts_display_final)
return
scroll_text = text + TTS_DISPLAY_SCROLL_GAP
if offset >= len(scroll_text):
offset = 0
looped_text = scroll_text + scroll_text[:TTS_DISPLAY_SCROLL_WIDTH]
visible_text = looped_text[offset:offset + TTS_DISPLAY_SCROLL_WIDTH].rstrip()
if visible_text and visible_text != last_sent:
await self._send_tts_text(
session,
visible_text,
session.tts_display_final,
)
last_sent = visible_text
offset += 1
await asyncio.sleep(TTS_DISPLAY_SCROLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
pass
except Exception as exc:
print(f"TTS 字幕滚动失败: {exc}")
async def _start_tts(self, session: DeviceSession) -> None:
if session.tts_active:
print("跳过 tts start当前已处于激活状态")
return
if time.monotonic() < session.tts_suppressed_until:
print("跳过 tts start中断后的残留音频仍在抑制窗口内")
return
if not session.tts_display_text:
session.tts_transcript_text = ""
session.tts_display_final = False
self._cancel_tts_display_task(session)
await self._send_tts_state(session, "start")
session.tts_active = True
async def _stop_tts(self, session: DeviceSession) -> None:
if not session.tts_active:
print("跳过 tts stop当前未激活")
return
self._cancel_tts_display_task(session)
await self._send_tts_state(session, "stop")
session.tts_active = False
session.tts_transcript_text = ""
session.tts_display_text = ""
session.tts_display_final = False
async def _force_stop_tts(self, session: DeviceSession, reason: str) -> None:
self._cancel_tts_display_task(session)
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
session.tts_active = False
session.tts_transcript_text = ""
session.tts_display_text = ""
session.tts_display_final = False
await self._send_tts_state(session, "stop")
print(f"已强制停止本地 TTS: device={session.device_id} reason={reason}")
async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None:
print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}")
session.tts_stream_id += 1
session.tts_suppressed_until = time.monotonic() + TTS_INTERRUPT_SUPPRESS_SECONDS
await self._force_stop_tts(session, reason)
asyncio.create_task(self._send_agent_interrupt(session, reason))
def _reset_tts_idle_timer(self, session: DeviceSession) -> None:
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = asyncio.create_task(
self._tts_idle_watchdog(session, session.tts_stream_id)
)
async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None:
try:
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
if stream_id != session.tts_stream_id:
return
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s切回聆听状态")
await self._stop_tts(session)
except asyncio.CancelledError:
pass
def _has_audible_audio(self, pcm_data: bytes) -> bool:
if len(pcm_data) < 2:
return False
sample_count = len(pcm_data) // 2
peak = 0
for i in range(sample_count):
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > peak:
peak = sample
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
return True
return False
def _maybe_forward_remote_audio(
self,
session: DeviceSession,
track: rtc.Track,
publication: Optional[rtc.TrackPublication],
participant: rtc.RemoteParticipant,
source: str,
) -> None:
if track.kind != rtc.TrackKind.KIND_AUDIO:
return
track_sid = (
getattr(track, "sid", None)
or getattr(publication, "sid", None)
or f"audio:{participant.identity}"
)
existing_task = session.forwarding_tracks.get(track_sid)
if existing_task is not None and not existing_task.done():
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
return
if existing_task is not None and existing_task.done():
print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}")
session.forwarding_tracks.pop(track_sid, None)
task = asyncio.create_task(
self.forward_audio_to_esp32(
session,
track_sid,
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1),
)
)
session.forwarding_tracks[track_sid] = task
print(
f"收到音频流: {participant.identity} sid={track_sid} "
f"source={source} room={session.room_name}"
)
def _scan_participant_audio_tracks(
self,
session: DeviceSession,
participant: rtc.RemoteParticipant,
source: str,
) -> None:
publications = getattr(participant, "track_publications", None) or {}
for publication in publications.values():
track = getattr(publication, "track", None)
if track is None:
continue
self._maybe_forward_remote_audio(session, track, publication, participant, source)
def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes:
if session.protocol_version == 2:
if len(message) < 16:
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
"!HHIII", message[:16]
)
if version != 2:
raise ValueError(f"version 2 包头版本异常: {version}")
if msg_type != 0:
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
if len(message) < 16 + payload_size:
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
return message[16:16 + payload_size]
if session.protocol_version == 3:
if len(message) < 4:
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
if msg_type != 0:
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
if len(message) < 4 + payload_size:
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
return message[4:4 + payload_size]
return message
def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes:
if session.protocol_version == 2:
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
return header + payload
if session.protocol_version == 3:
header = struct.pack("!BBH", 0, 0, len(payload))
return header + payload
return payload
def _current_tts_display_text(self, text: str) -> str:
normalized = " ".join(text.split())
if not normalized:
return ""
last_break = -1
previous_break = -1
for index, char in enumerate(normalized):
if char in TTS_DISPLAY_SENTENCE_BREAKS:
previous_break = last_break
last_break = index
if last_break == -1:
return normalized
if last_break == len(normalized) - 1:
start = previous_break + 1
else:
start = last_break + 1
return normalized[start:].strip() or normalized
def _register_room_handlers(self, session: DeviceSession) -> None:
@session.room.on("connection_state_changed")
def on_connection_state_changed(state: int) -> None:
# print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}")
pass
@session.room.on("connected")
def on_connected() -> None:
print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}")
self._log_agent_participants(session, "connected")
for participant in session.room.remote_participants.values():
if self._is_agent_participant(participant):
session.agent_ready.set()
self._scan_participant_audio_tracks(session, participant, "connected_scan")
@session.room.on("participant_connected")
def on_participant_connected(participant: rtc.RemoteParticipant) -> None:
role = "Agent" if self._is_agent_participant(participant) else "Remote participant"
print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}")
self._log_agent_participants(session, "participant_connected")
if self._is_agent_participant(participant):
session.agent_ready.set()
self._scan_participant_audio_tracks(
session, participant, "participant_connected_scan"
)
@session.room.on("participant_disconnected")
def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None:
print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}")
session.forwarding_tracks = {
track_sid: task
for track_sid, task in session.forwarding_tracks.items()
if not track_sid.endswith(f":{participant.identity}")
}
@session.room.on("data_received")
def on_data_received(data_packet: rtc.DataPacket) -> None:
identity = data_packet.participant.identity if data_packet.participant else "未知"
try:
print(
f"📩 [数据接收 | room={session.room_name} | {identity}]: "
f"{data_packet.data.decode('utf-8')}"
)
except Exception:
pass
@session.room.on("transcription_received")
def on_transcription_received(
segments: list[rtc.TranscriptionSegment],
participant: rtc.Participant,
track_pub: rtc.TrackPublication,
) -> None:
identity = participant.identity if participant else "未知"
is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(participant)
for segment in segments:
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}")
if is_agent:
if time.monotonic() < session.tts_suppressed_until:
continue
display_text = self._current_tts_display_text(segment.text)
if not display_text or display_text == session.tts_transcript_text:
continue
session.tts_transcript_text = display_text
else:
if not segment.final:
continue
display_text = segment.text
if session.websocket is not None:
ws = session.websocket
if is_agent:
self._update_tts_display_text(session, display_text, segment.final)
else:
asyncio.create_task(
ws.send(
json.dumps(
{
"type": "stt",
"text": segment.text,
"final": segment.final,
}
)
)
)
@session.room.on("track_subscribed")
def on_track_subscribed(
track: rtc.Track,
publication: rtc.TrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
self._maybe_forward_remote_audio(session, track, publication, participant, "event")
@session.room.on("track_published")
def on_track_published(
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
track_sid = getattr(publication, "sid", None)
# print(
# f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} "
# f"track_sid={track_sid}"
# )
track = getattr(publication, "track", None)
if track is not None:
self._maybe_forward_remote_audio(session, track, publication, participant, "published")
async def _connect_session_room(self, session: DeviceSession) -> None:
self._register_room_handlers(session)
# print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
# print(f"[config] token_url={TOKEN_URL}")
# print(f"[config] room={session.room_name} identity={session.identity}")
# print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}")
token = await fetch_token(session.room_name, session.identity)
try:
await session.room.connect(
LIVEKIT_WS_URL,
token,
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
)
except Exception as exc:
self._log_exception(
f"连接 LiveKit 房间失败: room={session.room_name}",
exc,
)
print(
"提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;"
"请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、"
"*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。"
)
raise
print(f"已连接到 LiveKit 房间: {session.room.name}")
# print(f"[livekit] local_identity={session.room.local_participant.identity}")
# print(f"[livekit] local_sid={session.room.local_participant.sid}")
# print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}")
self._log_agent_participants(session, "after_connect")
await self.ensure_agent_dispatched(session)
track = rtc.LocalAudioTrack.create_audio_track(
f"esp32-mic-{session.device_id}",
session.mic_source,
)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
publication = await session.room.local_participant.publish_track(track, options)
publication_sid = getattr(publication, "sid", None)
track_sid = getattr(track, "sid", None)
# print(
# f"已发布 ESP32 mic track: room={session.room_name} "
# f"track_sid={track_sid} publication_sid={publication_sid}"
# )
self._log_agent_participants(session, "after_publish_mic")
# print(f"等待 agent 加入: room={session.room_name}")
try:
await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS)
# print(f"✅ agent 已就绪: room={session.room_name}")
except asyncio.TimeoutError:
print(f"⚠️ agent 等待超时: room={session.room_name}")
async def start(self) -> None:
print(f"[config] websocket_port={WS_PORT}")
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
print(f"[config] token_url={TOKEN_URL}")
print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}")
async def close(self) -> None:
for session in list(self.device_sessions.values()):
await self._close_session(session)
async def _close_session(self, session: DeviceSession) -> None:
if session.closed:
return
session.closed = True
session.websocket = None
session.tts_active = False
session.tts_stream_id += 1
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
self._cancel_tts_display_task(session)
try:
await session.room.disconnect()
except Exception as exc:
print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}")
async def forward_audio_to_esp32(
self,
session: DeviceSession,
track_sid: str,
audio_stream: rtc.AudioStream,
) -> None:
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip")
pending_pcm = bytearray()
pre_roll_pcm = bytearray()
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = False
stream_id = session.tts_stream_id
print(
f"启动 TTS 转发: device={session.device_id} room={session.room_name} "
f"track_sid={track_sid} stream_id={stream_id}"
)
try:
async for event in audio_stream:
if stream_id != session.tts_stream_id:
print("检测到 TTS 被中断,停止当前播放并等待静音分隔")
pending_pcm.clear()
pre_roll_pcm.clear()
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = True
stream_id = session.tts_stream_id
if session.tts_active:
await self._stop_tts(session)
continue
frame = event.frame
pcm_data = frame.data.tobytes()
has_audible_audio = self._has_audible_audio(pcm_data)
if time.monotonic() < session.tts_suppressed_until:
pending_pcm.clear()
pre_roll_pcm.clear()
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = True
continue
if waiting_for_post_interrupt_silence:
if has_audible_audio:
silence_frame_streak = 0
continue
silence_frame_streak += 1
if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES:
continue
print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播")
waiting_for_post_interrupt_silence = False
silence_frame_streak = 0
stream_id = session.tts_stream_id
continue
current_frame_buffered = False
if not session.tts_active:
pre_roll_pcm.extend(pcm_data)
current_frame_buffered = True
if len(pre_roll_pcm) > pre_roll_max_bytes:
del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes]
if not has_audible_audio:
audible_frame_streak = 0
continue
audible_frame_streak += 1
if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES:
continue
await self._start_tts(session)
if not session.tts_active:
continue
print(
"检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 "
f"(frames={audible_frame_streak})"
)
pending_pcm.extend(pre_roll_pcm)
pre_roll_pcm.clear()
audible_frame_streak = 0
if has_audible_audio:
self._reset_tts_idle_timer(session)
if not current_frame_buffered:
pending_pcm.extend(pcm_data)
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2
while (
len(pending_pcm) >= frame_bytes
and stream_id == session.tts_stream_id
and session.websocket is not None
):
try:
opus_packet = encoder.encode(
bytes(pending_pcm[:frame_bytes]),
OUTPUT_SAMPLES_PER_OPUS_FRAME,
)
del pending_pcm[:frame_bytes]
await session.websocket.send(self._wrap_opus_payload(session, opus_packet))
except Exception as exc:
print(f"发送回 ESP32 失败: {exc}")
break
except Exception as exc:
print(f"音频流处理错误: {exc}")
finally:
print("🎧 TTS 音频结束")
task = session.forwarding_tracks.get(track_sid)
current_task = asyncio.current_task()
if task is current_task:
session.forwarding_tracks.pop(track_sid, None)
if stream_id == session.tts_stream_id and session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
if stream_id == session.tts_stream_id:
await self._stop_tts(session)
async def handle_websocket(self, websocket: Any) -> None:
header_version = websocket.request.headers.get("Protocol-Version")
try:
protocol_version = int(header_version) if header_version else 1
except ValueError:
protocol_version = 1
device_id = self._build_device_id(websocket)
existing_session = self.device_sessions.get(device_id)
if existing_session is not None:
print(f"检测到重复 device_id关闭旧 session: device={device_id}")
await self._close_session(existing_session)
self.device_sessions.pop(device_id, None)
room_name, identity = self._build_session_names(device_id)
session = DeviceSession(
device_id=device_id,
websocket=websocket,
protocol_version=protocol_version,
room_name=room_name,
identity=identity,
room=rtc.Room(),
mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1),
agent_ready=asyncio.Event(),
)
self.device_sessions[device_id] = session
print(f"ESP32 已连接: device={device_id}")
print(f"ESP32 协议版本: {session.protocol_version}")
session.tts_stream_id += 1
opus_decoder = None
try:
hello_msg = {
"type": "hello",
"transport": "websocket",
"session": {
"room": session.room_name,
"identity": session.identity,
},
"audio_params": {
"format": "opus",
"sample_rate": OUTPUT_SAMPLE_RATE,
"channels": 1,
"frame_duration": OUTPUT_FRAME_DURATION_MS,
},
}
await websocket.send(json.dumps(hello_msg))
print(f"已发送 server hello: device={device_id} room={session.room_name}")
await self._connect_session_room(session)
async for message in websocket:
if isinstance(message, bytes):
if len(message) < 4:
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
continue
audio_data = self._extract_opus_payload(session, message)
if not audio_data:
continue
try:
if opus_decoder is None:
# print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
pcm_bytes = opus_decoder.decode(audio_data, INPUT_MAX_SAMPLES_PER_OPUS_FRAME)
num_samples = len(pcm_bytes) // 2
if num_samples > 0:
session.captured_frame_count += 1
now = time.monotonic()
if (
session.captured_frame_count <= 5
or now - session.first_capture_log_time >= 5.0
):
session.first_capture_log_time = now
# print(
# f"[uplink] capture_frame count={session.captured_frame_count} "
# f"bytes={len(pcm_bytes)} samples={num_samples} "
# f"room={session.room_name}"
# )
try:
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
await session.mic_source.capture_frame(frame)
except TypeError:
frame = AudioFrame.create(
sample_rate=INPUT_SAMPLE_RATE,
num_channels=1,
samples_per_channel=num_samples,
)
memoryview(frame.data).cast("B")[:] = pcm_bytes
await session.mic_source.capture_frame(frame)
except Exception as exc:
print(f"Opus audio decode error ({len(message)} bytes): {exc}")
elif isinstance(message, str):
try:
data = json.loads(message)
# print(f"收到 ESP32 JSON 消息: {data}")
msg_type = data.get("type")
if msg_type == "abort":
reason = data.get("reason")
abort_reason = reason if isinstance(reason, str) and reason else "button_abort"
print(f"处理 ESP32 打断请求: reason={abort_reason}")
await self._abort_tts(session, abort_reason)
elif msg_type == "vision" and data.get("state") == "frame":
await self._publish_vision_frame(session, data)
except json.JSONDecodeError:
print(f"收到未知的字符消息: {message}")
except ConnectionClosedError as exc:
print(f"ESP32 异常断开: {exc}")
except Exception as exc:
self._log_exception("WebSocket 其他错误", exc)
finally:
print(f"ESP32 断开连接: device={device_id} room={session.room_name}")
await self._close_session(session)
self.device_sessions.pop(device_id, None)
async def main() -> None:
bridge = ESP32LiveKitBridge()
try:
await bridge.start()
async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT):
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
await asyncio.Future()
finally:
await bridge.close()
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as exc:
print(f"[error] {exc}", file=sys.stderr)
sys.exit(1)