Compare commits

..

1 Commits

Author SHA1 Message Date
f368e156f0 beaver test 2026-06-03 17:26:46 +08:00
13 changed files with 1475 additions and 1135 deletions

View File

@ -1,74 +0,0 @@
# LiveKit connection
LIVEKIT_URL=ws://localhost:7880
LIVEKIT_API_KEY=
LIVEKIT_API_SECRET=
CUSTOM_AGENT_NAME=my-agent
# ASR blackbox
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
CUSTOM_ASR_MODEL=qwen
CUSTOM_ASR_LANGUAGE=Chinese
CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# OpenAI-compatible LLM
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
# CUSTOM_LLM_MODEL=Qwen3.6-35B
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_LLM_BASE_URL=http://localhost/v1
CUSTOM_LLM_MODEL=Qwen-VL
CUSTOM_LLM_API_KEY=
CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_SAVE_MODEL_IMAGES=false
# CUSTOM_TEXT_LLM_MODEL=
# CUSTOM_VISION_LLM_MODEL=
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_VERIFY_SSL=false
# TTS blackbox
CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
CUSTOM_TTS_MODEL=voxcpmtts
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
CUSTOM_TTS_STREAMING=true
# CUSTOM_TTS_PROMPT_TEXT=澳门有乜嘢好食嘅
# VoxCPM TTS parameters
VOXCPM_CFG_VALUE=2.0
VOXCPM_INFERENCE_TIMESTEPS=10
VOXCPM_DO_NORMALIZE=true
VOXCPM_DENOISE=true
VOXCPM_RETRY_BADCASE=true
VOXCPM_RETRY_BADCASE_MAX_TIMES=3
VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD=6.0
# MeloTTS parameters
CUSTOM_TTS_SPEED=1.0
# CosyVoice parameters
CUSTOM_TTS_SPK_ID=
CUSTOM_TTS_MODE=
CUSTOM_TTS_INSTRUCT_TEXT=
# GPT-SoVITS parameters
CUSTOM_TTS_TEXT_LANG=zh
CUSTOM_TTS_PROMPT_LANG=zh
CUSTOM_TTS_TEXT_SPLIT_METHOD=cut0
CUSTOM_TTS_BATCH_SIZE=1
CUSTOM_TTS_MEDIA_TYPE=wav
CUSTOM_TTS_REF_AUDIO_PATH=
CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
CUSTOM_MEMORY_TIMEOUT=2
CUSTOM_MEMORY_MAX_CHARS=2000
CUSTOM_MEMORY_API_KEY=
CUSTOM_PREEMPTIVE_GENERATION=true

3
.gitignore vendored
View File

@ -1,3 +0,0 @@
__pycache__/
.env
model_images/

228
beaver_terminal_client.py Normal file
View File

@ -0,0 +1,228 @@
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from uuid import uuid4
import aiohttp
from dotenv import load_dotenv
logger = logging.getLogger("beaver-terminal-client")
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
DEFAULT_TERMINAL_PEER_ID = "device-001"
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
initial_counter: int = 0
instance_id: str | None = None
def __post_init__(self) -> None:
self.counter = self.initial_counter
def next_id(self) -> str:
self.counter += 1
if self.instance_id:
return f"{self.peer_id}-{self.instance_id}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
return {
"type": "connect",
"peer_id": peer_id,
"device_name": device_name,
"capabilities": ["text"],
}
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
return {
"type": "message",
"message_id": message_id,
"text": text,
}
class BeaverTerminalClient:
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
http_session: aiohttp.ClientSession | None = None,
message_ids: MessageIdGenerator | None = None,
) -> None:
self._url = url
self._peer_id = peer_id
self._device_name = device_name
self._owned_session = http_session is None
self._http_session = http_session
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._message_ids = message_ids or MessageIdGenerator(
peer_id=peer_id,
instance_id=uuid4().hex[:8],
)
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
async def send_text(self, text: str) -> str:
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
message_id = self._message_ids.next_id()
message_frame = build_message_frame(message_id=message_id, text=text)
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
return reply
continue
if (
frame_type == "message"
and frame.get("role") == "assistant"
and frame.get("message_id") == message_id
):
text = frame.get("text")
if frame.get("finish_reason") == "error":
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
return text if isinstance(text, str) else ""
if frame_type == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def ping(self) -> bool:
await self._send_json({"type": "ping"})
while True:
frame = await self._receive_json()
if frame.get("type") == "pong":
return True
if frame.get("type") == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
return self._http_session
async def _send_json(self, frame: dict[str, Any]) -> None:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
await self._ws.send_json(frame)
async def _receive_json(self) -> dict[str, Any]:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
if not isinstance(data, dict):
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
return data
def client_from_env() -> BeaverTerminalClient:
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
return BeaverTerminalClient(
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
)
async def run_console() -> None:
logging.basicConfig(level=logging.INFO)
client = client_from_env()
try:
await client.connect()
logger.info("Connected to Beaver session_id=%s", client.session_id)
while True:
text = await asyncio.to_thread(input, "> ")
text = text.strip()
if not text:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(run_console())

