Compare commits
1 Commits
7efd9eba98
...
test
| Author | SHA1 | Date | |
|---|---|---|---|
| f368e156f0 |
74
.env.example
74
.env.example
@ -1,74 +0,0 @@
|
|||||||
# LiveKit connection
|
|
||||||
LIVEKIT_URL=ws://localhost:7880
|
|
||||||
LIVEKIT_API_KEY=
|
|
||||||
LIVEKIT_API_SECRET=
|
|
||||||
CUSTOM_AGENT_NAME=my-agent
|
|
||||||
|
|
||||||
# ASR blackbox
|
|
||||||
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
|
|
||||||
CUSTOM_ASR_MODEL=qwen
|
|
||||||
CUSTOM_ASR_LANGUAGE=Chinese
|
|
||||||
CUSTOM_ASR_OUTPUT_LANGUAGE=zh
|
|
||||||
CUSTOM_ASR_HOTWORDS=
|
|
||||||
CUSTOM_ASR_ITN=
|
|
||||||
CUSTOM_ASR_CHUNK_MODE=
|
|
||||||
|
|
||||||
# OpenAI-compatible LLM
|
|
||||||
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
|
|
||||||
# CUSTOM_LLM_MODEL=Qwen3.6-35B
|
|
||||||
# CUSTOM_LLM_API_KEY=
|
|
||||||
# CUSTOM_LLM_VERIFY_SSL=false
|
|
||||||
|
|
||||||
CUSTOM_LLM_BASE_URL=http://localhost/v1
|
|
||||||
CUSTOM_LLM_MODEL=Qwen-VL
|
|
||||||
CUSTOM_LLM_API_KEY=
|
|
||||||
CUSTOM_LLM_VERIFY_SSL=false
|
|
||||||
CUSTOM_SAVE_MODEL_IMAGES=false
|
|
||||||
|
|
||||||
# CUSTOM_TEXT_LLM_MODEL=
|
|
||||||
# CUSTOM_VISION_LLM_MODEL=
|
|
||||||
|
|
||||||
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
|
|
||||||
# CUSTOM_LLM_MODEL=deepseek-v4-flash
|
|
||||||
# CUSTOM_LLM_API_KEY=
|
|
||||||
# CUSTOM_LLM_VERIFY_SSL=false
|
|
||||||
|
|
||||||
|
|
||||||
# TTS blackbox
|
|
||||||
CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
|
|
||||||
CUSTOM_TTS_MODEL=voxcpmtts
|
|
||||||
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
|
|
||||||
CUSTOM_TTS_STREAMING=true
|
|
||||||
# CUSTOM_TTS_PROMPT_TEXT=澳门有乜嘢好食嘅
|
|
||||||
|
|
||||||
# VoxCPM TTS parameters
|
|
||||||
VOXCPM_CFG_VALUE=2.0
|
|
||||||
VOXCPM_INFERENCE_TIMESTEPS=10
|
|
||||||
VOXCPM_DO_NORMALIZE=true
|
|
||||||
VOXCPM_DENOISE=true
|
|
||||||
VOXCPM_RETRY_BADCASE=true
|
|
||||||
VOXCPM_RETRY_BADCASE_MAX_TIMES=3
|
|
||||||
VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD=6.0
|
|
||||||
|
|
||||||
# MeloTTS parameters
|
|
||||||
CUSTOM_TTS_SPEED=1.0
|
|
||||||
|
|
||||||
# CosyVoice parameters
|
|
||||||
CUSTOM_TTS_SPK_ID=
|
|
||||||
CUSTOM_TTS_MODE=
|
|
||||||
CUSTOM_TTS_INSTRUCT_TEXT=
|
|
||||||
|
|
||||||
# GPT-SoVITS parameters
|
|
||||||
CUSTOM_TTS_TEXT_LANG=zh
|
|
||||||
CUSTOM_TTS_PROMPT_LANG=zh
|
|
||||||
CUSTOM_TTS_TEXT_SPLIT_METHOD=cut0
|
|
||||||
CUSTOM_TTS_BATCH_SIZE=1
|
|
||||||
CUSTOM_TTS_MEDIA_TYPE=wav
|
|
||||||
CUSTOM_TTS_REF_AUDIO_PATH=
|
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
|
|
||||||
CUSTOM_MEMORY_TIMEOUT=2
|
|
||||||
CUSTOM_MEMORY_MAX_CHARS=2000
|
|
||||||
CUSTOM_MEMORY_API_KEY=
|
|
||||||
CUSTOM_PREEMPTIVE_GENERATION=true
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
__pycache__/
|
|
||||||
.env
|
|
||||||
model_images/
|
|
||||||
228
beaver_terminal_client.py
Normal file
228
beaver_terminal_client.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
logger = logging.getLogger("beaver-terminal-client")
|
||||||
|
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
|
||||||
|
DEFAULT_TERMINAL_PEER_ID = "device-001"
|
||||||
|
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
|
||||||
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageIdGenerator:
|
||||||
|
peer_id: str
|
||||||
|
initial_counter: int = 0
|
||||||
|
instance_id: str | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.counter = self.initial_counter
|
||||||
|
|
||||||
|
def next_id(self) -> str:
|
||||||
|
self.counter += 1
|
||||||
|
if self.instance_id:
|
||||||
|
return f"{self.peer_id}-{self.instance_id}-{self.counter:06d}"
|
||||||
|
return f"{self.peer_id}-{self.counter:06d}"
|
||||||
|
|
||||||
|
|
||||||
|
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": peer_id,
|
||||||
|
"device_name": device_name,
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "message",
|
||||||
|
"message_id": message_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
peer_id: str,
|
||||||
|
device_name: str,
|
||||||
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
|
message_ids: MessageIdGenerator | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._url = url
|
||||||
|
self._peer_id = peer_id
|
||||||
|
self._device_name = device_name
|
||||||
|
self._owned_session = http_session is None
|
||||||
|
self._http_session = http_session
|
||||||
|
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||||
|
self._message_ids = message_ids or MessageIdGenerator(
|
||||||
|
peer_id=peer_id,
|
||||||
|
instance_id=uuid4().hex[:8],
|
||||||
|
)
|
||||||
|
self.session_id: str | None = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
session = self._ensure_http_session()
|
||||||
|
self._ws = await session.ws_connect(self._url)
|
||||||
|
await self._send_json(
|
||||||
|
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||||
|
)
|
||||||
|
frame = await self._receive_json()
|
||||||
|
if frame.get("type") != "connected":
|
||||||
|
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||||
|
session_id = frame.get("session_id")
|
||||||
|
self.session_id = session_id if isinstance(session_id, str) else None
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> str:
|
||||||
|
for attempt in range(2):
|
||||||
|
if not self._websocket_is_open():
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
message_id = self._message_ids.next_id()
|
||||||
|
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_json(message_frame)
|
||||||
|
return await self._wait_for_reply(message_id)
|
||||||
|
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||||
|
if attempt == 1:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
"Beaver websocket closed before assistant reply"
|
||||||
|
) from exc
|
||||||
|
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
raise BeaverTerminalError("unreachable Beaver send state")
|
||||||
|
|
||||||
|
async def _wait_for_reply(self, message_id: str) -> str:
|
||||||
|
while True:
|
||||||
|
frame = await self._receive_json()
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||||
|
reply = frame.get("reply")
|
||||||
|
if isinstance(reply, str):
|
||||||
|
return reply
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
frame_type == "message"
|
||||||
|
and frame.get("role") == "assistant"
|
||||||
|
and frame.get("message_id") == message_id
|
||||||
|
):
|
||||||
|
text = frame.get("text")
|
||||||
|
if frame.get("finish_reason") == "error":
|
||||||
|
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
|
||||||
|
return text if isinstance(text, str) else ""
|
||||||
|
|
||||||
|
if frame_type == "error":
|
||||||
|
error = frame.get("error")
|
||||||
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
await self._send_json({"type": "ping"})
|
||||||
|
while True:
|
||||||
|
frame = await self._receive_json()
|
||||||
|
if frame.get("type") == "pong":
|
||||||
|
return True
|
||||||
|
if frame.get("type") == "error":
|
||||||
|
error = frame.get("error")
|
||||||
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
if self._owned_session and self._http_session is not None:
|
||||||
|
await self._http_session.close()
|
||||||
|
self._http_session = None
|
||||||
|
|
||||||
|
async def _close_websocket(self) -> None:
|
||||||
|
if self._ws is not None:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
def _websocket_is_open(self) -> bool:
|
||||||
|
return self._ws is not None and not self._ws.closed
|
||||||
|
|
||||||
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = aiohttp.ClientSession()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def _send_json(self, frame: dict[str, Any]) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||||
|
await self._ws.send_json(frame)
|
||||||
|
|
||||||
|
async def _receive_json(self) -> dict[str, Any]:
|
||||||
|
if self._ws is None:
|
||||||
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||||
|
|
||||||
|
message = await self._ws.receive()
|
||||||
|
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||||
|
if message.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||||
|
)
|
||||||
|
if message.type != aiohttp.WSMsgType.TEXT:
|
||||||
|
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||||
|
data = message.json()
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def client_from_env() -> BeaverTerminalClient:
|
||||||
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
|
return BeaverTerminalClient(
|
||||||
|
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
|
||||||
|
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
|
||||||
|
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_console() -> None:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
client = client_from_env()
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
logger.info("Connected to Beaver session_id=%s", client.session_id)
|
||||||
|
while True:
|
||||||
|
text = await asyncio.to_thread(input, "> ")
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
if text in {"quit", "exit"}:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
reply = await client.send_text(text)
|
||||||
|
except BeaverTerminalError as exc:
|
||||||
|
logger.error("Beaver turn failed: %s", exc)
|
||||||
|
continue
|
||||||
|
print(reply)
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(run_console())
|
||||||
796
custom_agent.py
796
custom_agent.py
@ -1,40 +1,28 @@
|
|||||||
import base64
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncIterable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from memory import MemoryRecallClient
|
|
||||||
from tts import BlackboxTTS
|
|
||||||
|
|
||||||
from asr import BlackboxSTT
|
from asr import BlackboxSTT
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
AgentServer,
|
AgentServer,
|
||||||
AgentSession,
|
AgentSession,
|
||||||
ChatContext,
|
|
||||||
ChatMessage,
|
|
||||||
FlushSentinel,
|
|
||||||
JobContext,
|
JobContext,
|
||||||
JobProcess,
|
JobProcess,
|
||||||
MetricsCollectedEvent,
|
MetricsCollectedEvent,
|
||||||
ModelSettings,
|
|
||||||
RecordingOptions,
|
RecordingOptions,
|
||||||
TurnHandlingOptions,
|
TurnHandlingOptions,
|
||||||
cli,
|
cli,
|
||||||
llm,
|
|
||||||
metrics,
|
metrics,
|
||||||
room_io,
|
room_io,
|
||||||
stt,
|
stt,
|
||||||
)
|
)
|
||||||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
|
||||||
from livekit.plugins import openai, silero
|
from livekit.plugins import openai, silero
|
||||||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||||
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
logger = logging.getLogger("custom-agent")
|
logger = logging.getLogger("custom-agent")
|
||||||
|
|
||||||
@ -42,576 +30,19 @@ CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
|||||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||||
|
|
||||||
ROOM_LOCATOR_INSTRUCTIONS = """
|
|
||||||
你是一个房间物品定位助手。
|
|
||||||
当用户询问房间内某个物品的位置时:
|
|
||||||
- 只用一句中文回答
|
|
||||||
- 描述目标物品和其他物品的相对位置关系
|
|
||||||
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
|
|
||||||
- 不要解释推理过程
|
|
||||||
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
GENERAL_INSTRUCTIONS = """
|
|
||||||
你是一个智能语音助手。
|
|
||||||
正常回答用户问题。
|
|
||||||
回答自然、简洁、准确。
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
EMOTION_INSTRUCTIONS = """
|
|
||||||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
|
||||||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
|
||||||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
ROOM_LOCATOR_MODE = "room_locator"
|
|
||||||
GENERAL_MODE = "general"
|
|
||||||
VOICE_INPUT_MODE = "voice"
|
|
||||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
|
||||||
AUTO_INPUT_MODE = "auto"
|
|
||||||
VISION_FRAME_TOPIC = "vision.frame"
|
|
||||||
|
|
||||||
DEFAULT_EMOTION = "neutral"
|
|
||||||
EMOTION_LABELS = {
|
|
||||||
"neutral",
|
|
||||||
"happy",
|
|
||||||
"sad",
|
|
||||||
"angry",
|
|
||||||
"surprised",
|
|
||||||
"fearful",
|
|
||||||
"calm",
|
|
||||||
"concerned",
|
|
||||||
}
|
|
||||||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
|
||||||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
|
||||||
TTS_EMOTION_LINE_RE = re.compile(
|
|
||||||
r"^\s*(?:emotion|情绪)\s*[::=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!!\s-]*",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
MAX_EMOTION_PREFIX_CHARS = 80
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VisionFrame:
|
|
||||||
image_data_url: str
|
|
||||||
received_at: float
|
|
||||||
mime_type: str
|
|
||||||
saved_path: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class VisionFrameStore:
|
|
||||||
def __init__(self, *, max_age_seconds: float) -> None:
|
|
||||||
self._max_age_seconds = max_age_seconds
|
|
||||||
self._latest_frame: VisionFrame | None = None
|
|
||||||
|
|
||||||
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
|
|
||||||
self._latest_frame = VisionFrame(
|
|
||||||
image_data_url=f"data:{mime_type};base64,{image}",
|
|
||||||
received_at=time.monotonic(),
|
|
||||||
mime_type=mime_type,
|
|
||||||
saved_path=saved_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
def consume_fresh(self) -> VisionFrame | None:
|
|
||||||
frame = self._latest_frame
|
|
||||||
if frame is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
age = time.monotonic() - frame.received_at
|
|
||||||
self._latest_frame = None
|
|
||||||
if age > self._max_age_seconds:
|
|
||||||
logger.info("Dropping stale vision frame: age=%.3fs", age)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAgent(Agent):
|
class CustomAgent(Agent):
|
||||||
def __init__(
|
def __init__(self) -> None:
|
||||||
self,
|
super().__init__(
|
||||||
*,
|
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
|
||||||
memory_client: MemoryRecallClient | None = None,
|
"Keep your responses concise and friendly."
|
||||||
vision_store: VisionFrameStore | None = None,
|
"You are interacting with the user via a local ASR and LLM pipeline.",
|
||||||
input_mode: str = AUTO_INPUT_MODE,
|
)
|
||||||
text_llm: llm.LLM | None = None,
|
|
||||||
vision_llm: llm.LLM | None = None,
|
|
||||||
model_image_save_dir: Path | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
|
|
||||||
self._memory_client = memory_client
|
|
||||||
self._vision_store = vision_store
|
|
||||||
self._input_mode = input_mode
|
|
||||||
self._text_llm = text_llm
|
|
||||||
self._vision_llm = vision_llm
|
|
||||||
self._model_image_save_dir = model_image_save_dir
|
|
||||||
self.current_emotion = DEFAULT_EMOTION
|
|
||||||
self._emotion_prefix_buffer = ""
|
|
||||||
self._emotion_prefix_done = True
|
|
||||||
|
|
||||||
async def on_enter(self) -> None:
|
async def on_enter(self) -> None:
|
||||||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def llm_node(
|
|
||||||
self,
|
|
||||||
chat_ctx: ChatContext,
|
|
||||||
tools: list[llm.Tool],
|
|
||||||
model_settings: ModelSettings,
|
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
||||||
llm_node_started_at = time.perf_counter()
|
|
||||||
|
|
||||||
user_query = _latest_user_text(chat_ctx)
|
|
||||||
mode = _select_mode(user_query)
|
|
||||||
vision_frame = self._consume_vision_frame()
|
|
||||||
logger.info(
|
|
||||||
"Selected agent mode: %s input_mode=%s has_image=%s",
|
|
||||||
mode,
|
|
||||||
self._input_mode,
|
|
||||||
vision_frame is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_ctx = chat_ctx.copy()
|
|
||||||
update_chat_instructions(
|
|
||||||
chat_ctx,
|
|
||||||
instructions=_with_emotion_instructions(
|
|
||||||
ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
|
|
||||||
),
|
|
||||||
add_if_missing=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == ROOM_LOCATOR_MODE:
|
|
||||||
memory_context = await self._recall_room_memory(chat_ctx)
|
|
||||||
if memory_context:
|
|
||||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
|
||||||
|
|
||||||
if vision_frame is not None:
|
|
||||||
self._save_model_vision_frame(vision_frame)
|
|
||||||
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
|
|
||||||
|
|
||||||
llm_result = self._run_selected_llm(
|
|
||||||
chat_ctx,
|
|
||||||
tools,
|
|
||||||
model_settings,
|
|
||||||
has_image=vision_frame is not None,
|
|
||||||
)
|
|
||||||
if not hasattr(llm_result, "__aiter__"):
|
|
||||||
elapsed = time.perf_counter() - llm_node_started_at
|
|
||||||
logger.info("LLM node completed without streaming in %.3fs", elapsed)
|
|
||||||
return llm_result
|
|
||||||
|
|
||||||
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
||||||
first_chunk_at: float | None = None
|
|
||||||
chunk_count = 0
|
|
||||||
self._emotion_prefix_buffer = ""
|
|
||||||
self._emotion_prefix_done = False
|
|
||||||
try:
|
|
||||||
async for chunk in llm_result:
|
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_at is None:
|
|
||||||
first_chunk_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"LLM first chunk after %.3fs",
|
|
||||||
first_chunk_at - llm_node_started_at,
|
|
||||||
)
|
|
||||||
async for output_chunk in self._observe_emotion_prefix(chunk):
|
|
||||||
yield output_chunk
|
|
||||||
finally:
|
|
||||||
finished_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
|
|
||||||
finished_at - llm_node_started_at,
|
|
||||||
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
|
|
||||||
chunk_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _instrumented_stream()
|
|
||||||
|
|
||||||
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
|
|
||||||
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
|
|
||||||
|
|
||||||
def _consume_vision_frame(self) -> VisionFrame | None:
|
|
||||||
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
|
|
||||||
return None
|
|
||||||
return self._vision_store.consume_fresh()
|
|
||||||
|
|
||||||
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
|
|
||||||
if self._model_image_save_dir is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
_, b64_data = vision_frame.image_data_url.split(",", 1)
|
|
||||||
image_bytes = base64.b64decode(b64_data, validate=True)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to decode model vision frame for debug save")
|
|
||||||
return
|
|
||||||
|
|
||||||
extension = _image_extension_from_mime_type(vision_frame.mime_type)
|
|
||||||
timestamp_ms = int(time.time() * 1000)
|
|
||||||
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
path.write_bytes(image_bytes)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to save model vision frame: path=%s", path)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Saved model vision frame: path=%s bytes=%s source_path=%s",
|
|
||||||
path,
|
|
||||||
len(image_bytes),
|
|
||||||
vision_frame.saved_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_selected_llm(
|
|
||||||
self,
|
|
||||||
chat_ctx: ChatContext,
|
|
||||||
tools: list[llm.Tool],
|
|
||||||
model_settings: ModelSettings,
|
|
||||||
*,
|
|
||||||
has_image: bool,
|
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
||||||
selected_llm = self._vision_llm if has_image else self._text_llm
|
|
||||||
if selected_llm is None:
|
|
||||||
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
|
||||||
|
|
||||||
activity = self._get_activity_or_raise()
|
|
||||||
tool_choice = model_settings.tool_choice
|
|
||||||
conn_options = activity.session.conn_options.llm_conn_options
|
|
||||||
|
|
||||||
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
||||||
async with selected_llm.chat(
|
|
||||||
chat_ctx=chat_ctx,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
conn_options=conn_options,
|
|
||||||
) as stream:
|
|
||||||
async for chunk in stream:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return _stream()
|
|
||||||
async def _observe_emotion_prefix(
|
|
||||||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
||||||
if isinstance(chunk, str):
|
|
||||||
self._consume_emotion_prefix(chunk)
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
|
|
||||||
self._consume_emotion_prefix(chunk.delta.content)
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def _consume_emotion_prefix(self, content: str) -> None:
|
|
||||||
if self._emotion_prefix_done:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._emotion_prefix_buffer += content
|
|
||||||
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
|
|
||||||
if match:
|
|
||||||
emotion = match.group(1).lower()
|
|
||||||
if emotion not in EMOTION_LABELS:
|
|
||||||
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
|
|
||||||
emotion = DEFAULT_EMOTION
|
|
||||||
|
|
||||||
self.current_emotion = emotion
|
|
||||||
self._emotion_prefix_done = True
|
|
||||||
self._emotion_prefix_buffer = ""
|
|
||||||
logger.info("LLM emotion selected: %s", emotion)
|
|
||||||
return
|
|
||||||
|
|
||||||
candidate = self._emotion_prefix_buffer.lstrip().lower()
|
|
||||||
might_still_be_prefix = (
|
|
||||||
not candidate
|
|
||||||
or "<emotion=".startswith(candidate)
|
|
||||||
or (candidate.startswith("<emotion=") and ">" not in candidate)
|
|
||||||
)
|
|
||||||
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._emotion_prefix_done = True
|
|
||||||
self._emotion_prefix_buffer = ""
|
|
||||||
logger.warning("LLM response did not start with an emotion prefix")
|
|
||||||
|
|
||||||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
|
||||||
if self._memory_client is None:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
user_query = _latest_user_text(chat_ctx)
|
|
||||||
if not user_query:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
started_at = time.perf_counter()
|
|
||||||
try:
|
|
||||||
recalled = await self._memory_client.recall(user_query)
|
|
||||||
elapsed = time.perf_counter() - started_at
|
|
||||||
logger.info(
|
|
||||||
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
|
|
||||||
elapsed,
|
|
||||||
len(user_query),
|
|
||||||
len(recalled),
|
|
||||||
)
|
|
||||||
return recalled
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Unexpected memory recall failure after %.3fs",
|
|
||||||
time.perf_counter() - started_at,
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _select_mode(user_query: str) -> str:
|
|
||||||
normalized = _normalize_text(user_query)
|
|
||||||
if not normalized:
|
|
||||||
return GENERAL_MODE
|
|
||||||
|
|
||||||
if _is_room_locator_query(normalized):
|
|
||||||
return ROOM_LOCATOR_MODE
|
|
||||||
|
|
||||||
return GENERAL_MODE
|
|
||||||
|
|
||||||
|
|
||||||
def _with_emotion_instructions(instructions: str) -> str:
|
|
||||||
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
|
|
||||||
|
|
||||||
|
|
||||||
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
|
|
||||||
prefix_buffer = ""
|
|
||||||
scanning_prefix = True
|
|
||||||
|
|
||||||
async for chunk in text:
|
|
||||||
if not chunk:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if scanning_prefix:
|
|
||||||
prefix_buffer += chunk
|
|
||||||
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
|
|
||||||
if not done:
|
|
||||||
continue
|
|
||||||
|
|
||||||
scanning_prefix = False
|
|
||||||
prefix_buffer = ""
|
|
||||||
if cleaned:
|
|
||||||
yield _strip_inline_tts_emotion(cleaned)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cleaned = _strip_inline_tts_emotion(chunk)
|
|
||||||
if cleaned:
|
|
||||||
yield cleaned
|
|
||||||
|
|
||||||
if scanning_prefix and prefix_buffer:
|
|
||||||
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
|
|
||||||
cleaned = _strip_inline_tts_emotion(cleaned)
|
|
||||||
if cleaned:
|
|
||||||
yield cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
|
|
||||||
match = TTS_EMOTION_MARKUP_RE.match(text)
|
|
||||||
if match:
|
|
||||||
return text[match.end() :], True
|
|
||||||
|
|
||||||
match = TTS_EMOTION_LINE_RE.match(text)
|
|
||||||
if match:
|
|
||||||
return text[match.end() :], True
|
|
||||||
|
|
||||||
candidate = text.lstrip().lower()
|
|
||||||
might_still_be_emotion = (
|
|
||||||
not candidate
|
|
||||||
or "<emotion=".startswith(candidate)
|
|
||||||
or (candidate.startswith("<emotion") and ">" not in candidate)
|
|
||||||
or "emotion".startswith(candidate)
|
|
||||||
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
|
|
||||||
or "情绪".startswith(candidate)
|
|
||||||
)
|
|
||||||
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
|
||||||
return "", False
|
|
||||||
|
|
||||||
return text, True
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_inline_tts_emotion(text: str) -> str:
|
|
||||||
return TTS_EMOTION_MARKUP_RE.sub("", text)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_room_locator_query(normalized_text: str) -> bool:
|
|
||||||
room_context_hints = (
|
|
||||||
"房间",
|
|
||||||
"屋里",
|
|
||||||
"屋子",
|
|
||||||
"室内",
|
|
||||||
"客厅",
|
|
||||||
"卧室",
|
|
||||||
"书房",
|
|
||||||
"厨房",
|
|
||||||
"餐厅",
|
|
||||||
"沙发",
|
|
||||||
"桌",
|
|
||||||
"椅",
|
|
||||||
"床",
|
|
||||||
"门",
|
|
||||||
"窗",
|
|
||||||
"柜",
|
|
||||||
"电视",
|
|
||||||
"空调",
|
|
||||||
"书架",
|
|
||||||
"灯",
|
|
||||||
"冰箱",
|
|
||||||
"茶几",
|
|
||||||
"电脑",
|
|
||||||
"包",
|
|
||||||
"瓶",
|
|
||||||
"相机",
|
|
||||||
"植物",
|
|
||||||
)
|
|
||||||
spatial_hints = (
|
|
||||||
"在哪里",
|
|
||||||
"在哪",
|
|
||||||
"位置",
|
|
||||||
"方位",
|
|
||||||
"旁边",
|
|
||||||
"左边",
|
|
||||||
"右边",
|
|
||||||
"前面",
|
|
||||||
"后面",
|
|
||||||
"上面",
|
|
||||||
"下面",
|
|
||||||
"附近",
|
|
||||||
"对面",
|
|
||||||
"靠近",
|
|
||||||
"挨着",
|
|
||||||
"隔着",
|
|
||||||
)
|
|
||||||
software_hints = (
|
|
||||||
"python",
|
|
||||||
"代码",
|
|
||||||
"函数",
|
|
||||||
"class",
|
|
||||||
"bug",
|
|
||||||
"日志",
|
|
||||||
"logging",
|
|
||||||
"api",
|
|
||||||
"server",
|
|
||||||
"agent",
|
|
||||||
"prompt",
|
|
||||||
"模型",
|
|
||||||
"数据库",
|
|
||||||
"git",
|
|
||||||
"uv",
|
|
||||||
"ruff",
|
|
||||||
"mypy",
|
|
||||||
)
|
|
||||||
|
|
||||||
if any(hint in normalized_text for hint in software_hints):
|
|
||||||
return False
|
|
||||||
|
|
||||||
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
|
|
||||||
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
|
|
||||||
|
|
||||||
if has_spatial_hint and has_room_context_hint:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if has_spatial_hint and len(normalized_text) <= 12:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_text(text: str) -> str:
|
|
||||||
return "".join(text.split()).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
|
||||||
for item in reversed(chat_ctx.items):
|
|
||||||
if isinstance(item, ChatMessage) and item.role == "user":
|
|
||||||
return (item.text_content or "").strip()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
|
|
||||||
chat_ctx = chat_ctx.copy()
|
|
||||||
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
|
||||||
item = chat_ctx.items[index]
|
|
||||||
if isinstance(item, ChatMessage) and item.role == "user":
|
|
||||||
user_msg = item.model_copy(deep=True)
|
|
||||||
user_msg.content = [memory_context]
|
|
||||||
chat_ctx.items[index] = user_msg
|
|
||||||
return chat_ctx
|
|
||||||
|
|
||||||
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
|
|
||||||
return chat_ctx
|
|
||||||
|
|
||||||
|
|
||||||
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
|
|
||||||
chat_ctx = chat_ctx.copy()
|
|
||||||
image_content = llm.ImageContent(
|
|
||||||
image=vision_frame.image_data_url,
|
|
||||||
mime_type=vision_frame.mime_type,
|
|
||||||
inference_detail="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
|
||||||
item = chat_ctx.items[index]
|
|
||||||
if isinstance(item, ChatMessage) and item.role == "user":
|
|
||||||
user_msg = item.model_copy(deep=True)
|
|
||||||
content = list(user_msg.content)
|
|
||||||
content.append(image_content)
|
|
||||||
user_msg.content = content
|
|
||||||
chat_ctx.items[index] = user_msg
|
|
||||||
return chat_ctx
|
|
||||||
|
|
||||||
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
|
|
||||||
return chat_ctx
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_input_mode(value: str | None) -> str:
|
|
||||||
if not value:
|
|
||||||
return AUTO_INPUT_MODE
|
|
||||||
|
|
||||||
normalized = value.strip().lower().replace("-", "_")
|
|
||||||
aliases = {
|
|
||||||
"image_voice": VISION_VOICE_INPUT_MODE,
|
|
||||||
"image": VISION_VOICE_INPUT_MODE,
|
|
||||||
"vision": VISION_VOICE_INPUT_MODE,
|
|
||||||
"vision_voice": VISION_VOICE_INPUT_MODE,
|
|
||||||
"voice_image": VISION_VOICE_INPUT_MODE,
|
|
||||||
"audio": VOICE_INPUT_MODE,
|
|
||||||
"voice": VOICE_INPUT_MODE,
|
|
||||||
"auto": AUTO_INPUT_MODE,
|
|
||||||
}
|
|
||||||
mode = aliases.get(normalized)
|
|
||||||
if mode is not None:
|
|
||||||
return mode
|
|
||||||
|
|
||||||
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
|
|
||||||
return AUTO_INPUT_MODE
|
|
||||||
|
|
||||||
|
|
||||||
def _image_extension_from_mime_type(mime_type: str) -> str:
|
|
||||||
normalized = mime_type.strip().lower()
|
|
||||||
if normalized == "image/png":
|
|
||||||
return ".png"
|
|
||||||
if normalized == "image/webp":
|
|
||||||
return ".webp"
|
|
||||||
if normalized == "image/gif":
|
|
||||||
return ".gif"
|
|
||||||
return ".jpg"
|
|
||||||
|
|
||||||
|
|
||||||
def _model_image_save_dir_from_env() -> Path | None:
|
|
||||||
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
|
|
||||||
return None
|
|
||||||
|
|
||||||
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
|
|
||||||
if configured:
|
|
||||||
return Path(configured).expanduser()
|
|
||||||
|
|
||||||
return Path(__file__).with_name("model_images")
|
|
||||||
|
|
||||||
|
|
||||||
server = AgentServer()
|
server = AgentServer()
|
||||||
|
|
||||||
|
|
||||||
@ -635,27 +66,19 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||||
|
|
||||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
||||||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
if not MINIMAX_API_KEY:
|
||||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
|
||||||
if not LLM_API_KEY:
|
|
||||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
|
||||||
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
|
|
||||||
|
|
||||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||||
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
||||||
)
|
)
|
||||||
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||||
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
|
||||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
|
||||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
|
||||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
|
||||||
|
|
||||||
blackbox_stt = BlackboxSTT(
|
blackbox_stt = BlackboxSTT(
|
||||||
url=ASR_URL,
|
url=ASR_URL,
|
||||||
@ -671,50 +94,30 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncClient as OpenAIAsyncClient
|
from openai import AsyncClient as OpenAIAsyncClient
|
||||||
|
|
||||||
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
|
# Create a custom HTTP client that disables SSL verification
|
||||||
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
|
http_client = httpx.AsyncClient(verify=False)
|
||||||
|
|
||||||
if LLM_BASE_URL:
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||||
openai_client = OpenAIAsyncClient(
|
openai_client = OpenAIAsyncClient(
|
||||||
api_key=LLM_API_KEY,
|
api_key=MINIMAX_API_KEY,
|
||||||
base_url=LLM_BASE_URL,
|
base_url=MINIMAX_BASE_URL,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
openai_client = OpenAIAsyncClient(
|
|
||||||
api_key=LLM_API_KEY,
|
|
||||||
http_client=http_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
base_llm = openai.LLM(
|
|
||||||
model=LLM_MODEL,
|
|
||||||
client=openai_client,
|
|
||||||
)
|
|
||||||
text_llm = (
|
|
||||||
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
|
|
||||||
if TEXT_LLM_MODEL != LLM_MODEL
|
|
||||||
else base_llm
|
|
||||||
)
|
|
||||||
vision_llm = (
|
|
||||||
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
|
|
||||||
if VISION_LLM_MODEL != LLM_MODEL
|
|
||||||
else base_llm
|
|
||||||
)
|
|
||||||
vision_store = VisionFrameStore(
|
|
||||||
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
session: AgentSession = AgentSession(
|
session: AgentSession = AgentSession(
|
||||||
# 1. Custom ASR blackbox with StreamAdapter
|
# 1. Custom ASR blackbox with StreamAdapter
|
||||||
stt=stt_stream,
|
stt=stt_stream,
|
||||||
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
|
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
||||||
llm=base_llm,
|
llm=openai.LLM(
|
||||||
|
model=MINIMAX_MODEL,
|
||||||
|
client=openai_client,
|
||||||
|
),
|
||||||
# 3. TTS blackbox
|
# 3. TTS blackbox
|
||||||
tts=BlackboxTTS(
|
tts=BlackboxTTS(
|
||||||
url=TTS_URL,
|
url=TTS_URL,
|
||||||
model_name=TTS_MODEL,
|
model_name=TTS_MODEL,
|
||||||
params=_tts_params_from_env(TTS_MODEL),
|
params=_tts_params_from_env(TTS_MODEL),
|
||||||
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
|
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
|
||||||
sample_rate=TTS_SAMPLE_RATE,
|
sample_rate=TTS_SAMPLE_RATE,
|
||||||
num_channels=TTS_NUM_CHANNELS,
|
num_channels=TTS_NUM_CHANNELS,
|
||||||
),
|
),
|
||||||
@ -727,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
"false_interruption_timeout": 1.0,
|
"false_interruption_timeout": 1.0,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
|
preemptive_generation=False,
|
||||||
aec_warmup_duration=3.0,
|
aec_warmup_duration=3.0,
|
||||||
tts_text_transforms=[
|
tts_text_transforms=[
|
||||||
"filter_emoji",
|
"filter_emoji",
|
||||||
@ -739,78 +142,8 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||||
metrics.log_metrics(ev.metrics)
|
metrics.log_metrics(ev.metrics)
|
||||||
|
|
||||||
@session.on("conversation_item_added")
|
|
||||||
def _on_conversation_item_added(event) -> None:
|
|
||||||
item = getattr(event, "item", None)
|
|
||||||
if not isinstance(item, ChatMessage):
|
|
||||||
return
|
|
||||||
|
|
||||||
if item.role == "user" and item.metrics:
|
|
||||||
logger.info("User turn metrics: %s", item.metrics)
|
|
||||||
elif item.role == "assistant" and item.metrics:
|
|
||||||
logger.info("Assistant turn metrics: %s", item.metrics)
|
|
||||||
|
|
||||||
@ctx.room.on("data_received")
|
|
||||||
def _on_data_received(data_packet) -> None:
|
|
||||||
packet_topic = getattr(data_packet, "topic", None)
|
|
||||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
|
||||||
return
|
|
||||||
|
|
||||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
|
||||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = json.loads(data_packet.data.decode("utf-8"))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to decode vision frame payload")
|
|
||||||
return
|
|
||||||
|
|
||||||
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
|
||||||
return
|
|
||||||
|
|
||||||
image = payload.get("image")
|
|
||||||
if not isinstance(image, str) or not image:
|
|
||||||
logger.warning("Received vision frame without image data")
|
|
||||||
return
|
|
||||||
|
|
||||||
mime_type = payload.get("mime_type")
|
|
||||||
if not isinstance(mime_type, str) or not mime_type:
|
|
||||||
mime_type = "image/jpeg"
|
|
||||||
|
|
||||||
saved_path = payload.get("saved_path")
|
|
||||||
vision_store.update(
|
|
||||||
image=image,
|
|
||||||
mime_type=mime_type,
|
|
||||||
saved_path=saved_path if isinstance(saved_path, str) else None,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
|
|
||||||
mime_type,
|
|
||||||
len(image),
|
|
||||||
saved_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_client = (
|
|
||||||
MemoryRecallClient(
|
|
||||||
url=MEMORY_URL,
|
|
||||||
timeout=MEMORY_TIMEOUT,
|
|
||||||
max_chars=MEMORY_MAX_CHARS,
|
|
||||||
api_key=MEMORY_API_KEY,
|
|
||||||
)
|
|
||||||
if MEMORY_URL
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.start(
|
await session.start(
|
||||||
agent=CustomAgent(
|
agent=CustomAgent(),
|
||||||
memory_client=memory_client,
|
|
||||||
vision_store=vision_store,
|
|
||||||
input_mode=INPUT_MODE,
|
|
||||||
text_llm=text_llm,
|
|
||||||
vision_llm=vision_llm,
|
|
||||||
model_image_save_dir=_model_image_save_dir_from_env(),
|
|
||||||
),
|
|
||||||
room=ctx.room,
|
room=ctx.room,
|
||||||
room_options=room_io.RoomOptions(
|
room_options=room_io.RoomOptions(
|
||||||
audio_output=room_io.AudioOutputOptions(
|
audio_output=room_io.AudioOutputOptions(
|
||||||
@ -827,55 +160,49 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
|||||||
model_name = model_name.lower()
|
model_name = model_name.lower()
|
||||||
|
|
||||||
if model_name == "voxcpmtts":
|
if model_name == "voxcpmtts":
|
||||||
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
|
params.update(
|
||||||
_set_if_present(
|
{
|
||||||
params,
|
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||||
"prompt_text",
|
"prompt_text": os.getenv(
|
||||||
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
|
"CUSTOM_TTS_PROMPT_TEXT",
|
||||||
)
|
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
|
||||||
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
|
),
|
||||||
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
|
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
|
||||||
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
|
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
|
||||||
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
|
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
|
||||||
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
|
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
|
||||||
_set_if_present(
|
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
|
||||||
params,
|
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
|
||||||
"retry_badcase_max_times",
|
"retry_badcase_ratio_threshold": os.getenv(
|
||||||
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
|
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
|
||||||
)
|
),
|
||||||
_set_if_present(
|
}
|
||||||
params,
|
|
||||||
"retry_badcase_ratio_threshold",
|
|
||||||
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
|
|
||||||
)
|
)
|
||||||
elif model_name == "melotts":
|
elif model_name == "melotts":
|
||||||
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
|
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
|
||||||
elif model_name == "cosyvoicetts":
|
elif model_name == "cosyvoicetts":
|
||||||
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
||||||
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
||||||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||||
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
||||||
elif model_name == "sovitstts":
|
elif model_name == "sovitstts":
|
||||||
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
|
params.update(
|
||||||
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
|
{
|
||||||
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
|
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
|
||||||
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
|
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
|
||||||
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
|
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
|
||||||
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
|
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
|
||||||
|
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
|
||||||
|
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||||
|
}
|
||||||
|
)
|
||||||
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
||||||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
|
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
|
||||||
if model_name.lower() != "voxcpmtts":
|
|
||||||
return None
|
|
||||||
|
|
||||||
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
|
|
||||||
|
|
||||||
|
|
||||||
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
|
|
||||||
if value:
|
if value:
|
||||||
params[key] = value
|
params[key] = value
|
||||||
|
|
||||||
@ -891,17 +218,6 @@ def _env_int(name: str, default: int) -> int:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def _env_float(name: str, default: float) -> float:
|
|
||||||
value = os.getenv(name)
|
|
||||||
if not value:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _env_bool(name: str, default: bool) -> bool:
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
value = os.getenv(name)
|
value = os.getenv(name)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
292
memory.py
292
memory.py
@ -1,292 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
|
|
||||||
|
|
||||||
logger = logging.getLogger("memory-recall")
|
|
||||||
|
|
||||||
_LOCATION_STOPWORDS = {
|
|
||||||
"哪里",
|
|
||||||
"在哪",
|
|
||||||
"在哪里",
|
|
||||||
"哪儿",
|
|
||||||
"位置",
|
|
||||||
"什么地方",
|
|
||||||
"帮我找",
|
|
||||||
"帮我寻找",
|
|
||||||
"找一下",
|
|
||||||
"找",
|
|
||||||
"请问",
|
|
||||||
"请",
|
|
||||||
"吗",
|
|
||||||
"呢",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRecallClient:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
url: str,
|
|
||||||
timeout: float = 5.0,
|
|
||||||
max_chars: int = 2000,
|
|
||||||
api_key: str | None = None,
|
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
|
||||||
) -> None:
|
|
||||||
self._url = url
|
|
||||||
self._timeout = timeout
|
|
||||||
self._max_chars = max_chars
|
|
||||||
self._api_key = api_key
|
|
||||||
self._http_session = http_session
|
|
||||||
self._cached_payload: Any | None = None
|
|
||||||
|
|
||||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._http_session is None:
|
|
||||||
self._http_session = utils.http_context.http_session()
|
|
||||||
return self._http_session
|
|
||||||
|
|
||||||
async def recall(self, query: str) -> str:
|
|
||||||
query = query.strip()
|
|
||||||
if not query:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if self._api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self._ensure_session().get(
|
|
||||||
self._url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
|
||||||
) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
error_text = await resp.text()
|
|
||||||
raise APIStatusError(
|
|
||||||
message=f"Memory recall error: {error_text}",
|
|
||||||
status_code=resp.status,
|
|
||||||
request_id=None,
|
|
||||||
body=error_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await resp.json()
|
|
||||||
except aiohttp.ContentTypeError:
|
|
||||||
data = await resp.text()
|
|
||||||
|
|
||||||
self._cached_payload = data
|
|
||||||
return self._format_memory(data, query)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
|
|
||||||
)
|
|
||||||
return self._format_cached_memory(query)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.warning("Memory recall connection error: %s, using cached room graph", e)
|
|
||||||
return self._format_cached_memory(query)
|
|
||||||
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
|
|
||||||
logger.warning("Memory recall failed: %s, using cached room graph", e)
|
|
||||||
return self._format_cached_memory(query)
|
|
||||||
|
|
||||||
def _format_memory(self, data: Any, query: str) -> str:
|
|
||||||
memory = _format_room_graph_memory(data, query)
|
|
||||||
if len(memory) > self._max_chars:
|
|
||||||
memory = memory[: self._max_chars].rstrip()
|
|
||||||
return memory
|
|
||||||
|
|
||||||
def _format_cached_memory(self, query: str) -> str:
|
|
||||||
if self._cached_payload is None:
|
|
||||||
return ""
|
|
||||||
return self._format_memory(self._cached_payload, query)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_room_graph_memory(payload: Any, query: str) -> str:
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
logger.warning("Unsupported room graph response: %s", payload)
|
|
||||||
return ""
|
|
||||||
objects = payload.get("objects", [])
|
|
||||||
relations = payload.get("relations", [])
|
|
||||||
summary = payload.get("summary", "")
|
|
||||||
|
|
||||||
if not objects and not relations and not summary:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
query_terms = _query_terms(query)
|
|
||||||
relevant_objects, relevant_relations = _relevant_room_graph(
|
|
||||||
objects=objects,
|
|
||||||
relations=relations,
|
|
||||||
query_terms=query_terms,
|
|
||||||
)
|
|
||||||
|
|
||||||
objects_text = json.dumps(
|
|
||||||
relevant_objects or _compact_items(objects, limit=12),
|
|
||||||
ensure_ascii=False,
|
|
||||||
separators=(",", ":"),
|
|
||||||
)
|
|
||||||
relations_text = json.dumps(
|
|
||||||
relevant_relations or _compact_items(relations, limit=24),
|
|
||||||
ensure_ascii=False,
|
|
||||||
separators=(",", ":"),
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
你是一个物品定位助手。
|
|
||||||
|
|
||||||
目标物品:{query}
|
|
||||||
相关物品:{objects_text}
|
|
||||||
相关空间关系:{relations_text}
|
|
||||||
房间概览:{summary}
|
|
||||||
|
|
||||||
回答要求:
|
|
||||||
1. 只说明它和其他物品的位置关系。
|
|
||||||
2. 不要编造不存在的关系。
|
|
||||||
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
|
|
||||||
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
|
|
||||||
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
|
|
||||||
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
|
|
||||||
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
|
|
||||||
query_terms,
|
|
||||||
len(relevant_objects),
|
|
||||||
len(objects) if isinstance(objects, list) else 0,
|
|
||||||
len(relevant_relations),
|
|
||||||
len(relations) if isinstance(relations, list) else 0,
|
|
||||||
len(prompt),
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def _query_terms(query: str) -> list[str]:
|
|
||||||
normalized = re.sub(r"[\s??。!,、,.!]", "", query)
|
|
||||||
for word in _LOCATION_STOPWORDS:
|
|
||||||
normalized = normalized.replace(word, "")
|
|
||||||
|
|
||||||
terms = [normalized] if normalized else []
|
|
||||||
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
|
|
||||||
if token not in _LOCATION_STOPWORDS and token not in terms:
|
|
||||||
terms.append(token)
|
|
||||||
return terms[:4]
|
|
||||||
|
|
||||||
|
|
||||||
def _relevant_room_graph(
|
|
||||||
*,
|
|
||||||
objects: Any,
|
|
||||||
relations: Any,
|
|
||||||
query_terms: list[str],
|
|
||||||
) -> tuple[list[Any], list[Any]]:
|
|
||||||
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
matched_ids: set[str] = set()
|
|
||||||
matched_objects: list[Any] = []
|
|
||||||
object_by_id: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for obj in objects:
|
|
||||||
obj_id = _object_id(obj)
|
|
||||||
if obj_id:
|
|
||||||
object_by_id[obj_id] = obj
|
|
||||||
|
|
||||||
obj_text = _compact_text(obj)
|
|
||||||
if any(term and term in obj_text for term in query_terms):
|
|
||||||
matched_objects.append(obj)
|
|
||||||
if obj_id:
|
|
||||||
matched_ids.add(obj_id)
|
|
||||||
|
|
||||||
relevant_relations: list[Any] = []
|
|
||||||
related_ids: set[str] = set(matched_ids)
|
|
||||||
for relation in relations:
|
|
||||||
relation_text = _compact_text(relation)
|
|
||||||
relation_ids = _ids_in_value(relation)
|
|
||||||
if (
|
|
||||||
any(term and term in relation_text for term in query_terms)
|
|
||||||
or bool(matched_ids.intersection(relation_ids))
|
|
||||||
):
|
|
||||||
relevant_relations.append(relation)
|
|
||||||
related_ids.update(relation_ids)
|
|
||||||
|
|
||||||
relevant_objects = list(matched_objects)
|
|
||||||
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
|
|
||||||
for obj_id in related_ids:
|
|
||||||
obj = object_by_id.get(obj_id)
|
|
||||||
key = _object_key(obj)
|
|
||||||
if obj is not None and key not in seen_object_keys:
|
|
||||||
relevant_objects.append(obj)
|
|
||||||
seen_object_keys.add(key)
|
|
||||||
|
|
||||||
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_items(items: Any, *, limit: int) -> list[Any]:
|
|
||||||
if not isinstance(items, list):
|
|
||||||
return []
|
|
||||||
return [_compact_item(item) for item in items[:limit]]
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_item(item: Any) -> Any:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
return item
|
|
||||||
|
|
||||||
preferred_keys = (
|
|
||||||
"id",
|
|
||||||
"name",
|
|
||||||
"label",
|
|
||||||
"class",
|
|
||||||
"category",
|
|
||||||
"type",
|
|
||||||
"text",
|
|
||||||
"source",
|
|
||||||
"target",
|
|
||||||
"subject",
|
|
||||||
"object",
|
|
||||||
"relation",
|
|
||||||
"predicate",
|
|
||||||
"description",
|
|
||||||
)
|
|
||||||
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
|
|
||||||
return compact or item
|
|
||||||
|
|
||||||
|
|
||||||
def _object_id(obj: Any) -> str | None:
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
return None
|
|
||||||
for key in ("id", "object_id", "uuid", "name", "label"):
|
|
||||||
value = obj.get(key)
|
|
||||||
if isinstance(value, (str, int)):
|
|
||||||
return str(value)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _object_key(obj: Any) -> str:
|
|
||||||
return _object_id(obj) or _compact_text(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def _ids_in_value(value: Any) -> set[str]:
|
|
||||||
ids: set[str] = set()
|
|
||||||
if isinstance(value, dict):
|
|
||||||
for key, item in value.items():
|
|
||||||
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
|
|
||||||
if isinstance(item, (str, int)):
|
|
||||||
ids.add(str(item))
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
obj_id = _object_id(item)
|
|
||||||
if obj_id:
|
|
||||||
ids.add(obj_id)
|
|
||||||
ids.update(_ids_in_value(item))
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
ids.update(_ids_in_value(item))
|
|
||||||
return ids
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_text(value: Any) -> str:
|
|
||||||
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
|
|
||||||
188
test_agent.py
Normal file
188
test_agent.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
from livekit import rtc
|
||||||
|
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("test-agent")
|
||||||
|
|
||||||
|
TOKEN_URL = "http://localhost:8000/getToken"
|
||||||
|
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
|
||||||
|
ROOM_NAME = "test-room20"
|
||||||
|
WAV_FILE = "2food.wav"
|
||||||
|
TEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
class TestState:
|
||||||
|
def __init__(self):
|
||||||
|
self.agent_connected = False
|
||||||
|
self.tts_received = False
|
||||||
|
self.tts_count = 0
|
||||||
|
|
||||||
|
test_state = TestState()
|
||||||
|
|
||||||
|
|
||||||
|
def get_token(agent_name="my-agent"):
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
TOKEN_URL,
|
||||||
|
params={
|
||||||
|
"room": ROOM_NAME,
|
||||||
|
"identity": f"test-{uuid.uuid4().hex[:6]}",
|
||||||
|
"agent_name": agent_name,
|
||||||
|
},
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["token"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 获取token失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_wav(room, wav_path):
|
||||||
|
wav_path = Path(wav_path)
|
||||||
|
if not wav_path.exists():
|
||||||
|
logger.error(f"❌ WAV文件不存在: {wav_path}")
|
||||||
|
raise FileNotFoundError(f"文件不存在: {wav_path}")
|
||||||
|
|
||||||
|
logger.info(f"📂 开始上传: {wav_path}")
|
||||||
|
|
||||||
|
with wave.open(str(wav_path), "rb") as wf:
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
num_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
|
||||||
|
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||||||
|
|
||||||
|
source = AudioSource(sample_rate, num_channels)
|
||||||
|
track = LocalAudioTrack.create_audio_track("mic", source)
|
||||||
|
|
||||||
|
await room.local_participant.publish_track(track)
|
||||||
|
logger.info("📡 已发布音轨")
|
||||||
|
|
||||||
|
frame_duration = 0.02
|
||||||
|
samples_per_frame = int(sample_rate * frame_duration)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = wf.readframes(samples_per_frame)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16)
|
||||||
|
if len(audio) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples_per_channel = len(audio) // num_channels
|
||||||
|
|
||||||
|
frame = AudioFrame(
|
||||||
|
data=data,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=num_channels,
|
||||||
|
samples_per_channel=samples_per_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
await source.capture_frame(frame)
|
||||||
|
await asyncio.sleep(frame_duration)
|
||||||
|
|
||||||
|
logger.info("✅ WAV推流完成")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_agent():
|
||||||
|
try:
|
||||||
|
logger.info("🔑 正在获取token...")
|
||||||
|
token = get_token()
|
||||||
|
logger.info("✅ Token获取成功")
|
||||||
|
|
||||||
|
room = rtc.Room()
|
||||||
|
|
||||||
|
@room.on("participant_connected")
|
||||||
|
def on_participant_connected(participant):
|
||||||
|
logger.info(f"✅ 参与者加入: {participant.identity}")
|
||||||
|
if "agent" in participant.identity.lower():
|
||||||
|
test_state.agent_connected = True
|
||||||
|
logger.info("🎉 Agent已连接!")
|
||||||
|
|
||||||
|
@room.on("participant_disconnected")
|
||||||
|
def on_participant_disconnected(participant):
|
||||||
|
logger.info(f"❌ 参与者离开: {participant.identity}")
|
||||||
|
|
||||||
|
@room.on("track_subscribed")
|
||||||
|
def on_track_subscribed(track, publication, participant):
|
||||||
|
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||||
|
test_state.tts_count += 1
|
||||||
|
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
|
||||||
|
test_state.tts_received = True
|
||||||
|
|
||||||
|
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
|
||||||
|
await room.connect(WS_URL, token)
|
||||||
|
logger.info("✅ 已连接到房间")
|
||||||
|
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
|
||||||
|
|
||||||
|
logger.info("⏳ 等待Agent连接...")
|
||||||
|
for i in range(10):
|
||||||
|
if test_state.agent_connected:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if not test_state.agent_connected:
|
||||||
|
logger.warning("⚠️ Agent未连接")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("🎙️ 正在上传测试音频...")
|
||||||
|
await publish_wav(room, WAV_FILE)
|
||||||
|
|
||||||
|
logger.info("⏳ 等待Agent响应...")
|
||||||
|
for i in range(TEST_TIMEOUT):
|
||||||
|
if test_state.tts_received:
|
||||||
|
logger.info("✅ 收到Agent TTS响应!")
|
||||||
|
break
|
||||||
|
if i % 5 == 0:
|
||||||
|
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
logger.info("\n" + "="*60)
|
||||||
|
logger.info("✅ 测试结果")
|
||||||
|
logger.info("="*60)
|
||||||
|
logger.info(f"Agent连接: {'✅' if test_state.agent_connected else '❌'}")
|
||||||
|
logger.info(f"收到TTS响应: {'✅' if test_state.tts_received else '❌'}")
|
||||||
|
logger.info(f"TTS音频次数: {test_state.tts_count} 次")
|
||||||
|
logger.info("="*60)
|
||||||
|
|
||||||
|
await room.disconnect()
|
||||||
|
logger.info("✅ 已断开连接\n")
|
||||||
|
|
||||||
|
return test_state.agent_connected and test_state.tts_received
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 测试失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
logger.info("🚀 开始测试custom_agent...\n")
|
||||||
|
success = await test_agent()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("✅ 测试成功!custom_agent 正常工作")
|
||||||
|
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到,")
|
||||||
|
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.error("❌ 测试失败")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
exit(exit_code)
|
||||||
55
test_asr.py
Normal file
55
test_asr.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from asr import BlackboxSTT
|
||||||
|
from livekit import rtc
|
||||||
|
|
||||||
|
# 设置日志级别以查看输出
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("test-asr")
|
||||||
|
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
# 替换为你本地的一个音频文件路径
|
||||||
|
audio_path = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
|
# 初始化 ASR
|
||||||
|
stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice")
|
||||||
|
|
||||||
|
print(f"Testing ASR connectivity with file: {audio_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取音频文件
|
||||||
|
with wave.open(audio_path, "rb") as wf:
|
||||||
|
frames = wf.readframes(wf.getnframes())
|
||||||
|
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
|
||||||
|
# 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
|
||||||
|
# 所以我们需要传递一个包含 AudioFrame 的 list
|
||||||
|
|
||||||
|
# 这里我们模拟一个 Frame
|
||||||
|
frame = rtc.AudioFrame(
|
||||||
|
data=frames,
|
||||||
|
sample_rate=wf.getframerate(),
|
||||||
|
num_channels=wf.getnchannels(),
|
||||||
|
samples_per_channel=wf.getnframes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 recognize
|
||||||
|
result = await stt.recognize(buffer=[frame])
|
||||||
|
|
||||||
|
if result.alternatives:
|
||||||
|
print("\n--- ASR Result ---")
|
||||||
|
print(f"Text: {result.alternatives[0].text}")
|
||||||
|
print("------------------\n")
|
||||||
|
else:
|
||||||
|
print("ASR returned no text.")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Audio file not found at {audio_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test())
|
||||||
253
test_beaver_llm.py
Normal file
253
test_beaver_llm.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
try:
|
||||||
|
from custom.beaver_llm import BeaverLLM, latest_user_text
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from beaver_llm import BeaverLLM, latest_user_text
|
||||||
|
from livekit.agents import ChatContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_latest_user_text_uses_most_recent_user_message() -> None:
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="first")
|
||||||
|
ctx.add_message(role="assistant", content="ignored")
|
||||||
|
ctx.add_message(role="user", content=["second", "line"])
|
||||||
|
|
||||||
|
assert latest_user_text(ctx) == "second\nline"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_can_connect_before_first_message(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await beaver_llm.connect()
|
||||||
|
assert beaver_llm.session_id == "terminal-dev:local:livekit-room"
|
||||||
|
assert received == [
|
||||||
|
{
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "livekit-room",
|
||||||
|
"device_name": "livekit-custom-agent",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert collected.text == "beaver reply"
|
||||||
|
assert received[1]["type"] == "message"
|
||||||
|
assert received[1]["text"] == "hello beaver"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_connect_can_send_warmup_message(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-warmup",
|
||||||
|
"text": "ready",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
warmup_reply = await beaver_llm.connect(warmup_text="初始化连接")
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert warmup_reply == "ready"
|
||||||
|
assert received[0] == {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "livekit-room",
|
||||||
|
"device_name": "livekit-custom-agent",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
assert received[1]["type"] == "message"
|
||||||
|
assert received[1]["text"] == "初始化连接"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="system", content="ignored instructions")
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert collected.text == "beaver reply"
|
||||||
|
assert received[0] == {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "livekit-room",
|
||||||
|
"device_name": "livekit-custom-agent",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
assert received[1]["type"] == "message"
|
||||||
|
assert received[1]["message_id"].startswith("livekit-room-")
|
||||||
|
assert received[1]["message_id"].endswith("-000001")
|
||||||
|
assert received[1]["text"] == "hello beaver"
|
||||||
426
test_beaver_terminal_client.py
Normal file
426
test_beaver_terminal_client.py
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from custom.beaver_terminal_client import (
|
||||||
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalError,
|
||||||
|
MessageIdGenerator,
|
||||||
|
build_connect_frame,
|
||||||
|
build_message_frame,
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from beaver_terminal_client import (
|
||||||
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalError,
|
||||||
|
MessageIdGenerator,
|
||||||
|
build_connect_frame,
|
||||||
|
build_message_frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
||||||
|
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
|
||||||
|
|
||||||
|
assert frame == {
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "device-001",
|
||||||
|
"device_name": "desk-terminal",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_message_frame_uses_message_id_and_text() -> None:
|
||||||
|
frame = build_message_frame(message_id="device-001-000001", text="hello")
|
||||||
|
|
||||||
|
assert frame == {
|
||||||
|
"type": "message",
|
||||||
|
"message_id": "device-001-000001",
|
||||||
|
"text": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
|
||||||
|
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
|
||||||
|
|
||||||
|
assert generator.next_id() == "device-001-000008"
|
||||||
|
assert generator.next_id() == "device-001-000009"
|
||||||
|
assert generator.counter == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id_generator_can_include_instance_id() -> None:
|
||||||
|
generator = MessageIdGenerator(peer_id="device-001", instance_id="abc123ef")
|
||||||
|
|
||||||
|
assert generator.next_id() == "device-001-abc123ef-000001"
|
||||||
|
assert generator.next_id() == "device-001-abc123ef-000002"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_connects_sends_text_and_returns_assistant_reply(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "assistant reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert client.session_id == "terminal-dev:local:device-001"
|
||||||
|
assert reply == "assistant reply"
|
||||||
|
assert received == [
|
||||||
|
{
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "device-001",
|
||||||
|
"device_name": "desk-terminal",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"message_id": "device-001-000001",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": False,
|
||||||
|
"duplicate": True,
|
||||||
|
"pending": False,
|
||||||
|
"reply": "cached assistant reply",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert reply == "cached assistant reply"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json({"type": "error", "error": "text is required"})
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
with pytest.raises(BeaverTerminalError, match="text is required"):
|
||||||
|
await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "failed turn",
|
||||||
|
"finish_reason": "error",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
with pytest.raises(BeaverTerminalError, match="failed turn"):
|
||||||
|
await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "ping":
|
||||||
|
await ws.send_json({"type": "pong"})
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
assert await client.ping()
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
connect_peer_ids: list[str] = []
|
||||||
|
message_ids: list[str] = []
|
||||||
|
connection_count = 0
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
nonlocal connection_count
|
||||||
|
connection_count += 1
|
||||||
|
current_connection = connection_count
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
connect_peer_ids.append(frame["peer_id"])
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
message_ids.append(frame["message_id"])
|
||||||
|
if current_connection == 1:
|
||||||
|
await ws.close()
|
||||||
|
continue
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-2",
|
||||||
|
"text": "reply after reconnect",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert reply == "reply after reconnect"
|
||||||
|
assert connect_peer_ids == ["device-001", "device-001"]
|
||||||
|
assert message_ids == ["device-001-000001", "device-001-000002"]
|
||||||
130
test_livekit.py
Normal file
130
test_livekit.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
from livekit import rtc
|
||||||
|
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
|
||||||
|
|
||||||
|
TOKEN_URL = "http://localhost:8000/getToken"
|
||||||
|
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址
|
||||||
|
|
||||||
|
ROOM_NAME = "test-room20"
|
||||||
|
import uuid
|
||||||
|
IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
|
||||||
|
# IDENTITY = "test-user0"
|
||||||
|
|
||||||
|
|
||||||
|
def get_token():
|
||||||
|
resp = requests.get(
|
||||||
|
TOKEN_URL,
|
||||||
|
params={
|
||||||
|
"room": ROOM_NAME,
|
||||||
|
"identity": IDENTITY,
|
||||||
|
"agent_name": "my-agent", # 关键!!!
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = resp.json()
|
||||||
|
return data["token"]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
token = get_token()
|
||||||
|
|
||||||
|
room = rtc.Room()
|
||||||
|
|
||||||
|
@room.on("participant_connected")
|
||||||
|
def on_participant_connected(participant):
|
||||||
|
print(f"✅ 有人加入房间: {participant.identity}")
|
||||||
|
|
||||||
|
@room.on("participant_disconnected")
|
||||||
|
def on_participant_disconnected(participant):
|
||||||
|
print(f"❌ 有人离开房间: {participant.identity}")
|
||||||
|
|
||||||
|
print("🔌 正在连接房间...")
|
||||||
|
await room.connect(WS_URL, token)
|
||||||
|
|
||||||
|
print("✅ 已连接房间:", ROOM_NAME)
|
||||||
|
print("当前房间成员:")
|
||||||
|
for p in room.remote_participants.values():
|
||||||
|
print(" -", p.identity)
|
||||||
|
|
||||||
|
@room.on("data_received")
|
||||||
|
def on_data_received(data, participant, kind, topic):
|
||||||
|
try:
|
||||||
|
msg = data.decode()
|
||||||
|
print(f"📩 来自 {participant.identity}: {msg}")
|
||||||
|
except:
|
||||||
|
print("📩 收到二进制数据")
|
||||||
|
|
||||||
|
@room.on("track_subscribed")
|
||||||
|
def on_track_subscribed(track, publication, participant):
|
||||||
|
print(f"🎧 订阅轨道: {participant.identity}")
|
||||||
|
|
||||||
|
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||||
|
print("👉 TTS 音频来了")
|
||||||
|
|
||||||
|
# 等一下确保连接稳定
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await room.local_participant.publish_data(
|
||||||
|
b"hello",
|
||||||
|
reliable=True,
|
||||||
|
topic="chat"
|
||||||
|
)
|
||||||
|
# 上传 wav
|
||||||
|
await publish_wav(room, "2food.wav")
|
||||||
|
|
||||||
|
await room.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_wav(room, wav_path):
|
||||||
|
print("🎵 开始上传本地 wav:", wav_path)
|
||||||
|
|
||||||
|
wf = wave.open(wav_path, "rb")
|
||||||
|
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
num_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
|
||||||
|
print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
|
||||||
|
|
||||||
|
# 创建音频源
|
||||||
|
source = AudioSource(sample_rate, num_channels)
|
||||||
|
|
||||||
|
# 创建本地音轨
|
||||||
|
track = LocalAudioTrack.create_audio_track("mic", source)
|
||||||
|
|
||||||
|
# 发布轨道
|
||||||
|
await room.local_participant.publish_track(track)
|
||||||
|
print("📡 已发布音轨")
|
||||||
|
|
||||||
|
frame_duration = 0.02 # 20ms
|
||||||
|
samples_per_frame = int(sample_rate * frame_duration)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = wf.readframes(samples_per_frame)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 用于计算长度
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16)
|
||||||
|
|
||||||
|
if len(audio) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples_per_channel = len(audio) // num_channels
|
||||||
|
|
||||||
|
frame = AudioFrame(
|
||||||
|
data=data, # ✅ 关键:用 bytes
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=num_channels,
|
||||||
|
samples_per_channel=samples_per_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
await source.capture_frame(frame)
|
||||||
|
await asyncio.sleep(frame_duration)
|
||||||
|
print("✅ wav 推流结束")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
71
test_minimax.py
Normal file
71
test_minimax.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from livekit.agents.llm import ChatContext
|
||||||
|
from livekit.plugins import openai
|
||||||
|
|
||||||
|
# Configure logging to see what's happening
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("test-minimax")
|
||||||
|
|
||||||
|
async def test_minimax():
|
||||||
|
print("Loading .env...")
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configuration from environment or defaults from custom_agent.py
|
||||||
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||||
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI")
|
||||||
|
# Using the hardcoded key from custom_agent.py as a fallback if not in .env
|
||||||
|
API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA")
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from openai import AsyncClient as OpenAIAsyncClient
|
||||||
|
|
||||||
|
print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}")
|
||||||
|
|
||||||
|
# Create a custom HTTP client that disables SSL verification
|
||||||
|
http_client = httpx.AsyncClient(verify=False)
|
||||||
|
|
||||||
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||||
|
openai_client = OpenAIAsyncClient(
|
||||||
|
api_key=API_KEY,
|
||||||
|
base_url=MINIMAX_BASE_URL,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = openai.LLM(
|
||||||
|
model=MINIMAX_MODEL,
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Creating ChatContext...")
|
||||||
|
chat_ctx = ChatContext()
|
||||||
|
chat_ctx.add_message(
|
||||||
|
content="Hello! Can you introduce yourself? Please reply in Chinese.",
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n--- Testing Streaming Chat ---")
|
||||||
|
print(f"Request: {chat_ctx.items[-1].content}")
|
||||||
|
print("Response: ", end="", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nCalling llm.chat()...")
|
||||||
|
stream = llm.chat(chat_ctx=chat_ctx)
|
||||||
|
print("Iterating over stream...")
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.delta and chunk.delta.content:
|
||||||
|
print(chunk.delta.content, end="", flush=True)
|
||||||
|
print("\n--- Test Completed Successfully ---")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"\nTest failed with error: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Starting...")
|
||||||
|
try:
|
||||||
|
asyncio.run(asyncio.wait_for(test_minimax(), timeout=30))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("\nTest timed out after 30 seconds.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nAn error occurred: {e}")
|
||||||
66
test_voxcpm.py
Normal file
66
test_voxcpm.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts():
|
||||||
|
# Use the URL from the user's curl command
|
||||||
|
url = "http://10.6.80.21:5002/tts-blackbox"
|
||||||
|
|
||||||
|
# Check if we have a real wav file to test with
|
||||||
|
# In the earlier find_by_name, we found tests/change-sophie.wav
|
||||||
|
prompt_wav = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
if not os.path.exists(prompt_wav):
|
||||||
|
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
|
||||||
|
|
||||||
|
print(f"Testing BlackboxTTS with URL: {url}")
|
||||||
|
print(f"Using prompt wav: {prompt_wav}")
|
||||||
|
|
||||||
|
blackbox_tts = BlackboxTTS(
|
||||||
|
url=url,
|
||||||
|
model_name="voxcpmtts",
|
||||||
|
prompt_wav_path=prompt_wav,
|
||||||
|
params={
|
||||||
|
"streaming": "false",
|
||||||
|
"prompt_text": "澳门有乜嘢好食嘅",
|
||||||
|
"cfg_value": "2.0",
|
||||||
|
"inference_timesteps": "10",
|
||||||
|
"do_normalize": "true",
|
||||||
|
"denoise": "true",
|
||||||
|
"retry_badcase": "true",
|
||||||
|
"retry_badcase_max_times": "3",
|
||||||
|
"retry_badcase_ratio_threshold": "6.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "你好,这是一段测试文本"
|
||||||
|
print(f"Synthesizing text: {text}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = blackbox_tts.synthesize(text)
|
||||||
|
audio_frame = await stream.collect()
|
||||||
|
|
||||||
|
print("Successfully synthesized audio!")
|
||||||
|
print(
|
||||||
|
f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?"
|
||||||
|
)
|
||||||
|
# Actually AudioFrame has duration or samples
|
||||||
|
print(f"Samples: {len(audio_frame.data) // 2}")
|
||||||
|
|
||||||
|
# Save to file for manual check if possible
|
||||||
|
with open("test_output.wav", "wb") as f:
|
||||||
|
# This won't be a valid WAV yet if it's just raw PCM,
|
||||||
|
# but if collect() returns combined frames, we can use to_wav_bytes()
|
||||||
|
f.write(audio_frame.to_wav_bytes())
|
||||||
|
print("Saved output to test_output.wav")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TTS test failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_tts())
|
||||||
24
tts.py
24
tts.py
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import wave
|
import wave
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -89,7 +88,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
self._tts: BlackboxTTS = tts
|
self._tts: BlackboxTTS = tts
|
||||||
|
|
||||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||||
started_at = time.perf_counter()
|
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
form.add_field("text", self.input_text)
|
form.add_field("text", self.input_text)
|
||||||
form.add_field("model_name", self._tts._model_name)
|
form.add_field("model_name", self._tts._model_name)
|
||||||
@ -133,9 +131,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
content_type = resp.headers.get("Content-Type", "audio/wav")
|
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||||
logged_wav_format = False
|
logged_wav_format = False
|
||||||
wav_header_probe = bytearray()
|
wav_header_probe = bytearray()
|
||||||
first_audio_at: float | None = None
|
|
||||||
chunk_count = 0
|
|
||||||
total_bytes = 0
|
|
||||||
output_emitter.initialize(
|
output_emitter.initialize(
|
||||||
request_id=utils.shortuuid(),
|
request_id=utils.shortuuid(),
|
||||||
sample_rate=self._tts.sample_rate,
|
sample_rate=self._tts.sample_rate,
|
||||||
@ -145,16 +140,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
|
|
||||||
async for data, _ in resp.content.iter_chunks():
|
async for data, _ in resp.content.iter_chunks():
|
||||||
if data:
|
if data:
|
||||||
chunk_count += 1
|
|
||||||
total_bytes += len(data)
|
|
||||||
if first_audio_at is None:
|
|
||||||
first_audio_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
|
|
||||||
first_audio_at - started_at,
|
|
||||||
len(self.input_text),
|
|
||||||
len(data),
|
|
||||||
)
|
|
||||||
if not logged_wav_format:
|
if not logged_wav_format:
|
||||||
wav_header_probe.extend(data)
|
wav_header_probe.extend(data)
|
||||||
logged_wav_format = _log_wav_format(
|
logged_wav_format = _log_wav_format(
|
||||||
@ -171,15 +156,6 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
logged_wav_format = True
|
logged_wav_format = True
|
||||||
output_emitter.push(data)
|
output_emitter.push(data)
|
||||||
output_emitter.flush()
|
output_emitter.flush()
|
||||||
finished_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
|
|
||||||
finished_at - started_at,
|
|
||||||
(first_audio_at - started_at) if first_audio_at else -1.0,
|
|
||||||
chunk_count,
|
|
||||||
total_bytes,
|
|
||||||
len(self.input_text),
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as e:
|
except asyncio.TimeoutError as e:
|
||||||
raise APITimeoutError("TTS blackbox request timed out") from e
|
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user