initial commit
This commit is contained in:
171
custom_agent.py
Normal file
171
custom_agent.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from livekit import rtc
|
||||||
|
from livekit.agents import (
|
||||||
|
Agent,
|
||||||
|
AgentServer,
|
||||||
|
AgentSession,
|
||||||
|
APIConnectOptions,
|
||||||
|
JobContext,
|
||||||
|
JobProcess,
|
||||||
|
LanguageCode,
|
||||||
|
MetricsCollectedEvent,
|
||||||
|
NOT_GIVEN,
|
||||||
|
NotGivenOr,
|
||||||
|
TurnHandlingOptions,
|
||||||
|
cli,
|
||||||
|
metrics,
|
||||||
|
room_io,
|
||||||
|
stt,
|
||||||
|
text_transforms,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from livekit.plugins import silero, openai
|
||||||
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||||
|
|
||||||
|
logger = logging.getLogger("custom-agent")
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class SenseVoiceSTT(stt.STT):
|
||||||
|
def __init__(self, url: str):
|
||||||
|
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
|
||||||
|
self._url = url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return "sensevoice"
|
||||||
|
|
||||||
|
async def _recognize_impl(
|
||||||
|
self,
|
||||||
|
buffer: utils.AudioBuffer,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> stt.SpeechEvent:
|
||||||
|
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
|
||||||
|
data.add_field('model_name', 'sensevoice')
|
||||||
|
|
||||||
|
lang = language if language is not NOT_GIVEN else 'auto'
|
||||||
|
data.add_field('language', lang)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(self._url, data=data, timeout=30) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise Exception(f"ASR server returned status {resp.status}")
|
||||||
|
|
||||||
|
result = await resp.json()
|
||||||
|
if not result.get("result"):
|
||||||
|
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
|
||||||
|
|
||||||
|
text = result["result"][0].get("clean_text", "")
|
||||||
|
logger.info(f"SenseVoice ASR Result: {text}")
|
||||||
|
return stt.SpeechEvent(
|
||||||
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||||
|
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SenseVoice ASR error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
class CustomAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
|
||||||
|
"Keep your responses concise and friendly."
|
||||||
|
"You are interacting with the user via a local ASR and LLM pipeline.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_enter(self) -> None:
|
||||||
|
self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||||
|
|
||||||
|
server = AgentServer()
|
||||||
|
|
||||||
|
def prewarm(proc: JobProcess) -> None:
|
||||||
|
# Load Silero VAD as requested
|
||||||
|
proc.userdata["vad"] = silero.VAD.load()
|
||||||
|
|
||||||
|
server.setup_fnc = prewarm
|
||||||
|
|
||||||
|
@server.rtc_session(agent_name="my-agent")
|
||||||
|
async def entrypoint(ctx: JobContext) -> None:
|
||||||
|
ctx.log_context_fields = {
|
||||||
|
"room": ctx.room.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configuration for custom local endpoints
|
||||||
|
# These can be set in your .env file
|
||||||
|
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
||||||
|
|
||||||
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||||
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
||||||
|
|
||||||
|
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox")
|
||||||
|
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav")
|
||||||
|
|
||||||
|
# Initialize SenseVoice STT and wrap with StreamAdapter
|
||||||
|
sensevoice_stt = SenseVoiceSTT(url=ASR_URL)
|
||||||
|
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"])
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from openai import AsyncClient as OpenAIAsyncClient
|
||||||
|
|
||||||
|
# Create a custom HTTP client that disables SSL verification
|
||||||
|
http_client = httpx.AsyncClient(verify=False)
|
||||||
|
|
||||||
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||||
|
openai_client = OpenAIAsyncClient(
|
||||||
|
api_key="sk-orez64WkG1NkfksB5j_hGA",
|
||||||
|
base_url=MINIMAX_BASE_URL,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tts_voxcpm import VoxCPMTTS
|
||||||
|
|
||||||
|
session: AgentSession = AgentSession(
|
||||||
|
# 1. Custom SenseVoice ASR (STT) with StreamAdapter
|
||||||
|
stt=stt_stream,
|
||||||
|
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
||||||
|
llm=openai.LLM(
|
||||||
|
model=MINIMAX_MODEL,
|
||||||
|
client=openai_client,
|
||||||
|
),
|
||||||
|
# 3. VoxCPM TTS - Custom implementation for blackbox API
|
||||||
|
tts=VoxCPMTTS(
|
||||||
|
url=VOXCPM_URL,
|
||||||
|
prompt_wav_path=PROMPT_WAV,
|
||||||
|
),
|
||||||
|
# 4. Silero VAD
|
||||||
|
vad=ctx.proc.userdata["vad"],
|
||||||
|
turn_handling=TurnHandlingOptions(
|
||||||
|
turn_detection=MultilingualModel(),
|
||||||
|
interruption={
|
||||||
|
"resume_false_interruption": True,
|
||||||
|
"false_interruption_timeout": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
preemptive_generation=True,
|
||||||
|
aec_warmup_duration=3.0,
|
||||||
|
tts_text_transforms=[
|
||||||
|
"filter_emoji",
|
||||||
|
"filter_markdown",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@session.on("metrics_collected")
|
||||||
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||||
|
metrics.log_metrics(ev.metrics)
|
||||||
|
|
||||||
|
await session.start(
|
||||||
|
agent=CustomAgent(),
|
||||||
|
room=ctx.room,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli.run_app(server)
|
||||||
188
test_agent.py
Normal file
188
test_agent.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
from livekit import rtc
|
||||||
|
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("test-agent")
|
||||||
|
|
||||||
|
TOKEN_URL = "http://localhost:8000/getToken"
|
||||||
|
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||||||
|
ROOM_NAME = "test-room20"
|
||||||
|
WAV_FILE = "2food.wav"
|
||||||
|
TEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
class TestState:
|
||||||
|
def __init__(self):
|
||||||
|
self.agent_connected = False
|
||||||
|
self.tts_received = False
|
||||||
|
self.tts_count = 0
|
||||||
|
|
||||||
|
test_state = TestState()
|
||||||
|
|
||||||
|
|
||||||
|
def get_token(agent_name="my-agent"):
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
TOKEN_URL,
|
||||||
|
params={
|
||||||
|
"room": ROOM_NAME,
|
||||||
|
"identity": f"test-{uuid.uuid4().hex[:6]}",
|
||||||
|
"agent_name": agent_name,
|
||||||
|
},
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["token"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 获取token失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_wav(room, wav_path):
|
||||||
|
wav_path = Path(wav_path)
|
||||||
|
if not wav_path.exists():
|
||||||
|
logger.error(f"❌ WAV文件不存在: {wav_path}")
|
||||||
|
raise FileNotFoundError(f"文件不存在: {wav_path}")
|
||||||
|
|
||||||
|
logger.info(f"📂 开始上传: {wav_path}")
|
||||||
|
|
||||||
|
with wave.open(str(wav_path), "rb") as wf:
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
num_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
|
||||||
|
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||||||
|
|
||||||
|
source = AudioSource(sample_rate, num_channels)
|
||||||
|
track = LocalAudioTrack.create_audio_track("mic", source)
|
||||||
|
|
||||||
|
await room.local_participant.publish_track(track)
|
||||||
|
logger.info("📡 已发布音轨")
|
||||||
|
|
||||||
|
frame_duration = 0.02
|
||||||
|
samples_per_frame = int(sample_rate * frame_duration)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = wf.readframes(samples_per_frame)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16)
|
||||||
|
if len(audio) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples_per_channel = len(audio) // num_channels
|
||||||
|
|
||||||
|
frame = AudioFrame(
|
||||||
|
data=data,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=num_channels,
|
||||||
|
samples_per_channel=samples_per_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
await source.capture_frame(frame)
|
||||||
|
await asyncio.sleep(frame_duration)
|
||||||
|
|
||||||
|
logger.info("✅ WAV推流完成")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_agent():
|
||||||
|
try:
|
||||||
|
logger.info("🔑 正在获取token...")
|
||||||
|
token = get_token()
|
||||||
|
logger.info("✅ Token获取成功")
|
||||||
|
|
||||||
|
room = rtc.Room()
|
||||||
|
|
||||||
|
@room.on("participant_connected")
|
||||||
|
def on_participant_connected(participant):
|
||||||
|
logger.info(f"✅ 参与者加入: {participant.identity}")
|
||||||
|
if "agent" in participant.identity.lower():
|
||||||
|
test_state.agent_connected = True
|
||||||
|
logger.info("🎉 Agent已连接!")
|
||||||
|
|
||||||
|
@room.on("participant_disconnected")
|
||||||
|
def on_participant_disconnected(participant):
|
||||||
|
logger.info(f"❌ 参与者离开: {participant.identity}")
|
||||||
|
|
||||||
|
@room.on("track_subscribed")
|
||||||
|
def on_track_subscribed(track, publication, participant):
|
||||||
|
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||||
|
test_state.tts_count += 1
|
||||||
|
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
|
||||||
|
test_state.tts_received = True
|
||||||
|
|
||||||
|
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
|
||||||
|
await room.connect(WS_URL, token)
|
||||||
|
logger.info("✅ 已连接到房间")
|
||||||
|
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
|
||||||
|
|
||||||
|
logger.info("⏳ 等待Agent连接...")
|
||||||
|
for i in range(10):
|
||||||
|
if test_state.agent_connected:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if not test_state.agent_connected:
|
||||||
|
logger.warning("⚠️ Agent未连接")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("🎙️ 正在上传测试音频...")
|
||||||
|
await publish_wav(room, WAV_FILE)
|
||||||
|
|
||||||
|
logger.info("⏳ 等待Agent响应...")
|
||||||
|
for i in range(TEST_TIMEOUT):
|
||||||
|
if test_state.tts_received:
|
||||||
|
logger.info("✅ 收到Agent TTS响应!")
|
||||||
|
break
|
||||||
|
if i % 5 == 0:
|
||||||
|
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
logger.info("\n" + "="*60)
|
||||||
|
logger.info("✅ 测试结果")
|
||||||
|
logger.info("="*60)
|
||||||
|
logger.info(f"Agent连接: {'✅' if test_state.agent_connected else '❌'}")
|
||||||
|
logger.info(f"收到TTS响应: {'✅' if test_state.tts_received else '❌'}")
|
||||||
|
logger.info(f"TTS音频次数: {test_state.tts_count} 次")
|
||||||
|
logger.info("="*60)
|
||||||
|
|
||||||
|
await room.disconnect()
|
||||||
|
logger.info("✅ 已断开连接\n")
|
||||||
|
|
||||||
|
return test_state.agent_connected and test_state.tts_received
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 测试失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
logger.info("🚀 开始测试custom_agent...\n")
|
||||||
|
success = await test_agent()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("✅ 测试成功!custom_agent 正常工作")
|
||||||
|
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到,")
|
||||||
|
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.error("❌ 测试失败")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
exit(exit_code)
|
||||||
53
test_asr.py
Normal file
53
test_asr.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import wave
|
||||||
|
from custom_agent import SenseVoiceSTT
|
||||||
|
from livekit import rtc
|
||||||
|
from livekit.agents import utils
|
||||||
|
|
||||||
|
# 设置日志级别以查看输出
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("test-asr")
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
# 替换为你本地的一个音频文件路径
|
||||||
|
audio_path = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
|
# 初始化 ASR
|
||||||
|
stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox")
|
||||||
|
|
||||||
|
print(f"Testing ASR connectivity with file: {audio_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取音频文件
|
||||||
|
with wave.open(audio_path, 'rb') as wf:
|
||||||
|
frames = wf.readframes(wf.getnframes())
|
||||||
|
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
|
||||||
|
# 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
|
||||||
|
# 所以我们需要传递一个包含 AudioFrame 的 list
|
||||||
|
|
||||||
|
# 这里我们模拟一个 Frame
|
||||||
|
frame = rtc.AudioFrame(
|
||||||
|
data=frames,
|
||||||
|
sample_rate=wf.getframerate(),
|
||||||
|
num_channels=wf.getnchannels(),
|
||||||
|
samples_per_channel=wf.getnframes()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 recognize
|
||||||
|
result = await stt.recognize(buffer=[frame])
|
||||||
|
|
||||||
|
if result.alternatives:
|
||||||
|
print(f"\n--- ASR Result ---")
|
||||||
|
print(f"Text: {result.alternatives[0].text}")
|
||||||
|
print(f"------------------\n")
|
||||||
|
else:
|
||||||
|
print("ASR returned no text.")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Audio file not found at {audio_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test())
|
||||||
130
test_livekit.py
Normal file
130
test_livekit.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
from livekit import rtc
|
||||||
|
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||||||
|
|
||||||
|
TOKEN_URL = "http://localhost:8000/getToken"
|
||||||
|
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址
|
||||||
|
|
||||||
|
ROOM_NAME = "test-room20"
|
||||||
|
import uuid
|
||||||
|
IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
||||||
|
# IDENTITY = "test-user0"
|
||||||
|
|
||||||
|
|
||||||
|
def get_token():
|
||||||
|
resp = requests.get(
|
||||||
|
TOKEN_URL,
|
||||||
|
params={
|
||||||
|
"room": ROOM_NAME,
|
||||||
|
"identity": IDENTITY,
|
||||||
|
"agent_name": "my-agent", # 关键!!!
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = resp.json()
|
||||||
|
return data["token"]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
token = get_token()
|
||||||
|
|
||||||
|
room = rtc.Room()
|
||||||
|
|
||||||
|
@room.on("participant_connected")
|
||||||
|
def on_participant_connected(participant):
|
||||||
|
print(f"✅ 有人加入房间: {participant.identity}")
|
||||||
|
|
||||||
|
@room.on("participant_disconnected")
|
||||||
|
def on_participant_disconnected(participant):
|
||||||
|
print(f"❌ 有人离开房间: {participant.identity}")
|
||||||
|
|
||||||
|
print("🔌 正在连接房间...")
|
||||||
|
await room.connect(WS_URL, token)
|
||||||
|
|
||||||
|
print("✅ 已连接房间:", ROOM_NAME)
|
||||||
|
print("当前房间成员:")
|
||||||
|
for p in room.remote_participants.values():
|
||||||
|
print(" -", p.identity)
|
||||||
|
|
||||||
|
@room.on("data_received")
|
||||||
|
def on_data_received(data, participant, kind, topic):
|
||||||
|
try:
|
||||||
|
msg = data.decode()
|
||||||
|
print(f"📩 来自 {participant.identity}: {msg}")
|
||||||
|
except:
|
||||||
|
print("📩 收到二进制数据")
|
||||||
|
|
||||||
|
@room.on("track_subscribed")
|
||||||
|
def on_track_subscribed(track, publication, participant):
|
||||||
|
print(f"🎧 订阅轨道: {participant.identity}")
|
||||||
|
|
||||||
|
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||||
|
print("👉 TTS 音频来了")
|
||||||
|
|
||||||
|
# 等一下确保连接稳定
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await room.local_participant.publish_data(
|
||||||
|
b"hello",
|
||||||
|
reliable=True,
|
||||||
|
topic="chat"
|
||||||
|
)
|
||||||
|
# 上传 wav
|
||||||
|
await publish_wav(room, "2food.wav")
|
||||||
|
|
||||||
|
await room.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_wav(room, wav_path):
|
||||||
|
print("🎵 开始上传本地 wav:", wav_path)
|
||||||
|
|
||||||
|
wf = wave.open(wav_path, "rb")
|
||||||
|
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
num_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
|
||||||
|
print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||||||
|
|
||||||
|
# 创建音频源
|
||||||
|
source = AudioSource(sample_rate, num_channels)
|
||||||
|
|
||||||
|
# 创建本地音轨
|
||||||
|
track = LocalAudioTrack.create_audio_track("mic", source)
|
||||||
|
|
||||||
|
# 发布轨道
|
||||||
|
await room.local_participant.publish_track(track)
|
||||||
|
print("📡 已发布音轨")
|
||||||
|
|
||||||
|
frame_duration = 0.02 # 20ms
|
||||||
|
samples_per_frame = int(sample_rate * frame_duration)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = wf.readframes(samples_per_frame)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 用于计算长度
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16)
|
||||||
|
|
||||||
|
if len(audio) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples_per_channel = len(audio) // num_channels
|
||||||
|
|
||||||
|
frame = AudioFrame(
|
||||||
|
data=data, # ✅ 关键:用 bytes
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=num_channels,
|
||||||
|
samples_per_channel=samples_per_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
await source.capture_frame(frame)
|
||||||
|
await asyncio.sleep(frame_duration)
|
||||||
|
print("✅ wav 推流结束")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
71
test_minimax.py
Normal file
71
test_minimax.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from livekit.agents.llm import ChatContext
|
||||||
|
from livekit.plugins import openai
|
||||||
|
|
||||||
|
# Configure logging to see what's happening
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("test-minimax")
|
||||||
|
|
||||||
|
async def test_minimax():
|
||||||
|
print("Loading .env...")
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configuration from environment or defaults from custom_agent.py
|
||||||
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||||
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI")
|
||||||
|
# Using the hardcoded key from custom_agent.py as a fallback if not in .env
|
||||||
|
API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA")
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from openai import AsyncClient as OpenAIAsyncClient
|
||||||
|
|
||||||
|
print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}")
|
||||||
|
|
||||||
|
# Create a custom HTTP client that disables SSL verification
|
||||||
|
http_client = httpx.AsyncClient(verify=False)
|
||||||
|
|
||||||
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||||
|
openai_client = OpenAIAsyncClient(
|
||||||
|
api_key=API_KEY,
|
||||||
|
base_url=MINIMAX_BASE_URL,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = openai.LLM(
|
||||||
|
model=MINIMAX_MODEL,
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Creating ChatContext...")
|
||||||
|
chat_ctx = ChatContext()
|
||||||
|
chat_ctx.add_message(
|
||||||
|
content="Hello! Can you introduce yourself? Please reply in Chinese.",
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n--- Testing Streaming Chat ---")
|
||||||
|
print(f"Request: {chat_ctx.items[-1].content}")
|
||||||
|
print("Response: ", end="", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nCalling llm.chat()...")
|
||||||
|
stream = llm.chat(chat_ctx=chat_ctx)
|
||||||
|
print("Iterating over stream...")
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.delta and chunk.delta.content:
|
||||||
|
print(chunk.delta.content, end="", flush=True)
|
||||||
|
print("\n--- Test Completed Successfully ---")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"\nTest failed with error: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Starting...")
|
||||||
|
try:
|
||||||
|
asyncio.run(asyncio.wait_for(test_minimax(), timeout=30))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("\nTest timed out after 30 seconds.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nAn error occurred: {e}")
|
||||||
50
test_voxcpm.py
Normal file
50
test_voxcpm.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from tts_voxcpm import VoxCPMTTS
|
||||||
|
from livekit.agents import tts
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
async def test_tts():
|
||||||
|
# Use the URL from the user's curl command
|
||||||
|
url = "http://10.6.80.21:5002/tts-blackbox"
|
||||||
|
|
||||||
|
# Check if we have a real wav file to test with
|
||||||
|
# In the earlier find_by_name, we found tests/change-sophie.wav
|
||||||
|
prompt_wav = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
if not os.path.exists(prompt_wav):
|
||||||
|
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
|
||||||
|
|
||||||
|
print(f"Testing VoxCPMTTS with URL: {url}")
|
||||||
|
print(f"Using prompt wav: {prompt_wav}")
|
||||||
|
|
||||||
|
vox_tts = VoxCPMTTS(
|
||||||
|
url=url,
|
||||||
|
prompt_wav_path=prompt_wav
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "你好,这是一段测试文本"
|
||||||
|
print(f"Synthesizing text: {text}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = vox_tts.synthesize(text)
|
||||||
|
audio_frame = await stream.collect()
|
||||||
|
|
||||||
|
print(f"Successfully synthesized audio!")
|
||||||
|
print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?")
|
||||||
|
# Actually AudioFrame has duration or samples
|
||||||
|
print(f"Samples: {len(audio_frame.data) // 2}")
|
||||||
|
|
||||||
|
# Save to file for manual check if possible
|
||||||
|
with open("test_output.wav", "wb") as f:
|
||||||
|
# This won't be a valid WAV yet if it's just raw PCM,
|
||||||
|
# but if collect() returns combined frames, we can use to_wav_bytes()
|
||||||
|
f.write(audio_frame.to_wav_bytes())
|
||||||
|
print("Saved output to test_output.wav")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TTS test failed: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_tts())
|
||||||
118
tts_voxcpm.py
Normal file
118
tts_voxcpm.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger("voxcpm-tts")
|
||||||
|
|
||||||
|
class VoxCPMTTS(tts.TTS):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
model_name: str = "voxcpmtts",
|
||||||
|
prompt_text: str = "澳门有乜嘢好食嘅",
|
||||||
|
prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav",
|
||||||
|
cfg_value: str = "2.0",
|
||||||
|
inference_timesteps: str = "10",
|
||||||
|
do_normalize: str = "true",
|
||||||
|
denoise: str = "true",
|
||||||
|
retry_badcase: str = "true",
|
||||||
|
retry_badcase_max_times: str = "3",
|
||||||
|
retry_badcase_ratio_threshold: str = "6.0",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
capabilities=tts.TTSCapabilities(streaming=False),
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=1,
|
||||||
|
)
|
||||||
|
self._url = url
|
||||||
|
self._opts = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"streaming": "false",
|
||||||
|
"prompt_text": prompt_text,
|
||||||
|
"cfg_value": str(cfg_value),
|
||||||
|
"inference_timesteps": str(inference_timesteps),
|
||||||
|
"do_normalize": str(do_normalize),
|
||||||
|
"denoise": str(denoise),
|
||||||
|
"retry_badcase": str(retry_badcase),
|
||||||
|
"retry_badcase_max_times": str(retry_badcase_max_times),
|
||||||
|
"retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold),
|
||||||
|
}
|
||||||
|
self._prompt_wav_path = prompt_wav_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._opts["model_name"]
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> tts.ChunkedStream:
|
||||||
|
return VoxCPMStream(
|
||||||
|
self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options
|
||||||
|
)
|
||||||
|
|
||||||
|
class VoxCPMStream(tts.ChunkedStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts: VoxCPMTTS,
|
||||||
|
text: str,
|
||||||
|
url: str,
|
||||||
|
opts: dict,
|
||||||
|
prompt_wav_path: str,
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
):
|
||||||
|
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
||||||
|
self._url = url
|
||||||
|
self._opts = opts
|
||||||
|
self._prompt_wav_path = prompt_wav_path
|
||||||
|
|
||||||
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||||
|
# Initialize emitter early to avoid "AudioEmitter isn't started" error on failure
|
||||||
|
output_emitter.initialize(
|
||||||
|
request_id="",
|
||||||
|
sample_rate=self._tts.sample_rate,
|
||||||
|
num_channels=self._tts.num_channels,
|
||||||
|
mime_type="audio/wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field("text", self.input_text)
|
||||||
|
for k, v in self._opts.items():
|
||||||
|
data.add_field(k, v)
|
||||||
|
|
||||||
|
# Open the prompt wav file if it exists
|
||||||
|
f = None
|
||||||
|
if os.path.exists(self._prompt_wav_path):
|
||||||
|
f = open(self._prompt_wav_path, "rb")
|
||||||
|
data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set a reasonable timeout for synthesis
|
||||||
|
async with session.post(
|
||||||
|
self._url, data=data, timeout=aiohttp.ClientTimeout(total=60)
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
err_text = await resp.text()
|
||||||
|
logger.error(f"VoxCPM TTS error: {resp.status} {err_text}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read the entire audio data (since streaming=false)
|
||||||
|
audio_data = await resp.read()
|
||||||
|
|
||||||
|
output_emitter.push(audio_data)
|
||||||
|
output_emitter.flush()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VoxCPM TTS request failed: {e}")
|
||||||
|
finally:
|
||||||
|
if f:
|
||||||
|
f.close()
|
||||||
Reference in New Issue
Block a user