feat: support defferent models
This commit is contained in:
@ -5,40 +5,52 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import queue
|
import queue
|
||||||
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
import opuslib
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from livekit import rtc
|
from livekit import rtc
|
||||||
from livekit.rtc import AudioSource, AudioFrame
|
from livekit.rtc import AudioSource, AudioFrame
|
||||||
from websockets.exceptions import ConnectionClosedError
|
from websockets.exceptions import ConnectionClosedError
|
||||||
import http.server
|
|
||||||
import multipart
|
|
||||||
from urllib.parse import parse_qs
|
|
||||||
|
|
||||||
# 配置信息
|
try:
|
||||||
# TOKEN_URL = "http://10.6.80.130:8000/v1/token"
|
from livekit import api as livekit_api
|
||||||
# LIVEKIT_WS_URL = "ws://10.6.80.130:8000/"
|
except ImportError:
|
||||||
# ROOM = "vera-room"
|
livekit_api = None
|
||||||
# IDENTITY = "vera-1"
|
|
||||||
# TOKEN_URL = "https://omnichat.bwgdi.com/v1/token"
|
|
||||||
TOKEN_URL = "http://10.6.80.130:8000/getToken"
|
TOKEN_URL = "http://10.6.80.130:8000/getToken"
|
||||||
LIVEKIT_WS_URL = "wss://test-b2zm4kva.livekit.cloud"
|
LIVEKIT_WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||||||
# LIVEKIT_WS_URL = "wss://rtc.bwgdi.com/"
|
|
||||||
ROOM = "test-livekit-room2"
|
ROOM = "test-livekit-room2"
|
||||||
IDENTITY = "uv-livekit-hardcoded"
|
IDENTITY = "uv-livekit-hardcoded"
|
||||||
|
AGENT_NAME = "my-agent"
|
||||||
import uuid
|
import uuid
|
||||||
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
# IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
||||||
CONNECT_TIMEOUT_SECONDS = 10.0
|
CONNECT_TIMEOUT_SECONDS = 10.0
|
||||||
WS_PORT = 8080
|
WS_PORT = 8080
|
||||||
|
AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT = True
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
INPUT_SAMPLE_RATE = 16000
|
||||||
|
OUTPUT_SAMPLE_RATE = 24000
|
||||||
|
INPUT_FRAME_DURATION_MS = 60
|
||||||
|
INPUT_SAMPLES_PER_OPUS_FRAME = INPUT_SAMPLE_RATE * INPUT_FRAME_DURATION_MS // 1000
|
||||||
|
OUTPUT_FRAME_DURATION_MS = 20
|
||||||
|
OUTPUT_SAMPLES_PER_OPUS_FRAME = OUTPUT_SAMPLE_RATE * OUTPUT_FRAME_DURATION_MS // 1000
|
||||||
|
TTS_IDLE_TIMEOUT_SECONDS = 0.8
|
||||||
|
TTS_SILENCE_PEAK_THRESHOLD = 96
|
||||||
|
TTS_PRE_ROLL_MS = 200
|
||||||
|
SAVE_AGENT_TTS_WAV = True
|
||||||
|
AGENT_TTS_WAV_PREFIX = "agent_tts"
|
||||||
|
|
||||||
async def fetch_token() -> str:
|
async def fetch_token() -> str:
|
||||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
TOKEN_URL,
|
TOKEN_URL,
|
||||||
params={"room": ROOM, "identity": IDENTITY, "agent_name": "my-agent"},
|
params={"room": ROOM, "identity": IDENTITY, "agent_name": AGENT_NAME},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@ -57,12 +69,248 @@ class ESP32LiveKitBridge:
|
|||||||
self.room = rtc.Room()
|
self.room = rtc.Room()
|
||||||
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
|
# 创建一个音频源,用于将 ESP32 的声音推送到 LiveKit
|
||||||
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
|
# 注意:采样率需与 ESP32 发送的一致,通常是 16000 或 24000
|
||||||
self.mic_source = AudioSource(sample_rate=SAMPLE_RATE, num_channels=1)
|
self.mic_source = AudioSource(sample_rate=INPUT_SAMPLE_RATE, num_channels=1)
|
||||||
self.esp_ws = None # 保存 WebSocket 连接
|
self.esp_ws = None # 保存 WebSocket 连接
|
||||||
self.audio_queue = queue.Queue()
|
self.audio_queue = queue.Queue()
|
||||||
self.wav_writer_thread: Optional[threading.Thread] = None
|
self.wav_writer_thread: Optional[threading.Thread] = None
|
||||||
self.stop_event = threading.Event()
|
self.stop_event = threading.Event()
|
||||||
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
|
self.agent_ready = asyncio.Event() # Moved here to be accessible earlier if needed
|
||||||
|
self.protocol_version = 1
|
||||||
|
self.tts_active = False
|
||||||
|
self.forwarding_tracks: set[str] = set()
|
||||||
|
self.tts_idle_task: Optional[asyncio.Task] = None
|
||||||
|
self.tts_stream_id = 0
|
||||||
|
self.agent_dispatch_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def ensure_agent_dispatched(self) -> None:
|
||||||
|
if not AUTO_DISPATCH_AGENT_ON_ESP32_CONNECT:
|
||||||
|
return
|
||||||
|
|
||||||
|
for participant in self.room.remote_participants.values():
|
||||||
|
if AGENT_NAME in participant.identity:
|
||||||
|
print(f"Agent 已在房间中,跳过 dispatch: {participant.identity}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.agent_dispatch_task is not None and not self.agent_dispatch_task.done():
|
||||||
|
print("Agent dispatch 正在进行中,跳过重复请求")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.agent_dispatch_task = asyncio.create_task(self._dispatch_agent())
|
||||||
|
await self.agent_dispatch_task
|
||||||
|
|
||||||
|
async def _dispatch_agent(self) -> None:
|
||||||
|
print(f"准备 dispatch agent: room={ROOM}, agent={AGENT_NAME}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if await self._dispatch_agent_with_sdk():
|
||||||
|
return
|
||||||
|
|
||||||
|
if await self._dispatch_agent_with_cli():
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Agent dispatch 未执行:未找到 livekit-api 环境变量,也无法使用 lk CLI")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Agent dispatch 失败: {e}")
|
||||||
|
|
||||||
|
async def _dispatch_agent_with_sdk(self) -> bool:
|
||||||
|
if livekit_api is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
api_key = os.getenv("LIVEKIT_API_KEY")
|
||||||
|
api_secret = os.getenv("LIVEKIT_API_SECRET")
|
||||||
|
if not api_key or not api_secret:
|
||||||
|
return False
|
||||||
|
|
||||||
|
lkapi = livekit_api.LiveKitAPI(
|
||||||
|
url=LIVEKIT_WS_URL,
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
dispatch = await lkapi.agent_dispatch.create_dispatch(
|
||||||
|
livekit_api.CreateAgentDispatchRequest(
|
||||||
|
agent_name=AGENT_NAME,
|
||||||
|
room=ROOM,
|
||||||
|
metadata=json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"Agent dispatch 已创建: {dispatch}")
|
||||||
|
finally:
|
||||||
|
await lkapi.aclose()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _dispatch_agent_with_cli(self) -> bool:
|
||||||
|
lk_path = shutil.which("lk")
|
||||||
|
if lk_path is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
lk_path,
|
||||||
|
"dispatch",
|
||||||
|
"create",
|
||||||
|
"--room",
|
||||||
|
ROOM,
|
||||||
|
"--agent-name",
|
||||||
|
AGENT_NAME,
|
||||||
|
"--metadata",
|
||||||
|
json.dumps({"source": "bridge_server", "identity": IDENTITY}),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
stdout_text = stdout.decode("utf-8", errors="replace").strip()
|
||||||
|
stderr_text = stderr.decode("utf-8", errors="replace").strip()
|
||||||
|
|
||||||
|
if stdout_text:
|
||||||
|
print(f"lk dispatch 输出: {stdout_text}")
|
||||||
|
if stderr_text:
|
||||||
|
print(f"lk dispatch 错误输出: {stderr_text}")
|
||||||
|
|
||||||
|
if process.returncode != 0:
|
||||||
|
print(f"lk dispatch create 失败,退出码: {process.returncode}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Agent dispatch 已通过 lk CLI 创建: room={ROOM}, agent={AGENT_NAME}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_tts_state(self, state: str) -> None:
|
||||||
|
if self.esp_ws is None:
|
||||||
|
print(f"跳过 tts {state},ESP32 尚未连接")
|
||||||
|
return
|
||||||
|
await self.esp_ws.send(json.dumps({"type": "tts", "state": state}))
|
||||||
|
print(f"已发送 tts {state}")
|
||||||
|
|
||||||
|
async def _start_tts(self) -> None:
|
||||||
|
if self.tts_active:
|
||||||
|
print("跳过 tts start,当前已处于激活状态")
|
||||||
|
return
|
||||||
|
await self._send_tts_state("start")
|
||||||
|
if self.esp_ws is not None:
|
||||||
|
self.tts_active = True
|
||||||
|
|
||||||
|
async def _stop_tts(self) -> None:
|
||||||
|
if not self.tts_active:
|
||||||
|
print("跳过 tts stop,当前未激活")
|
||||||
|
return
|
||||||
|
await self._send_tts_state("stop")
|
||||||
|
self.tts_active = False
|
||||||
|
|
||||||
|
async def _abort_tts(self, reason: str = "client_abort") -> None:
|
||||||
|
print(f"收到打断请求,停止当前 TTS: reason={reason}")
|
||||||
|
self.tts_stream_id += 1
|
||||||
|
if self.tts_idle_task is not None:
|
||||||
|
self.tts_idle_task.cancel()
|
||||||
|
self.tts_idle_task = None
|
||||||
|
await self._stop_tts()
|
||||||
|
|
||||||
|
def _reset_tts_idle_timer(self) -> None:
|
||||||
|
if self.tts_idle_task is not None:
|
||||||
|
self.tts_idle_task.cancel()
|
||||||
|
self.tts_idle_task = asyncio.create_task(self._tts_idle_watchdog(self.tts_stream_id))
|
||||||
|
|
||||||
|
async def _tts_idle_watchdog(self, stream_id: int) -> None:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(TTS_IDLE_TIMEOUT_SECONDS)
|
||||||
|
if stream_id != self.tts_stream_id:
|
||||||
|
return
|
||||||
|
print(f"TTS 空闲超过 {TTS_IDLE_TIMEOUT_SECONDS}s,切回聆听状态")
|
||||||
|
await self._stop_tts()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _has_audible_audio(self, pcm_data: bytes) -> bool:
|
||||||
|
if len(pcm_data) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
sample_count = len(pcm_data) // 2
|
||||||
|
peak = 0
|
||||||
|
for i in range(sample_count):
|
||||||
|
sample = int.from_bytes(pcm_data[i * 2:i * 2 + 2], "little", signed=True)
|
||||||
|
if sample < 0:
|
||||||
|
sample = -sample
|
||||||
|
if sample > peak:
|
||||||
|
peak = sample
|
||||||
|
if peak >= TTS_SILENCE_PEAK_THRESHOLD:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _maybe_forward_remote_audio(
|
||||||
|
self,
|
||||||
|
track: rtc.Track,
|
||||||
|
publication: Optional[rtc.TrackPublication],
|
||||||
|
participant: rtc.RemoteParticipant,
|
||||||
|
source: str,
|
||||||
|
) -> None:
|
||||||
|
if track.kind != rtc.TrackKind.KIND_AUDIO:
|
||||||
|
return
|
||||||
|
|
||||||
|
track_sid = getattr(track, "sid", None) or getattr(publication, "sid", None) or f"audio:{participant.identity}"
|
||||||
|
if track_sid in self.forwarding_tracks:
|
||||||
|
print(f"跳过重复音频轨: {participant.identity} sid={track_sid} source={source}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.forwarding_tracks.add(track_sid)
|
||||||
|
print(f"收到音频流: {participant.identity} sid={track_sid} source={source}")
|
||||||
|
asyncio.create_task(
|
||||||
|
self.forward_audio_to_esp32(
|
||||||
|
rtc.AudioStream(track, sample_rate=OUTPUT_SAMPLE_RATE, num_channels=1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _scan_existing_remote_audio_tracks(self) -> None:
|
||||||
|
participants = list(self.room.remote_participants.values())
|
||||||
|
print(f"开始扫描远端音频轨,participants={len(participants)}")
|
||||||
|
for participant in participants:
|
||||||
|
publications = getattr(participant, "track_publications", {}) or {}
|
||||||
|
print(f"扫描远端参与者: {participant.identity}, publications={len(publications)}")
|
||||||
|
for pub_sid, publication in publications.items():
|
||||||
|
track = getattr(publication, "track", None)
|
||||||
|
kind = getattr(publication, "kind", None)
|
||||||
|
print(
|
||||||
|
f"检查 publication: participant={participant.identity} pub_sid={pub_sid} "
|
||||||
|
f"kind={kind} has_track={track is not None}"
|
||||||
|
)
|
||||||
|
if track is not None:
|
||||||
|
self._maybe_forward_remote_audio(track, publication, participant, "scan")
|
||||||
|
|
||||||
|
def _extract_opus_payload(self, message: bytes) -> bytes:
|
||||||
|
if self.protocol_version == 2:
|
||||||
|
if len(message) < 16:
|
||||||
|
raise ValueError(f"version 2 音频包过短: {len(message)} bytes")
|
||||||
|
version, msg_type, _reserved, _timestamp, payload_size = struct.unpack(
|
||||||
|
"!HHIII", message[:16]
|
||||||
|
)
|
||||||
|
if version != 2:
|
||||||
|
raise ValueError(f"version 2 包头版本异常: {version}")
|
||||||
|
if msg_type != 0:
|
||||||
|
raise ValueError(f"version 2 音频包类型异常: {msg_type}")
|
||||||
|
if len(message) < 16 + payload_size:
|
||||||
|
raise ValueError(f"version 2 音频包长度不足: {len(message)} < {16 + payload_size}")
|
||||||
|
return message[16:16 + payload_size]
|
||||||
|
|
||||||
|
if self.protocol_version == 3:
|
||||||
|
if len(message) < 4:
|
||||||
|
raise ValueError(f"version 3 音频包过短: {len(message)} bytes")
|
||||||
|
msg_type, _reserved, payload_size = struct.unpack("!BBH", message[:4])
|
||||||
|
if msg_type != 0:
|
||||||
|
raise ValueError(f"version 3 音频包类型异常: {msg_type}")
|
||||||
|
if len(message) < 4 + payload_size:
|
||||||
|
raise ValueError(f"version 3 音频包长度不足: {len(message)} < {4 + payload_size}")
|
||||||
|
return message[4:4 + payload_size]
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _wrap_opus_payload(self, payload: bytes) -> bytes:
|
||||||
|
if self.protocol_version == 2:
|
||||||
|
header = struct.pack("!HHIII", 2, 0, 0, 0, len(payload))
|
||||||
|
return header + payload
|
||||||
|
|
||||||
|
if self.protocol_version == 3:
|
||||||
|
header = struct.pack("!BBH", 0, 0, len(payload))
|
||||||
|
return header + payload
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
def _wav_writer_loop(self):
|
def _wav_writer_loop(self):
|
||||||
import wave
|
import wave
|
||||||
@ -71,7 +319,7 @@ class ESP32LiveKitBridge:
|
|||||||
with wave.open("bridge_debug.wav", "wb") as wav_file:
|
with wave.open("bridge_debug.wav", "wb") as wav_file:
|
||||||
wav_file.setnchannels(1)
|
wav_file.setnchannels(1)
|
||||||
wav_file.setsampwidth(2) # 16-bit
|
wav_file.setsampwidth(2) # 16-bit
|
||||||
wav_file.setframerate(SAMPLE_RATE)
|
wav_file.setframerate(INPUT_SAMPLE_RATE)
|
||||||
while not self.stop_event.is_set() or not self.audio_queue.empty():
|
while not self.stop_event.is_set() or not self.audio_queue.empty():
|
||||||
try:
|
try:
|
||||||
# 使用 timeout 避免永久阻塞,以便检查 stop_event
|
# 使用 timeout 避免永久阻塞,以便检查 stop_event
|
||||||
@ -93,13 +341,23 @@ class ESP32LiveKitBridge:
|
|||||||
def on_connected():
|
def on_connected():
|
||||||
print("✅ 成功连接到 LiveKit 房间")
|
print("✅ 成功连接到 LiveKit 房间")
|
||||||
if self.room.remote_participants:
|
if self.room.remote_participants:
|
||||||
|
self._scan_existing_remote_audio_tracks()
|
||||||
self.agent_ready.set()
|
self.agent_ready.set()
|
||||||
|
|
||||||
@self.room.on("participant_connected")
|
@self.room.on("participant_connected")
|
||||||
def on_participant_connected(p: rtc.RemoteParticipant):
|
def on_participant_connected(p: rtc.RemoteParticipant):
|
||||||
print(f"👋 Agent ({p.identity}) 已加入房间")
|
print(f"👋 Agent ({p.identity}) 已加入房间")
|
||||||
|
self._scan_existing_remote_audio_tracks()
|
||||||
self.agent_ready.set()
|
self.agent_ready.set()
|
||||||
|
|
||||||
|
@self.room.on("participant_disconnected")
|
||||||
|
def on_participant_disconnected(p: rtc.RemoteParticipant):
|
||||||
|
print(f"👋 远端参与者离开房间: {p.identity}")
|
||||||
|
self.forwarding_tracks = {
|
||||||
|
track_sid for track_sid in self.forwarding_tracks
|
||||||
|
if not track_sid.endswith(f":{p.identity}")
|
||||||
|
}
|
||||||
|
|
||||||
@self.room.on("data_received")
|
@self.room.on("data_received")
|
||||||
def on_data_received(data_packet: rtc.DataPacket):
|
def on_data_received(data_packet: rtc.DataPacket):
|
||||||
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
identity = data_packet.participant.identity if data_packet.participant else "未知"
|
||||||
@ -145,6 +403,7 @@ class ESP32LiveKitBridge:
|
|||||||
print(f"已连接到 LiveKit 房间: {self.room.name}")
|
print(f"已连接到 LiveKit 房间: {self.room.name}")
|
||||||
print(f"[livekit] local_identity={self.room.local_participant.identity}")
|
print(f"[livekit] local_identity={self.room.local_participant.identity}")
|
||||||
print(f"[livekit] local_sid={self.room.local_participant.sid}")
|
print(f"[livekit] local_sid={self.room.local_participant.sid}")
|
||||||
|
print(f"[livekit] remote_participants={list(self.room.remote_participants.keys())}")
|
||||||
|
|
||||||
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
|
# 2. 发布麦克风轨道 (ESP32 -> LiveKit)
|
||||||
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
|
track = rtc.LocalAudioTrack.create_audio_track("esp32-mic", self.mic_source)
|
||||||
@ -154,11 +413,10 @@ class ESP32LiveKitBridge:
|
|||||||
# 3. 监听房间内的音频 (LiveKit -> ESP32)
|
# 3. 监听房间内的音频 (LiveKit -> ESP32)
|
||||||
@self.room.on("track_subscribed")
|
@self.room.on("track_subscribed")
|
||||||
def on_track_subscribed(track, publication, participant):
|
def on_track_subscribed(track, publication, participant):
|
||||||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
self._maybe_forward_remote_audio(track, publication, participant, "event")
|
||||||
print(f"收到音频流: {participant.identity}")
|
|
||||||
asyncio.create_task(self.forward_audio_to_esp32(rtc.AudioStream(track, sample_rate=SAMPLE_RATE, num_channels=1)))
|
|
||||||
|
|
||||||
print("等待 agent 加入...")
|
print("等待 agent 加入...")
|
||||||
|
self._scan_existing_remote_audio_tracks()
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
|
await asyncio.wait_for(self.agent_ready.wait(), timeout=10)
|
||||||
print("✅ agent 已就绪")
|
print("✅ agent 已就绪")
|
||||||
@ -172,53 +430,121 @@ class ESP32LiveKitBridge:
|
|||||||
await self.room.disconnect()
|
await self.room.disconnect()
|
||||||
|
|
||||||
async def forward_audio_to_esp32(self, audio_stream):
|
async def forward_audio_to_esp32(self, audio_stream):
|
||||||
"""从 LiveKit 接收音频,通过 WebSocket 发回给 ESP32"""
|
"""从 LiveKit 接收音频并发给 ESP32"""
|
||||||
import opuslib
|
|
||||||
import json
|
encoder = opuslib.Encoder(OUTPUT_SAMPLE_RATE, 1, 'voip')
|
||||||
# 创建下行 Opus 编码器
|
pending_pcm = bytearray()
|
||||||
encoder = opuslib.Encoder(SAMPLE_RATE, 1, 'voip')
|
pre_roll_pcm = bytearray()
|
||||||
|
pre_roll_max_bytes = OUTPUT_SAMPLE_RATE * TTS_PRE_ROLL_MS // 1000 * 2
|
||||||
# 1. 告知 ESP32 开始说话,切换 UI 到“说话中”并准备解码
|
tts_wav_file = None
|
||||||
if self.esp_ws:
|
tts_wav_path = None
|
||||||
await self.esp_ws.send(json.dumps({"type": "tts", "state": "start"}))
|
stream_id = self.tts_stream_id + 1
|
||||||
|
self.tts_stream_id = stream_id
|
||||||
|
|
||||||
|
if SAVE_AGENT_TTS_WAV:
|
||||||
|
tts_wav_path = f"{AGENT_TTS_WAV_PREFIX}_{int(time.time())}.wav"
|
||||||
|
tts_wav_file = wave.open(tts_wav_path, "wb")
|
||||||
|
tts_wav_file.setnchannels(1)
|
||||||
|
tts_wav_file.setsampwidth(2)
|
||||||
|
tts_wav_file.setframerate(OUTPUT_SAMPLE_RATE)
|
||||||
|
print(f"开始保存 Agent TTS 音频: {tts_wav_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in audio_stream:
|
async for event in audio_stream:
|
||||||
if self.esp_ws:
|
if stream_id != self.tts_stream_id:
|
||||||
|
print("检测到更新的 TTS 流,停止旧流转发")
|
||||||
|
break
|
||||||
|
|
||||||
|
frame = event.frame
|
||||||
|
pcm_data = frame.data.tobytes()
|
||||||
|
if tts_wav_file is not None:
|
||||||
|
tts_wav_file.writeframes(pcm_data)
|
||||||
|
|
||||||
|
has_audible_audio = self._has_audible_audio(pcm_data)
|
||||||
|
current_frame_buffered = False
|
||||||
|
if not self.tts_active:
|
||||||
|
pre_roll_pcm.extend(pcm_data)
|
||||||
|
current_frame_buffered = True
|
||||||
|
if len(pre_roll_pcm) > pre_roll_max_bytes:
|
||||||
|
del pre_roll_pcm[:len(pre_roll_pcm) - pre_roll_max_bytes]
|
||||||
|
|
||||||
|
if not has_audible_audio:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._start_tts()
|
||||||
|
if not self.tts_active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("检测到可听 TTS 音频,切换到 Speaking 并开始转发")
|
||||||
|
pending_pcm.extend(pre_roll_pcm)
|
||||||
|
pre_roll_pcm.clear()
|
||||||
|
|
||||||
|
if has_audible_audio:
|
||||||
|
self._reset_tts_idle_timer()
|
||||||
|
|
||||||
|
# LiveKit 下行帧长度不一定等于 ESP32 协议里声明的帧长。
|
||||||
|
# 这里聚合成固定 20ms PCM,再编码为 Opus 发给设备,减少卡顿。
|
||||||
|
if not current_frame_buffered:
|
||||||
|
pending_pcm.extend(pcm_data)
|
||||||
|
frame_bytes = OUTPUT_SAMPLES_PER_OPUS_FRAME * 2 # mono 16-bit PCM
|
||||||
|
|
||||||
|
while len(pending_pcm) >= frame_bytes and self.esp_ws and stream_id == self.tts_stream_id:
|
||||||
try:
|
try:
|
||||||
# AudioStream 迭代产生的是 AudioFrameEvent,需要从中提取 frame
|
opus_packet = encoder.encode(
|
||||||
frame = event.frame
|
bytes(pending_pcm[:frame_bytes]),
|
||||||
# 将 PCM 编码为 Opus 才能发给 ESP32
|
OUTPUT_SAMPLES_PER_OPUS_FRAME,
|
||||||
pcm_data = frame.data.tobytes()
|
)
|
||||||
|
del pending_pcm[:frame_bytes]
|
||||||
# 使用当前帧的实际采样数进行编码
|
# print(f"发送 TTS 音频包: opus={len(opus_packet)} bytes, protocol={self.protocol_version}")
|
||||||
opus_packet = encoder.encode(pcm_data, frame.samples_per_channel)
|
await self.esp_ws.send(self._wrap_opus_payload(opus_packet))
|
||||||
await self.esp_ws.send(opus_packet)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"发送回 ESP32 失败: {e}")
|
print(f"发送回 ESP32 失败: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"音频流处理错误: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 2. 音频流结束,告知 ESP32 停止说话,切换回聆听或闲置状态
|
print("🎧 TTS 音频结束")
|
||||||
if self.esp_ws:
|
if tts_wav_file is not None:
|
||||||
await self.esp_ws.send(json.dumps({"type": "tts", "state": "stop"}))
|
tts_wav_file.close()
|
||||||
|
print(f"Agent TTS 音频已保存: {tts_wav_path}")
|
||||||
|
if stream_id == self.tts_stream_id and self.tts_idle_task is not None:
|
||||||
|
self.tts_idle_task.cancel()
|
||||||
|
self.tts_idle_task = None
|
||||||
|
if stream_id == self.tts_stream_id:
|
||||||
|
await self._stop_tts()
|
||||||
|
|
||||||
async def handle_websocket(self, websocket):
|
async def handle_websocket(self, websocket):
|
||||||
"""处理来自 ESP32 的 WebSocket 连接"""
|
"""处理来自 ESP32 的 WebSocket 连接"""
|
||||||
self.esp_ws = websocket
|
self.esp_ws = websocket
|
||||||
|
header_version = websocket.request.headers.get("Protocol-Version")
|
||||||
|
try:
|
||||||
|
self.protocol_version = int(header_version) if header_version else 1
|
||||||
|
except ValueError:
|
||||||
|
self.protocol_version = 1
|
||||||
print("ESP32 已连接")
|
print("ESP32 已连接")
|
||||||
|
print(f"ESP32 协议版本: {self.protocol_version}")
|
||||||
|
self.tts_stream_id += 1
|
||||||
|
self.tts_active = False
|
||||||
|
if self.tts_idle_task is not None:
|
||||||
|
self.tts_idle_task.cancel()
|
||||||
|
self.tts_idle_task = None
|
||||||
opus_decoder = None
|
opus_decoder = None
|
||||||
try:
|
try:
|
||||||
|
await self.ensure_agent_dispatched()
|
||||||
|
|
||||||
# 发送 hello 告诉 ESP32 握手成功
|
# 发送 hello 告诉 ESP32 握手成功
|
||||||
hello_msg = {
|
hello_msg = {
|
||||||
"type": "hello",
|
"type": "hello",
|
||||||
"transport": "websocket",
|
"transport": "websocket",
|
||||||
"audio_params": {
|
"audio_params": {
|
||||||
"format": "opus", # 明确要求 ESP32 发送 Opus
|
"format": "opus", # 明确要求 ESP32 发送 Opus
|
||||||
"sample_rate": SAMPLE_RATE,
|
"sample_rate": OUTPUT_SAMPLE_RATE,
|
||||||
"channels": 1,
|
"channels": 1,
|
||||||
"frame_duration": 60
|
"frame_duration": OUTPUT_FRAME_DURATION_MS
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
import json
|
|
||||||
await websocket.send(json.dumps(hello_msg))
|
await websocket.send(json.dumps(hello_msg))
|
||||||
|
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
@ -229,18 +555,14 @@ class ESP32LiveKitBridge:
|
|||||||
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
print(f"收到过短的字节消息 ({len(message)} bytes),跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# ESP32 默认使用 websocket_protocol version=1 (见 websocket_protocol.cc)
|
audio_data = self._extract_opus_payload(message)
|
||||||
# 这个版本下,没有 4 字节的 header,接收到的就是原生的 Opus 数据帧。
|
# print(f"收到音频包长度: {len(message)}, Opus 负载长度: {len(audio_data)}")
|
||||||
# 直接丢给 opuslib 解码即可。
|
|
||||||
audio_data = message
|
|
||||||
print(f"收到音频包长度: {len(message)}")
|
|
||||||
if audio_data:
|
if audio_data:
|
||||||
try:
|
try:
|
||||||
# Create Opus decoder if not exists
|
# Create Opus decoder if not exists
|
||||||
if opus_decoder is None:
|
if opus_decoder is None:
|
||||||
import opuslib
|
print(f"初始化 Opus 解码器: {INPUT_SAMPLE_RATE}Hz, mono")
|
||||||
print(f"初始化 Opus 解码器: {SAMPLE_RATE}Hz, mono")
|
opus_decoder = opuslib.Decoder(INPUT_SAMPLE_RATE, 1)
|
||||||
opus_decoder = opuslib.Decoder(SAMPLE_RATE, 1)
|
|
||||||
|
|
||||||
# 启动音频保存线程
|
# 启动音频保存线程
|
||||||
self.stop_event.clear()
|
self.stop_event.clear()
|
||||||
@ -249,8 +571,8 @@ class ESP32LiveKitBridge:
|
|||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
# Decode Opus packet.
|
# Decode Opus packet.
|
||||||
# Frame size for 60ms is SAMPLE_RATE * 0.06
|
# ESP32 -> bridge 保持按设备上行配置 60ms 解码。
|
||||||
frame_size = int(SAMPLE_RATE * 0.06)
|
frame_size = INPUT_SAMPLES_PER_OPUS_FRAME
|
||||||
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
|
pcm_bytes = opus_decoder.decode(audio_data, frame_size)
|
||||||
|
|
||||||
# 将音频数据放入队列由后台线程保存
|
# 将音频数据放入队列由后台线程保存
|
||||||
@ -260,20 +582,26 @@ class ESP32LiveKitBridge:
|
|||||||
if num_samples > 0:
|
if num_samples > 0:
|
||||||
try:
|
try:
|
||||||
# Use the more robust frame creation logic from test_client_wav.py
|
# Use the more robust frame creation logic from test_client_wav.py
|
||||||
frame = AudioFrame(pcm_bytes, SAMPLE_RATE, 1, num_samples)
|
frame = AudioFrame(pcm_bytes, INPUT_SAMPLE_RATE, 1, num_samples)
|
||||||
await self.mic_source.capture_frame(frame)
|
await self.mic_source.capture_frame(frame)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Fallback for different SDK versions
|
# Fallback for different SDK versions
|
||||||
frame = AudioFrame.create(sample_rate=SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
|
frame = AudioFrame.create(sample_rate=INPUT_SAMPLE_RATE, num_channels=1, samples_per_channel=num_samples)
|
||||||
memoryview(frame.data).cast('B')[:] = pcm_bytes
|
memoryview(frame.data).cast('B')[:] = pcm_bytes
|
||||||
await self.mic_source.capture_frame(frame)
|
await self.mic_source.capture_frame(frame)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Opus audio decode error ({len(message)} bytes): {e}")
|
print(f"Opus audio decode error ({len(message)} bytes): {e}")
|
||||||
elif isinstance(message, str):
|
elif isinstance(message, str):
|
||||||
import json
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
print(f"收到 ESP32 JSON 消息: {data}")
|
print(f"收到 ESP32 JSON 消息: {data}")
|
||||||
|
msg_type = data.get("type")
|
||||||
|
if msg_type == "abort":
|
||||||
|
reason = data.get("reason")
|
||||||
|
if reason is None:
|
||||||
|
await self._abort_tts("button_abort")
|
||||||
|
else:
|
||||||
|
print(f"忽略非按键打断: reason={reason}")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"收到未知的字符消息: {message}")
|
print(f"收到未知的字符消息: {message}")
|
||||||
except ConnectionClosedError as e:
|
except ConnectionClosedError as e:
|
||||||
@ -283,6 +611,11 @@ class ESP32LiveKitBridge:
|
|||||||
finally:
|
finally:
|
||||||
print("ESP32 断开连接")
|
print("ESP32 断开连接")
|
||||||
self.esp_ws = None
|
self.esp_ws = None
|
||||||
|
self.tts_stream_id += 1
|
||||||
|
self.tts_active = False
|
||||||
|
if self.tts_idle_task is not None:
|
||||||
|
self.tts_idle_task.cancel()
|
||||||
|
self.tts_idle_task = None
|
||||||
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
|
if hasattr(self, "wav_writer_thread") and self.wav_writer_thread:
|
||||||
self.stop_event.set()
|
self.stop_event.set()
|
||||||
# 我们不一定需要 join,因为是 daemon=True
|
# 我们不一定需要 join,因为是 daemon=True
|
||||||
|
|||||||
Reference in New Issue
Block a user