View File

@ -1,40 +1,28 @@
import base64
import json
import logging
import os
import re
import time
from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
ChatContext,
ChatMessage,
FlushSentinel,
JobContext,
JobProcess,
MetricsCollectedEvent,
ModelSettings,
RecordingOptions,
TurnHandlingOptions,
cli,
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
from tts import BlackboxTTS
logger = logging.getLogger("custom-agent")
@ -42,576 +30,19 @@ CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
ROOM_LOCATOR_INSTRUCTIONS = """
你是一个房间物品定位助手。
当用户询问房间内某个物品的位置时:
- 只用一句中文回答
- 描述目标物品和其他物品的相对位置关系
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
- 不要解释推理过程
如果用户的问题与房间物品定位无关,则正常回答用户问题。
""".strip()
GENERAL_INSTRUCTIONS = """
你是一个智能语音助手。
正常回答用户问题。
回答自然、简洁、准确。
""".strip()
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"surprised",
"fearful",
"calm",
"concerned",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
TTS_EMOTION_LINE_RE = re.compile(
r"^\s*(?:emotion|情绪)\s*[:=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!\s-]*",
re.IGNORECASE,
)
MAX_EMOTION_PREFIX_CHARS = 80
@dataclass
class VisionFrame:
image_data_url: str
received_at: float
mime_type: str
saved_path: str | None = None
class VisionFrameStore:
def __init__(self, *, max_age_seconds: float) -> None:
self._max_age_seconds = max_age_seconds
self._latest_frame: VisionFrame | None = None
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
self._latest_frame = VisionFrame(
image_data_url=f"data:{mime_type};base64,{image}",
received_at=time.monotonic(),
mime_type=mime_type,
saved_path=saved_path,
)
def consume_fresh(self) -> VisionFrame | None:
frame = self._latest_frame
if frame is None:
return None
age = time.monotonic() - frame.received_at
self._latest_frame = None
if age > self._max_age_seconds:
logger.info("Dropping stale vision frame: age=%.3fs", age)
return None
return frame
class CustomAgent(Agent):
def __init__(
self,
*,
memory_client: MemoryRecallClient | None = None,
vision_store: VisionFrameStore | None = None,
input_mode: str = AUTO_INPUT_MODE,
text_llm: llm.LLM | None = None,
vision_llm: llm.LLM | None = None,
model_image_save_dir: Path | None = None,
) -> None:
super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
self._memory_client = memory_client
self._vision_store = vision_store
self._input_mode = input_mode
self._text_llm = text_llm
self._vision_llm = vision_llm
self._model_image_save_dir = model_image_save_dir
self.current_emotion = DEFAULT_EMOTION
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = True
def __init__(self) -> None:
super().__init__(
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
"Keep your responses concise and friendly."
"You are interacting with the user via a local ASR and LLM pipeline.",
)
async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself")
pass
async def llm_node(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query)
vision_frame = self._consume_vision_frame()
logger.info(
"Selected agent mode: %s input_mode=%s has_image=%s",
mode,
self._input_mode,
vision_frame is not None,
)
chat_ctx = chat_ctx.copy()
update_chat_instructions(
chat_ctx,
instructions=_with_emotion_instructions(
ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
),
add_if_missing=True,
)
if mode == ROOM_LOCATOR_MODE:
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
if vision_frame is not None:
self._save_model_vision_frame(vision_frame)
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
llm_result = self._run_selected_llm(
chat_ctx,
tools,
model_settings,
has_image=vision_frame is not None,
)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
return llm_result
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None
chunk_count = 0
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = False
try:
async for chunk in llm_result:
chunk_count += 1
if first_chunk_at is None:
first_chunk_at = time.perf_counter()
logger.info(
"LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at,
)
async for output_chunk in self._observe_emotion_prefix(chunk):
yield output_chunk
finally:
finished_at = time.perf_counter()
logger.info(
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
finished_at - llm_node_started_at,
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
chunk_count,
)
return _instrumented_stream()
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
def _consume_vision_frame(self) -> VisionFrame | None:
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
return None
return self._vision_store.consume_fresh()
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
if self._model_image_save_dir is None:
return
try:
_, b64_data = vision_frame.image_data_url.split(",", 1)
image_bytes = base64.b64decode(b64_data, validate=True)
except Exception:
logger.exception("Failed to decode model vision frame for debug save")
return
extension = _image_extension_from_mime_type(vision_frame.mime_type)
timestamp_ms = int(time.time() * 1000)
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
try:
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
except Exception:
logger.exception("Failed to save model vision frame: path=%s", path)
return
logger.info(
"Saved model vision frame: path=%s bytes=%s source_path=%s",
path,
len(image_bytes),
vision_frame.saved_path,
)
def _run_selected_llm(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
*,
has_image: bool,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
selected_llm = self._vision_llm if has_image else self._text_llm
if selected_llm is None:
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
activity = self._get_activity_or_raise()
tool_choice = model_settings.tool_choice
conn_options = activity.session.conn_options.llm_conn_options
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
async with selected_llm.chat(
chat_ctx=chat_ctx,
tools=tools,
tool_choice=tool_choice,
conn_options=conn_options,
) as stream:
async for chunk in stream:
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
if isinstance(chunk, str):
self._consume_emotion_prefix(chunk)
yield chunk
return
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
self._consume_emotion_prefix(chunk.delta.content)
yield chunk
return
yield chunk
def _consume_emotion_prefix(self, content: str) -> None:
if self._emotion_prefix_done:
return
self._emotion_prefix_buffer += content
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
if match:
emotion = match.group(1).lower()
if emotion not in EMOTION_LABELS:
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
emotion = DEFAULT_EMOTION
self.current_emotion = emotion
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.info("LLM emotion selected: %s", emotion)
return
candidate = self._emotion_prefix_buffer.lstrip().lower()
might_still_be_prefix = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion=") and ">" not in candidate)
)
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.warning("LLM response did not start with an emotion prefix")
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None:
return ""
user_query = _latest_user_text(chat_ctx)
if not user_query:
return ""
started_at = time.perf_counter()
try:
recalled = await self._memory_client.recall(user_query)
elapsed = time.perf_counter() - started_at
logger.info(
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
elapsed,
len(user_query),
len(recalled),
)
return recalled
except Exception:
logger.exception(
"Unexpected memory recall failure after %.3fs",
time.perf_counter() - started_at,
)
return ""
def _select_mode(user_query: str) -> str:
normalized = _normalize_text(user_query)
if not normalized:
return GENERAL_MODE
if _is_room_locator_query(normalized):
return ROOM_LOCATOR_MODE
return GENERAL_MODE
def _with_emotion_instructions(instructions: str) -> str:
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
prefix_buffer = ""
scanning_prefix = True
async for chunk in text:
if not chunk:
continue
if scanning_prefix:
prefix_buffer += chunk
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
if not done:
continue
scanning_prefix = False
prefix_buffer = ""
if cleaned:
yield _strip_inline_tts_emotion(cleaned)
continue
cleaned = _strip_inline_tts_emotion(chunk)
if cleaned:
yield cleaned
if scanning_prefix and prefix_buffer:
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
cleaned = _strip_inline_tts_emotion(cleaned)
if cleaned:
yield cleaned
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
match = TTS_EMOTION_MARKUP_RE.match(text)
if match:
return text[match.end() :], True
match = TTS_EMOTION_LINE_RE.match(text)
if match:
return text[match.end() :], True
candidate = text.lstrip().lower()
might_still_be_emotion = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion") and ">" not in candidate)
or "emotion".startswith(candidate)
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
or "情绪".startswith(candidate)
)
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return "", False
return text, True
def _strip_inline_tts_emotion(text: str) -> str:
return TTS_EMOTION_MARKUP_RE.sub("", text)
def _is_room_locator_query(normalized_text: str) -> bool:
room_context_hints = (
"房间",
"屋里",
"屋子",
"室内",
"客厅",
"卧室",
"书房",
"厨房",
"餐厅",
"沙发",
"",
"",
"",
"",
"",
"",
"电视",
"空调",
"书架",
"",
"冰箱",
"茶几",
"电脑",
"",
"",
"相机",
"植物",
)
spatial_hints = (
"在哪里",
"在哪",
"位置",
"方位",
"旁边",
"左边",
"右边",
"前面",
"后面",
"上面",
"下面",
"附近",
"对面",
"靠近",
"挨着",
"隔着",
)
software_hints = (
"python",
"代码",
"函数",
"class",
"bug",
"日志",
"logging",
"api",
"server",
"agent",
"prompt",
"模型",
"数据库",
"git",
"uv",
"ruff",
"mypy",
)
if any(hint in normalized_text for hint in software_hints):
return False
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
if has_spatial_hint and has_room_context_hint:
return True
if has_spatial_hint and len(normalized_text) <= 12:
return True
return False
def _normalize_text(text: str) -> str:
return "".join(text.split()).lower()
def _latest_user_text(chat_ctx: ChatContext) -> str:
for item in reversed(chat_ctx.items):
if isinstance(item, ChatMessage) and item.role == "user":
return (item.text_content or "").strip()
return ""
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
chat_ctx = chat_ctx.copy()
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
user_msg.content = [memory_context]
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
mime_type=vision_frame.mime_type,
inference_detail="auto",
)
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
content = list(user_msg.content)
content.append(image_content)
user_msg.content = content
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
return chat_ctx
def _normalize_input_mode(value: str | None) -> str:
if not value:
return AUTO_INPUT_MODE
normalized = value.strip().lower().replace("-", "_")
aliases = {
"image_voice": VISION_VOICE_INPUT_MODE,
"image": VISION_VOICE_INPUT_MODE,
"vision": VISION_VOICE_INPUT_MODE,
"vision_voice": VISION_VOICE_INPUT_MODE,
"voice_image": VISION_VOICE_INPUT_MODE,
"audio": VOICE_INPUT_MODE,
"voice": VOICE_INPUT_MODE,
"auto": AUTO_INPUT_MODE,
}
mode = aliases.get(normalized)
if mode is not None:
return mode
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
return AUTO_INPUT_MODE
def _image_extension_from_mime_type(mime_type: str) -> str:
normalized = mime_type.strip().lower()
if normalized == "image/png":
return ".png"
if normalized == "image/webp":
return ".webp"
if normalized == "image/gif":
return ".gif"
return ".jpg"
def _model_image_save_dir_from_env() -> Path | None:
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
return None
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
if configured:
return Path(configured).expanduser()
return Path(__file__).with_name("model_images")
server = AgentServer()
@ -635,27 +66,19 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
if not MINIMAX_API_KEY:
raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
)
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -671,50 +94,30 @@ async def entrypoint(ctx: JobContext) -> None:
import httpx
from openai import AsyncClient as OpenAIAsyncClient
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_store = VisionFrameStore(
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key=MINIMAX_API_KEY,
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
llm=base_llm,
# 2. Minimax LLM - Using OpenAI plugin with local base_url
llm=openai.LLM(
model=MINIMAX_MODEL,
client=openai_client,
),
# 3. TTS blackbox
tts=BlackboxTTS(
url=TTS_URL,
model_name=TTS_MODEL,
params=_tts_params_from_env(TTS_MODEL),
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
@ -727,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None:
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
preemptive_generation=False,
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
@ -739,78 +142,8 @@ async def entrypoint(ctx: JobContext) -> None:
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
@session.on("conversation_item_added")
def _on_conversation_item_added(event) -> None:
item = getattr(event, "item", None)
if not isinstance(item, ChatMessage):
return
if item.role == "user" and item.metrics:
logger.info("User turn metrics: %s", item.metrics)
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
return
image = payload.get("image")
if not isinstance(image, str) or not image:
logger.warning("Received vision frame without image data")
return
mime_type = payload.get("mime_type")
if not isinstance(mime_type, str) or not mime_type:
mime_type = "image/jpeg"
saved_path = payload.get("saved_path")
vision_store.update(
image=image,
mime_type=mime_type,
saved_path=saved_path if isinstance(saved_path, str) else None,
)
logger.info(
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
mime_type,
len(image),
saved_path,
)
memory_client = (
MemoryRecallClient(
url=MEMORY_URL,
timeout=MEMORY_TIMEOUT,
max_chars=MEMORY_MAX_CHARS,
api_key=MEMORY_API_KEY,
)
if MEMORY_URL
else None
)
await session.start(
agent=CustomAgent(
memory_client=memory_client,
vision_store=vision_store,
input_mode=INPUT_MODE,
text_llm=text_llm,
vision_llm=vision_llm,
model_image_save_dir=_model_image_save_dir_from_env(),
),
agent=CustomAgent(),
room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(
@ -827,55 +160,49 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]:
model_name = model_name.lower()
if model_name == "voxcpmtts":
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
_set_if_present(
params,
"prompt_text",
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
)
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
_set_if_present(
params,
"retry_badcase_max_times",
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
)
_set_if_present(
params,
"retry_badcase_ratio_threshold",
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
params.update(
{
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
"prompt_text": os.getenv(
"CUSTOM_TTS_PROMPT_TEXT",
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
),
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
"retry_badcase_ratio_threshold": os.getenv(
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
),
}
)
elif model_name == "melotts":
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
elif model_name == "cosyvoicetts":
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
elif model_name == "sovitstts":
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
params.update(
{
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
}
)
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
return params
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
if model_name.lower() != "voxcpmtts":
return None
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
if value:
params[key] = value
@ -891,17 +218,6 @@ def _env_int(name: str, default: int) -> int:
return default
def _env_float(name: str, default: float) -> float:
value = os.getenv(name)
if not value:
return default
try:
return float(value)
except ValueError:
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
return default
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:

292
memory.py
View File

@ -1,292 +0,0 @@
from __future__ import annotations
import asyncio
import json
import logging
import re
from typing import Any
import aiohttp
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
logger = logging.getLogger("memory-recall")
_LOCATION_STOPWORDS = {
"哪里",
"在哪",
"在哪里",
"哪儿",
"位置",
"什么地方",
"帮我找",
"帮我寻找",
"找一下",
"",
"请问",
"",
"",
"",
}
class MemoryRecallClient:
def __init__(
self,
*,
url: str,
timeout: float = 5.0,
max_chars: int = 2000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
self._url = url
self._timeout = timeout
self._max_chars = max_chars
self._api_key = api_key
self._http_session = http_session
self._cached_payload: Any | None = None
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def recall(self, query: str) -> str:
query = query.strip()
if not query:
return ""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
async with self._ensure_session().get(
self._url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"Memory recall error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
try:
data = await resp.json()
except aiohttp.ContentTypeError:
data = await resp.text()
self._cached_payload = data
return self._format_memory(data, query)
except asyncio.TimeoutError:
logger.warning(
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
)
return self._format_cached_memory(query)
except aiohttp.ClientError as e:
logger.warning("Memory recall connection error: %s, using cached room graph", e)
return self._format_cached_memory(query)
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
logger.warning("Memory recall failed: %s, using cached room graph", e)
return self._format_cached_memory(query)
def _format_memory(self, data: Any, query: str) -> str:
memory = _format_room_graph_memory(data, query)
if len(memory) > self._max_chars:
memory = memory[: self._max_chars].rstrip()
return memory
def _format_cached_memory(self, query: str) -> str:
if self._cached_payload is None:
return ""
return self._format_memory(self._cached_payload, query)
def _format_room_graph_memory(payload: Any, query: str) -> str:
if not isinstance(payload, dict):
logger.warning("Unsupported room graph response: %s", payload)
return ""
objects = payload.get("objects", [])
relations = payload.get("relations", [])
summary = payload.get("summary", "")
if not objects and not relations and not summary:
return ""
query_terms = _query_terms(query)
relevant_objects, relevant_relations = _relevant_room_graph(
objects=objects,
relations=relations,
query_terms=query_terms,
)
objects_text = json.dumps(
relevant_objects or _compact_items(objects, limit=12),
ensure_ascii=False,
separators=(",", ":"),
)
relations_text = json.dumps(
relevant_relations or _compact_items(relations, limit=24),
ensure_ascii=False,
separators=(",", ":"),
)
prompt = f"""
你是一个物品定位助手。
目标物品:{query}
相关物品:{objects_text}
相关空间关系:{relations_text}
房间概览:{summary}
回答要求:
1. 只说明它和其他物品的位置关系。
2. 不要编造不存在的关系。
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip()
logger.info(
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
query_terms,
len(relevant_objects),
len(objects) if isinstance(objects, list) else 0,
len(relevant_relations),
len(relations) if isinstance(relations, list) else 0,
len(prompt),
)
return prompt
def _query_terms(query: str) -> list[str]:
normalized = re.sub(r"[\s?。!,、,.!]", "", query)
for word in _LOCATION_STOPWORDS:
normalized = normalized.replace(word, "")
terms = [normalized] if normalized else []
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
if token not in _LOCATION_STOPWORDS and token not in terms:
terms.append(token)
return terms[:4]
def _relevant_room_graph(
*,
objects: Any,
relations: Any,
query_terms: list[str],
) -> tuple[list[Any], list[Any]]:
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
return [], []
matched_ids: set[str] = set()
matched_objects: list[Any] = []
object_by_id: dict[str, Any] = {}
for obj in objects:
obj_id = _object_id(obj)
if obj_id:
object_by_id[obj_id] = obj
obj_text = _compact_text(obj)
if any(term and term in obj_text for term in query_terms):
matched_objects.append(obj)
if obj_id:
matched_ids.add(obj_id)
relevant_relations: list[Any] = []
related_ids: set[str] = set(matched_ids)
for relation in relations:
relation_text = _compact_text(relation)
relation_ids = _ids_in_value(relation)
if (
any(term and term in relation_text for term in query_terms)
or bool(matched_ids.intersection(relation_ids))
):
relevant_relations.append(relation)
related_ids.update(relation_ids)
relevant_objects = list(matched_objects)
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
for obj_id in related_ids:
obj = object_by_id.get(obj_id)
key = _object_key(obj)
if obj is not None and key not in seen_object_keys:
relevant_objects.append(obj)
seen_object_keys.add(key)
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
def _compact_items(items: Any, *, limit: int) -> list[Any]:
if not isinstance(items, list):
return []
return [_compact_item(item) for item in items[:limit]]
def _compact_item(item: Any) -> Any:
if not isinstance(item, dict):
return item
preferred_keys = (
"id",
"name",
"label",
"class",
"category",
"type",
"text",
"source",
"target",
"subject",
"object",
"relation",
"predicate",
"description",
)
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
return compact or item
def _object_id(obj: Any) -> str | None:
if not isinstance(obj, dict):
return None
for key in ("id", "object_id", "uuid", "name", "label"):
value = obj.get(key)
if isinstance(value, (str, int)):
return str(value)
return None
def _object_key(obj: Any) -> str:
return _object_id(obj) or _compact_text(obj)
def _ids_in_value(value: Any) -> set[str]:
ids: set[str] = set()
if isinstance(value, dict):
for key, item in value.items():
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
if isinstance(item, (str, int)):
ids.add(str(item))
elif isinstance(item, dict):
obj_id = _object_id(item)
if obj_id:
ids.add(obj_id)
ids.update(_ids_in_value(item))
elif isinstance(value, list):
for item in value:
ids.update(_ids_in_value(item))
return ids
def _compact_text(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))

188
test_agent.py Normal file
View File

@ -0,0 +1,188 @@
import asyncio
import requests
import logging
from pathlib import Path
import uuid
import wave
import numpy as np
from datetime import datetime
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("test-agent")
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM_NAME = "test-room20"
WAV_FILE = "2food.wav"
TEST_TIMEOUT = 30
class TestState:
def __init__(self):
self.agent_connected = False
self.tts_received = False
self.tts_count = 0
test_state = TestState()
def get_token(agent_name="my-agent"):
try:
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": f"test-{uuid.uuid4().hex[:6]}",
"agent_name": agent_name,
},
timeout=5
)
resp.raise_for_status()
return resp.json()["token"]
except Exception as e:
logger.error(f"❌ 获取token失败: {e}")
raise
async def publish_wav(room, wav_path):
wav_path = Path(wav_path)
if not wav_path.exists():
logger.error(f"❌ WAV文件不存在: {wav_path}")
raise FileNotFoundError(f"文件不存在: {wav_path}")
logger.info(f"📂 开始上传: {wav_path}")
with wave.open(str(wav_path), "rb") as wf:
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
source = AudioSource(sample_rate, num_channels)
track = LocalAudioTrack.create_audio_track("mic", source)
await room.local_participant.publish_track(track)
logger.info("📡 已发布音轨")
frame_duration = 0.02
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
logger.info("✅ WAV推流完成")
async def test_agent():
try:
logger.info("🔑 正在获取token...")
token = get_token()
logger.info("✅ Token获取成功")
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
logger.info(f"✅ 参与者加入: {participant.identity}")
if "agent" in participant.identity.lower():
test_state.agent_connected = True
logger.info("🎉 Agent已连接")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
logger.info(f"❌ 参与者离开: {participant.identity}")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
if track.kind == rtc.TrackKind.KIND_AUDIO:
test_state.tts_count += 1
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
test_state.tts_received = True
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
await room.connect(WS_URL, token)
logger.info("✅ 已连接到房间")
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
logger.info("⏳ 等待Agent连接...")
for i in range(10):
if test_state.agent_connected:
break
await asyncio.sleep(1)
if not test_state.agent_connected:
logger.warning("⚠️ Agent未连接")
return False
logger.info("🎙️ 正在上传测试音频...")
await publish_wav(room, WAV_FILE)
logger.info("⏳ 等待Agent响应...")
for i in range(TEST_TIMEOUT):
if test_state.tts_received:
logger.info("✅ 收到Agent TTS响应!")
break
if i % 5 == 0:
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
await asyncio.sleep(1)
await asyncio.sleep(2)
logger.info("\n" + "="*60)
logger.info("✅ 测试结果")
logger.info("="*60)
logger.info(f"Agent连接: {'' if test_state.agent_connected else ''}")
logger.info(f"收到TTS响应: {'' if test_state.tts_received else ''}")
logger.info(f"TTS音频次数: {test_state.tts_count}")
logger.info("="*60)
await room.disconnect()
logger.info("✅ 已断开连接\n")
return test_state.agent_connected and test_state.tts_received
except Exception as e:
logger.error(f"❌ 测试失败: {e}", exc_info=True)
return False
async def main():
logger.info("🚀 开始测试custom_agent...\n")
success = await test_agent()
if success:
logger.info("✅ 测试成功custom_agent 正常工作")
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到")
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
return 0
else:
logger.error("❌ 测试失败")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
exit(exit_code)

55
test_asr.py Normal file
View File

@ -0,0 +1,55 @@
import asyncio
import logging
import wave
from asr import BlackboxSTT
from livekit import rtc
# 设置日志级别以查看输出
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-asr")
async def test():
# 替换为你本地的一个音频文件路径
audio_path = "/home/verachen/Music/voice/2food.wav"
# 初始化 ASR
stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice")
print(f"Testing ASR connectivity with file: {audio_path}")
try:
# 读取音频文件
with wave.open(audio_path, "rb") as wf:
frames = wf.readframes(wf.getnframes())
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
# 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
# 所以我们需要传递一个包含 AudioFrame 的 list
# 这里我们模拟一个 Frame
frame = rtc.AudioFrame(
data=frames,
sample_rate=wf.getframerate(),
num_channels=wf.getnchannels(),
samples_per_channel=wf.getnframes(),
)
# 调用 recognize
result = await stt.recognize(buffer=[frame])
if result.alternatives:
print("\n--- ASR Result ---")
print(f"Text: {result.alternatives[0].text}")
print("------------------\n")
else:
print("ASR returned no text.")
except FileNotFoundError:
print(f"Error: Audio file not found at {audio_path}")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(test())

253
test_beaver_llm.py Normal file
View File

@ -0,0 +1,253 @@
import json
import aiohttp
from aiohttp import web
try:
from custom.beaver_llm import BeaverLLM, latest_user_text
except ModuleNotFoundError:
from beaver_llm import BeaverLLM, latest_user_text
from livekit.agents import ChatContext
def test_latest_user_text_uses_most_recent_user_message() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="user", content="first")
ctx.add_message(role="assistant", content="ignored")
ctx.add_message(role="user", content=["second", "line"])
assert latest_user_text(ctx) == "second\nline"
async def test_beaver_llm_can_connect_before_first_message(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello beaver")
try:
await beaver_llm.connect()
assert beaver_llm.session_id == "terminal-dev:local:livekit-room"
assert received == [
{
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
]
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[1]["type"] == "message"
assert received[1]["text"] == "hello beaver"
async def test_beaver_llm_connect_can_send_warmup_message(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-warmup",
"text": "ready",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
try:
warmup_reply = await beaver_llm.connect(warmup_text="初始化连接")
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert warmup_reply == "ready"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["text"] == "初始化连接"
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="system", content="ignored instructions")
ctx.add_message(role="user", content="hello beaver")
try:
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"

View File

@ -0,0 +1,426 @@
import asyncio
import json
import sys
from pathlib import Path
import aiohttp
import pytest
from aiohttp import web
if __name__ == "__main__":
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
raise SystemExit(pytest.main([__file__]))
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
def test_build_connect_frame_uses_stable_peer_id() -> None:
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
assert frame == {
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
}
def test_build_message_frame_uses_message_id_and_text() -> None:
frame = build_message_frame(message_id="device-001-000001", text="hello")
assert frame == {
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
}
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
assert generator.next_id() == "device-001-000008"
assert generator.next_id() == "device-001-000009"
assert generator.counter == 9
def test_message_id_generator_can_include_instance_id() -> None:
generator = MessageIdGenerator(peer_id="device-001", instance_id="abc123ef")
assert generator.next_id() == "device-001-abc123ef-000001"
assert generator.next_id() == "device-001-abc123ef-000002"
async def test_client_connects_sends_text_and_returns_assistant_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "assistant reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert client.session_id == "terminal-dev:local:device-001"
assert reply == "assistant reply"
assert received == [
{
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
},
{
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
},
]
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": False,
"duplicate": True,
"pending": False,
"reply": "cached assistant reply",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "cached assistant reply"
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json({"type": "error", "error": "text is required"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="text is required"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "failed turn",
"finish_reason": "error",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="failed turn"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "ping":
await ws.send_json({"type": "pong"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
assert await client.ping()
finally:
await client.close()
await runner.cleanup()
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
unused_tcp_port: int,
) -> None:
connect_peer_ids: list[str] = []
message_ids: list[str] = []
connection_count = 0
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
nonlocal connection_count
connection_count += 1
current_connection = connection_count
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
connect_peer_ids.append(frame["peer_id"])
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
message_ids.append(frame["message_id"])
if current_connection == 1:
await ws.close()
continue
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-2",
"text": "reply after reconnect",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
await asyncio.sleep(0.01)
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "reply after reconnect"
assert connect_peer_ids == ["device-001", "device-001"]
assert message_ids == ["device-001-000001", "device-001-000002"]

130
test_livekit.py Normal file
View File

@ -0,0 +1,130 @@
import asyncio
import requests
from livekit import rtc
import wave
import numpy as np
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址
ROOM_NAME = "test-room20"
import uuid
IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
# IDENTITY = "test-user0"
def get_token():
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": IDENTITY,
"agent_name": "my-agent", # 关键!!!
},
)
data = resp.json()
return data["token"]
async def main():
token = get_token()
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
print(f"✅ 有人加入房间: {participant.identity}")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
print(f"❌ 有人离开房间: {participant.identity}")
print("🔌 正在连接房间...")
await room.connect(WS_URL, token)
print("✅ 已连接房间:", ROOM_NAME)
print("当前房间成员:")
for p in room.remote_participants.values():
print(" -", p.identity)
@room.on("data_received")
def on_data_received(data, participant, kind, topic):
try:
msg = data.decode()
print(f"📩 来自 {participant.identity}: {msg}")
except:
print("📩 收到二进制数据")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
print(f"🎧 订阅轨道: {participant.identity}")
if track.kind == rtc.TrackKind.KIND_AUDIO:
print("👉 TTS 音频来了")
# 等一下确保连接稳定
await asyncio.sleep(1)
await room.local_participant.publish_data(
b"hello",
reliable=True,
topic="chat"
)
# 上传 wav
await publish_wav(room, "2food.wav")
await room.disconnect()
async def publish_wav(room, wav_path):
print("🎵 开始上传本地 wav:", wav_path)
wf = wave.open(wav_path, "rb")
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
# 创建音频源
source = AudioSource(sample_rate, num_channels)
# 创建本地音轨
track = LocalAudioTrack.create_audio_track("mic", source)
# 发布轨道
await room.local_participant.publish_track(track)
print("📡 已发布音轨")
frame_duration = 0.02 # 20ms
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
# 用于计算长度
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data, # ✅ 关键:用 bytes
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
print("✅ wav 推流结束")
if __name__ == "__main__":
asyncio.run(main())

71
test_minimax.py Normal file
View File

@ -0,0 +1,71 @@
import asyncio
import os
import logging
from dotenv import load_dotenv
from livekit.agents.llm import ChatContext
from livekit.plugins import openai
# Configure logging to see what's happening
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-minimax")
async def test_minimax():
print("Loading .env...")
load_dotenv()
# Configuration from environment or defaults from custom_agent.py
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI")
# Using the hardcoded key from custom_agent.py as a fallback if not in .env
API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA")
import httpx
from openai import AsyncClient as OpenAIAsyncClient
print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}")
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key=API_KEY,
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
llm = openai.LLM(
model=MINIMAX_MODEL,
client=openai_client,
)
print("Creating ChatContext...")
chat_ctx = ChatContext()
chat_ctx.add_message(
content="Hello! Can you introduce yourself? Please reply in Chinese.",
role="user",
)
print(f"\n--- Testing Streaming Chat ---")
print(f"Request: {chat_ctx.items[-1].content}")
print("Response: ", end="", flush=True)
try:
print("\nCalling llm.chat()...")
stream = llm.chat(chat_ctx=chat_ctx)
print("Iterating over stream...")
async for chunk in stream:
if chunk.delta and chunk.delta.content:
print(chunk.delta.content, end="", flush=True)
print("\n--- Test Completed Successfully ---")
except Exception as e:
logger.error(f"\nTest failed with error: {e}")
if __name__ == "__main__":
print("Starting...")
try:
asyncio.run(asyncio.wait_for(test_minimax(), timeout=30))
except asyncio.TimeoutError:
print("\nTest timed out after 30 seconds.")
except Exception as e:
print(f"\nAn error occurred: {e}")

66
test_voxcpm.py Normal file
View File

@ -0,0 +1,66 @@
import asyncio
import logging
import os
from tts import BlackboxTTS
logging.basicConfig(level=logging.INFO)
async def test_tts():
# Use the URL from the user's curl command
url = "http://10.6.80.21:5002/tts-blackbox"
# Check if we have a real wav file to test with
# In the earlier find_by_name, we found tests/change-sophie.wav
prompt_wav = "/home/verachen/Music/voice/2food.wav"
if not os.path.exists(prompt_wav):
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
print(f"Testing BlackboxTTS with URL: {url}")
print(f"Using prompt wav: {prompt_wav}")
blackbox_tts = BlackboxTTS(
url=url,
model_name="voxcpmtts",
prompt_wav_path=prompt_wav,
params={
"streaming": "false",
"prompt_text": "澳门有乜嘢好食嘅",
"cfg_value": "2.0",
"inference_timesteps": "10",
"do_normalize": "true",
"denoise": "true",
"retry_badcase": "true",
"retry_badcase_max_times": "3",
"retry_badcase_ratio_threshold": "6.0",
},
)
text = "你好,这是一段测试文本"
print(f"Synthesizing text: {text}")
try:
stream = blackbox_tts.synthesize(text)
audio_frame = await stream.collect()
print("Successfully synthesized audio!")
print(
f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?"
)
# Actually AudioFrame has duration or samples
print(f"Samples: {len(audio_frame.data) // 2}")
# Save to file for manual check if possible
with open("test_output.wav", "wb") as f:
# This won't be a valid WAV yet if it's just raw PCM,
# but if collect() returns combined frames, we can use to_wav_bytes()
f.write(audio_frame.to_wav_bytes())
print("Saved output to test_output.wav")
except Exception as e:
print(f"TTS test failed: {e}")
if __name__ == "__main__":
asyncio.run(test_tts())

24
tts.py
View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import logging
import os
import time
import wave
from collections.abc import Mapping
from io import BytesIO
@ -89,7 +88,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
started_at = time.perf_counter()
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name)
@ -133,9 +131,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False
wav_header_probe = bytearray()
first_audio_at: float | None = None
chunk_count = 0
total_bytes = 0
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate,
@ -145,16 +140,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
async for data, _ in resp.content.iter_chunks():
if data:
chunk_count += 1
total_bytes += len(data)
if first_audio_at is None:
first_audio_at = time.perf_counter()
logger.info(
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
first_audio_at - started_at,
len(self.input_text),
len(data),
)
if not logged_wav_format:
wav_header_probe.extend(data)
logged_wav_format = _log_wav_format(
@ -171,15 +156,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
logged_wav_format = True
output_emitter.push(data)
output_emitter.flush()
finished_at = time.perf_counter()
logger.info(
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
finished_at - started_at,
(first_audio_at - started_at) if first_audio_at else -1.0,
chunk_count,
total_bytes,
len(self.input_text),
)
except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e: