1690 lines
69 KiB
Python
1690 lines
69 KiB
Python
import asyncio
|
||
import base64
|
||
import contextlib
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import struct
|
||
import sys
|
||
import time
|
||
import traceback
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
|
||
import httpx
|
||
import opuslib
|
||
import websockets
|
||
from livekit import rtc
|
||
from livekit.rtc import AudioFrame, AudioSource
|
||
from websockets.exceptions import ConnectionClosedError
|
||
|
||
try:
|
||
from livekit import api as livekit_api
|
||
except ImportError:
|
||
livekit_api = None
|
||
|
||
TOKEN_URL = "http://172.19.0.240:8000/getToken"
|
||
LIVEKIT_WS_URL = "ws://172.19.0.240:7880"
|
||
ROOM_PREFIX = "test-livekit"
|
||
IDENTITY_PREFIX = "uv-livekit"
|
||
LEGACY_AGENT_NAME = os.getenv("LIVEKIT_AGENT_NAME", "normal-agent")
|
||
DEFAULT_AGENT_MODE = os.getenv("LIVEKIT_DEFAULT_AGENT_MODE", "normal").strip().lower()
|
||
AGENT_NAMES = {
|
||
"normal": os.getenv("LIVEKIT_NORMAL_AGENT_NAME", LEGACY_AGENT_NAME),
|
||
"beaver": os.getenv("LIVEKIT_BEAVER_AGENT_NAME", "beaver-agent"),
|
||
}
|
||
CHAT_MODE_AGENT_NAMES = {
|
||
"normal": AGENT_NAMES["normal"],
|
||
"beaver": AGENT_NAMES["beaver"],
|
||
"vision-normal": os.getenv("LIVEKIT_VISION_NORMAL_AGENT_NAME", "vision-normal-agent"),
|
||
"vision-beaver": os.getenv("LIVEKIT_VISION_BEAVER_AGENT_NAME", "vision-beaver-agent"),
|
||
}
|
||
CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0"))
|
||
AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0"))
|
||
WS_PORT = 8080
|
||
WS_MAX_QUEUE = int(os.getenv("BRIDGE_WS_MAX_QUEUE", "128"))
|
||
WS_MAX_SIZE = int(os.getenv("BRIDGE_WS_MAX_SIZE", str(8 * 1024 * 1024)))
|
||
AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower()
|
||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||
VISION_FRAME_SAVE_DIR = Path(os.getenv("VISION_FRAME_SAVE_DIR", str(PROJECT_ROOT / "vision_frames")))
|
||
|
||
INPUT_SAMPLE_RATE = int(os.getenv("BRIDGE_INPUT_SAMPLE_RATE", "16000"))
|
||
OUTPUT_SAMPLE_RATE = int(os.getenv("BRIDGE_OUTPUT_SAMPLE_RATE", "24000"))
|
||
INPUT_FRAME_DURATION_MS = int(os.getenv("BRIDGE_INPUT_FRAME_DURATION_MS", "60"))
|
||
INPUT_MAX_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * 120 // 1000
|
||
OUTPUT_FRAME_DURATION_MS = int(os.getenv("BRIDGE_OUTPUT_FRAME_DURATION_MS", "60"))
|
||
AUDIO_STATS_INTERVAL_SECONDS = float(os.getenv("BRIDGE_AUDIO_STATS_INTERVAL_SECONDS", "5.0"))
|
||
DOWNLINK_SEND_GAP_WARN_MS = float(os.getenv("BRIDGE_DOWNLINK_SEND_GAP_WARN_MS", "180.0"))
|
||
UPLINK_CAPTURE_TIMEOUT_SECONDS = float(os.getenv("BRIDGE_UPLINK_CAPTURE_TIMEOUT_SECONDS", "0.25"))
|
||
TTS_IDLE_TIMEOUT_SECONDS = float(os.getenv("TTS_IDLE_TIMEOUT_SECONDS", "1.2"))
|
||
TTS_MIN_ACTIVE_SECONDS = float(os.getenv("TTS_MIN_ACTIVE_SECONDS", "1.0"))
|
||
TTS_SILENCE_PEAK_THRESHOLD = 96
|
||
TTS_PRE_ROLL_MS = int(os.getenv("TTS_PRE_ROLL_MS", "480"))
|
||
TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = int(os.getenv("TTS_START_CONSECUTIVE_AUDIBLE_FRAMES", "1"))
|
||
TTS_INTERRUPT_SILENCE_FRAMES = 3
|
||
INTERRUPT_TOPIC = "lk.interrupt"
|
||
VISION_FRAME_TOPIC = "vision.frame"
|
||
MCP_TOPIC = "mcp"
|
||
AGENT_STATE_ATTRIBUTE = "lk.agent.state"
|
||
TTS_DISPLAY_SENTENCE_BREAKS = "。!?!?;;"
|
||
TTS_DISPLAY_SCROLL_WIDTH = int(os.getenv("TTS_DISPLAY_SCROLL_WIDTH", "18"))
|
||
TTS_DISPLAY_SCROLL_INTERVAL_SECONDS = float(os.getenv("TTS_DISPLAY_SCROLL_INTERVAL_SECONDS", "0.18"))
|
||
TTS_DISPLAY_SCROLL_GAP = " "
|
||
TTS_INTERRUPT_SUPPRESS_SECONDS = float(os.getenv("TTS_INTERRUPT_SUPPRESS_SECONDS", "0.8"))
|
||
TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS = float(
|
||
os.getenv("TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS", "0.25")
|
||
)
|
||
TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS = float(
|
||
os.getenv("TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS", "8.0")
|
||
)
|
||
EMOTION_TEXT_PATTERN = re.compile(
|
||
r"^\s*<?\s*emotion\s*=\s*([^\s>,,;;]+)\s*>?[\s,,;;]*(.*)$",
|
||
re.DOTALL,
|
||
)
|
||
EMOTION_TEST_SEQUENCE = [
|
||
emotion.strip()
|
||
for emotion in os.getenv("BRIDGE_EMOTION_TEST_SEQUENCE", "").split(",")
|
||
if emotion.strip()
|
||
]
|
||
EMOTION_TEST_INTERVAL_SECONDS = float(os.getenv("BRIDGE_EMOTION_TEST_INTERVAL_SECONDS", "2.0"))
|
||
|
||
|
||
@dataclass
|
||
class DeviceSession:
|
||
device_id: str
|
||
websocket: Any
|
||
protocol_version: int
|
||
room_name: str
|
||
identity: str
|
||
chat_mode: str
|
||
agent_mode: str
|
||
agent_name: str
|
||
vision_enabled: bool
|
||
room: rtc.Room
|
||
mic_source: AudioSource
|
||
agent_ready: asyncio.Event
|
||
forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict)
|
||
tts_active: bool = False
|
||
tts_thinking: bool = False
|
||
tts_idle_task: Optional[asyncio.Task] = None
|
||
tts_display_task: Optional[asyncio.Task] = None
|
||
tts_stream_id: int = 0
|
||
tts_transcript_text: str = ""
|
||
tts_display_text: str = ""
|
||
tts_display_final: bool = False
|
||
tts_emotion: str = ""
|
||
tts_suppressed_until: float = 0.0
|
||
tts_started_at: float = 0.0
|
||
tts_last_audible_at: float = 0.0
|
||
tts_waiting_for_user_audio_after_interrupt: bool = False
|
||
last_interrupt_time: float = 0.0
|
||
last_uplink_audible_time: float = 0.0
|
||
agent_dispatch_task: Optional[asyncio.Task] = None
|
||
closed: bool = False
|
||
captured_frame_count: int = 0
|
||
first_capture_log_time: float = 0.0
|
||
|
||
|
||
async def fetch_token(room_name: str, identity: str, agent_name: str) -> str:
|
||
params = {"room": room_name, "identity": identity}
|
||
if AGENT_DISPATCH_MODE == "token":
|
||
params["agent_name"] = agent_name
|
||
|
||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||
response = await client.get(TOKEN_URL, params=params)
|
||
response.raise_for_status()
|
||
|
||
payload: dict[str, Any] = response.json()
|
||
token = payload.get("token")
|
||
if not isinstance(token, str) or not token:
|
||
raise ValueError(f"token response missing token field: {payload}")
|
||
|
||
# print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
|
||
# print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
|
||
return token
|
||
|
||
|
||
class ESP32LiveKitBridge:
|
||
def __init__(self) -> None:
|
||
self.device_sessions: dict[str, DeviceSession] = {}
|
||
|
||
def _log_exception(self, prefix: str, exc: BaseException) -> None:
|
||
print(f"{prefix}: type={type(exc).__name__} detail={exc!r}")
|
||
formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||
if formatted_tb.strip():
|
||
print(formatted_tb.rstrip())
|
||
|
||
def _audio_duration_ms(self, sample_count: int, sample_rate: int) -> float:
|
||
if sample_rate <= 0:
|
||
return 0.0
|
||
return sample_count * 1000.0 / sample_rate
|
||
|
||
def _build_server_hello(self, session: DeviceSession) -> dict[str, Any]:
|
||
return {
|
||
"type": "hello",
|
||
"transport": "websocket",
|
||
"session": {
|
||
"room": session.room_name,
|
||
"identity": session.identity,
|
||
},
|
||
"audio_params": {
|
||
"format": "opus",
|
||
"sample_rate": OUTPUT_SAMPLE_RATE,
|
||
"channels": 1,
|
||
"frame_duration": OUTPUT_FRAME_DURATION_MS,
|
||
},
|
||
}
|
||
|
||
def _log_client_hello(self, session: DeviceSession, message: dict[str, Any]) -> None:
|
||
audio_params = message.get("audio_params")
|
||
if not isinstance(audio_params, dict):
|
||
return
|
||
|
||
sample_rate = audio_params.get("sample_rate")
|
||
frame_duration = audio_params.get("frame_duration")
|
||
channels = audio_params.get("channels")
|
||
fmt = audio_params.get("format")
|
||
print(
|
||
"[client-audio] "
|
||
f"device={session.device_id} format={fmt} sample_rate={sample_rate} "
|
||
f"channels={channels} frame_duration={frame_duration}"
|
||
)
|
||
|
||
if sample_rate != INPUT_SAMPLE_RATE or channels != 1:
|
||
print(
|
||
"[client-audio] warning: bridge uplink decode expects "
|
||
f"{INPUT_SAMPLE_RATE}Hz mono, got {sample_rate}Hz channels={channels}"
|
||
)
|
||
if frame_duration != INPUT_FRAME_DURATION_MS:
|
||
print(
|
||
"[client-audio] warning: bridge expects "
|
||
f"{INPUT_FRAME_DURATION_MS}ms uplink frames, got {frame_duration}ms"
|
||
)
|
||
|
||
def _track_room_connect_task(
|
||
self,
|
||
session: DeviceSession,
|
||
task: asyncio.Task[Any],
|
||
) -> None:
|
||
if task.cancelled():
|
||
return
|
||
try:
|
||
task.result()
|
||
except Exception as exc:
|
||
self._log_exception(
|
||
f"LiveKit 房间连接后台任务失败: room={session.room_name}",
|
||
exc,
|
||
)
|
||
websocket = session.websocket
|
||
if websocket is not None:
|
||
asyncio.create_task(websocket.close(code=1011, reason="livekit connect failed"))
|
||
|
||
async def _capture_mic_frame(
|
||
self,
|
||
session: DeviceSession,
|
||
pcm_bytes: bytes,
|
||
num_samples: int,
|
||
) -> bool:
|
||
try:
|
||
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
|
||
except TypeError:
|
||
frame = AudioFrame.create(
|
||
sample_rate=INPUT_SAMPLE_RATE,
|
||
num_channels=1,
|
||
samples_per_channel=num_samples,
|
||
)
|
||
memoryview(frame.data).cast("B")[:] = pcm_bytes
|
||
|
||
try:
|
||
await asyncio.wait_for(
|
||
session.mic_source.capture_frame(frame),
|
||
timeout=UPLINK_CAPTURE_TIMEOUT_SECONDS,
|
||
)
|
||
return True
|
||
except asyncio.TimeoutError:
|
||
print(
|
||
"[uplink] warning: capture_frame timeout, dropping frame "
|
||
f"device={session.device_id} samples={num_samples}"
|
||
)
|
||
return False
|
||
|
||
def _build_device_id(self, websocket: Any) -> str:
|
||
headers = websocket.request.headers
|
||
requested_id = headers.get("X-Device-Id") or headers.get("Device-Id")
|
||
if requested_id:
|
||
return requested_id
|
||
return f"ws-{id(websocket):x}"
|
||
|
||
def _build_session_names(self, device_id: str) -> tuple[str, str]:
|
||
session_tag = uuid.uuid4().hex[:8]
|
||
room_name = f"{ROOM_PREFIX}-{session_tag}"
|
||
identity = f"{IDENTITY_PREFIX}-{session_tag}"
|
||
print(f"[session] device={device_id} room={room_name} identity={identity}")
|
||
return room_name, identity
|
||
|
||
def _resolve_agent_selection(self, headers: Any) -> tuple[str, str, str, bool]:
|
||
requested_chat_mode = (
|
||
headers.get("Chat-Mode")
|
||
or headers.get("X-Chat-Mode")
|
||
or ""
|
||
).strip().lower()
|
||
chat_mode_to_agent = {
|
||
"normal": ("normal", False),
|
||
"beaver": ("beaver", False),
|
||
"vision-normal": ("normal", True),
|
||
"vision-beaver": ("beaver", True),
|
||
}
|
||
|
||
if requested_chat_mode in chat_mode_to_agent:
|
||
requested_mode, vision_enabled = chat_mode_to_agent[requested_chat_mode]
|
||
else:
|
||
if requested_chat_mode:
|
||
print(f"未知 Chat-Mode={requested_chat_mode!r},回退到 Agent-Mode")
|
||
requested_mode = (
|
||
headers.get("Agent-Mode")
|
||
or headers.get("X-Agent-Mode")
|
||
or DEFAULT_AGENT_MODE
|
||
or "normal"
|
||
).strip().lower()
|
||
vision_enabled = False
|
||
requested_chat_mode = requested_mode
|
||
|
||
requested_name = headers.get("Agent-Name") or headers.get("X-Agent-Name")
|
||
if requested_name:
|
||
return requested_chat_mode, "custom", requested_name, vision_enabled
|
||
|
||
if requested_mode not in AGENT_NAMES:
|
||
print(f"未知 Agent-Mode={requested_mode!r},回退到 normal")
|
||
requested_mode = "normal"
|
||
requested_chat_mode = "vision-normal" if vision_enabled else "normal"
|
||
|
||
if requested_chat_mode in CHAT_MODE_AGENT_NAMES:
|
||
return (
|
||
requested_chat_mode,
|
||
requested_mode,
|
||
CHAT_MODE_AGENT_NAMES[requested_chat_mode],
|
||
vision_enabled,
|
||
)
|
||
|
||
return requested_chat_mode, requested_mode, AGENT_NAMES[requested_mode], vision_enabled
|
||
|
||
def _is_agent_participant(self, participant: rtc.RemoteParticipant, agent_name: str) -> bool:
|
||
identity = getattr(participant, "identity", "") or ""
|
||
return identity.startswith("agent-") or agent_name in identity
|
||
|
||
def _get_agent_identities(self, session: DeviceSession) -> list[str]:
|
||
return [
|
||
participant.identity
|
||
for participant in session.room.remote_participants.values()
|
||
if self._is_agent_participant(participant, session.agent_name)
|
||
]
|
||
|
||
def _log_agent_participants(self, session: DeviceSession, source: str) -> None:
|
||
agent_identities = self._get_agent_identities(session)
|
||
# print(
|
||
# f"[agent-check] source={source} room={session.room_name} "
|
||
# f"agent_count={len(agent_identities)} agents={agent_identities}"
|
||
# )
|
||
|
||
async def _has_existing_dispatch(self, session: DeviceSession) -> bool:
|
||
if livekit_api is None:
|
||
return False
|
||
|
||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||
if not api_key or not api_secret:
|
||
return False
|
||
|
||
request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None)
|
||
if request_cls is None:
|
||
return False
|
||
|
||
lkapi = livekit_api.LiveKitAPI(
|
||
url=LIVEKIT_WS_URL,
|
||
api_key=api_key,
|
||
api_secret=api_secret,
|
||
)
|
||
try:
|
||
response = await lkapi.agent_dispatch.list_dispatch(
|
||
request_cls(room=session.room_name)
|
||
)
|
||
dispatches = (
|
||
getattr(response, "agent_dispatches", None)
|
||
or getattr(response, "dispatches", None)
|
||
or getattr(response, "items", None)
|
||
or []
|
||
)
|
||
for dispatch in dispatches:
|
||
dispatch_agent_name = getattr(dispatch, "agent_name", None)
|
||
dispatch_room = getattr(dispatch, "room", None)
|
||
if dispatch_room == session.room_name and (
|
||
dispatch_agent_name == session.agent_name or dispatch_agent_name is None
|
||
):
|
||
print(
|
||
f"检测到已有 dispatch: room={session.room_name} "
|
||
f"agent={dispatch_agent_name}"
|
||
)
|
||
return True
|
||
except Exception as exc:
|
||
print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}")
|
||
finally:
|
||
await lkapi.aclose()
|
||
|
||
return False
|
||
|
||
async def ensure_agent_dispatched(self, session: DeviceSession) -> None:
|
||
if AGENT_DISPATCH_MODE != "bridge":
|
||
# print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}")
|
||
return
|
||
|
||
if await self._has_existing_dispatch(session):
|
||
return
|
||
|
||
for participant in session.room.remote_participants.values():
|
||
if self._is_agent_participant(participant, session.agent_name):
|
||
# print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
|
||
return
|
||
|
||
if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done():
|
||
# print("Agent dispatch 正在进行中,跳过重复请求")
|
||
return
|
||
|
||
session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session))
|
||
await session.agent_dispatch_task
|
||
|
||
async def _dispatch_agent(self, session: DeviceSession) -> None:
|
||
print(f"准备 dispatch agent: room={session.room_name}, agent={session.agent_name}")
|
||
|
||
try:
|
||
if await self._dispatch_agent_with_sdk(session):
|
||
return
|
||
|
||
if await self._dispatch_agent_with_cli(session):
|
||
return
|
||
|
||
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
|
||
except Exception as exc:
|
||
print(f"Agent dispatch 失败: {exc}")
|
||
|
||
async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool:
|
||
if livekit_api is None:
|
||
return False
|
||
|
||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||
if not api_key or not api_secret:
|
||
return False
|
||
|
||
lkapi = livekit_api.LiveKitAPI(
|
||
url=LIVEKIT_WS_URL,
|
||
api_key=api_key,
|
||
api_secret=api_secret,
|
||
)
|
||
try:
|
||
dispatch = await lkapi.agent_dispatch.create_dispatch(
|
||
livekit_api.CreateAgentDispatchRequest(
|
||
agent_name=session.agent_name,
|
||
room=session.room_name,
|
||
metadata=json.dumps(
|
||
{
|
||
"source": "bridge_server",
|
||
"identity": session.identity,
|
||
"device_id": session.device_id,
|
||
"chat_mode": session.chat_mode,
|
||
"agent_mode": session.agent_mode,
|
||
"agent_name": session.agent_name,
|
||
"vision_enabled": session.vision_enabled,
|
||
}
|
||
),
|
||
)
|
||
)
|
||
print(f"Agent dispatch 已创建: {dispatch}")
|
||
finally:
|
||
await lkapi.aclose()
|
||
|
||
return True
|
||
|
||
async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool:
|
||
lk_path = shutil.which("lk")
|
||
if lk_path is None:
|
||
return False
|
||
|
||
process = await asyncio.create_subprocess_exec(
|
||
lk_path,
|
||
"dispatch",
|
||
"create",
|
||
"--room",
|
||
session.room_name,
|
||
"--agent-name",
|
||
session.agent_name,
|
||
"--metadata",
|
||
json.dumps(
|
||
{
|
||
"source": "bridge_server",
|
||
"identity": session.identity,
|
||
"device_id": session.device_id,
|
||
"chat_mode": session.chat_mode,
|
||
"agent_mode": session.agent_mode,
|
||
"agent_name": session.agent_name,
|
||
"vision_enabled": session.vision_enabled,
|
||
}
|
||
),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
stdout_text = stdout.decode("utf-8", errors="replace").strip()
|
||
stderr_text = stderr.decode("utf-8", errors="replace").strip()
|
||
|
||
if stdout_text:
|
||
print(f"lk dispatch 输出: {stdout_text}")
|
||
if stderr_text:
|
||
print(f"lk dispatch 错误输出: {stderr_text}")
|
||
|
||
if process.returncode != 0:
|
||
print(f"lk dispatch create 失败,退出码: {process.returncode}")
|
||
return False
|
||
|
||
print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={session.agent_name}")
|
||
return True
|
||
|
||
async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool:
|
||
participant = getattr(session.room, "local_participant", None)
|
||
if participant is None:
|
||
print("跳过发送 agent 控制事件,local participant 尚未就绪")
|
||
return False
|
||
|
||
data = json.dumps(payload).encode("utf-8")
|
||
agent_identities = self._get_agent_identities(session)
|
||
kwargs: dict[str, Any] = {}
|
||
if agent_identities:
|
||
kwargs["destination_identities"] = agent_identities
|
||
|
||
last_error: Optional[Exception] = None
|
||
for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs):
|
||
try:
|
||
await participant.publish_data(data, **attempt)
|
||
print(
|
||
f"已发送 agent 控制事件: topic={attempt.get('topic', '<default>')} "
|
||
f"targets={agent_identities or 'broadcast'} payload={payload}"
|
||
)
|
||
return True
|
||
except TypeError as exc:
|
||
last_error = exc
|
||
except Exception as exc:
|
||
print(f"发送 agent 控制事件失败: {exc}")
|
||
return False
|
||
|
||
if last_error is not None:
|
||
print(f"发送 agent 控制事件失败,publish_data 签名不兼容: {last_error}")
|
||
return False
|
||
|
||
async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None:
|
||
payload = {
|
||
"type": "interrupt",
|
||
"topic": INTERRUPT_TOPIC,
|
||
"reason": reason,
|
||
"room": session.room_name,
|
||
"identity": session.identity,
|
||
"device_id": session.device_id,
|
||
}
|
||
ok = await self._publish_agent_event(session, payload)
|
||
if not ok:
|
||
print("警告: bridge 已停止 TTS,但 agent 侧 interrupt 未确认送出")
|
||
|
||
def _save_vision_frame(self, session: DeviceSession, image: str) -> Optional[Path]:
|
||
try:
|
||
image_bytes = base64.b64decode(image, validate=True)
|
||
except Exception as exc:
|
||
print(f"vision frame base64 解码失败: {exc}")
|
||
return None
|
||
|
||
safe_device_id = "".join(
|
||
char if char.isalnum() or char in ("-", "_") else "_"
|
||
for char in session.device_id
|
||
)
|
||
timestamp_ms = int(time.time() * 1000)
|
||
VISION_FRAME_SAVE_DIR.mkdir(parents=True, exist_ok=True)
|
||
path = VISION_FRAME_SAVE_DIR / f"{timestamp_ms}_{safe_device_id}.jpg"
|
||
path.write_bytes(image_bytes)
|
||
return path
|
||
|
||
async def _publish_vision_frame(self, session: DeviceSession, message: dict[str, Any]) -> None:
|
||
image = message.get("image")
|
||
if not isinstance(image, str) or not image:
|
||
print("收到 vision frame,但 image 字段为空")
|
||
return
|
||
|
||
saved_path = await asyncio.to_thread(self._save_vision_frame, session, image)
|
||
if saved_path is None:
|
||
return
|
||
print(f"已保存 vision frame: {saved_path}")
|
||
|
||
participant = getattr(session.room, "local_participant", None)
|
||
if participant is None:
|
||
print("跳过发送 vision frame,local participant 尚未就绪")
|
||
return
|
||
|
||
payload = {
|
||
"type": "vision_frame",
|
||
"topic": VISION_FRAME_TOPIC,
|
||
"room": session.room_name,
|
||
"identity": session.identity,
|
||
"device_id": session.device_id,
|
||
"mime_type": message.get("mime_type", "image/jpeg"),
|
||
"image": image,
|
||
"saved_path": str(saved_path),
|
||
}
|
||
data = json.dumps(payload).encode("utf-8")
|
||
agent_identities = self._get_agent_identities(session)
|
||
kwargs: dict[str, Any] = {}
|
||
if agent_identities:
|
||
kwargs["destination_identities"] = agent_identities
|
||
|
||
last_error: Optional[Exception] = None
|
||
for attempt in ({"topic": VISION_FRAME_TOPIC, **kwargs}, kwargs):
|
||
try:
|
||
await participant.publish_data(data, **attempt)
|
||
print(
|
||
f"已发送 vision frame: bytes={len(data)} "
|
||
f"targets={agent_identities or 'broadcast'}"
|
||
)
|
||
return
|
||
except TypeError as exc:
|
||
last_error = exc
|
||
except Exception as exc:
|
||
print(f"发送 vision frame 失败: {exc}")
|
||
return
|
||
|
||
if last_error is not None:
|
||
print(f"发送 vision frame 失败,publish_data 签名不兼容: {last_error}")
|
||
|
||
async def _publish_mcp_message(self, session: DeviceSession, message: dict[str, Any]) -> None:
|
||
payload = message.get("payload")
|
||
if not isinstance(payload, dict):
|
||
print(f"收到 ESP32 MCP 消息但缺少 payload: {message}")
|
||
return
|
||
|
||
participant = getattr(session.room, "local_participant", None)
|
||
if participant is None:
|
||
print("跳过发送 MCP 消息,local participant 尚未就绪")
|
||
return
|
||
|
||
outbound = {
|
||
"type": "mcp",
|
||
"topic": MCP_TOPIC,
|
||
"room": session.room_name,
|
||
"identity": session.identity,
|
||
"device_id": session.device_id,
|
||
"payload": payload,
|
||
}
|
||
data = json.dumps(outbound, ensure_ascii=False).encode("utf-8")
|
||
agent_identities = self._get_agent_identities(session)
|
||
kwargs: dict[str, Any] = {}
|
||
if agent_identities:
|
||
kwargs["destination_identities"] = agent_identities
|
||
|
||
last_error: Optional[Exception] = None
|
||
for attempt in ({"topic": MCP_TOPIC, **kwargs}, kwargs):
|
||
try:
|
||
await participant.publish_data(data, **attempt)
|
||
print(
|
||
f"已发送 MCP 响应: id={payload.get('id')} "
|
||
f"targets={agent_identities or 'broadcast'}"
|
||
)
|
||
return
|
||
except TypeError as exc:
|
||
last_error = exc
|
||
except Exception as exc:
|
||
print(f"发送 MCP 响应失败: {exc}")
|
||
return
|
||
|
||
if last_error is not None:
|
||
print(f"发送 MCP 响应失败,publish_data 签名不兼容: {last_error}")
|
||
|
||
async def _forward_mcp_to_device(
|
||
self,
|
||
session: DeviceSession,
|
||
payload: dict[str, Any],
|
||
*,
|
||
source_identity: str,
|
||
) -> None:
|
||
if session.websocket is None:
|
||
print("跳过 MCP 请求,ESP32 尚未连接")
|
||
return
|
||
|
||
await session.websocket.send(
|
||
json.dumps(
|
||
{
|
||
"type": "mcp",
|
||
"payload": payload,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
print(
|
||
f"已转发 MCP 请求到 ESP32: id={payload.get('id')} "
|
||
f"method={payload.get('method')} source={source_identity}"
|
||
)
|
||
|
||
async def _send_tts_state(self, session: DeviceSession, state: str) -> None:
|
||
if session.websocket is None:
|
||
print(f"跳过 tts {state},ESP32 尚未连接")
|
||
return
|
||
await session.websocket.send(json.dumps({"type": "tts", "state": state}))
|
||
print(f"已发送 tts {state}: device={session.device_id}")
|
||
|
||
async def _send_emotion(self, session: DeviceSession, emotion: str) -> None:
|
||
if session.websocket is None:
|
||
print(f"跳过 emotion {emotion},ESP32 尚未连接")
|
||
return
|
||
await session.websocket.send(json.dumps({"type": "llm", "emotion": emotion}))
|
||
print(f"已发送 emotion: device={session.device_id} emotion={emotion}")
|
||
|
||
def _parse_emotion_text(self, text: str) -> tuple[Optional[str], str]:
|
||
match = EMOTION_TEXT_PATTERN.match(text)
|
||
if match is None:
|
||
return None, text.strip()
|
||
emotion, tts_text = match.groups()
|
||
return emotion.strip(), tts_text.strip()
|
||
|
||
async def _run_emotion_test_sequence(self, session: DeviceSession) -> None:
|
||
if not EMOTION_TEST_SEQUENCE:
|
||
return
|
||
|
||
for index, emotion in enumerate(EMOTION_TEST_SEQUENCE):
|
||
if session.websocket is None or session.closed:
|
||
return
|
||
if index > 0:
|
||
await asyncio.sleep(EMOTION_TEST_INTERVAL_SECONDS)
|
||
await self._send_emotion(session, emotion)
|
||
|
||
async def _send_tts_text(self, session: DeviceSession, text: str, final: bool) -> None:
|
||
if session.websocket is None:
|
||
return
|
||
raw_text = text
|
||
_emotion, text = self._parse_emotion_text(text)
|
||
if not text:
|
||
print(f"[tts->esp32] skip empty text: raw={raw_text!r} final={final}")
|
||
return
|
||
print(f"[tts->esp32] text={text!r} final={final}")
|
||
await session.websocket.send(
|
||
json.dumps(
|
||
{
|
||
"type": "tts",
|
||
"state": "sentence_start",
|
||
"text": text,
|
||
"final": final,
|
||
}
|
||
)
|
||
)
|
||
|
||
def _cancel_tts_display_task(self, session: DeviceSession) -> None:
|
||
if session.tts_display_task is not None:
|
||
session.tts_display_task.cancel()
|
||
session.tts_display_task = None
|
||
|
||
def _update_tts_display_text(self, session: DeviceSession, text: str, final: bool) -> None:
|
||
session.tts_display_text = text
|
||
session.tts_display_final = final
|
||
|
||
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
|
||
self._cancel_tts_display_task(session)
|
||
asyncio.create_task(self._send_tts_text(session, text, final))
|
||
return
|
||
|
||
if session.tts_display_task is None or session.tts_display_task.done():
|
||
session.tts_display_task = asyncio.create_task(
|
||
self._scroll_tts_display_text(session, session.tts_stream_id)
|
||
)
|
||
|
||
async def _scroll_tts_display_text(self, session: DeviceSession, stream_id: int) -> None:
|
||
offset = 0
|
||
last_sent = ""
|
||
try:
|
||
while stream_id == session.tts_stream_id and session.websocket is not None:
|
||
text = session.tts_display_text
|
||
if not text:
|
||
return
|
||
|
||
if len(text) <= TTS_DISPLAY_SCROLL_WIDTH:
|
||
if text != last_sent:
|
||
await self._send_tts_text(session, text, session.tts_display_final)
|
||
return
|
||
|
||
scroll_text = text + TTS_DISPLAY_SCROLL_GAP
|
||
if offset >= len(scroll_text):
|
||
offset = 0
|
||
|
||
looped_text = scroll_text + scroll_text[:TTS_DISPLAY_SCROLL_WIDTH]
|
||
visible_text = looped_text[offset:offset + TTS_DISPLAY_SCROLL_WIDTH].rstrip()
|
||
if visible_text and visible_text != last_sent:
|
||
await self._send_tts_text(
|
||
session,
|
||
visible_text,
|
||
session.tts_display_final,
|
||
)
|
||
last_sent = visible_text
|
||
|
||
offset += 1
|
||
await asyncio.sleep(TTS_DISPLAY_SCROLL_INTERVAL_SECONDS)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as exc:
|
||
print(f"TTS 字幕滚动失败: {exc}")
|
||
|
||
async def _start_tts(self, session: DeviceSession) -> None:
|
||
if session.tts_active:
|
||
print("跳过 tts start,当前已处于激活状态")
|
||
return
|
||
block_reason = self._tts_resume_block_reason(session, include_user_quiet=False)
|
||
if block_reason is not None:
|
||
print(f"跳过 tts start,打断后仍在等待稳定聆听: {block_reason}")
|
||
return
|
||
if time.monotonic() < session.tts_suppressed_until:
|
||
print("跳过 tts start,中断后的残留音频仍在抑制窗口内")
|
||
return
|
||
if not session.tts_display_text:
|
||
session.tts_transcript_text = ""
|
||
session.tts_display_final = False
|
||
session.tts_emotion = ""
|
||
self._cancel_tts_display_task(session)
|
||
now = time.monotonic()
|
||
session.tts_started_at = now
|
||
session.tts_last_audible_at = now
|
||
await self._send_tts_state(session, "start")
|
||
session.tts_active = True
|
||
session.tts_thinking = False
|
||
|
||
async def _start_thinking(self, session: DeviceSession) -> None:
|
||
if session.tts_active:
|
||
print("跳过 tts thinking,当前已处于 TTS 播放状态")
|
||
return
|
||
if session.tts_thinking:
|
||
print("跳过 tts thinking,当前已处于思考状态")
|
||
return
|
||
block_reason = self._tts_resume_block_reason(session, include_user_quiet=False)
|
||
if block_reason is not None:
|
||
print(f"跳过 tts thinking,打断后仍在等待稳定聆听: {block_reason}")
|
||
return
|
||
if time.monotonic() < session.tts_suppressed_until:
|
||
print("跳过 tts thinking,中断后的残留音频仍在抑制窗口内")
|
||
return
|
||
await self._send_tts_state(session, "thinking")
|
||
session.tts_thinking = True
|
||
|
||
def _tts_resume_block_reason(
|
||
self,
|
||
session: DeviceSession,
|
||
now: Optional[float] = None,
|
||
*,
|
||
include_user_quiet: bool = True,
|
||
) -> Optional[str]:
|
||
if now is None:
|
||
now = time.monotonic()
|
||
|
||
if session.tts_waiting_for_user_audio_after_interrupt:
|
||
return "waiting_for_user_audio_after_interrupt"
|
||
|
||
if session.last_interrupt_time <= 0.0:
|
||
return None
|
||
|
||
since_interrupt = now - session.last_interrupt_time
|
||
if since_interrupt > TTS_POST_INTERRUPT_LISTEN_WINDOW_SECONDS:
|
||
return None
|
||
|
||
if session.last_uplink_audible_time < session.last_interrupt_time:
|
||
return None
|
||
|
||
if not include_user_quiet:
|
||
return None
|
||
|
||
quiet_for = now - session.last_uplink_audible_time
|
||
if quiet_for < TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS:
|
||
return f"user_audio_quiet_for={quiet_for:.2f}s"
|
||
|
||
return None
|
||
|
||
def _handle_agent_state(self, session: DeviceSession, participant: rtc.Participant) -> None:
|
||
state = participant.attributes.get(AGENT_STATE_ATTRIBUTE)
|
||
if not isinstance(state, str) or not state:
|
||
return
|
||
|
||
print(
|
||
f"[agent-state] room={session.room_name} identity={participant.identity} state={state}"
|
||
)
|
||
if state == "thinking":
|
||
asyncio.create_task(self._start_thinking(session))
|
||
|
||
async def _stop_tts(self, session: DeviceSession) -> None:
|
||
if not session.tts_active and not session.tts_thinking:
|
||
print("跳过 tts stop,当前未激活")
|
||
return
|
||
self._cancel_tts_display_task(session)
|
||
await self._send_tts_state(session, "stop")
|
||
session.tts_active = False
|
||
session.tts_thinking = False
|
||
session.tts_started_at = 0.0
|
||
session.tts_last_audible_at = 0.0
|
||
session.tts_transcript_text = ""
|
||
session.tts_display_text = ""
|
||
session.tts_display_final = False
|
||
session.tts_emotion = ""
|
||
|
||
async def _force_stop_tts(self, session: DeviceSession, reason: str) -> None:
|
||
self._cancel_tts_display_task(session)
|
||
if session.tts_idle_task is not None:
|
||
session.tts_idle_task.cancel()
|
||
session.tts_idle_task = None
|
||
session.tts_active = False
|
||
session.tts_thinking = False
|
||
session.tts_started_at = 0.0
|
||
session.tts_last_audible_at = 0.0
|
||
session.tts_transcript_text = ""
|
||
session.tts_display_text = ""
|
||
session.tts_display_final = False
|
||
session.tts_emotion = ""
|
||
await self._send_tts_state(session, "stop")
|
||
print(f"已强制停止本地 TTS: device={session.device_id} reason={reason}")
|
||
|
||
async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None:
|
||
print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}")
|
||
now = time.monotonic()
|
||
session.tts_stream_id += 1
|
||
session.last_interrupt_time = now
|
||
session.tts_suppressed_until = now + TTS_INTERRUPT_SUPPRESS_SECONDS
|
||
session.tts_waiting_for_user_audio_after_interrupt = True
|
||
await self._force_stop_tts(session, reason)
|
||
asyncio.create_task(self._send_agent_interrupt(session, reason))
|
||
|
||
def _reset_tts_idle_timer(self, session: DeviceSession) -> None:
|
||
session.tts_last_audible_at = time.monotonic()
|
||
if session.tts_idle_task is not None:
|
||
session.tts_idle_task.cancel()
|
||
session.tts_idle_task = asyncio.create_task(
|
||
self._tts_idle_watchdog(session, session.tts_stream_id)
|
||
)
|
||
|
||
async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None:
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
|
||
if stream_id != session.tts_stream_id or not session.tts_active:
|
||
return
|
||
|
||
now = time.monotonic()
|
||
idle_for = now - session.tts_last_audible_at
|
||
active_for = now - session.tts_started_at
|
||
remaining = max(
|
||
TTS_IDLE_TIMEOUT_SECONDS - idle_for,
|
||
TTS_MIN_ACTIVE_SECONDS - active_for,
|
||
)
|
||
if remaining > 0:
|
||
await asyncio.sleep(remaining)
|
||
continue
|
||
|
||
print(
|
||
"TTS 静音达到阈值,切回聆听状态: "
|
||
f"idle={idle_for:.2f}s active={active_for:.2f}s"
|
||
)
|
||
await self._stop_tts(session)
|
||
return
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
def _has_audible_audio(self, pcm_data: bytes) -> bool:
|
||
if len(pcm_data) < 2:
|
||
return False
|
||
|
||
sample_count = len(pcm_data) // 2
|
||
peak = 0
|
||
for i in range(sample_count):
|
||
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
|
||
if sample < 0:
|
||
sample = -sample
|
||
if sample > peak:
|
||
peak = sample
|
||
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
|
||
return True
|
||
return False
|
||
|
||
def _maybe_forward_remote_audio(
|
||
self,
|
||
session: DeviceSession,
|
||
track: rtc.Track,
|
||
publication: Optional[rtc.TrackPublication],
|
||
participant: rtc.RemoteParticipant,
|
||
source: str,
|
||
) -> None:
|
||
if track.kind != rtc.TrackKind.KIND_AUDIO:
|
||
return
|
||
|
||
track_sid = (
|
||
getattr(track, "sid", None)
|
||
or getattr(publication, "sid", None)
|
||
or f"audio:{participant.identity}"
|
||
)
|
||
existing_task = session.forwarding_tracks.get(track_sid)
|
||
if existing_task is not None and not existing_task.done():
|
||
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
|
||
return
|
||
if existing_task is not None and existing_task.done():
|
||
print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}")
|
||
session.forwarding_tracks.pop(track_sid, None)
|
||
|
||
task = asyncio.create_task(
|
||
self.forward_audio_to_esp32(
|
||
session,
|
||
track_sid,
|
||
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1),
|
||
)
|
||
)
|
||
session.forwarding_tracks[track_sid] = task
|
||
print(
|
||
f"收到音频流: {participant.identity} sid={track_sid} "
|
||
f"source={source} room={session.room_name}"
|
||
)
|
||
|
||
def _scan_participant_audio_tracks(
|
||
self,
|
||
session: DeviceSession,
|
||
participant: rtc.RemoteParticipant,
|
||
source: str,
|
||
) -> None:
|
||
publications = getattr(participant, "track_publications", None) or {}
|
||
for publication in publications.values():
|
||
track = getattr(publication, "track", None)
|
||
if track is None:
|
||
continue
|
||
self._maybe_forward_remote_audio(session, track, publication, participant, source)
|
||
|
||
def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes:
|
||
if session.protocol_version == 2:
|
||
if len(message) < 16:
|
||
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
|
||
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
|
||
"!HHIII", message[:16]
|
||
)
|
||
if version != 2:
|
||
raise ValueError(f"version 2 包头版本异常: {version}")
|
||
if msg_type != 0:
|
||
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
|
||
if len(message) < 16 + payload_size:
|
||
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
|
||
return message[16:16 + payload_size]
|
||
|
||
if session.protocol_version == 3:
|
||
if len(message) < 4:
|
||
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
|
||
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
|
||
if msg_type != 0:
|
||
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
|
||
if len(message) < 4 + payload_size:
|
||
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
|
||
return message[4:4 + payload_size]
|
||
|
||
return message
|
||
|
||
def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes:
|
||
if session.protocol_version == 2:
|
||
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
|
||
return header + payload
|
||
|
||
if session.protocol_version == 3:
|
||
header = struct.pack("!BBH", 0, 0, len(payload))
|
||
return header + payload
|
||
|
||
return payload
|
||
|
||
def _current_tts_display_text(self, text: str) -> str:
|
||
normalized = " ".join(text.split())
|
||
if not normalized:
|
||
return ""
|
||
|
||
last_break = -1
|
||
previous_break = -1
|
||
for index, char in enumerate(normalized):
|
||
if char in TTS_DISPLAY_SENTENCE_BREAKS:
|
||
previous_break = last_break
|
||
last_break = index
|
||
|
||
if last_break == -1:
|
||
return normalized
|
||
|
||
if last_break == len(normalized) - 1:
|
||
start = previous_break + 1
|
||
else:
|
||
start = last_break + 1
|
||
|
||
return normalized[start:].strip() or normalized
|
||
|
||
def _register_room_handlers(self, session: DeviceSession) -> None:
|
||
@session.room.on("connection_state_changed")
|
||
def on_connection_state_changed(state: int) -> None:
|
||
# print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}")
|
||
pass
|
||
|
||
@session.room.on("connected")
|
||
def on_connected() -> None:
|
||
print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}")
|
||
self._log_agent_participants(session, "connected")
|
||
for participant in session.room.remote_participants.values():
|
||
if self._is_agent_participant(participant, session.agent_name):
|
||
session.agent_ready.set()
|
||
self._scan_participant_audio_tracks(session, participant, "connected_scan")
|
||
self._handle_agent_state(session, participant)
|
||
|
||
@session.room.on("participant_connected")
|
||
def on_participant_connected(participant: rtc.RemoteParticipant) -> None:
|
||
role = "Agent" if self._is_agent_participant(participant, session.agent_name) else "Remote participant"
|
||
print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}")
|
||
self._log_agent_participants(session, "participant_connected")
|
||
if self._is_agent_participant(participant, session.agent_name):
|
||
session.agent_ready.set()
|
||
self._scan_participant_audio_tracks(
|
||
session, participant, "participant_connected_scan"
|
||
)
|
||
self._handle_agent_state(session, participant)
|
||
|
||
@session.room.on("participant_disconnected")
|
||
def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None:
|
||
print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}")
|
||
session.forwarding_tracks = {
|
||
track_sid: task
|
||
for track_sid, task in session.forwarding_tracks.items()
|
||
if not track_sid.endswith(f":{participant.identity}")
|
||
}
|
||
|
||
@session.room.on("participant_attributes_changed")
|
||
def on_participant_attributes_changed(changed: list[str], participant: rtc.Participant) -> None:
|
||
if AGENT_STATE_ATTRIBUTE not in changed:
|
||
return
|
||
if not isinstance(participant, rtc.RemoteParticipant):
|
||
return
|
||
if not self._is_agent_participant(participant, session.agent_name):
|
||
return
|
||
self._handle_agent_state(session, participant)
|
||
|
||
@session.room.on("data_received")
|
||
def on_data_received(data_packet: rtc.DataPacket) -> None:
|
||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||
packet_topic = getattr(data_packet, "topic", None)
|
||
try:
|
||
decoded = data_packet.data.decode("utf-8")
|
||
print(
|
||
f"📩 [数据接收 | room={session.room_name} | {identity}]: "
|
||
f"{decoded}"
|
||
)
|
||
except Exception:
|
||
decoded = ""
|
||
|
||
try:
|
||
payload = json.loads(decoded) if decoded else None
|
||
except Exception:
|
||
payload = None
|
||
|
||
if isinstance(payload, dict) and (
|
||
packet_topic == MCP_TOPIC
|
||
or payload.get("type") == "mcp"
|
||
or payload.get("topic") == MCP_TOPIC
|
||
):
|
||
mcp_payload = payload.get("payload")
|
||
if isinstance(mcp_payload, dict):
|
||
asyncio.create_task(
|
||
self._forward_mcp_to_device(
|
||
session,
|
||
mcp_payload,
|
||
source_identity=identity,
|
||
)
|
||
)
|
||
else:
|
||
print(f"收到 MCP 数据但缺少 payload: {payload}")
|
||
|
||
@session.room.on("transcription_received")
|
||
def on_transcription_received(
|
||
segments: list[rtc.TranscriptionSegment],
|
||
participant: rtc.Participant,
|
||
track_pub: rtc.TrackPublication,
|
||
) -> None:
|
||
identity = participant.identity if participant else "未知"
|
||
is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(
|
||
participant, session.agent_name
|
||
)
|
||
for segment in segments:
|
||
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
|
||
print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}")
|
||
if is_agent:
|
||
if time.monotonic() < session.tts_suppressed_until:
|
||
continue
|
||
print(f"[livekit-llm] raw={segment.text!r} final={segment.final}")
|
||
emotion, tts_text = self._parse_emotion_text(segment.text)
|
||
print(
|
||
f"[livekit-llm] parsed emotion={emotion!r} "
|
||
f"tts_text={tts_text!r} final={segment.final}"
|
||
)
|
||
if emotion and emotion != session.tts_emotion:
|
||
session.tts_emotion = emotion
|
||
asyncio.create_task(self._send_emotion(session, emotion))
|
||
display_text = self._current_tts_display_text(tts_text)
|
||
print(f"[livekit-llm] display_text={display_text!r} final={segment.final}")
|
||
if not display_text or display_text == session.tts_transcript_text:
|
||
continue
|
||
session.tts_transcript_text = display_text
|
||
else:
|
||
if not segment.final:
|
||
continue
|
||
display_text = segment.text
|
||
asyncio.create_task(self._start_thinking(session))
|
||
|
||
if session.websocket is not None:
|
||
ws = session.websocket
|
||
if is_agent:
|
||
self._update_tts_display_text(session, display_text, segment.final)
|
||
else:
|
||
asyncio.create_task(
|
||
ws.send(
|
||
json.dumps(
|
||
{
|
||
"type": "stt",
|
||
"text": segment.text,
|
||
"final": segment.final,
|
||
}
|
||
)
|
||
)
|
||
)
|
||
|
||
@session.room.on("track_subscribed")
|
||
def on_track_subscribed(
|
||
track: rtc.Track,
|
||
publication: rtc.TrackPublication,
|
||
participant: rtc.RemoteParticipant,
|
||
) -> None:
|
||
self._maybe_forward_remote_audio(session, track, publication, participant, "event")
|
||
|
||
@session.room.on("track_published")
|
||
def on_track_published(
|
||
publication: rtc.RemoteTrackPublication,
|
||
participant: rtc.RemoteParticipant,
|
||
) -> None:
|
||
track_sid = getattr(publication, "sid", None)
|
||
# print(
|
||
# f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} "
|
||
# f"track_sid={track_sid}"
|
||
# )
|
||
track = getattr(publication, "track", None)
|
||
if track is not None:
|
||
self._maybe_forward_remote_audio(session, track, publication, participant, "published")
|
||
|
||
async def _connect_session_room(self, session: DeviceSession) -> None:
|
||
self._register_room_handlers(session)
|
||
|
||
# print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
|
||
# print(f"[config] token_url={TOKEN_URL}")
|
||
# print(f"[config] room={session.room_name} identity={session.identity}")
|
||
# print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}")
|
||
token = await fetch_token(session.room_name, session.identity, session.agent_name)
|
||
|
||
try:
|
||
await session.room.connect(
|
||
LIVEKIT_WS_URL,
|
||
token,
|
||
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
|
||
)
|
||
except Exception as exc:
|
||
self._log_exception(
|
||
f"连接 LiveKit 房间失败: room={session.room_name}",
|
||
exc,
|
||
)
|
||
print(
|
||
"提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;"
|
||
"请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、"
|
||
"*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。"
|
||
)
|
||
raise
|
||
|
||
print(f"已连接到 LiveKit 房间: {session.room.name}")
|
||
# print(f"[livekit] local_identity={session.room.local_participant.identity}")
|
||
# print(f"[livekit] local_sid={session.room.local_participant.sid}")
|
||
# print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}")
|
||
self._log_agent_participants(session, "after_connect")
|
||
|
||
await self.ensure_agent_dispatched(session)
|
||
|
||
track = rtc.LocalAudioTrack.create_audio_track(
|
||
f"esp32-mic-{session.device_id}",
|
||
session.mic_source,
|
||
)
|
||
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
|
||
publication = await session.room.local_participant.publish_track(track, options)
|
||
publication_sid = getattr(publication, "sid", None)
|
||
track_sid = getattr(track, "sid", None)
|
||
# print(
|
||
# f"已发布 ESP32 mic track: room={session.room_name} "
|
||
# f"track_sid={track_sid} publication_sid={publication_sid}"
|
||
# )
|
||
self._log_agent_participants(session, "after_publish_mic")
|
||
|
||
# print(f"等待 agent 加入: room={session.room_name}")
|
||
try:
|
||
await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS)
|
||
if session.closed:
|
||
return
|
||
# print(f"✅ agent 已就绪: room={session.room_name}")
|
||
except asyncio.TimeoutError:
|
||
if session.closed:
|
||
return
|
||
print(f"⚠️ agent 等待超时: room={session.room_name}")
|
||
|
||
async def start(self) -> None:
|
||
print(f"[config] websocket_port={WS_PORT}")
|
||
print(f"[config] websocket_max_queue={WS_MAX_QUEUE} websocket_max_size={WS_MAX_SIZE}")
|
||
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
|
||
print(f"[config] token_url={TOKEN_URL}")
|
||
print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}")
|
||
print(
|
||
"[config] audio="
|
||
f"uplink_decode:{INPUT_SAMPLE_RATE}Hz/{INPUT_FRAME_DURATION_MS}ms "
|
||
f"downlink_encode:{OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms "
|
||
f"stats_interval:{AUDIO_STATS_INTERVAL_SECONDS}s "
|
||
f"capture_timeout:{UPLINK_CAPTURE_TIMEOUT_SECONDS}s "
|
||
f"tts_idle:{TTS_IDLE_TIMEOUT_SECONDS}s "
|
||
f"tts_min_active:{TTS_MIN_ACTIVE_SECONDS}s "
|
||
f"tts_start_frames:{TTS_START_CONSECUTIVE_AUDIBLE_FRAMES} "
|
||
f"tts_pre_roll:{TTS_PRE_ROLL_MS}ms"
|
||
)
|
||
print(
|
||
"[config] agents="
|
||
f"normal:{CHAT_MODE_AGENT_NAMES['normal']} "
|
||
f"beaver:{CHAT_MODE_AGENT_NAMES['beaver']} "
|
||
f"vision-normal:{CHAT_MODE_AGENT_NAMES['vision-normal']} "
|
||
f"vision-beaver:{CHAT_MODE_AGENT_NAMES['vision-beaver']} "
|
||
f"default_mode:{DEFAULT_AGENT_MODE}"
|
||
)
|
||
if EMOTION_TEST_SEQUENCE:
|
||
print(
|
||
"[config] emotion_test_sequence="
|
||
f"{','.join(EMOTION_TEST_SEQUENCE)} "
|
||
f"interval={EMOTION_TEST_INTERVAL_SECONDS}s"
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
for session in list(self.device_sessions.values()):
|
||
await self._close_session(session)
|
||
|
||
async def _close_session(self, session: DeviceSession) -> None:
|
||
if session.closed:
|
||
return
|
||
session.closed = True
|
||
session.websocket = None
|
||
session.agent_ready.set()
|
||
session.tts_active = False
|
||
session.tts_thinking = False
|
||
session.tts_stream_id += 1
|
||
if session.tts_idle_task is not None:
|
||
session.tts_idle_task.cancel()
|
||
session.tts_idle_task = None
|
||
self._cancel_tts_display_task(session)
|
||
try:
|
||
await session.room.disconnect()
|
||
except Exception as exc:
|
||
print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}")
|
||
|
||
async def forward_audio_to_esp32(
|
||
self,
|
||
session: DeviceSession,
|
||
track_sid: str,
|
||
audio_stream: rtc.AudioStream,
|
||
) -> None:
|
||
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip")
|
||
pending_pcm = bytearray()
|
||
pre_roll_pcm = bytearray()
|
||
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
|
||
output_samples_per_opus_frame = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
|
||
output_frame_bytes = output_samples_per_opus_frame * 2
|
||
audible_frame_streak = 0
|
||
silence_frame_streak = 0
|
||
waiting_for_post_interrupt_silence = False
|
||
downlink_packets = 0
|
||
downlink_audio_ms = 0.0
|
||
last_downlink_stats_time = time.monotonic()
|
||
last_send_time: Optional[float] = None
|
||
stream_id = session.tts_stream_id
|
||
print(
|
||
f"启动 TTS 转发: device={session.device_id} room={session.room_name} "
|
||
f"track_sid={track_sid} stream_id={stream_id} "
|
||
f"opus={OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms"
|
||
)
|
||
|
||
try:
|
||
async for event in audio_stream:
|
||
if stream_id != session.tts_stream_id:
|
||
print("检测到 TTS 被中断,停止当前播放并等待静音分隔")
|
||
pending_pcm.clear()
|
||
pre_roll_pcm.clear()
|
||
audible_frame_streak = 0
|
||
silence_frame_streak = 0
|
||
waiting_for_post_interrupt_silence = True
|
||
stream_id = session.tts_stream_id
|
||
if session.tts_active:
|
||
await self._stop_tts(session)
|
||
continue
|
||
|
||
frame = event.frame
|
||
pcm_data = frame.data.tobytes()
|
||
has_audible_audio = self._has_audible_audio(pcm_data)
|
||
|
||
now = time.monotonic()
|
||
if now < session.tts_suppressed_until:
|
||
pending_pcm.clear()
|
||
pre_roll_pcm.clear()
|
||
audible_frame_streak = 0
|
||
silence_frame_streak = 0
|
||
waiting_for_post_interrupt_silence = True
|
||
continue
|
||
|
||
block_reason = self._tts_resume_block_reason(
|
||
session,
|
||
now,
|
||
include_user_quiet=False,
|
||
)
|
||
if block_reason is not None:
|
||
pending_pcm.clear()
|
||
pre_roll_pcm.clear()
|
||
audible_frame_streak = 0
|
||
silence_frame_streak = 0
|
||
if block_reason == "waiting_for_user_audio_after_interrupt":
|
||
waiting_for_post_interrupt_silence = True
|
||
continue
|
||
|
||
if (
|
||
waiting_for_post_interrupt_silence
|
||
and session.last_interrupt_time > 0.0
|
||
and session.last_uplink_audible_time >= session.last_interrupt_time
|
||
and now - session.last_uplink_audible_time
|
||
>= TTS_POST_INTERRUPT_USER_AUDIO_GRACE_SECONDS
|
||
):
|
||
print("检测到用户打断后语音已结束,允许新 TTS 直接起播")
|
||
waiting_for_post_interrupt_silence = False
|
||
silence_frame_streak = 0
|
||
|
||
if waiting_for_post_interrupt_silence:
|
||
if has_audible_audio:
|
||
silence_frame_streak = 0
|
||
continue
|
||
|
||
silence_frame_streak += 1
|
||
if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES:
|
||
continue
|
||
|
||
print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播")
|
||
waiting_for_post_interrupt_silence = False
|
||
silence_frame_streak = 0
|
||
stream_id = session.tts_stream_id
|
||
continue
|
||
|
||
current_frame_buffered = False
|
||
if not session.tts_active:
|
||
pre_roll_pcm.extend(pcm_data)
|
||
current_frame_buffered = True
|
||
if len(pre_roll_pcm) > pre_roll_max_bytes:
|
||
del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes]
|
||
|
||
if not has_audible_audio:
|
||
audible_frame_streak = 0
|
||
continue
|
||
|
||
audible_frame_streak += 1
|
||
if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES:
|
||
continue
|
||
|
||
await self._start_tts(session)
|
||
if not session.tts_active:
|
||
continue
|
||
|
||
print(
|
||
"检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 "
|
||
f"(frames={audible_frame_streak})"
|
||
)
|
||
pending_pcm.extend(pre_roll_pcm)
|
||
pre_roll_pcm.clear()
|
||
audible_frame_streak = 0
|
||
|
||
if has_audible_audio:
|
||
self._reset_tts_idle_timer(session)
|
||
|
||
if not current_frame_buffered:
|
||
pending_pcm.extend(pcm_data)
|
||
|
||
while (
|
||
len(pending_pcm) >= output_frame_bytes
|
||
and stream_id == session.tts_stream_id
|
||
and session.websocket is not None
|
||
):
|
||
try:
|
||
now = time.monotonic()
|
||
if last_send_time is not None:
|
||
send_gap_ms = (now - last_send_time) * 1000.0
|
||
if send_gap_ms > DOWNLINK_SEND_GAP_WARN_MS:
|
||
print(
|
||
"[downlink] warning: send gap "
|
||
f"{send_gap_ms:.1f}ms device={session.device_id} "
|
||
f"pending_ms={self._audio_duration_ms(len(pending_pcm) // 2, OUTPUT_SAMPLE_RATE):.1f}"
|
||
)
|
||
last_send_time = now
|
||
|
||
opus_packet = encoder.encode(
|
||
bytes(pending_pcm[:output_frame_bytes]),
|
||
output_samples_per_opus_frame,
|
||
)
|
||
del pending_pcm[:output_frame_bytes]
|
||
await session.websocket.send(self._wrap_opus_payload(session, opus_packet))
|
||
downlink_packets += 1
|
||
downlink_audio_ms += OUTPUT_FRAME_DURATION_MS
|
||
if now - last_downlink_stats_time >= AUDIO_STATS_INTERVAL_SECONDS:
|
||
print(
|
||
"[downlink] "
|
||
f"device={session.device_id} packets={downlink_packets} "
|
||
f"audio_ms={downlink_audio_ms:.0f} "
|
||
f"pending_ms={self._audio_duration_ms(len(pending_pcm) // 2, OUTPUT_SAMPLE_RATE):.1f}"
|
||
)
|
||
downlink_packets = 0
|
||
downlink_audio_ms = 0.0
|
||
last_downlink_stats_time = now
|
||
except Exception as exc:
|
||
print(f"发送回 ESP32 失败: {exc}")
|
||
break
|
||
|
||
except Exception as exc:
|
||
print(f"音频流处理错误: {exc}")
|
||
finally:
|
||
print("🎧 TTS 音频结束")
|
||
task = session.forwarding_tracks.get(track_sid)
|
||
current_task = asyncio.current_task()
|
||
if task is current_task:
|
||
session.forwarding_tracks.pop(track_sid, None)
|
||
if stream_id == session.tts_stream_id and session.tts_idle_task is not None:
|
||
session.tts_idle_task.cancel()
|
||
session.tts_idle_task = None
|
||
if stream_id == session.tts_stream_id:
|
||
await self._stop_tts(session)
|
||
|
||
async def handle_websocket(self, websocket: Any) -> None:
|
||
header_version = websocket.request.headers.get("Protocol-Version")
|
||
try:
|
||
protocol_version = int(header_version) if header_version else 1
|
||
except ValueError:
|
||
protocol_version = 1
|
||
|
||
device_id = self._build_device_id(websocket)
|
||
existing_session = self.device_sessions.get(device_id)
|
||
if existing_session is not None:
|
||
print(f"检测到重复 device_id,关闭旧 session: device={device_id}")
|
||
await self._close_session(existing_session)
|
||
self.device_sessions.pop(device_id, None)
|
||
|
||
chat_mode, agent_mode, agent_name, vision_enabled = self._resolve_agent_selection(websocket.request.headers)
|
||
room_name, identity = self._build_session_names(device_id)
|
||
session = DeviceSession(
|
||
device_id=device_id,
|
||
websocket=websocket,
|
||
protocol_version=protocol_version,
|
||
room_name=room_name,
|
||
identity=identity,
|
||
chat_mode=chat_mode,
|
||
agent_mode=agent_mode,
|
||
agent_name=agent_name,
|
||
vision_enabled=vision_enabled,
|
||
room=rtc.Room(),
|
||
mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1),
|
||
agent_ready=asyncio.Event(),
|
||
)
|
||
self.device_sessions[device_id] = session
|
||
|
||
print(f"ESP32 已连接: device={device_id}")
|
||
print(f"ESP32 协议版本: {session.protocol_version}")
|
||
print(
|
||
f"ESP32 mode: chat={session.chat_mode} "
|
||
f"agent={session.agent_mode}/{session.agent_name} "
|
||
f"vision={session.vision_enabled}"
|
||
)
|
||
session.tts_stream_id += 1
|
||
opus_decoder = None
|
||
uplink_packets = 0
|
||
uplink_audio_ms = 0.0
|
||
uplink_decode_errors = 0
|
||
uplink_dropped_frames = 0
|
||
last_uplink_stats_time = time.monotonic()
|
||
room_connect_task: Optional[asyncio.Task[Any]] = None
|
||
|
||
try:
|
||
hello_msg = self._build_server_hello(session)
|
||
await websocket.send(json.dumps(hello_msg))
|
||
print(
|
||
f"已发送 server hello: device={device_id} room={session.room_name} "
|
||
f"audio={OUTPUT_SAMPLE_RATE}Hz/{OUTPUT_FRAME_DURATION_MS}ms"
|
||
)
|
||
asyncio.create_task(self._run_emotion_test_sequence(session))
|
||
|
||
room_connect_task = asyncio.create_task(self._connect_session_room(session))
|
||
room_connect_task.add_done_callback(
|
||
lambda task: self._track_room_connect_task(session, task)
|
||
)
|
||
|
||
async for message in websocket:
|
||
if isinstance(message, bytes):
|
||
if len(message) < 4:
|
||
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
||
continue
|
||
|
||
audio_data = self._extract_opus_payload(session, message)
|
||
if not audio_data:
|
||
continue
|
||
|
||
try:
|
||
if opus_decoder is None:
|
||
# print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
|
||
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
|
||
|
||
pcm_bytes = opus_decoder.decode(audio_data, INPUT_MAX_SAMPLES_PER_OPUS_FRAME)
|
||
|
||
num_samples = len(pcm_bytes) // 2
|
||
if num_samples > 0:
|
||
session.captured_frame_count += 1
|
||
now = time.monotonic()
|
||
uplink_packets += 1
|
||
uplink_audio_ms += self._audio_duration_ms(num_samples, INPUT_SAMPLE_RATE)
|
||
if self._has_audible_audio(pcm_bytes):
|
||
session.last_uplink_audible_time = now
|
||
if session.tts_waiting_for_user_audio_after_interrupt:
|
||
session.tts_waiting_for_user_audio_after_interrupt = False
|
||
print(
|
||
f"[uplink] detected user audio after interrupt: "
|
||
f"device={session.device_id}"
|
||
)
|
||
if (
|
||
session.captured_frame_count <= 5
|
||
or now - session.first_capture_log_time >= 5.0
|
||
):
|
||
session.first_capture_log_time = now
|
||
# print(
|
||
# f"[uplink] capture_frame count={session.captured_frame_count} "
|
||
# f"bytes={len(pcm_bytes)} samples={num_samples} "
|
||
# f"room={session.room_name}"
|
||
# )
|
||
if now - last_uplink_stats_time >= AUDIO_STATS_INTERVAL_SECONDS:
|
||
print(
|
||
"[uplink] "
|
||
f"device={session.device_id} packets={uplink_packets} "
|
||
f"audio_ms={uplink_audio_ms:.0f} "
|
||
f"decode_errors={uplink_decode_errors} "
|
||
f"dropped_frames={uplink_dropped_frames}"
|
||
)
|
||
uplink_packets = 0
|
||
uplink_audio_ms = 0.0
|
||
uplink_decode_errors = 0
|
||
uplink_dropped_frames = 0
|
||
last_uplink_stats_time = now
|
||
if not await self._capture_mic_frame(session, pcm_bytes, num_samples):
|
||
uplink_dropped_frames += 1
|
||
except Exception as exc:
|
||
uplink_decode_errors += 1
|
||
print(f"Opus audio decode error ({len(message)} bytes): {exc}")
|
||
elif isinstance(message, str):
|
||
try:
|
||
data = json.loads(message)
|
||
# print(f"收到 ESP32 JSON 消息: {data}")
|
||
msg_type = data.get("type")
|
||
if msg_type == "hello":
|
||
self._log_client_hello(session, data)
|
||
elif msg_type == "abort":
|
||
reason = data.get("reason")
|
||
abort_reason = reason if isinstance(reason, str) and reason else "button_abort"
|
||
print(f"处理 ESP32 打断请求: reason={abort_reason}")
|
||
await self._abort_tts(session, abort_reason)
|
||
elif msg_type == "mcp":
|
||
await self._publish_mcp_message(session, data)
|
||
elif msg_type == "vision" and data.get("state") == "frame":
|
||
await self._publish_vision_frame(session, data)
|
||
except json.JSONDecodeError:
|
||
print(f"收到未知的字符消息: {message}")
|
||
except ConnectionClosedError as exc:
|
||
print(f"ESP32 异常断开: {exc}")
|
||
except Exception as exc:
|
||
self._log_exception("WebSocket 其他错误", exc)
|
||
finally:
|
||
print(f"ESP32 断开连接: device={device_id} room={session.room_name}")
|
||
if room_connect_task is not None and not room_connect_task.done():
|
||
room_connect_task.cancel()
|
||
with contextlib.suppress(asyncio.CancelledError):
|
||
await room_connect_task
|
||
await self._close_session(session)
|
||
self.device_sessions.pop(device_id, None)
|
||
|
||
|
||
async def main() -> None:
|
||
bridge = ESP32LiveKitBridge()
|
||
try:
|
||
await bridge.start()
|
||
async with websockets.serve(
|
||
bridge.handle_websocket,
|
||
"0.0.0.0",
|
||
WS_PORT,
|
||
max_queue=WS_MAX_QUEUE,
|
||
max_size=WS_MAX_SIZE,
|
||
):
|
||
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
|
||
await asyncio.Future()
|
||
finally:
|
||
await bridge.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
asyncio.run(main())
|
||
except Exception as exc:
|
||
print(f"[error] {exc}", file=sys.stderr)
|
||
sys.exit(1)
|