Files
xiaozhi-esp32/main/bridge_server.py
2026-06-12 14:23:41 +08:00

1690 lines
69 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import contextlib
import json
import os
import re
import shutil
import struct
import sys
import time
import traceback
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
import httpx
import opuslib
import websockets
from livekit import rtc
from livekit.rtc import AudioFrame, AudioSource
from websockets.exceptions import ConnectionClosedError
try:
from livekit import api as livekit_api
except ImportError:
livekit_api = None
TOKEN_URL = "http://172.19.0.240:8000/getToken"
LIVEKIT_WS_URL = "ws://172.19.0.240:7880"
ROOM_PREFIX = "test-livekit"
IDENTITY_PREFIX = "uv-livekit"
LEGACY_AGENT_NAME = os.getenv("LIVEKIT_AGENT_NAME", "normal-agent")
DEFAULT_AGENT_MODE = os.getenv("LIVEKIT_DEFAULT_AGENT_MODE", "normal").strip().lower()
AGENT_NAMES = {
"normal": os.getenv("LIVEKIT_NORMAL_AGENT_NAME", LEGACY_AGENT_NAME),
"beaver": os.getenv("LIVEKIT_BEAVER_AGENT_NAME", "beaver-agent"),
}
CHAT_MODE_AGENT_NAMES = {
"normal": AGENT_NAMES["normal"],
"beaver": AGENT_NAMES["beaver"],
"vision-normal": os.getenv("LIVEKIT_VISION_NORMAL_AGENT_NAME", "vision-normal-agent"),
"vision-beaver": os.getenv("LIVEKIT_VISION_BEAVER_AGENT_NAME", "vision-beaver-agent"),
}
CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0"))
AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0"))
WS_PORT = 8080
WS_MAX_QUEUE = int(os.getenv("BRIDGE_WS_MAX_QUEUE", "128"))
WS_MAX_SIZE = int(os.getenv("BRIDGE_WS_MAX_SIZE", str(8 * 1024 * 1024)))
AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower()
PROJECT_ROOT = Path(__file__).resolve().parent.parent
VISION_FRAME_SAVE_DIR = Path(os.getenv("VISION_FRAME_SAVE_DIR", str(PROJECT_ROOT / "vision_frames")))
INPUT_SAMPLE_RATE = int(os.getenv("BRIDGE_INPUT_SAMPLE_RATE", "16000"))
OUTPUT_SAMPLE_RATE = int(os.getenv("BRIDGE_OUTPUT_SAMPLE_RATE", "24000"))
INPUT_FRAME_DURATION_MS = int(os.getenv("BRIDGE_INPUT_FRAME_DURATION_MS", "60"))
INPUT_MAX_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * 120 // 1000
OUTPUT_FRAME_DURATION_MS = int(os.getenv("BRIDGE_OUTPUT_FRAME_DURATION_MS", "60"))
AUDIO_STATS_INTERVAL_SECONDS = float(os.getenv("BRIDGE_AUDIO_STATS_INTERVAL_SECONDS", "5.0"))
DOWNLINK_SEND_GAP_WARN_MS = float(os.getenv("BRIDGE_DOWNLINK_SEND_GAP_WARN_MS", "180.0"))
UPLINK_CAPTURE_TIMEOUT_SECONDS = float(os.getenv("BRIDGE_UPLINK_CAPTURE_TIMEOUT_SECONDS", "0.25"))
TTS_IDLE_TIMEOUT_SECONDS = float(os.getenv("TTS_IDLE_TIMEOUT_SECONDS", "1.2"))
TTS_MIN_ACTIVE_SECONDS = float(os.getenv("TTS_MIN_ACTIVE_SECONDS", "1.0"))
TTS_SILENCE_PEAK_THRESHOLD = 96
TTS_PRE_ROLL_MS = int(os.getenv("TTS_PRE_ROLL_MS", "480"))
TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = int(os.getenv("TTS_START_CONSECUTIVE_AUDIBLE_FRAMES", "1"))
TTS_INTERRUPT_SILENCE_FRAMES = 3
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
MCP_TOPIC = "mcp"
AGENT_STATE_ATTRIBUTE = "lk.agent.state"
TTS_DISPLAY_SENTENCE_BREAKS = "。!?!?;"
TTS_DISPLAY_SCROLL_WIDTH = int(os.getenv("TTS_DISPLAY_SCROLL_WIDTH", "18"))
TTS_DISPLAY_SCROLL_INTERVAL_SECONDS = float(os.getenv("TTS_DISPLAY_SCROLL_INTERVAL_SECONDS", "0.18"))
TTS_DISPLAY_SCROLL_GAP = " "
TTS_INTERRUPT_SUPPRESS_SECONDS = float(os.getenv("TTS_INTERRUPT_SUPPRESS_SECONDS", "0.8"))
TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS = float(
os.getenv("TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS", "0.25")
)
TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS = float(
os.getenv("TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS", "8.0")
)
EMOTION_TEXT_PATTERN = re.compile(
r"^\s*<?\s*emotion\s*=\s*([^\s>,;]+)\s*>?[\s,;]*(.*)$",
re.DOTALL,
)
EMOTION_TEST_SEQUENCE = [
emotion.strip()
for emotion in os.getenv("BRIDGE_EMOTION_TEST_SEQUENCE", "").split(",")
if emotion.strip()
]
EMOTION_TEST_INTERVAL_SECONDS = float(os.getenv("BRIDGE_EMOTION_TEST_INTERVAL_SECONDS", "2.0"))
@dataclass
class DeviceSession:
device_id: str
websocket: Any
protocol_version: int
room_name: str
identity: str
chat_mode: str
agent_mode: str
agent_name: str
vision_enabled: bool
room: rtc.Room
mic_source: AudioSource
agent_ready: asyncio.Event
forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict)
tts_active: bool = False
tts_thinking: bool = False
tts_idle_task: Optional[asyncio.Task] = None
tts_display_task: Optional[asyncio.Task] = None
tts_stream_id: int = 0
tts_transcript_text: str = ""
tts_display_text: str = ""
tts_display_final: bool = False
tts_emotion: str = ""
tts_suppressed_until: float = 0.0
tts_started_at: float = 0.0
tts_last_audible_at: float = 0.0
tts_waiting_for_user_audio_after_interrupt: bool = False
last_interrupt_time: float = 0.0
last_uplink_audible_time: float = 0.0
agent_dispatch_task: Optional[asyncio.Task] = None
closed: bool = False
captured_frame_count: int = 0
first_capture_log_time: float = 0.0
async def fetch_token(room_name: str, identity: str, agent_name: str) -> str:
params = {"room": room_name, "identity": identity}
if AGENT_DISPATCH_MODE == "token":
params["agent_name"] = agent_name
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
response = await client.get(TOKEN_URL, params=params)
response.raise_for_status()
payload: dict[str, Any] = response.json()
token = payload.get("token")
if not isinstance(token, str) or not token:
raise ValueError(f"token response missing token field: {payload}")
# print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
# print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
return token
class ESP32LiveKitBridge:
def __init__(self) -> None:
self.device_sessions: dict[str, DeviceSession] = {}
def _log_exception(self, prefix: str, exc: BaseException) -> None:
print(f"{prefix}: type={type(exc).__name__} detail={exc!r}")
formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
if formatted_tb.strip():
print(formatted_tb.rstrip())
def _audio_duration_ms(self, sample_count: int, sample_rate: int) -> float:
if sample_rate <= 0:
return 0.0
return sample_count * 1000.0 / sample_rate
def _build_server_hello(self, session: DeviceSession) -> dict[str, Any]:
return {
"type": "hello",
"transport": "websocket",
"session": {
"room": session.room_name,
"identity": session.identity,
},
"audio_params": {
"format": "opus",
"sample_rate": OUTPUT_SAMPLE_RATE,
"channels": 1,
"frame_duration": OUTPUT_FRAME_DURATION_MS,
},
}
def _log_client_hello(self, session: DeviceSession, message: dict[str, Any]) -> None:
audio_params = message.get("audio_params")
if not isinstance(audio_params, dict):
return
sample_rate = audio_params.get("sample_rate")
frame_duration = audio_params.get("frame_duration")
channels = audio_params.get("channels")
fmt = audio_params.get("format")
print(
"[client-audio] "
f"device={session.device_id} format={fmt} sample_rate={sample_rate} "
f"channels={channels} frame_duration={frame_duration}"
)
if sample_rate != INPUT_SAMPLE_RATE or channels != 1:
print(
"[client-audio] warning: bridge uplink decode expects "
f"{INPUT_SAMPLE_RATE}Hz mono, got {sample_rate}Hz channels={channels}"
)
if frame_duration != INPUT_FRAME_DURATION_MS:
print(
"[client-audio] warning: bridge expects "
f"{INPUT_FRAME_DURATION_MS}ms uplink frames, got {frame_duration}ms"
)
def _track_room_connect_task(
self,
session: DeviceSession,
task: asyncio.Task[Any],
) -> None:
if task.cancelled():
return
try:
task.result()
except Exception as exc:
self._log_exception(
f"LiveKit 房间连接后台任务失败: room={session.room_name}",
exc,
)
websocket = session.websocket
if websocket is not None:
asyncio.create_task(websocket.close(code=1011, reason="livekit connect failed"))
async def _capture_mic_frame(
self,
session: DeviceSession,
pcm_bytes: bytes,
num_samples: int,
) -> bool:
try:
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
except TypeError:
frame = AudioFrame.create(
sample_rate=INPUT_SAMPLE_RATE,
num_channels=1,
samples_per_channel=num_samples,
)
memoryview(frame.data).cast("B")[:] = pcm_bytes
try:
await asyncio.wait_for(
session.mic_source.capture_frame(frame),
timeout=UPLINK_CAPTURE_TIMEOUT_SECONDS,
)
return True
except asyncio.TimeoutError:
print(
"[uplink] warning: capture_frame timeout, dropping frame "
f"device={session.device_id} samples={num_samples}"
)
return False
def _build_device_id(self, websocket: Any) -> str:
headers = websocket.request.headers
requested_id = headers.get("X-Device-Id") or headers.get("Device-Id")
if requested_id:
return requested_id
return f"ws-{id(websocket):x}"
def _build_session_names(self, device_id: str) -> tuple[str, str]:
session_tag = uuid.uuid4().hex[:8]
room_name = f"{ROOM_PREFIX}-{session_tag}"
identity = f"{IDENTITY_PREFIX}-{session_tag}"
print(f"[session] device={device_id} room={room_name} identity={identity}")
return room_name, identity
def _resolve_agent_selection(self, headers: Any) -> tuple[str, str, str, bool]:
requested_chat_mode = (
headers.get("Chat-Mode")
or headers.get("X-Chat-Mode")
or ""
).strip().lower()
chat_mode_to_agent = {
"normal": ("normal", False),
"beaver": ("beaver", False),
"vision-normal": ("normal", True),
"vision-beaver": ("beaver", True),
}
if requested_chat_mode in chat_mode_to_agent:
requested_mode, vision_enabled = chat_mode_to_agent[requested_chat_mode]
else:
if requested_chat_mode:
print(f"未知 Chat-Mode={requested_chat_mode!r},回退到 Agent-Mode")
requested_mode = (
headers.get("Agent-Mode")
or headers.get("X-Agent-Mode")
or DEFAULT_AGENT_MODE
or "normal"
).strip().lower()
vision_enabled = False
requested_chat_mode = requested_mode
requested_name = headers.get("Agent-Name") or headers.get("X-Agent-Name")
if requested_name:
return requested_chat_mode, "custom", requested_name, vision_enabled
if requested_mode not in AGENT_NAMES:
print(f"未知 Agent-Mode={requested_mode!r},回退到 normal")
requested_mode = "normal"
requested_chat_mode = "vision-normal" if vision_enabled else "normal"
if requested_chat_mode in CHAT_MODE_AGENT_NAMES:
return (
requested_chat_mode,
requested_mode,
CHAT_MODE_AGENT_NAMES[requested_chat_mode],
vision_enabled,
)
return requested_chat_mode, requested_mode, AGENT_NAMES[requested_mode], vision_enabled
def _is_agent_participant(self, participant: rtc.RemoteParticipant, agent_name: str) -> bool:
identity = getattr(participant, "identity", "") or ""
return identity.startswith("agent-") or agent_name in identity
def _get_agent_identities(self, session: DeviceSession) -> list[str]:
return [
participant.identity
for participant in session.room.remote_participants.values()
if self._is_agent_participant(participant, session.agent_name)
]
def _log_agent_participants(self, session: DeviceSession, source: str) -> None:
agent_identities = self._get_agent_identities(session)
# print(
# f"[agent-check] source={source} room={session.room_name} "
# f"agent_count={len(agent_identities)} agents={agent_identities}"
# )
async def _has_existing_dispatch(self, session: DeviceSession) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None)
if request_cls is None:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
response = await lkapi.agent_dispatch.list_dispatch(
request_cls(room=session.room_name)
)
dispatches = (
getattr(response, "agent_dispatches", None)
or getattr(response, "dispatches", None)
or getattr(response, "items", None)
or []
)
for dispatch in dispatches:
dispatch_agent_name = getattr(dispatch, "agent_name", None)
dispatch_room = getattr(dispatch, "room", None)
if dispatch_room == session.room_name and (
dispatch_agent_name == session.agent_name or dispatch_agent_name is None
):
print(
f"检测到已有 dispatch: room={session.room_name} "
f"agent={dispatch_agent_name}"
)
return True
except Exception as exc:
print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}")
finally:
await lkapi.aclose()
return False
async def ensure_agent_dispatched(self, session: DeviceSession) -> None:
if AGENT_DISPATCH_MODE != "bridge":
# print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}")
return
if await self._has_existing_dispatch(session):
return
for participant in session.room.remote_participants.values():
if self._is_agent_participant(participant, session.agent_name):
# print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
return
if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done():
# print("Agent dispatch 正在进行中,跳过重复请求")
return
session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session))
await session.agent_dispatch_task
async def _dispatch_agent(self, session: DeviceSession) -> None:
print(f"准备 dispatch agent: room={session.room_name}, agent={session.agent_name}")
try:
if await self._dispatch_agent_with_sdk(session):
return
if await self._dispatch_agent_with_cli(session):
return
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
except Exception as exc:
print(f"Agent dispatch 失败: {exc}")
async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool:
if livekit_api is None:
return False
api_key = os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_API_SECRET")
if not api_key or not api_secret:
return False
lkapi = livekit_api.LiveKitAPI(
url=LIVEKIT_WS_URL,
api_key=api_key,
api_secret=api_secret,
)
try:
dispatch = await lkapi.agent_dispatch.create_dispatch(
livekit_api.CreateAgentDispatchRequest(
agent_name=session.agent_name,
room=session.room_name,
metadata=json.dumps(
{
"source": "bridge_server",
"identity": session.identity,
"device_id": session.device_id,
"chat_mode": session.chat_mode,
"agent_mode": session.agent_mode,
"agent_name": session.agent_name,
"vision_enabled": session.vision_enabled,
}
),
)
)
print(f"Agent dispatch 已创建: {dispatch}")
finally:
await lkapi.aclose()
return True
async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool:
lk_path = shutil.which("lk")
if lk_path is None:
return False
process = await asyncio.create_subprocess_exec(
lk_path,
"dispatch",
"create",
"--room",
session.room_name,
"--agent-name",
session.agent_name,
"--metadata",
json.dumps(
{
"source": "bridge_server",
"identity": session.identity,
"device_id": session.device_id,
"chat_mode": session.chat_mode,
"agent_mode": session.agent_mode,
"agent_name": session.agent_name,
"vision_enabled": session.vision_enabled,
}
),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode("utf-8", errors="replace").strip()
stderr_text = stderr.decode("utf-8", errors="replace").strip()
if stdout_text:
print(f"lk dispatch 输出: {stdout_text}")
if stderr_text:
print(f"lk dispatch 错误输出: {stderr_text}")
if process.returncode != 0:
print(f"lk dispatch create 失败,退出码: {process.returncode}")
return False
print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={session.agent_name}")
return True
async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool:
participant = getattr(session.room, "local_participant", None)
if participant is None:
print("跳过发送 agent 控制事件local participant 尚未就绪")
return False
data = json.dumps(payload).encode("utf-8")
agent_identities = self._get_agent_identities(session)
kwargs: dict[str, Any] = {}
if agent_identities:
kwargs["destination_identities"] = agent_identities
last_error: Optional[Exception] = None
for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs):
try:
await participant.publish_data(data, **attempt)
print(
f"已发送 agent 控制事件: topic={attempt.get('topic', '<default>')} "
f"targets={agent_identities or 'broadcast'} payload={payload}"
)
return True
except TypeError as exc:
last_error = exc
except Exception as exc:
print(f"发送 agent 控制事件失败: {exc}")
return False
if last_error is not None:
print(f"发送 agent 控制事件失败publish_data 签名不兼容: {last_error}")
return False
async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None:
payload = {
"type": "interrupt",
"topic": INTERRUPT_TOPIC,
"reason": reason,
"room": session.room_name,
"identity": session.identity,
"device_id": session.device_id,
}
ok = await self._publish_agent_event(session, payload)
if not ok:
print("警告: bridge 已停止 TTS但 agent 侧 interrupt 未确认送出")
def _save_vision_frame(self, session: DeviceSession, image: str) -> Optional[Path]:
try:
image_bytes = base64.b64decode(image, validate=True)
except Exception as exc:
print(f"vision frame base64 解码失败: {exc}")
return None
safe_device_id = "".join(
char if char.isalnum() or char in ("-", "_") else "_"
for char in session.device_id
)
timestamp_ms = int(time.time() * 1000)
VISION_FRAME_SAVE_DIR.mkdir(parents=True, exist_ok=True)
path = VISION_FRAME_SAVE_DIR / f"{timestamp_ms}_{safe_device_id}.jpg"
path.write_bytes(image_bytes)
return path
async def _publish_vision_frame(self, session: DeviceSession, message: dict[str, Any]) -> None:
image = message.get("image")
if not isinstance(image, str) or not image:
print("收到 vision frame但 image 字段为空")
return
saved_path = await asyncio.to_thread(self._save_vision_frame, session, image)
if saved_path is None:
return
print(f"已保存 vision frame: {saved_path}")
participant = getattr(session.room, "local_participant", None)
if participant is None:
print("跳过发送 vision framelocal participant 尚未就绪")
return
payload = {
"type": "vision_frame",
"topic": VISION_FRAME_TOPIC,
"room": session.room_name,
"identity": session.identity,
"device_id": session.device_id,
"mime_type": message.get("mime_type", "image/jpeg"),
"image": image,
"saved_path": str(saved_path),
}
data = json.dumps(payload).encode("utf-8")
agent_identities = self._get_agent_identities(session)
kwargs: dict[str, Any] = {}
if agent_identities:
kwargs["destination_identities"] = agent_identities
last_error: Optional[Exception] = None
for attempt in ({"topic": VISION_FRAME_TOPIC, **kwargs}, kwargs):
try:
await participant.publish_data(data, **attempt)
print(
f"已发送 vision frame: bytes={len(data)} "
f"targets={agent_identities or 'broadcast'}"
)
return
except TypeError as exc:
last_error = exc
except Exception as exc:
print(f"发送 vision frame 失败: {exc}")
return
if last_error is not None:
print(f"发送 vision frame 失败publish_data 签名不兼容: {last_error}")
async def _publish_mcp_message(self, session: DeviceSession, message: dict[str, Any]) -> None:
payload = message.get("payload")
if not isinstance(payload, dict):
print(f"收到 ESP32 MCP 消息但缺少 payload: {message}")
return
participant = getattr(session.room, "local_participant", None)
if participant is None:
print("跳过发送 MCP 消息local participant 尚未就绪")
return
outbound = {
"type": "mcp",
"topic": MCP_TOPIC,
"room": session.room_name,
"identity": session.identity,
"device_id": session.device_id,
"payload": payload,
}
data = json.dumps(outbound, ensure_ascii=False).encode("utf-8")
agent_identities = self._get_agent_identities(session)
kwargs: dict[str, Any] = {}
if agent_identities:
kwargs["destination_identities"] = agent_identities
last_error: Optional[Exception] = None
for attempt in ({"topic": MCP_TOPIC, **kwargs}, kwargs):
try:
await participant.publish_data(data, **attempt)
print(
f"已发送 MCP 响应: id={payload.get('id')} "
f"targets={agent_identities or 'broadcast'}"
)
return
except TypeError as exc:
last_error = exc
except Exception as exc:
print(f"发送 MCP 响应失败: {exc}")
return
if last_error is not None:
print(f"发送 MCP 响应失败publish_data 签名不兼容: {last_error}")
async def _forward_mcp_to_device(
self,
session: DeviceSession,
payload: dict[str, Any],
*,
source_identity: str,
) -> None:
if session.websocket is None:
print("跳过 MCP 请求ESP32 尚未连接")
return
await session.websocket.send(
json.dumps(
{
"type": "mcp",
"payload": payload,
},
ensure_ascii=False,
)
)
print(
f"已转发 MCP 请求到 ESP32: id={payload.get('id')} "
f"method={payload.get('method')} source={source_identity}"
)
async def _send_tts_state(self, session: DeviceSession, state: str) -> None:
if session.websocket is None:
print(f"跳过 tts {state}ESP32 尚未连接")
return
await session.websocket.send(json.dumps({"type": "tts", "state": state}))
print(f"已发送 tts {state}: device={session.device_id}")
async def _send_emotion(self, session: DeviceSession, emotion: str) -> None:
if session.websocket is None:
print(f"跳过 emotion {emotion}ESP32 尚未连接")
return
await session.websocket.send(json.dumps({"type": "llm", "emotion": emotion}))
print(f"已发送 emotion: device={session.device_id} emotion={emotion}")
def _parse_emotion_text(self, text: str) -> tuple[Optional[str], str]:
match = EMOTION_TEXT_PATTERN.match(text)
if match is None:
return None, text.strip()
emotion, tts_text = match.groups()
return emotion.strip(), tts_text.strip()
async def _run_emotion_test_sequence(self, session: DeviceSession) -> None:
if not EMOTION_TEST_SEQUENCE:
return
for index, emotion in enumerate(EMOTION_TEST_SEQUENCE):
if session.websocket is None or session.closed:
return
if index > 0:
await asyncio.sleep(EMOTION_TEST_INTERVAL_SECONDS)
await self._send_emotion(session, emotion)
async def _send_tts_text(self, session: DeviceSession, text: str, final: bool) -> None:
if session.websocket is None:
return
raw_text = text
_emotion, text = self._parse_emotion_text(text)
if not text:
print(f"[tts->esp32] skip empty text: raw={raw_text!r} final={final}")
return
print(f"[tts->esp32] text={text!r} final={final}")
await session.websocket.send(
json.dumps(
{
"type": "tts",
"state": "sentence_start",
"text": text,
"final": final,
}
)
)
def _cancel_tts_display_task(self, session: DeviceSession) -> None:
if session.tts_display_task is not None:
session.tts_display_task.cancel()
session.tts_display_task = None
def _update_tts_display_text(self, session: DeviceSession, text: str, final: bool) -> None:
session.tts_display_text = text
session.tts_display_final = final
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
self._cancel_tts_display_task(session)
asyncio.create_task(self._send_tts_text(session, text, final))
return
if session.tts_display_task is None or session.tts_display_task.done():
session.tts_display_task = asyncio.create_task(
self._scroll_tts_display_text(session, session.tts_stream_id)
)
async def _scroll_tts_display_text(self, session: DeviceSession, stream_id: int) -> None:
offset = 0
last_sent = ""
try:
while stream_id == session.tts_stream_id and session.websocket is not None:
text = session.tts_display_text
if not text:
return
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
if text != last_sent:
await self._send_tts_text(session, text, session.tts_display_final)
return
scroll_text = text + TTS_DISPLAY_SCROLL_GAP
if offset >= len(scroll_text):
offset = 0
looped_text = scroll_text + scroll_text[:TTS_DISPLAY_SCROLL_WIDTH]
visible_text = looped_text[offset:offset + TTS_DISPLAY_SCROLL_WIDTH].rstrip()
if visible_text and visible_text != last_sent:
await self._send_tts_text(
session,
visible_text,
session.tts_display_final,
)
last_sent = visible_text
offset += 1
await asyncio.sleep(TTS_DISPLAY_SCROLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
pass
except Exception as exc:
print(f"TTS 字幕滚动失败: {exc}")
async def _start_tts(self, session: DeviceSession) -> None:
if session.tts_active:
print("跳过 tts start当前已处于激活状态")
return
block_reason = self._tts_resume_block_reason(session, include_user_quiet=False)
if block_reason is not None:
print(f"跳过 tts start打断后仍在等待稳定聆听: {block_reason}")
return
if time.monotonic() < session.tts_suppressed_until:
print("跳过 tts start中断后的残留音频仍在抑制窗口内")
return
if not session.tts_display_text:
session.tts_transcript_text = ""
session.tts_display_final = False
session.tts_emotion = ""
self._cancel_tts_display_task(session)
now = time.monotonic()
session.tts_started_at = now
session.tts_last_audible_at = now
await self._send_tts_state(session, "start")
session.tts_active = True
session.tts_thinking = False
async def _start_thinking(self, session: DeviceSession) -> None:
if session.tts_active:
print("跳过 tts thinking当前已处于 TTS 播放状态")
return
if session.tts_thinking:
print("跳过 tts thinking当前已处于思考状态")
return
block_reason = self._tts_resume_block_reason(session, include_user_quiet=False)
if block_reason is not None:
print(f"跳过 tts thinking打断后仍在等待稳定聆听: {block_reason}")
return
if time.monotonic() < session.tts_suppressed_until:
print("跳过 tts thinking中断后的残留音频仍在抑制窗口内")
return
await self._send_tts_state(session, "thinking")
session.tts_thinking = True
def _tts_resume_block_reason(
self,
session: DeviceSession,
now: Optional[float] = None,
*,
include_user_quiet: bool = True,
) -> Optional[str]:
if now is None:
now = time.monotonic()
if session.tts_waiting_for_user_audio_after_interrupt:
return "waiting_for_user_audio_after_interrupt"
if session.last_interrupt_time <= 0.0:
return None
since_interrupt = now - session.last_interrupt_time
if since_interrupt > TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS:
return None
if session.last_uplink_audible_time < session.last_interrupt_time:
return None
if not include_user_quiet:
return None
quiet_for = now - session.last_uplink_audible_time
if quiet_for < TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS:
return f"user_audio_quiet_for={quiet_for:.2f}s"
return None
def _handle_agent_state(self, session: DeviceSession, participant: rtc.Participant) -> None:
state = participant.attributes.get(AGENT_STATE_ATTRIBUTE)
if not isinstance(state, str) or not state:
return
print(
f"[agent-state] room={session.room_name} identity={participant.identity} state={state}"
)
if state == "thinking":
asyncio.create_task(self._start_thinking(session))
async def _stop_tts(self, session: DeviceSession) -> None:
if not session.tts_active and not session.tts_thinking:
print("跳过 tts stop当前未激活")
return
self._cancel_tts_display_task(session)
await self._send_tts_state(session, "stop")
session.tts_active = False
session.tts_thinking = False
session.tts_started_at = 0.0
session.tts_last_audible_at = 0.0
session.tts_transcript_text = ""
session.tts_display_text = ""
session.tts_display_final = False
session.tts_emotion = ""
async def _force_stop_tts(self, session: DeviceSession, reason: str) -> None:
self._cancel_tts_display_task(session)
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
session.tts_active = False
session.tts_thinking = False
session.tts_started_at = 0.0
session.tts_last_audible_at = 0.0
session.tts_transcript_text = ""
session.tts_display_text = ""
session.tts_display_final = False
session.tts_emotion = ""
await self._send_tts_state(session, "stop")
print(f"已强制停止本地 TTS: device={session.device_id} reason={reason}")
async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None:
print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}")
now = time.monotonic()
session.tts_stream_id += 1
session.last_interrupt_time = now
session.tts_suppressed_until = now + TTS_INTERRUPT_SUPPRESS_SECONDS
session.tts_waiting_for_user_audio_after_interrupt = True
await self._force_stop_tts(session, reason)
asyncio.create_task(self._send_agent_interrupt(session, reason))
def _reset_tts_idle_timer(self, session: DeviceSession) -> None:
session.tts_last_audible_at = time.monotonic()
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = asyncio.create_task(
self._tts_idle_watchdog(session, session.tts_stream_id)
)
async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None:
try:
while True:
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
if stream_id != session.tts_stream_id or not session.tts_active:
return
now = time.monotonic()
idle_for = now - session.tts_last_audible_at
active_for = now - session.tts_started_at
remaining = max(
TTS_IDLE_TIMEOUT_SECONDS - idle_for,
TTS_MIN_ACTIVE_SECONDS - active_for,
)
if remaining > 0:
await asyncio.sleep(remaining)
continue
print(
"TTS 静音达到阈值,切回聆听状态: "
f"idle={idle_for:.2f}s active={active_for:.2f}s"
)
await self._stop_tts(session)
return
except asyncio.CancelledError:
pass
def _has_audible_audio(self, pcm_data: bytes) -> bool:
if len(pcm_data) < 2:
return False
sample_count = len(pcm_data) // 2
peak = 0
for i in range(sample_count):
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > peak:
peak = sample
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
return True
return False
def _maybe_forward_remote_audio(
self,
session: DeviceSession,
track: rtc.Track,
publication: Optional[rtc.TrackPublication],
participant: rtc.RemoteParticipant,
source: str,
) -> None:
if track.kind != rtc.TrackKind.KIND_AUDIO:
return
track_sid = (
getattr(track, "sid", None)
or getattr(publication, "sid", None)
or f"audio:{participant.identity}"
)
existing_task = session.forwarding_tracks.get(track_sid)
if existing_task is not None and not existing_task.done():
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
return
if existing_task is not None and existing_task.done():
print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}")
session.forwarding_tracks.pop(track_sid, None)
task = asyncio.create_task(
self.forward_audio_to_esp32(
session,
track_sid,
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1),
)
)
session.forwarding_tracks[track_sid] = task
print(
f"收到音频流: {participant.identity} sid={track_sid} "
f"source={source} room={session.room_name}"
)
def _scan_participant_audio_tracks(
self,
session: DeviceSession,
participant: rtc.RemoteParticipant,
source: str,
) -> None:
publications = getattr(participant, "track_publications", None) or {}
for publication in publications.values():
track = getattr(publication, "track", None)
if track is None:
continue
self._maybe_forward_remote_audio(session, track, publication, participant, source)
def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes:
if session.protocol_version == 2:
if len(message) < 16:
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
"!HHIII", message[:16]
)
if version != 2:
raise ValueError(f"version 2 包头版本异常: {version}")
if msg_type != 0:
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
if len(message) < 16 + payload_size:
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
return message[16:16 + payload_size]
if session.protocol_version == 3:
if len(message) < 4:
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
if msg_type != 0:
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
if len(message) < 4 + payload_size:
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
return message[4:4 + payload_size]
return message
def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes:
if session.protocol_version == 2:
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
return header + payload
if session.protocol_version == 3:
header = struct.pack("!BBH", 0, 0, len(payload))
return header + payload
return payload
def _current_tts_display_text(self, text: str) -> str:
normalized = " ".join(text.split())
if not normalized:
return ""
last_break = -1
previous_break = -1
for index, char in enumerate(normalized):
if char in TTS_DISPLAY_SENTENCE_BREAKS:
previous_break = last_break
last_break = index
if last_break == -1:
return normalized
if last_break == len(normalized) - 1:
start = previous_break + 1
else:
start = last_break + 1
return normalized[start:].strip() or normalized
def _register_room_handlers(self, session: DeviceSession) -> None:
@session.room.on("connection_state_changed")
def on_connection_state_changed(state: int) -> None:
# print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}")
pass
@session.room.on("connected")
def on_connected() -> None:
print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}")
self._log_agent_participants(session, "connected")
for participant in session.room.remote_participants.values():
if self._is_agent_participant(participant, session.agent_name):
session.agent_ready.set()
self._scan_participant_audio_tracks(session, participant, "connected_scan")
self._handle_agent_state(session, participant)
@session.room.on("participant_connected")
def on_participant_connected(participant: rtc.RemoteParticipant) -> None:
role = "Agent" if self._is_agent_participant(participant, session.agent_name) else "Remote participant"
print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}")
self._log_agent_participants(session, "participant_connected")
if self._is_agent_participant(participant, session.agent_name):
session.agent_ready.set()
self._scan_participant_audio_tracks(
session, participant, "participant_connected_scan"
)
self._handle_agent_state(session, participant)
@session.room.on("participant_disconnected")
def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None:
print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}")
session.forwarding_tracks = {
track_sid: task
for track_sid, task in session.forwarding_tracks.items()
if not track_sid.endswith(f":{participant.identity}")
}
@session.room.on("participant_attributes_changed")
def on_participant_attributes_changed(changed: list[str], participant: rtc.Participant) -> None:
if AGENT_STATE_ATTRIBUTE not in changed:
return
if not isinstance(participant, rtc.RemoteParticipant):
return
if not self._is_agent_participant(participant, session.agent_name):
return
self._handle_agent_state(session, participant)
@session.room.on("data_received")
def on_data_received(data_packet: rtc.DataPacket) -> None:
identity = data_packet.participant.identity if data_packet.participant else "未知"
packet_topic = getattr(data_packet, "topic", None)
try:
decoded = data_packet.data.decode("utf-8")
print(
f"📩 [数据接收 | room={session.room_name} | {identity}]: "
f"{decoded}"
)
except Exception:
decoded = ""
try:
payload = json.loads(decoded) if decoded else None
except Exception:
payload = None
if isinstance(payload, dict) and (
packet_topic == MCP_TOPIC
or payload.get("type") == "mcp"
or payload.get("topic") == MCP_TOPIC
):
mcp_payload = payload.get("payload")
if isinstance(mcp_payload, dict):
asyncio.create_task(
self._forward_mcp_to_device(
session,
mcp_payload,
source_identity=identity,
)
)
else:
print(f"收到 MCP 数据但缺少 payload: {payload}")
@session.room.on("transcription_received")
def on_transcription_received(
segments: list[rtc.TranscriptionSegment],
participant: rtc.Participant,
track_pub: rtc.TrackPublication,
) -> None:
identity = participant.identity if participant else "未知"
is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(
participant, session.agent_name
)
for segment in segments:
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}")
if is_agent:
if time.monotonic() < session.tts_suppressed_until:
continue
print(f"[livekit-llm] raw={segment.text!r} final={segment.final}")
emotion, tts_text = self._parse_emotion_text(segment.text)
print(
f"[livekit-llm] parsed emotion={emotion!r} "
f"tts_text={tts_text!r} final={segment.final}"
)
if emotion and emotion != session.tts_emotion:
session.tts_emotion = emotion
asyncio.create_task(self._send_emotion(session, emotion))
display_text = self._current_tts_display_text(tts_text)
print(f"[livekit-llm] display_text={display_text!r} final={segment.final}")
if not display_text or display_text == session.tts_transcript_text:
continue
session.tts_transcript_text = display_text
else:
if not segment.final:
continue
display_text = segment.text
asyncio.create_task(self._start_thinking(session))
if session.websocket is not None:
ws = session.websocket
if is_agent:
self._update_tts_display_text(session, display_text, segment.final)
else:
asyncio.create_task(
ws.send(
json.dumps(
{
"type": "stt",
"text": segment.text,
"final": segment.final,
}
)
)
)
@session.room.on("track_subscribed")
def on_track_subscribed(
track: rtc.Track,
publication: rtc.TrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
self._maybe_forward_remote_audio(session, track, publication, participant, "event")
@session.room.on("track_published")
def on_track_published(
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
track_sid = getattr(publication, "sid", None)
# print(
# f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} "
# f"track_sid={track_sid}"
# )
track = getattr(publication, "track", None)
if track is not None:
self._maybe_forward_remote_audio(session, track, publication, participant, "published")
async def _connect_session_room(self, session: DeviceSession) -> None:
self._register_room_handlers(session)
# print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
# print(f"[config] token_url={TOKEN_URL}")
# print(f"[config] room={session.room_name} identity={session.identity}")
# print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}")
token = await fetch_token(session.room_name, session.identity, session.agent_name)
try:
await session.room.connect(
LIVEKIT_WS_URL,
token,
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
)
except Exception as exc:
self._log_exception(
f"连接 LiveKit 房间失败: room={session.room_name}",
exc,
)
print(
"提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;"
"请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、"
"*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。"
)
raise
print(f"已连接到 LiveKit 房间: {session.room.name}")
# print(f"[livekit] local_identity={session.room.local_participant.identity}")
# print(f"[livekit] local_sid={session.room.local_participant.sid}")
# print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}")
self._log_agent_participants(session, "after_connect")
await self.ensure_agent_dispatched(session)
track = rtc.LocalAudioTrack.create_audio_track(
f"esp32-mic-{session.device_id}",
session.mic_source,
)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
publication = await session.room.local_participant.publish_track(track, options)
publication_sid = getattr(publication, "sid", None)
track_sid = getattr(track, "sid", None)
# print(
# f"已发布 ESP32 mic track: room={session.room_name} "
# f"track_sid={track_sid} publication_sid={publication_sid}"
# )
self._log_agent_participants(session, "after_publish_mic")
# print(f"等待 agent 加入: room={session.room_name}")
try:
await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS)
if session.closed:
return
# print(f"✅ agent 已就绪: room={session.room_name}")
except asyncio.TimeoutError:
if session.closed:
return
print(f"⚠️ agent 等待超时: room={session.room_name}")
async def start(self) -> None:
print(f"[config] websocket_port={WS_PORT}")
print(f"[config] websocket_max_queue={WS_MAX_QUEUE} websocket_max_size={WS_MAX_SIZE}")
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
print(f"[config] token_url={TOKEN_URL}")
print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}")
print(
"[config] audio="
f"uplink_decode:{INPUT_SAMPLE_RATE}Hz/{INPUT_FRAME_DURATION_MS}ms "
f"downlink_encode:{OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms "
f"stats_interval:{AUDIO_STATS_INTERVAL_SECONDS}s "
f"capture_timeout:{UPLINK_CAPTURE_TIMEOUT_SECONDS}s "
f"tts_idle:{TTS_IDLE_TIMEOUT_SECONDS}s "
f"tts_min_active:{TTS_MIN_ACTIVE_SECONDS}s "
f"tts_start_frames:{TTS_START_CONSECUTIVE_AUDIBLE_FRAMES} "
f"tts_pre_roll:{TTS_PRE_ROLL_MS}ms"
)
print(
"[config] agents="
f"normal:{CHAT_MODE_AGENT_NAMES['normal']} "
f"beaver:{CHAT_MODE_AGENT_NAMES['beaver']} "
f"vision-normal:{CHAT_MODE_AGENT_NAMES['vision-normal']} "
f"vision-beaver:{CHAT_MODE_AGENT_NAMES['vision-beaver']} "
f"default_mode:{DEFAULT_AGENT_MODE}"
)
if EMOTION_TEST_SEQUENCE:
print(
"[config] emotion_test_sequence="
f"{','.join(EMOTION_TEST_SEQUENCE)} "
f"interval={EMOTION_TEST_INTERVAL_SECONDS}s"
)
async def close(self) -> None:
for session in list(self.device_sessions.values()):
await self._close_session(session)
async def _close_session(self, session: DeviceSession) -> None:
if session.closed:
return
session.closed = True
session.websocket = None
session.agent_ready.set()
session.tts_active = False
session.tts_thinking = False
session.tts_stream_id += 1
if session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
self._cancel_tts_display_task(session)
try:
await session.room.disconnect()
except Exception as exc:
print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}")
async def forward_audio_to_esp32(
self,
session: DeviceSession,
track_sid: str,
audio_stream: rtc.AudioStream,
) -> None:
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip")
pending_pcm = bytearray()
pre_roll_pcm = bytearray()
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
output_samples_per_opus_frame = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
output_frame_bytes = output_samples_per_opus_frame * 2
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = False
downlink_packets = 0
downlink_audio_ms = 0.0
last_downlink_stats_time = time.monotonic()
last_send_time: Optional[float] = None
stream_id = session.tts_stream_id
print(
f"启动 TTS 转发: device={session.device_id} room={session.room_name} "
f"track_sid={track_sid} stream_id={stream_id} "
f"opus={OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms"
)
try:
async for event in audio_stream:
if stream_id != session.tts_stream_id:
print("检测到 TTS 被中断,停止当前播放并等待静音分隔")
pending_pcm.clear()
pre_roll_pcm.clear()
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = True
stream_id = session.tts_stream_id
if session.tts_active:
await self._stop_tts(session)
continue
frame = event.frame
pcm_data = frame.data.tobytes()
has_audible_audio = self._has_audible_audio(pcm_data)
now = time.monotonic()
if now < session.tts_suppressed_until:
pending_pcm.clear()
pre_roll_pcm.clear()
audible_frame_streak = 0
silence_frame_streak = 0
waiting_for_post_interrupt_silence = True
continue
block_reason = self._tts_resume_block_reason(
session,
now,
include_user_quiet=False,
)
if block_reason is not None:
pending_pcm.clear()
pre_roll_pcm.clear()
audible_frame_streak = 0
silence_frame_streak = 0
if block_reason == "waiting_for_user_audio_after_interrupt":
waiting_for_post_interrupt_silence = True
continue
if (
waiting_for_post_interrupt_silence
and session.last_interrupt_time > 0.0
and session.last_uplink_audible_time >= session.last_interrupt_time
and now - session.last_uplink_audible_time
>= TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS
):
print("检测到用户打断后语音已结束,允许新 TTS 直接起播")
waiting_for_post_interrupt_silence = False
silence_frame_streak = 0
if waiting_for_post_interrupt_silence:
if has_audible_audio:
silence_frame_streak = 0
continue
silence_frame_streak += 1
if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES:
continue
print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播")
waiting_for_post_interrupt_silence = False
silence_frame_streak = 0
stream_id = session.tts_stream_id
continue
current_frame_buffered = False
if not session.tts_active:
pre_roll_pcm.extend(pcm_data)
current_frame_buffered = True
if len(pre_roll_pcm) > pre_roll_max_bytes:
del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes]
if not has_audible_audio:
audible_frame_streak = 0
continue
audible_frame_streak += 1
if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES:
continue
await self._start_tts(session)
if not session.tts_active:
continue
print(
"检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 "
f"(frames={audible_frame_streak})"
)
pending_pcm.extend(pre_roll_pcm)
pre_roll_pcm.clear()
audible_frame_streak = 0
if has_audible_audio:
self._reset_tts_idle_timer(session)
if not current_frame_buffered:
pending_pcm.extend(pcm_data)
while (
len(pending_pcm) >= output_frame_bytes
and stream_id == session.tts_stream_id
and session.websocket is not None
):
try:
now = time.monotonic()
if last_send_time is not None:
send_gap_ms = (now - last_send_time) * 1000.0
if send_gap_ms > DOWNLINK_SEND_GAP_WARN_MS:
print(
"[downlink] warning: send gap "
f"{send_gap_ms:.1f}ms device={session.device_id} "
f"pending_ms={self._audio_duration_ms(len(pending_pcm) // 2, OUTPUT_SAMPLE_RATE):.1f}"
)
last_send_time = now
opus_packet = encoder.encode(
bytes(pending_pcm[:output_frame_bytes]),
output_samples_per_opus_frame,
)
del pending_pcm[:output_frame_bytes]
await session.websocket.send(self._wrap_opus_payload(session, opus_packet))
downlink_packets += 1
downlink_audio_ms += OUTPUT_FRAME_DURATION_MS
if now - last_downlink_stats_time >= AUDIO_STATS_INTERVAL_SECONDS:
print(
"[downlink] "
f"device={session.device_id} packets={downlink_packets} "
f"audio_ms={downlink_audio_ms:.0f} "
f"pending_ms={self._audio_duration_ms(len(pending_pcm) // 2, OUTPUT_SAMPLE_RATE):.1f}"
)
downlink_packets = 0
downlink_audio_ms = 0.0
last_downlink_stats_time = now
except Exception as exc:
print(f"发送回 ESP32 失败: {exc}")
break
except Exception as exc:
print(f"音频流处理错误: {exc}")
finally:
print("🎧 TTS 音频结束")
task = session.forwarding_tracks.get(track_sid)
current_task = asyncio.current_task()
if task is current_task:
session.forwarding_tracks.pop(track_sid, None)
if stream_id == session.tts_stream_id and session.tts_idle_task is not None:
session.tts_idle_task.cancel()
session.tts_idle_task = None
if stream_id == session.tts_stream_id:
await self._stop_tts(session)
async def handle_websocket(self, websocket: Any) -> None:
header_version = websocket.request.headers.get("Protocol-Version")
try:
protocol_version = int(header_version) if header_version else 1
except ValueError:
protocol_version = 1
device_id = self._build_device_id(websocket)
existing_session = self.device_sessions.get(device_id)
if existing_session is not None:
print(f"检测到重复 device_id关闭旧 session: device={device_id}")
await self._close_session(existing_session)
self.device_sessions.pop(device_id, None)
chat_mode, agent_mode, agent_name, vision_enabled = self._resolve_agent_selection(websocket.request.headers)
room_name, identity = self._build_session_names(device_id)
session = DeviceSession(
device_id=device_id,
websocket=websocket,
protocol_version=protocol_version,
room_name=room_name,
identity=identity,
chat_mode=chat_mode,
agent_mode=agent_mode,
agent_name=agent_name,
vision_enabled=vision_enabled,
room=rtc.Room(),
mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1),
agent_ready=asyncio.Event(),
)
self.device_sessions[device_id] = session
print(f"ESP32 已连接: device={device_id}")
print(f"ESP32 协议版本: {session.protocol_version}")
print(
f"ESP32 mode: chat={session.chat_mode} "
f"agent={session.agent_mode}/{session.agent_name} "
f"vision={session.vision_enabled}"
)
session.tts_stream_id += 1
opus_decoder = None
uplink_packets = 0
uplink_audio_ms = 0.0
uplink_decode_errors = 0
uplink_dropped_frames = 0
last_uplink_stats_time = time.monotonic()
room_connect_task: Optional[asyncio.Task[Any]] = None
try:
hello_msg = self._build_server_hello(session)
await websocket.send(json.dumps(hello_msg))
print(
f"已发送 server hello: device={device_id} room={session.room_name} "
f"audio={OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms"
)
asyncio.create_task(self._run_emotion_test_sequence(session))
room_connect_task = asyncio.create_task(self._connect_session_room(session))
room_connect_task.add_done_callback(
lambda task: self._track_room_connect_task(session, task)
)
async for message in websocket:
if isinstance(message, bytes):
if len(message) < 4:
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
continue
audio_data = self._extract_opus_payload(session, message)
if not audio_data:
continue
try:
if opus_decoder is None:
# print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
pcm_bytes = opus_decoder.decode(audio_data, INPUT_MAX_SAMPLES_PER_OPUS_FRAME)
num_samples = len(pcm_bytes) // 2
if num_samples > 0:
session.captured_frame_count += 1
now = time.monotonic()
uplink_packets += 1
uplink_audio_ms += self._audio_duration_ms(num_samples, INPUT_SAMPLE_RATE)
if self._has_audible_audio(pcm_bytes):
session.last_uplink_audible_time = now
if session.tts_waiting_for_user_audio_after_interrupt:
session.tts_waiting_for_user_audio_after_interrupt = False
print(
f"[uplink] detected user audio after interrupt: "
f"device={session.device_id}"
)
if (
session.captured_frame_count <= 5
or now - session.first_capture_log_time >= 5.0
):
session.first_capture_log_time = now
# print(
# f"[uplink] capture_frame count={session.captured_frame_count} "
# f"bytes={len(pcm_bytes)} samples={num_samples} "
# f"room={session.room_name}"
# )
if now - last_uplink_stats_time >= AUDIO_STATS_INTERVAL_SECONDS:
print(
"[uplink] "
f"device={session.device_id} packets={uplink_packets} "
f"audio_ms={uplink_audio_ms:.0f} "
f"decode_errors={uplink_decode_errors} "
f"dropped_frames={uplink_dropped_frames}"
)
uplink_packets = 0
uplink_audio_ms = 0.0
uplink_decode_errors = 0
uplink_dropped_frames = 0
last_uplink_stats_time = now
if not await self._capture_mic_frame(session, pcm_bytes, num_samples):
uplink_dropped_frames += 1
except Exception as exc:
uplink_decode_errors += 1
print(f"Opus audio decode error ({len(message)} bytes): {exc}")
elif isinstance(message, str):
try:
data = json.loads(message)
# print(f"收到 ESP32 JSON 消息: {data}")
msg_type = data.get("type")
if msg_type == "hello":
self._log_client_hello(session, data)
elif msg_type == "abort":
reason = data.get("reason")
abort_reason = reason if isinstance(reason, str) and reason else "button_abort"
print(f"处理 ESP32 打断请求: reason={abort_reason}")
await self._abort_tts(session, abort_reason)
elif msg_type == "mcp":
await self._publish_mcp_message(session, data)
elif msg_type == "vision" and data.get("state") == "frame":
await self._publish_vision_frame(session, data)
except json.JSONDecodeError:
print(f"收到未知的字符消息: {message}")
except ConnectionClosedError as exc:
print(f"ESP32 异常断开: {exc}")
except Exception as exc:
self._log_exception("WebSocket 其他错误", exc)
finally:
print(f"ESP32 断开连接: device={device_id} room={session.room_name}")
if room_connect_task is not None and not room_connect_task.done():
room_connect_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await room_connect_task
await self._close_session(session)
self.device_sessions.pop(device_id, None)
async def main() -> None:
bridge = ESP32LiveKitBridge()
try:
await bridge.start()
async with websockets.serve(
bridge.handle_websocket,
"0.0.0.0",
WS_PORT,
max_queue=WS_MAX_QUEUE,
max_size=WS_MAX_SIZE,
):
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
await asyncio.Future()
finally:
await bridge.close()
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as exc:
print(f"[error] {exc}", file=sys.stderr)
sys.exit(1)