feat: ws connect
This commit is contained in:
919
main/bridge_server.py
Normal file
919
main/bridge_server.py
Normal file
@ -0,0 +1,919 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
import opuslib
|
||||
import websockets
|
||||
from livekit import rtc
|
||||
from livekit.rtc import AudioFrame, AudioSource
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
try:
|
||||
from livekit import api as livekit_api
|
||||
except ImportError:
|
||||
livekit_api = None
|
||||
|
||||
TOKEN_URL = "http://10.6.80.130:8000/getToken"
|
||||
LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud"
|
||||
ROOM_PREFIX = "test-livekit"
|
||||
IDENTITY_PREFIX = "uv-livekit"
|
||||
AGENT_NAME = "my-agent"
|
||||
CONNECT_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_CONNECT_TIMEOUT_SECONDS", "20.0"))
|
||||
AGENT_READY_TIMEOUT_SECONDS = float(os.getenv("LIVEKIT_AGENT_READY_TIMEOUT_SECONDS", "10.0"))
|
||||
WS_PORT = 8080
|
||||
AGENT_DISPATCH_MODE = os.getenv("AGENT_DISPATCH_MODE", "token").lower()
|
||||
|
||||
INPUT_SAMPLE_RATE = 16000
|
||||
OUTPUT_SAMPLE_RATE = 24000
|
||||
INPUT_FRAME_DURATION_MS = 60
|
||||
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
|
||||
OUTPUT_FRAME_DURATION_MS = 20
|
||||
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
|
||||
TTS_IDLE_TIMEOUT_SECONDS = 0.8
|
||||
TTS_SILENCE_PEAK_THRESHOLD = 96
|
||||
TTS_PRE_ROLL_MS = 200
|
||||
TTS_START_CONSECUTIVE_AUDIBLE_FRAMES = 3
|
||||
TTS_INTERRUPT_SILENCE_FRAMES = 3
|
||||
INTERRUPT_TOPIC = "lk.interrupt"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceSession:
|
||||
device_id: str
|
||||
websocket: Any
|
||||
protocol_version: int
|
||||
room_name: str
|
||||
identity: str
|
||||
room: rtc.Room
|
||||
mic_source: AudioSource
|
||||
agent_ready: asyncio.Event
|
||||
forwarding_tracks: dict[str, asyncio.Task[Any]] = field(default_factory=dict)
|
||||
tts_active: bool = False
|
||||
tts_idle_task: Optional[asyncio.Task] = None
|
||||
tts_stream_id: int = 0
|
||||
tts_transcript_text: str = ""
|
||||
agent_dispatch_task: Optional[asyncio.Task] = None
|
||||
closed: bool = False
|
||||
captured_frame_count: int = 0
|
||||
first_capture_log_time: float = 0.0
|
||||
|
||||
|
||||
async def fetch_token(room_name: str, identity: str) -> str:
|
||||
params = {"room": room_name, "identity": identity}
|
||||
if AGENT_DISPATCH_MODE == "token":
|
||||
params["agent_name"] = AGENT_NAME
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
response = await client.get(TOKEN_URL, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
payload: dict[str, Any] = response.json()
|
||||
token = payload.get("token")
|
||||
if not isinstance(token, str) or not token:
|
||||
raise ValueError(f"token response missing token field: {payload}")
|
||||
|
||||
print(f"[token] room={payload.get('room')} identity={payload.get('identity')}")
|
||||
print(f"[token] jwt_prefix={token[:16]}... len={len(token)}")
|
||||
return token
|
||||
|
||||
|
||||
class ESP32LiveKitBridge:
|
||||
def __init__(self) -> None:
|
||||
self.device_sessions: dict[str, DeviceSession] = {}
|
||||
|
||||
def _log_exception(self, prefix: str, exc: BaseException) -> None:
|
||||
print(f"{prefix}: type={type(exc).__name__} detail={exc!r}")
|
||||
formatted_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||
if formatted_tb.strip():
|
||||
print(formatted_tb.rstrip())
|
||||
|
||||
def _build_device_id(self, websocket: Any) -> str:
|
||||
headers = websocket.request.headers
|
||||
requested_id = headers.get("X-Device-Id") or headers.get("Device-Id")
|
||||
if requested_id:
|
||||
return requested_id
|
||||
return f"ws-{id(websocket):x}"
|
||||
|
||||
def _build_session_names(self, device_id: str) -> tuple[str, str]:
|
||||
session_tag = uuid.uuid4().hex[:8]
|
||||
room_name = f"{ROOM_PREFIX}-{session_tag}"
|
||||
identity = f"{IDENTITY_PREFIX}-{session_tag}"
|
||||
print(f"[session] device={device_id} room={room_name} identity={identity}")
|
||||
return room_name, identity
|
||||
|
||||
def _is_agent_participant(self, participant: rtc.RemoteParticipant) -> bool:
|
||||
identity = getattr(participant, "identity", "") or ""
|
||||
return identity.startswith("agent-") or AGENT_NAME in identity
|
||||
|
||||
def _get_agent_identities(self, session: DeviceSession) -> list[str]:
|
||||
return [
|
||||
participant.identity
|
||||
for participant in session.room.remote_participants.values()
|
||||
if self._is_agent_participant(participant)
|
||||
]
|
||||
|
||||
def _log_agent_participants(self, session: DeviceSession, source: str) -> None:
|
||||
agent_identities = self._get_agent_identities(session)
|
||||
print(
|
||||
f"[agent-check] source={source} room={session.room_name} "
|
||||
f"agent_count={len(agent_identities)} agents={agent_identities}"
|
||||
)
|
||||
|
||||
async def _has_existing_dispatch(self, session: DeviceSession) -> bool:
|
||||
if livekit_api is None:
|
||||
return False
|
||||
|
||||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||||
if not api_key or not api_secret:
|
||||
return False
|
||||
|
||||
request_cls = getattr(livekit_api, "ListAgentDispatchRequest", None)
|
||||
if request_cls is None:
|
||||
return False
|
||||
|
||||
lkapi = livekit_api.LiveKitAPI(
|
||||
url=LIVEKIT_WS_URL,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
)
|
||||
try:
|
||||
response = await lkapi.agent_dispatch.list_dispatch(
|
||||
request_cls(room=session.room_name)
|
||||
)
|
||||
dispatches = (
|
||||
getattr(response, "agent_dispatches", None)
|
||||
or getattr(response, "dispatches", None)
|
||||
or getattr(response, "items", None)
|
||||
or []
|
||||
)
|
||||
for dispatch in dispatches:
|
||||
dispatch_agent_name = getattr(dispatch, "agent_name", None)
|
||||
dispatch_room = getattr(dispatch, "room", None)
|
||||
if dispatch_room == session.room_name and (
|
||||
dispatch_agent_name == AGENT_NAME or dispatch_agent_name is None
|
||||
):
|
||||
print(
|
||||
f"检测到已有 dispatch: room={session.room_name} "
|
||||
f"agent={dispatch_agent_name}"
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
print(f"list_dispatch 查询失败,继续按无 dispatch 处理: {exc}")
|
||||
finally:
|
||||
await lkapi.aclose()
|
||||
|
||||
return False
|
||||
|
||||
async def ensure_agent_dispatched(self, session: DeviceSession) -> None:
|
||||
if AGENT_DISPATCH_MODE != "bridge":
|
||||
print(f"跳过 bridge 手动 dispatch: mode={AGENT_DISPATCH_MODE}")
|
||||
return
|
||||
|
||||
if await self._has_existing_dispatch(session):
|
||||
return
|
||||
|
||||
for participant in session.room.remote_participants.values():
|
||||
if self._is_agent_participant(participant):
|
||||
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
|
||||
return
|
||||
|
||||
if session.agent_dispatch_task is not None and not session.agent_dispatch_task.done():
|
||||
print("Agent dispatch 正在进行中,跳过重复请求")
|
||||
return
|
||||
|
||||
session.agent_dispatch_task = asyncio.create_task(self._dispatch_agent(session))
|
||||
await session.agent_dispatch_task
|
||||
|
||||
async def _dispatch_agent(self, session: DeviceSession) -> None:
|
||||
print(f"准备 dispatch agent: room={session.room_name}, agent={AGENT_NAME}")
|
||||
|
||||
try:
|
||||
if await self._dispatch_agent_with_sdk(session):
|
||||
return
|
||||
|
||||
if await self._dispatch_agent_with_cli(session):
|
||||
return
|
||||
|
||||
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
|
||||
except Exception as exc:
|
||||
print(f"Agent dispatch 失败: {exc}")
|
||||
|
||||
async def _dispatch_agent_with_sdk(self, session: DeviceSession) -> bool:
|
||||
if livekit_api is None:
|
||||
return False
|
||||
|
||||
api_key = os.getenv("LIVEKIT_API_KEY")
|
||||
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||||
if not api_key or not api_secret:
|
||||
return False
|
||||
|
||||
lkapi = livekit_api.LiveKitAPI(
|
||||
url=LIVEKIT_WS_URL,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
)
|
||||
try:
|
||||
dispatch = await lkapi.agent_dispatch.create_dispatch(
|
||||
livekit_api.CreateAgentDispatchRequest(
|
||||
agent_name=AGENT_NAME,
|
||||
room=session.room_name,
|
||||
metadata=json.dumps(
|
||||
{
|
||||
"source": "bridge_server",
|
||||
"identity": session.identity,
|
||||
"device_id": session.device_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
print(f"Agent dispatch 已创建: {dispatch}")
|
||||
finally:
|
||||
await lkapi.aclose()
|
||||
|
||||
return True
|
||||
|
||||
async def _dispatch_agent_with_cli(self, session: DeviceSession) -> bool:
|
||||
lk_path = shutil.which("lk")
|
||||
if lk_path is None:
|
||||
return False
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
lk_path,
|
||||
"dispatch",
|
||||
"create",
|
||||
"--room",
|
||||
session.room_name,
|
||||
"--agent-name",
|
||||
AGENT_NAME,
|
||||
"--metadata",
|
||||
json.dumps(
|
||||
{
|
||||
"source": "bridge_server",
|
||||
"identity": session.identity,
|
||||
"device_id": session.device_id,
|
||||
}
|
||||
),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout_text = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_text = stderr.decode("utf-8", errors="replace").strip()
|
||||
|
||||
if stdout_text:
|
||||
print(f"lk dispatch 输出: {stdout_text}")
|
||||
if stderr_text:
|
||||
print(f"lk dispatch 错误输出: {stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
print(f"lk dispatch create 失败,退出码: {process.returncode}")
|
||||
return False
|
||||
|
||||
print(f"Agent dispatch 已通过 lk CLI 创建: room={session.room_name}, agent={AGENT_NAME}")
|
||||
return True
|
||||
|
||||
async def _publish_agent_event(self, session: DeviceSession, payload: dict[str, Any]) -> bool:
|
||||
participant = getattr(session.room, "local_participant", None)
|
||||
if participant is None:
|
||||
print("跳过发送 agent 控制事件,local participant 尚未就绪")
|
||||
return False
|
||||
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
agent_identities = self._get_agent_identities(session)
|
||||
kwargs: dict[str, Any] = {}
|
||||
if agent_identities:
|
||||
kwargs["destination_identities"] = agent_identities
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in ({"topic": INTERRUPT_TOPIC, **kwargs}, kwargs):
|
||||
try:
|
||||
await participant.publish_data(data, **attempt)
|
||||
print(
|
||||
f"已发送 agent 控制事件: topic={attempt.get('topic', '<default>')} "
|
||||
f"targets={agent_identities or 'broadcast'} payload={payload}"
|
||||
)
|
||||
return True
|
||||
except TypeError as exc:
|
||||
last_error = exc
|
||||
except Exception as exc:
|
||||
print(f"发送 agent 控制事件失败: {exc}")
|
||||
return False
|
||||
|
||||
if last_error is not None:
|
||||
print(f"发送 agent 控制事件失败,publish_data 签名不兼容: {last_error}")
|
||||
return False
|
||||
|
||||
async def _send_agent_interrupt(self, session: DeviceSession, reason: str) -> None:
|
||||
payload = {
|
||||
"type": "interrupt",
|
||||
"topic": INTERRUPT_TOPIC,
|
||||
"reason": reason,
|
||||
"room": session.room_name,
|
||||
"identity": session.identity,
|
||||
"device_id": session.device_id,
|
||||
}
|
||||
ok = await self._publish_agent_event(session, payload)
|
||||
if not ok:
|
||||
print("警告: bridge 已停止 TTS,但 agent 侧 interrupt 未确认送出")
|
||||
|
||||
async def _send_tts_state(self, session: DeviceSession, state: str) -> None:
|
||||
if session.websocket is None:
|
||||
print(f"跳过 tts {state},ESP32 尚未连接")
|
||||
return
|
||||
await session.websocket.send(json.dumps({"type": "tts", "state": state}))
|
||||
print(f"已发送 tts {state}: device={session.device_id}")
|
||||
|
||||
async def _start_tts(self, session: DeviceSession) -> None:
|
||||
if session.tts_active:
|
||||
print("跳过 tts start,当前已处于激活状态")
|
||||
return
|
||||
session.tts_transcript_text = ""
|
||||
await self._send_tts_state(session, "start")
|
||||
session.tts_active = True
|
||||
|
||||
async def _stop_tts(self, session: DeviceSession) -> None:
|
||||
if not session.tts_active:
|
||||
print("跳过 tts stop,当前未激活")
|
||||
return
|
||||
await self._send_tts_state(session, "stop")
|
||||
session.tts_active = False
|
||||
session.tts_transcript_text = ""
|
||||
|
||||
async def _abort_tts(self, session: DeviceSession, reason: str = "client_abort") -> None:
|
||||
print(f"收到打断请求,停止当前 TTS: device={session.device_id} reason={reason}")
|
||||
session.tts_stream_id += 1
|
||||
if session.tts_idle_task is not None:
|
||||
session.tts_idle_task.cancel()
|
||||
session.tts_idle_task = None
|
||||
await self._send_agent_interrupt(session, reason)
|
||||
await self._stop_tts(session)
|
||||
|
||||
def _reset_tts_idle_timer(self, session: DeviceSession) -> None:
|
||||
if session.tts_idle_task is not None:
|
||||
session.tts_idle_task.cancel()
|
||||
session.tts_idle_task = asyncio.create_task(
|
||||
self._tts_idle_watchdog(session, session.tts_stream_id)
|
||||
)
|
||||
|
||||
async def _tts_idle_watchdog(self, session: DeviceSession, stream_id: int) -> None:
|
||||
try:
|
||||
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
|
||||
if stream_id != session.tts_stream_id:
|
||||
return
|
||||
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态")
|
||||
await self._stop_tts(session)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _has_audible_audio(self, pcm_data: bytes) -> bool:
|
||||
if len(pcm_data) < 2:
|
||||
return False
|
||||
|
||||
sample_count = len(pcm_data) // 2
|
||||
peak = 0
|
||||
for i in range(sample_count):
|
||||
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
|
||||
if sample < 0:
|
||||
sample = -sample
|
||||
if sample > peak:
|
||||
peak = sample
|
||||
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _maybe_forward_remote_audio(
|
||||
self,
|
||||
session: DeviceSession,
|
||||
track: rtc.Track,
|
||||
publication: Optional[rtc.TrackPublication],
|
||||
participant: rtc.RemoteParticipant,
|
||||
source: str,
|
||||
) -> None:
|
||||
if track.kind != rtc.TrackKind.KIND_AUDIO:
|
||||
return
|
||||
|
||||
track_sid = (
|
||||
getattr(track, "sid", None)
|
||||
or getattr(publication, "sid", None)
|
||||
or f"audio:{participant.identity}"
|
||||
)
|
||||
existing_task = session.forwarding_tracks.get(track_sid)
|
||||
if existing_task is not None and not existing_task.done():
|
||||
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
|
||||
return
|
||||
if existing_task is not None and existing_task.done():
|
||||
print(f"检测到已结束的音频转发任务,重新创建: sid={track_sid}")
|
||||
session.forwarding_tracks.pop(track_sid, None)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self.forward_audio_to_esp32(
|
||||
session,
|
||||
track_sid,
|
||||
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1),
|
||||
)
|
||||
)
|
||||
session.forwarding_tracks[track_sid] = task
|
||||
print(
|
||||
f"收到音频流: {participant.identity} sid={track_sid} "
|
||||
f"source={source} room={session.room_name}"
|
||||
)
|
||||
|
||||
def _scan_participant_audio_tracks(
|
||||
self,
|
||||
session: DeviceSession,
|
||||
participant: rtc.RemoteParticipant,
|
||||
source: str,
|
||||
) -> None:
|
||||
publications = getattr(participant, "track_publications", None) or {}
|
||||
for publication in publications.values():
|
||||
track = getattr(publication, "track", None)
|
||||
if track is None:
|
||||
continue
|
||||
self._maybe_forward_remote_audio(session, track, publication, participant, source)
|
||||
|
||||
def _extract_opus_payload(self, session: DeviceSession, message: bytes) -> bytes:
|
||||
if session.protocol_version == 2:
|
||||
if len(message) < 16:
|
||||
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
|
||||
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
|
||||
"!HHIII", message[:16]
|
||||
)
|
||||
if version != 2:
|
||||
raise ValueError(f"version 2 包头版本异常: {version}")
|
||||
if msg_type != 0:
|
||||
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
|
||||
if len(message) < 16 + payload_size:
|
||||
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
|
||||
return message[16:16 + payload_size]
|
||||
|
||||
if session.protocol_version == 3:
|
||||
if len(message) < 4:
|
||||
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
|
||||
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
|
||||
if msg_type != 0:
|
||||
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
|
||||
if len(message) < 4 + payload_size:
|
||||
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
|
||||
return message[4:4 + payload_size]
|
||||
|
||||
return message
|
||||
|
||||
def _wrap_opus_payload(self, session: DeviceSession, payload: bytes) -> bytes:
|
||||
if session.protocol_version == 2:
|
||||
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
|
||||
return header + payload
|
||||
|
||||
if session.protocol_version == 3:
|
||||
header = struct.pack("!BBH", 0, 0, len(payload))
|
||||
return header + payload
|
||||
|
||||
return payload
|
||||
|
||||
def _register_room_handlers(self, session: DeviceSession) -> None:
|
||||
@session.room.on("connection_state_changed")
|
||||
def on_connection_state_changed(state: int) -> None:
|
||||
print(f"[livekit] room={session.room_name} state={rtc.ConnectionState.Name(state)}")
|
||||
|
||||
@session.room.on("connected")
|
||||
def on_connected() -> None:
|
||||
print(f"✅ 成功连接到 LiveKit 房间: room={session.room_name}")
|
||||
self._log_agent_participants(session, "connected")
|
||||
for participant in session.room.remote_participants.values():
|
||||
if self._is_agent_participant(participant):
|
||||
session.agent_ready.set()
|
||||
self._scan_participant_audio_tracks(session, participant, "connected_scan")
|
||||
|
||||
@session.room.on("participant_connected")
|
||||
def on_participant_connected(participant: rtc.RemoteParticipant) -> None:
|
||||
role = "Agent" if self._is_agent_participant(participant) else "Remote participant"
|
||||
print(f"👋 {role} ({participant.identity}) 已加入房间: room={session.room_name}")
|
||||
self._log_agent_participants(session, "participant_connected")
|
||||
if self._is_agent_participant(participant):
|
||||
session.agent_ready.set()
|
||||
self._scan_participant_audio_tracks(
|
||||
session, participant, "participant_connected_scan"
|
||||
)
|
||||
|
||||
@session.room.on("participant_disconnected")
|
||||
def on_participant_disconnected(participant: rtc.RemoteParticipant) -> None:
|
||||
print(f"👋 远端参与者离开房间: room={session.room_name} identity={participant.identity}")
|
||||
session.forwarding_tracks = {
|
||||
track_sid: task
|
||||
for track_sid, task in session.forwarding_tracks.items()
|
||||
if not track_sid.endswith(f":{participant.identity}")
|
||||
}
|
||||
|
||||
@session.room.on("data_received")
|
||||
def on_data_received(data_packet: rtc.DataPacket) -> None:
|
||||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||||
try:
|
||||
print(
|
||||
f"📩 [数据接收 | room={session.room_name} | {identity}]: "
|
||||
f"{data_packet.data.decode('utf-8')}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@session.room.on("transcription_received")
|
||||
def on_transcription_received(
|
||||
segments: list[rtc.TranscriptionSegment],
|
||||
participant: rtc.Participant,
|
||||
track_pub: rtc.TrackPublication,
|
||||
) -> None:
|
||||
identity = participant.identity if participant else "未知"
|
||||
is_agent = isinstance(participant, rtc.RemoteParticipant) and self._is_agent_participant(participant)
|
||||
for segment in segments:
|
||||
status = "✅ 最终结果" if segment.final else "⏳ 正在思考/中间结果"
|
||||
print(f"🗣️ [{status} | room={session.room_name} | {identity}]: {segment.text}")
|
||||
if session.websocket is not None:
|
||||
ws = session.websocket
|
||||
if is_agent:
|
||||
session.tts_transcript_text = segment.text
|
||||
asyncio.create_task(
|
||||
ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "tts",
|
||||
"state": "sentence_start",
|
||||
"text": session.tts_transcript_text,
|
||||
"final": segment.final,
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stt",
|
||||
"text": segment.text,
|
||||
"final": segment.final,
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@session.room.on("track_subscribed")
|
||||
def on_track_subscribed(
|
||||
track: rtc.Track,
|
||||
publication: rtc.TrackPublication,
|
||||
participant: rtc.RemoteParticipant,
|
||||
) -> None:
|
||||
self._maybe_forward_remote_audio(session, track, publication, participant, "event")
|
||||
|
||||
@session.room.on("track_published")
|
||||
def on_track_published(
|
||||
publication: rtc.RemoteTrackPublication,
|
||||
participant: rtc.RemoteParticipant,
|
||||
) -> None:
|
||||
track_sid = getattr(publication, "sid", None)
|
||||
print(
|
||||
f"📡 远端音轨已发布: room={session.room_name} identity={participant.identity} "
|
||||
f"track_sid={track_sid}"
|
||||
)
|
||||
track = getattr(publication, "track", None)
|
||||
if track is not None:
|
||||
self._maybe_forward_remote_audio(session, track, publication, participant, "published")
|
||||
|
||||
async def _connect_session_room(self, session: DeviceSession) -> None:
|
||||
self._register_room_handlers(session)
|
||||
|
||||
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
|
||||
print(f"[config] token_url={TOKEN_URL}")
|
||||
print(f"[config] room={session.room_name} identity={session.identity}")
|
||||
print(f"[config] livekit_connect_timeout={CONNECT_TIMEOUT_SECONDS}")
|
||||
token = await fetch_token(session.room_name, session.identity)
|
||||
|
||||
try:
|
||||
await session.room.connect(
|
||||
LIVEKIT_WS_URL,
|
||||
token,
|
||||
options=rtc.RoomOptions(connect_timeout=CONNECT_TIMEOUT_SECONDS),
|
||||
)
|
||||
except Exception as exc:
|
||||
self._log_exception(
|
||||
f"连接 LiveKit 房间失败: room={session.room_name}",
|
||||
exc,
|
||||
)
|
||||
print(
|
||||
"提示: token 已下发但 WebRTC 未建连时,通常与 ICE/TURN/防火墙/NAT 有关;"
|
||||
"请重点检查到 *.livekit.cloud:443、*.turn.livekit.cloud:443、"
|
||||
"*.host.livekit.cloud:3478、UDP 50000-60000、TCP 7881 的出站连通性。"
|
||||
)
|
||||
raise
|
||||
|
||||
print(f"已连接到 LiveKit 房间: {session.room.name}")
|
||||
print(f"[livekit] local_identity={session.room.local_participant.identity}")
|
||||
print(f"[livekit] local_sid={session.room.local_participant.sid}")
|
||||
print(f"[livekit] remote_participants={list(session.room.remote_participants.keys())}")
|
||||
self._log_agent_participants(session, "after_connect")
|
||||
|
||||
await self.ensure_agent_dispatched(session)
|
||||
|
||||
track = rtc.LocalAudioTrack.create_audio_track(
|
||||
f"esp32-mic-{session.device_id}",
|
||||
session.mic_source,
|
||||
)
|
||||
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
|
||||
publication = await session.room.local_participant.publish_track(track, options)
|
||||
publication_sid = getattr(publication, "sid", None)
|
||||
track_sid = getattr(track, "sid", None)
|
||||
print(
|
||||
f"已发布 ESP32 mic track: room={session.room_name} "
|
||||
f"track_sid={track_sid} publication_sid={publication_sid}"
|
||||
)
|
||||
self._log_agent_participants(session, "after_publish_mic")
|
||||
|
||||
print(f"等待 agent 加入: room={session.room_name}")
|
||||
try:
|
||||
await asyncio.wait_for(session.agent_ready.wait(), timeout=AGENT_READY_TIMEOUT_SECONDS)
|
||||
print(f"✅ agent 已就绪: room={session.room_name}")
|
||||
except asyncio.TimeoutError:
|
||||
print(f"⚠️ agent 等待超时: room={session.room_name}")
|
||||
|
||||
async def start(self) -> None:
|
||||
print(f"[config] websocket_port={WS_PORT}")
|
||||
print(f"[config] livekit_ws_url={LIVEKIT_WS_URL}")
|
||||
print(f"[config] token_url={TOKEN_URL}")
|
||||
print(f"[config] agent_dispatch_mode={AGENT_DISPATCH_MODE}")
|
||||
|
||||
async def close(self) -> None:
|
||||
for session in list(self.device_sessions.values()):
|
||||
await self._close_session(session)
|
||||
|
||||
async def _close_session(self, session: DeviceSession) -> None:
|
||||
if session.closed:
|
||||
return
|
||||
session.closed = True
|
||||
session.websocket = None
|
||||
session.tts_active = False
|
||||
session.tts_stream_id += 1
|
||||
if session.tts_idle_task is not None:
|
||||
session.tts_idle_task.cancel()
|
||||
session.tts_idle_task = None
|
||||
try:
|
||||
await session.room.disconnect()
|
||||
except Exception as exc:
|
||||
print(f"断开 LiveKit 房间失败: room={session.room_name} error={exc}")
|
||||
|
||||
async def forward_audio_to_esp32(
|
||||
self,
|
||||
session: DeviceSession,
|
||||
track_sid: str,
|
||||
audio_stream: rtc.AudioStream,
|
||||
) -> None:
|
||||
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, "voip")
|
||||
pending_pcm = bytearray()
|
||||
pre_roll_pcm = bytearray()
|
||||
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
|
||||
audible_frame_streak = 0
|
||||
silence_frame_streak = 0
|
||||
waiting_for_post_interrupt_silence = False
|
||||
stream_id = session.tts_stream_id
|
||||
print(
|
||||
f"启动 TTS 转发: device={session.device_id} room={session.room_name} "
|
||||
f"track_sid={track_sid} stream_id={stream_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in audio_stream:
|
||||
if stream_id != session.tts_stream_id:
|
||||
print("检测到 TTS 被中断,停止当前播放并等待静音分隔")
|
||||
pending_pcm.clear()
|
||||
pre_roll_pcm.clear()
|
||||
audible_frame_streak = 0
|
||||
silence_frame_streak = 0
|
||||
waiting_for_post_interrupt_silence = True
|
||||
stream_id = session.tts_stream_id
|
||||
if session.tts_active:
|
||||
await self._stop_tts(session)
|
||||
continue
|
||||
|
||||
frame = event.frame
|
||||
pcm_data = frame.data.tobytes()
|
||||
has_audible_audio = self._has_audible_audio(pcm_data)
|
||||
|
||||
if waiting_for_post_interrupt_silence:
|
||||
if has_audible_audio:
|
||||
silence_frame_streak = 0
|
||||
continue
|
||||
|
||||
silence_frame_streak += 1
|
||||
if silence_frame_streak < TTS_INTERRUPT_SILENCE_FRAMES:
|
||||
continue
|
||||
|
||||
print("检测到 interrupt 后静音分隔,允许下一轮 TTS 重新起播")
|
||||
waiting_for_post_interrupt_silence = False
|
||||
silence_frame_streak = 0
|
||||
stream_id = session.tts_stream_id
|
||||
continue
|
||||
|
||||
current_frame_buffered = False
|
||||
if not session.tts_active:
|
||||
pre_roll_pcm.extend(pcm_data)
|
||||
current_frame_buffered = True
|
||||
if len(pre_roll_pcm) > pre_roll_max_bytes:
|
||||
del pre_roll_pcm[: len(pre_roll_pcm) - pre_roll_max_bytes]
|
||||
|
||||
if not has_audible_audio:
|
||||
audible_frame_streak = 0
|
||||
continue
|
||||
|
||||
audible_frame_streak += 1
|
||||
if audible_frame_streak < TTS_START_CONSECUTIVE_AUDIBLE_FRAMES:
|
||||
continue
|
||||
|
||||
await self._start_tts(session)
|
||||
if not session.tts_active:
|
||||
continue
|
||||
|
||||
print(
|
||||
"检测到连续可听 TTS 音频,切换到 Speaking 并开始转发 "
|
||||
f"(frames={audible_frame_streak})"
|
||||
)
|
||||
pending_pcm.extend(pre_roll_pcm)
|
||||
pre_roll_pcm.clear()
|
||||
audible_frame_streak = 0
|
||||
|
||||
if has_audible_audio:
|
||||
self._reset_tts_idle_timer(session)
|
||||
|
||||
if not current_frame_buffered:
|
||||
pending_pcm.extend(pcm_data)
|
||||
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2
|
||||
|
||||
while (
|
||||
len(pending_pcm) >= frame_bytes
|
||||
and stream_id == session.tts_stream_id
|
||||
and session.websocket is not None
|
||||
):
|
||||
try:
|
||||
opus_packet = encoder.encode(
|
||||
bytes(pending_pcm[:frame_bytes]),
|
||||
OUTPUT_SAMPLES_PER_OPUS_FRAME,
|
||||
)
|
||||
del pending_pcm[:frame_bytes]
|
||||
await session.websocket.send(self._wrap_opus_payload(session, opus_packet))
|
||||
except Exception as exc:
|
||||
print(f"发送回 ESP32 失败: {exc}")
|
||||
break
|
||||
|
||||
except Exception as exc:
|
||||
print(f"音频流处理错误: {exc}")
|
||||
finally:
|
||||
print("🎧 TTS 音频结束")
|
||||
task = session.forwarding_tracks.get(track_sid)
|
||||
current_task = asyncio.current_task()
|
||||
if task is current_task:
|
||||
session.forwarding_tracks.pop(track_sid, None)
|
||||
if stream_id == session.tts_stream_id and session.tts_idle_task is not None:
|
||||
session.tts_idle_task.cancel()
|
||||
session.tts_idle_task = None
|
||||
if stream_id == session.tts_stream_id:
|
||||
await self._stop_tts(session)
|
||||
|
||||
async def handle_websocket(self, websocket: Any) -> None:
|
||||
header_version = websocket.request.headers.get("Protocol-Version")
|
||||
try:
|
||||
protocol_version = int(header_version) if header_version else 1
|
||||
except ValueError:
|
||||
protocol_version = 1
|
||||
|
||||
device_id = self._build_device_id(websocket)
|
||||
existing_session = self.device_sessions.get(device_id)
|
||||
if existing_session is not None:
|
||||
print(f"检测到重复 device_id,关闭旧 session: device={device_id}")
|
||||
await self._close_session(existing_session)
|
||||
self.device_sessions.pop(device_id, None)
|
||||
|
||||
room_name, identity = self._build_session_names(device_id)
|
||||
session = DeviceSession(
|
||||
device_id=device_id,
|
||||
websocket=websocket,
|
||||
protocol_version=protocol_version,
|
||||
room_name=room_name,
|
||||
identity=identity,
|
||||
room=rtc.Room(),
|
||||
mic_source=AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1),
|
||||
agent_ready=asyncio.Event(),
|
||||
)
|
||||
self.device_sessions[device_id] = session
|
||||
|
||||
print(f"ESP32 已连接: device={device_id}")
|
||||
print(f"ESP32 协议版本: {session.protocol_version}")
|
||||
session.tts_stream_id += 1
|
||||
opus_decoder = None
|
||||
|
||||
try:
|
||||
await self._connect_session_room(session)
|
||||
|
||||
hello_msg = {
|
||||
"type": "hello",
|
||||
"transport": "websocket",
|
||||
"session": {
|
||||
"room": session.room_name,
|
||||
"identity": session.identity,
|
||||
},
|
||||
"audio_params": {
|
||||
"format": "opus",
|
||||
"sample_rate": OUTPUT_SAMPLE_RATE,
|
||||
"channels": 1,
|
||||
"frame_duration": OUTPUT_FRAME_DURATION_MS,
|
||||
},
|
||||
}
|
||||
await websocket.send(json.dumps(hello_msg))
|
||||
|
||||
async for message in websocket:
|
||||
if isinstance(message, bytes):
|
||||
if len(message) < 4:
|
||||
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
||||
continue
|
||||
|
||||
audio_data = self._extract_opus_payload(session, message)
|
||||
if not audio_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
if opus_decoder is None:
|
||||
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
|
||||
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
|
||||
|
||||
pcm_bytes = opus_decoder.decode(audio_data, INPUT_SAMPLES_PER_OPUS_FRAME)
|
||||
|
||||
num_samples = len(pcm_bytes) // 2
|
||||
if num_samples > 0:
|
||||
session.captured_frame_count += 1
|
||||
now = time.monotonic()
|
||||
if (
|
||||
session.captured_frame_count <= 5
|
||||
or now - session.first_capture_log_time >= 5.0
|
||||
):
|
||||
session.first_capture_log_time = now
|
||||
print(
|
||||
f"[uplink] capture_frame count={session.captured_frame_count} "
|
||||
f"bytes={len(pcm_bytes)} samples={num_samples} "
|
||||
f"room={session.room_name}"
|
||||
)
|
||||
try:
|
||||
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
|
||||
await session.mic_source.capture_frame(frame)
|
||||
except TypeError:
|
||||
frame = AudioFrame.create(
|
||||
sample_rate=INPUT_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
samples_per_channel=num_samples,
|
||||
)
|
||||
memoryview(frame.data).cast("B")[:] = pcm_bytes
|
||||
await session.mic_source.capture_frame(frame)
|
||||
except Exception as exc:
|
||||
print(f"Opus audio decode error ({len(message)} bytes): {exc}")
|
||||
elif isinstance(message, str):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
print(f"收到 ESP32 JSON 消息: {data}")
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "abort":
|
||||
reason = data.get("reason")
|
||||
abort_reason = reason if isinstance(reason, str) and reason else "button_abort"
|
||||
print(f"处理 ESP32 打断请求: reason={abort_reason}")
|
||||
await self._abort_tts(session, abort_reason)
|
||||
except json.JSONDecodeError:
|
||||
print(f"收到未知的字符消息: {message}")
|
||||
except ConnectionClosedError as exc:
|
||||
print(f"ESP32 异常断开: {exc}")
|
||||
except Exception as exc:
|
||||
self._log_exception("WebSocket 其他错误", exc)
|
||||
finally:
|
||||
print(f"ESP32 断开连接: device={device_id} room={session.room_name}")
|
||||
await self._close_session(session)
|
||||
self.device_sessions.pop(device_id, None)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
bridge = ESP32LiveKitBridge()
|
||||
try:
|
||||
await bridge.start()
|
||||
async with websockets.serve(bridge.handle_websocket, "0.0.0.0", WS_PORT):
|
||||
print(f"WebSocket 服务器运行在端口 {WS_PORT},等待 ESP32 连接...")
|
||||
await asyncio.Future()
|
||||
finally:
|
||||
await bridge.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except Exception as exc:
|
||||
print(f"[error] {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user