Compare commits

...

17 Commits

Author SHA1 Message Date
52e6d3cd9c feat: connect to beaver 2026-06-03 17:25:08 +08:00
0a50f25dfa feat: beaver first commit 2026-06-02 14:07:56 +08:00
34cf1b9736 feat: support beaver llm provider mode 2026-06-02 11:50:14 +08:00
af261d3b63 fix: use fresh beaver message id after reconnect 2026-06-02 10:33:03 +08:00
879c73bfee fix: reconnect beaver terminal websocket before retrying turn 2026-06-02 10:18:19 +08:00
b1cad592e2 feat: add beaver terminal websocket client 2026-06-01 17:06:23 +08:00
e7529dc47b feat: add hermes gateway llm adapter 2026-06-01 16:50:00 +08:00
7efd9eba98 feat: add emotion prompt 2026-06-01 09:46:04 +08:00
e097323176 feat: support vlm chat 2026-05-25 17:17:38 +08:00
2064db15dc add env 2026-05-22 15:36:38 +08:00
f272053a95 fix: prompt 2026-05-22 14:46:10 +08:00
fba51a5257 perf: improve speed 2026-05-15 10:44:31 +08:00
b18c5b40da fix: tts parameters 2026-05-14 15:33:20 +08:00
89011fed81 fix: memory recall fuction prompt 2026-05-14 11:18:04 +08:00
3a2f5c4252 feat: memory recall fuction 2026-05-14 10:16:08 +08:00
746053fd58 fix 2026-05-13 15:35:04 +08:00
6ec16bf68e remove test files from main 2026-05-11 11:30:28 +08:00
16 changed files with 2760 additions and 570 deletions

83
.env.example Normal file
View File

@ -0,0 +1,83 @@
# LiveKit connection
LIVEKIT_URL=ws://localhost:7880
LIVEKIT_API_KEY=
LIVEKIT_API_SECRET=
CUSTOM_AGENT_NAME=my-agent
# Beaver terminal text WebSocket
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
TERMINAL_PEER_ID=device-001
TERMINAL_DEVICE_NAME=desk-terminal
# ASR blackbox
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
CUSTOM_ASR_MODEL=qwen
CUSTOM_ASR_LANGUAGE=Chinese
CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# LLM backend: openai/openai-compatible or hermes_gateway/openclaw.
CUSTOM_LLM_PROVIDER=beaver
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
# OpenAI-compatible LLM
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
# CUSTOM_LLM_MODEL=Qwen3.6-35B
# CUSTOM_LLM_API_KEY=sk-
# CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_LLM_BASE_URL=http:/localhost/v1
CUSTOM_LLM_MODEL=Mistral-Medium-3.5-128B
CUSTOM_LLM_API_KEY=sk-
CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_SAVE_MODEL_IMAGES=true
# CUSTOM_TEXT_LLM_MODEL=
# CUSTOM_VISION_LLM_MODEL=
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash
# CUSTOM_LLM_API_KEY=sk-
# CUSTOM_LLM_VERIFY_SSL=false
# TTS blackbox
CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
CUSTOM_TTS_MODEL=voxcpmtts
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
CUSTOM_TTS_STREAMING=true
# CUSTOM_TTS_PROMPT_TEXT=澳门有乜嘢好食嘅
# VoxCPM TTS parameters
VOXCPM_CFG_VALUE=2.0
VOXCPM_INFERENCE_TIMESTEPS=10
VOXCPM_DO_NORMALIZE=true
VOXCPM_DENOISE=true
VOXCPM_RETRY_BADCASE=true
VOXCPM_RETRY_BADCASE_MAX_TIMES=3
VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD=6.0
# MeloTTS parameters
CUSTOM_TTS_SPEED=1.0
# CosyVoice parameters
CUSTOM_TTS_SPK_ID=
CUSTOM_TTS_MODE=
CUSTOM_TTS_INSTRUCT_TEXT=
# GPT-SoVITS parameters
CUSTOM_TTS_TEXT_LANG=zh
CUSTOM_TTS_PROMPT_LANG=zh
CUSTOM_TTS_TEXT_SPLIT_METHOD=cut0
CUSTOM_TTS_BATCH_SIZE=1
CUSTOM_TTS_MEDIA_TYPE=wav
CUSTOM_TTS_REF_AUDIO_PATH=
CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
CUSTOM_MEMORY_TIMEOUT=2
CUSTOM_MEMORY_MAX_CHARS=2000
CUSTOM_MEMORY_API_KEY=
CUSTOM_PREEMPTIVE_GENERATION=false

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__/
.env
model_images/

124
beaver_llm.py Normal file
View File

@ -0,0 +1,124 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Sequence
from typing import Any
try:
from beaver_terminal_client import BeaverTerminalClient
except ModuleNotFoundError:
from custom.beaver_terminal_client import BeaverTerminalClient
from livekit.agents import llm
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
NotGivenOr,
)
from livekit.agents.utils import shortuuid
logger = logging.getLogger("beaver-llm")
def latest_user_text(chat_ctx: llm.ChatContext) -> str:
for message in reversed(chat_ctx.messages()):
if message.role != "user":
continue
return _content_to_text(message.content)
return ""
def _content_to_text(content: Sequence[llm.ChatContent]) -> str:
text_parts = [item for item in content if isinstance(item, str)]
return "\n".join(text_parts)
class BeaverLLM(llm.LLM):
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
model_name: str = "beaver-terminal",
) -> None:
super().__init__()
self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name)
self._model_name = model_name
self._lock = asyncio.Lock()
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "beaver"
@property
def session_id(self) -> str | None:
return self._client.session_id
async def connect(self, *, warmup_text: str | None = None) -> str | None:
warmup_reply: str | None = None
async with self._lock:
await self._client.connect()
if warmup_text and warmup_text.strip():
warmup_reply = await self._client.send_text(warmup_text.strip())
if warmup_reply is None:
logger.info("Beaver handshake completed session_id=%s", self.session_id)
else:
logger.info(
"Beaver handshake warmup completed session_id=%s reply_len=%s",
self.session_id,
len(warmup_reply),
)
return warmup_reply
def chat(
self,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool] | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> llm.LLMStream:
return BeaverLLMStream(
self,
chat_ctx=chat_ctx,
tools=tools or [],
conn_options=conn_options,
)
async def aclose(self) -> None:
await self._client.close()
class BeaverLLMStream(llm.LLMStream):
def __init__(
self,
beaver_llm: BeaverLLM,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool],
conn_options: APIConnectOptions,
) -> None:
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._beaver_llm = beaver_llm
async def _run(self) -> None:
user_text = latest_user_text(self.chat_ctx)
async with self._beaver_llm._lock:
reply = await self._beaver_llm._client.send_text(user_text)
if reply:
self._event_ch.send_nowait(
llm.ChatChunk(
id=shortuuid("beaver_"),
delta=llm.ChoiceDelta(role="assistant", content=reply),
)
)

221
beaver_terminal_client.py Normal file
View File

@ -0,0 +1,221 @@
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import aiohttp
from dotenv import load_dotenv
logger = logging.getLogger("beaver-terminal-client")
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
DEFAULT_TERMINAL_PEER_ID = "device-001"
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
initial_counter: int = 0
def __post_init__(self) -> None:
self.counter = self.initial_counter
def next_id(self) -> str:
self.counter += 1
return f"{self.peer_id}-{self.counter:06d}"
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
return {
"type": "connect",
"peer_id": peer_id,
"device_name": device_name,
"capabilities": ["text"],
}
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
return {
"type": "message",
"message_id": message_id,
"text": text,
}
class BeaverTerminalClient:
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
http_session: aiohttp.ClientSession | None = None,
message_ids: MessageIdGenerator | None = None,
) -> None:
self._url = url
self._peer_id = peer_id
self._device_name = device_name
self._owned_session = http_session is None
self._http_session = http_session
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
async def send_text(self, text: str) -> str:
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
message_id = self._message_ids.next_id()
message_frame = build_message_frame(message_id=message_id, text=text)
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
return reply
continue
if (
frame_type == "message"
and frame.get("role") == "assistant"
and frame.get("message_id") == message_id
):
text = frame.get("text")
if frame.get("finish_reason") == "error":
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
return text if isinstance(text, str) else ""
if frame_type == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def ping(self) -> bool:
await self._send_json({"type": "ping"})
while True:
frame = await self._receive_json()
if frame.get("type") == "pong":
return True
if frame.get("type") == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
return self._http_session
async def _send_json(self, frame: dict[str, Any]) -> None:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
await self._ws.send_json(frame)
async def _receive_json(self) -> dict[str, Any]:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
if not isinstance(data, dict):
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
return data
def client_from_env() -> BeaverTerminalClient:
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
return BeaverTerminalClient(
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
)
async def run_console() -> None:
logging.basicConfig(level=logging.INFO)
client = client_from_env()
try:
await client.connect()
logger.info("Connected to Beaver session_id=%s", client.session_id)
while True:
text = await asyncio.to_thread(input, "> ")
text = text.strip()
if not text:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(run_console())

File diff suppressed because it is too large Load Diff

391
hermes_gateway.py Normal file
View File

@ -0,0 +1,391 @@
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
import aiohttp
from livekit.agents import llm
from livekit.agents._exceptions import APIConnectionError
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
NotGivenOr,
)
from livekit.agents.utils import shortuuid
@dataclass
class GatewaySessionState:
room_name: str
agent_id: str | None = None
session_key: str | None = None
session_mode: str = "per_room"
def __post_init__(self) -> None:
if self.session_mode != "per_room":
raise ValueError("Hermes gateway only supports CUSTOM_HERMES_SESSION_MODE=per_room")
if self.session_key is None:
suffix = self.agent_id or "default"
self.session_key = f"livekit:{self.room_name}:{suffix}"
class HermesGatewayLLM(llm.LLM):
def __init__(
self,
*,
url: str,
token: str | None,
state: GatewaySessionState,
agent_id: str | None = None,
model_name: str = "hermes-agent",
request_timeout: float = 30.0,
) -> None:
super().__init__()
self._url = url
self._token = token
self._state = state
self._agent_id = agent_id
self._model_name = model_name
self._request_timeout = request_timeout
self._http_session: aiohttp.ClientSession | None = None
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "hermes-gateway"
def chat(
self,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool] | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> llm.LLMStream:
return HermesGatewayLLMStream(
self,
chat_ctx=chat_ctx,
tools=tools or [],
conn_options=conn_options,
)
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
timeout = aiohttp.ClientTimeout(total=self._request_timeout)
self._http_session = aiohttp.ClientSession(timeout=timeout)
return self._http_session
async def aclose(self) -> None:
if self._http_session is not None:
await self._http_session.close()
self._http_session = None
class HermesGatewayLLMStream(llm.LLMStream):
def __init__(
self,
llm: HermesGatewayLLM,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool],
conn_options: APIConnectOptions,
) -> None:
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._llm = llm
async def _run(self) -> None:
request_id = shortuuid("gwreq_")
async with self._llm._ensure_http_session().ws_connect(self._llm._url) as ws:
await _connect(ws, token=self._llm._token)
await _send_rpc(
ws,
method="sessions.create",
params={
"key": self._llm._state.session_key,
"sessionKey": self._llm._state.session_key,
"agentId": self._llm._agent_id,
"metadata": {
"source": "livekit",
"room": self._llm._state.room_name,
},
"idempotencyKey": self._llm._state.session_key,
},
request_id=shortuuid("gwcreate_"),
)
await _send_rpc(
ws,
method="sessions.send",
params={
"key": self._llm._state.session_key,
"sessionKey": self._llm._state.session_key,
"agentId": self._llm._agent_id,
"messages": chat_context_to_gateway_messages(self.chat_ctx),
"stream": True,
"idempotencyKey": request_id,
},
request_id=request_id,
wait_response=False,
)
streamed_text = ""
async for frame in _iter_gateway_frames(ws):
if is_error_response(frame, request_id=request_id):
raise APIConnectionError(_gateway_error_message(frame), retryable=False)
text = extract_text_delta(frame)
if text:
if text.startswith(streamed_text):
text = text[len(streamed_text) :]
if text:
streamed_text += text
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(role="assistant", content=text),
)
)
if is_terminal_event(frame, request_id=request_id):
return
def build_connect_params(*, token: str | None) -> dict[str, Any]:
params: dict[str, Any] = {
"minProtocol": 3,
"maxProtocol": 4,
"client": {
"id": "gateway-client",
"version": "livekit-custom-agent",
"platform": "python",
"mode": "backend",
},
"role": "operator",
"scopes": ["operator.read", "operator.write"],
"caps": [],
"commands": [],
"permissions": {},
"locale": "zh-CN",
"userAgent": "livekit-custom-agent",
}
if token:
params["auth"] = {"token": token}
return params
def chat_context_to_gateway_messages(chat_ctx: llm.ChatContext) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for message in chat_ctx.messages():
content = _message_content_to_gateway_content(message.content)
if content is None:
continue
messages.append({"role": message.role, "content": content})
return messages
def extract_text_delta(frame: dict[str, Any]) -> str:
payload = frame.get("payload")
if not isinstance(payload, dict):
payload = frame
for path in (
("delta", "content"),
("delta", "text"),
("message", "delta", "content"),
("message", "delta", "text"),
("message", "content"),
("content",),
("text",),
):
value = _get_nested(payload, path)
text = _content_to_text(value)
if text:
return text
return ""
def is_terminal_event(frame: dict[str, Any], *, request_id: str) -> bool:
if frame.get("type") == "res" and frame.get("id") == request_id:
return True
event = frame.get("event")
if event in {
"agent.done",
"agent.completed",
"agent.error",
"session.message.completed",
"session.run.completed",
"sessions.run.completed",
"run.completed",
"run.failed",
}:
return True
payload = frame.get("payload")
if isinstance(payload, dict) and payload.get("done") is True:
return True
return False
def is_error_response(frame: dict[str, Any], *, request_id: str) -> bool:
if (
frame.get("type") == "res"
and frame.get("id") == request_id
and frame.get("ok") is False
):
return True
return frame.get("event") in {
"agent.error",
"session.error",
"session.run.failed",
"sessions.run.failed",
"run.failed",
}
def _gateway_error_message(frame: dict[str, Any]) -> str:
error = frame.get("error")
if isinstance(error, str):
return f"OpenClaw gateway request failed: {error}"
if isinstance(error, dict):
message = error.get("message") or error.get("error")
if isinstance(message, str):
return f"OpenClaw gateway request failed: {message}"
payload = frame.get("payload")
if isinstance(payload, dict):
message = payload.get("message") or payload.get("error")
if isinstance(message, str):
return f"OpenClaw gateway request failed: {message}"
return f"OpenClaw gateway request failed: {frame!r}"
async def _connect(ws: aiohttp.ClientWebSocketResponse, *, token: str | None) -> None:
first = await _receive_json(ws)
if first.get("event") != "connect.challenge":
raise RuntimeError(f"expected connect.challenge, received {first!r}")
request_id = shortuuid("gwconnect_")
await _send_rpc(
ws,
method="connect",
params=build_connect_params(token=token),
request_id=request_id,
wait_response=False,
)
response = await _wait_for_response(ws, request_id=request_id)
if not response.get("ok"):
raise RuntimeError(f"OpenClaw gateway connect failed: {response.get('error')!r}")
async def _send_rpc(
ws: aiohttp.ClientWebSocketResponse,
*,
method: str,
params: dict[str, Any],
request_id: str,
wait_response: bool = True,
) -> dict[str, Any] | None:
await ws.send_str(
json.dumps(
{
"type": "req",
"id": request_id,
"method": method,
"params": _drop_none(params),
}
)
)
if not wait_response:
return None
response = await _wait_for_response(ws, request_id=request_id)
if not response.get("ok", False):
raise RuntimeError(f"OpenClaw gateway RPC {method} failed: {response.get('error')!r}")
return response
async def _wait_for_response(
ws: aiohttp.ClientWebSocketResponse, *, request_id: str
) -> dict[str, Any]:
async for frame in _iter_gateway_frames(ws):
if frame.get("type") == "res" and frame.get("id") == request_id:
return frame
raise RuntimeError(f"OpenClaw gateway closed before response {request_id}")
async def _iter_gateway_frames(
ws: aiohttp.ClientWebSocketResponse,
) -> AsyncIterator[dict[str, Any]]:
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
yield json.loads(message.data)
elif message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
return
elif message.type == aiohttp.WSMsgType.ERROR:
raise RuntimeError(f"OpenClaw gateway websocket error: {ws.exception()!r}")
async def _receive_json(ws: aiohttp.ClientWebSocketResponse) -> dict[str, Any]:
message = await ws.receive()
if message.type != aiohttp.WSMsgType.TEXT:
raise RuntimeError(f"expected gateway text frame, received {message.type!r}")
return json.loads(message.data)
def _message_content_to_gateway_content(content: list[llm.ChatContent]) -> Any:
parts: list[dict[str, Any]] = []
for item in content:
if isinstance(item, str):
if item:
parts.append({"type": "text", "text": item})
elif isinstance(item, llm.ImageContent) and isinstance(item.image, str):
parts.append({"type": "image_url", "image_url": {"url": item.image}})
if not parts:
return None
if len(parts) == 1 and parts[0]["type"] == "text":
return parts[0]["text"]
return parts
def _content_to_text(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, list):
text_parts: list[str] = []
for item in value:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
text_parts.append(text)
return "".join(text_parts)
return ""
def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any:
current: Any = data
for key in path:
if not isinstance(current, dict):
return None
current = current.get(key)
return current
def _drop_none(value: Any) -> Any:
if isinstance(value, dict):
return {key: _drop_none(item) for key, item in value.items() if item is not None}
if isinstance(value, list):
return [_drop_none(item) for item in value]
return value

292
memory.py Normal file
View File

@ -0,0 +1,292 @@
from __future__ import annotations
import asyncio
import json
import logging
import re
from typing import Any
import aiohttp
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
logger = logging.getLogger("memory-recall")
_LOCATION_STOPWORDS = {
"哪里",
"在哪",
"在哪里",
"哪儿",
"位置",
"什么地方",
"帮我找",
"帮我寻找",
"找一下",
"",
"请问",
"",
"",
"",
}
class MemoryRecallClient:
def __init__(
self,
*,
url: str,
timeout: float = 5.0,
max_chars: int = 2000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
self._url = url
self._timeout = timeout
self._max_chars = max_chars
self._api_key = api_key
self._http_session = http_session
self._cached_payload: Any | None = None
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def recall(self, query: str) -> str:
query = query.strip()
if not query:
return ""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
async with self._ensure_session().get(
self._url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"Memory recall error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
try:
data = await resp.json()
except aiohttp.ContentTypeError:
data = await resp.text()
self._cached_payload = data
return self._format_memory(data, query)
except asyncio.TimeoutError:
logger.warning(
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
)
return self._format_cached_memory(query)
except aiohttp.ClientError as e:
logger.warning("Memory recall connection error: %s, using cached room graph", e)
return self._format_cached_memory(query)
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
logger.warning("Memory recall failed: %s, using cached room graph", e)
return self._format_cached_memory(query)
def _format_memory(self, data: Any, query: str) -> str:
memory = _format_room_graph_memory(data, query)
if len(memory) > self._max_chars:
memory = memory[: self._max_chars].rstrip()
return memory
def _format_cached_memory(self, query: str) -> str:
if self._cached_payload is None:
return ""
return self._format_memory(self._cached_payload, query)
def _format_room_graph_memory(payload: Any, query: str) -> str:
if not isinstance(payload, dict):
logger.warning("Unsupported room graph response: %s", payload)
return ""
objects = payload.get("objects", [])
relations = payload.get("relations", [])
summary = payload.get("summary", "")
if not objects and not relations and not summary:
return ""
query_terms = _query_terms(query)
relevant_objects, relevant_relations = _relevant_room_graph(
objects=objects,
relations=relations,
query_terms=query_terms,
)
objects_text = json.dumps(
relevant_objects or _compact_items(objects, limit=12),
ensure_ascii=False,
separators=(",", ":"),
)
relations_text = json.dumps(
relevant_relations or _compact_items(relations, limit=24),
ensure_ascii=False,
separators=(",", ":"),
)
prompt = f"""
你是一个物品定位助手。
目标物品:{query}
相关物品:{objects_text}
相关空间关系:{relations_text}
房间概览:{summary}
回答要求:
1. 只说明它和其他物品的位置关系。
2. 不要编造不存在的关系。
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip()
logger.info(
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
query_terms,
len(relevant_objects),
len(objects) if isinstance(objects, list) else 0,
len(relevant_relations),
len(relations) if isinstance(relations, list) else 0,
len(prompt),
)
return prompt
def _query_terms(query: str) -> list[str]:
normalized = re.sub(r"[\s?。!,、,.!]", "", query)
for word in _LOCATION_STOPWORDS:
normalized = normalized.replace(word, "")
terms = [normalized] if normalized else []
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
if token not in _LOCATION_STOPWORDS and token not in terms:
terms.append(token)
return terms[:4]
def _relevant_room_graph(
*,
objects: Any,
relations: Any,
query_terms: list[str],
) -> tuple[list[Any], list[Any]]:
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
return [], []
matched_ids: set[str] = set()
matched_objects: list[Any] = []
object_by_id: dict[str, Any] = {}
for obj in objects:
obj_id = _object_id(obj)
if obj_id:
object_by_id[obj_id] = obj
obj_text = _compact_text(obj)
if any(term and term in obj_text for term in query_terms):
matched_objects.append(obj)
if obj_id:
matched_ids.add(obj_id)
relevant_relations: list[Any] = []
related_ids: set[str] = set(matched_ids)
for relation in relations:
relation_text = _compact_text(relation)
relation_ids = _ids_in_value(relation)
if (
any(term and term in relation_text for term in query_terms)
or bool(matched_ids.intersection(relation_ids))
):
relevant_relations.append(relation)
related_ids.update(relation_ids)
relevant_objects = list(matched_objects)
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
for obj_id in related_ids:
obj = object_by_id.get(obj_id)
key = _object_key(obj)
if obj is not None and key not in seen_object_keys:
relevant_objects.append(obj)
seen_object_keys.add(key)
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
def _compact_items(items: Any, *, limit: int) -> list[Any]:
if not isinstance(items, list):
return []
return [_compact_item(item) for item in items[:limit]]
def _compact_item(item: Any) -> Any:
if not isinstance(item, dict):
return item
preferred_keys = (
"id",
"name",
"label",
"class",
"category",
"type",
"text",
"source",
"target",
"subject",
"object",
"relation",
"predicate",
"description",
)
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
return compact or item
def _object_id(obj: Any) -> str | None:
if not isinstance(obj, dict):
return None
for key in ("id", "object_id", "uuid", "name", "label"):
value = obj.get(key)
if isinstance(value, (str, int)):
return str(value)
return None
def _object_key(obj: Any) -> str:
return _object_id(obj) or _compact_text(obj)
def _ids_in_value(value: Any) -> set[str]:
ids: set[str] = set()
if isinstance(value, dict):
for key, item in value.items():
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
if isinstance(item, (str, int)):
ids.add(str(item))
elif isinstance(item, dict):
obj_id = _object_id(item)
if obj_id:
ids.add(obj_id)
ids.update(_ids_in_value(item))
elif isinstance(value, list):
for item in value:
ids.update(_ids_in_value(item))
return ids
def _compact_text(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))

View File

@ -1,188 +0,0 @@
import asyncio
import requests
import logging
from pathlib import Path
import uuid
import wave
import numpy as np
from datetime import datetime
from livekit import rtc
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("test-agent")
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud"
ROOM_NAME = "test-room20"
WAV_FILE = "2food.wav"
TEST_TIMEOUT = 30
class TestState:
def __init__(self):
self.agent_connected = False
self.tts_received = False
self.tts_count = 0
test_state = TestState()
def get_token(agent_name="my-agent"):
try:
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": f"test-{uuid.uuid4().hex[:6]}",
"agent_name": agent_name,
},
timeout=5
)
resp.raise_for_status()
return resp.json()["token"]
except Exception as e:
logger.error(f"❌ 获取token失败: {e}")
raise
async def publish_wav(room, wav_path):
wav_path = Path(wav_path)
if not wav_path.exists():
logger.error(f"❌ WAV文件不存在: {wav_path}")
raise FileNotFoundError(f"文件不存在: {wav_path}")
logger.info(f"📂 开始上传: {wav_path}")
with wave.open(str(wav_path), "rb") as wf:
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
logger.info(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
source = AudioSource(sample_rate, num_channels)
track = LocalAudioTrack.create_audio_track("mic", source)
await room.local_participant.publish_track(track)
logger.info("📡 已发布音轨")
frame_duration = 0.02
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
logger.info("✅ WAV推流完成")
async def test_agent():
try:
logger.info("🔑 正在获取token...")
token = get_token()
logger.info("✅ Token获取成功")
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
logger.info(f"✅ 参与者加入: {participant.identity}")
if "agent" in participant.identity.lower():
test_state.agent_connected = True
logger.info("🎉 Agent已连接")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
logger.info(f"❌ 参与者离开: {participant.identity}")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
if track.kind == rtc.TrackKind.KIND_AUDIO:
test_state.tts_count += 1
logger.info(f"🎵 收到TTS音频! (第 {test_state.tts_count} 次)")
test_state.tts_received = True
logger.info(f"🔌 正在连接房间 {ROOM_NAME}...")
await room.connect(WS_URL, token)
logger.info("✅ 已连接到房间")
logger.info(f"🆔 本地参与者ID: {room.local_participant.identity}")
logger.info("⏳ 等待Agent连接...")
for i in range(10):
if test_state.agent_connected:
break
await asyncio.sleep(1)
if not test_state.agent_connected:
logger.warning("⚠️ Agent未连接")
return False
logger.info("🎙️ 正在上传测试音频...")
await publish_wav(room, WAV_FILE)
logger.info("⏳ 等待Agent响应...")
for i in range(TEST_TIMEOUT):
if test_state.tts_received:
logger.info("✅ 收到Agent TTS响应!")
break
if i % 5 == 0:
logger.info(f" 等待中... ({i+1}/{TEST_TIMEOUT}秒)")
await asyncio.sleep(1)
await asyncio.sleep(2)
logger.info("\n" + "="*60)
logger.info("✅ 测试结果")
logger.info("="*60)
logger.info(f"Agent连接: {'' if test_state.agent_connected else ''}")
logger.info(f"收到TTS响应: {'' if test_state.tts_received else ''}")
logger.info(f"TTS音频次数: {test_state.tts_count}")
logger.info("="*60)
await room.disconnect()
logger.info("✅ 已断开连接\n")
return test_state.agent_connected and test_state.tts_received
except Exception as e:
logger.error(f"❌ 测试失败: {e}", exc_info=True)
return False
async def main():
logger.info("🚀 开始测试custom_agent...\n")
success = await test_agent()
if success:
logger.info("✅ 测试成功custom_agent 正常工作")
logger.info("💡 提示: Agent内部的转录和响应日志只能在Agent自身看到")
logger.info(" 或通过 agent-starter-react 这样的客户端交互查看")
return 0
else:
logger.error("❌ 测试失败")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
exit(exit_code)

View File

@ -1,55 +0,0 @@
import asyncio
import logging
import wave
from asr import BlackboxSTT
from livekit import rtc
# 设置日志级别以查看输出
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-asr")
async def test():
# 替换为你本地的一个音频文件路径
audio_path = "/home/verachen/Music/voice/2food.wav"
# 初始化 ASR
stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice")
print(f"Testing ASR connectivity with file: {audio_path}")
try:
# 读取音频文件
with wave.open(audio_path, "rb") as wf:
frames = wf.readframes(wf.getnframes())
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
# 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
# 所以我们需要传递一个包含 AudioFrame 的 list
# 这里我们模拟一个 Frame
frame = rtc.AudioFrame(
data=frames,
sample_rate=wf.getframerate(),
num_channels=wf.getnchannels(),
samples_per_channel=wf.getnframes(),
)
# 调用 recognize
result = await stt.recognize(buffer=[frame])
if result.alternatives:
print("\n--- ASR Result ---")
print(f"Text: {result.alternatives[0].text}")
print("------------------\n")
else:
print("ASR returned no text.")
except FileNotFoundError:
print(f"Error: Audio file not found at {audio_path}")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(test())

98
test_beaver_llm.py Normal file
View File

@ -0,0 +1,98 @@
import json
import aiohttp
from aiohttp import web
try:
from custom.beaver_llm import BeaverLLM, latest_user_text
except ModuleNotFoundError:
from beaver_llm import BeaverLLM, latest_user_text
from livekit.agents import ChatContext
def test_latest_user_text_uses_most_recent_user_message() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="user", content="first")
ctx.add_message(role="assistant", content="ignored")
ctx.add_message(role="user", content=["second", "line"])
assert latest_user_text(ctx) == "second\nline"
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="system", content="ignored instructions")
ctx.add_message(role="user", content="hello beaver")
try:
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"

View File

@ -0,0 +1,426 @@
import asyncio
import json
import sys
from pathlib import Path
import aiohttp
import pytest
from aiohttp import web
if __name__ == "__main__":
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
raise SystemExit(pytest.main([__file__]))
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
def test_build_connect_frame_uses_stable_peer_id() -> None:
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
assert frame == {
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
}
def test_build_message_frame_uses_message_id_and_text() -> None:
frame = build_message_frame(message_id="device-001-000001", text="hello")
assert frame == {
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
}
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
assert generator.next_id() == "device-001-000008"
assert generator.next_id() == "device-001-000009"
assert generator.counter == 9
def test_message_id_generator_can_include_nonce() -> None:
generator = MessageIdGenerator(peer_id="device-001", nonce="run12345")
assert generator.next_id() == "device-001-run12345-000001"
assert generator.next_id() == "device-001-run12345-000002"
async def test_client_connects_sends_text_and_returns_assistant_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "assistant reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert client.session_id == "terminal-dev:local:device-001"
assert reply == "assistant reply"
assert received == [
{
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
},
{
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
},
]
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": False,
"duplicate": True,
"pending": False,
"reply": "cached assistant reply",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "cached assistant reply"
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json({"type": "error", "error": "text is required"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="text is required"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "failed turn",
"finish_reason": "error",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="failed turn"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "ping":
await ws.send_json({"type": "pong"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
assert await client.ping()
finally:
await client.close()
await runner.cleanup()
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
unused_tcp_port: int,
) -> None:
connect_peer_ids: list[str] = []
message_ids: list[str] = []
connection_count = 0
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
nonlocal connection_count
connection_count += 1
current_connection = connection_count
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
connect_peer_ids.append(frame["peer_id"])
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
message_ids.append(frame["message_id"])
if current_connection == 1:
await ws.close()
continue
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-2",
"text": "reply after reconnect",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
await asyncio.sleep(0.01)
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "reply after reconnect"
assert connect_peer_ids == ["device-001", "device-001"]
assert message_ids == ["device-001-000001", "device-001-000002"]

262
test_hermes_gateway.py Normal file
View File

@ -0,0 +1,262 @@
import json
import aiohttp
import pytest
from aiohttp import web
from custom.hermes_gateway import (
GatewaySessionState,
HermesGatewayLLM,
build_connect_params,
chat_context_to_gateway_messages,
extract_text_delta,
is_error_response,
is_terminal_event,
)
from livekit.agents import ChatContext, llm
from livekit.agents._exceptions import APIConnectionError
def test_chat_context_to_gateway_messages_preserves_text_and_images() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="system", content="system prompt")
ctx.add_message(role="user", content=["look here", llm.ImageContent(image="data:image/png;base64,abc")])
messages = chat_context_to_gateway_messages(ctx)
assert messages == [
{"role": "system", "content": "system prompt"},
{
"role": "user",
"content": [
{"type": "text", "text": "look here"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
},
]
def test_extract_text_delta_accepts_common_gateway_event_shapes() -> None:
assert (
extract_text_delta(
{"type": "event", "event": "agent", "payload": {"delta": {"content": "hi"}}}
)
== "hi"
)
assert extract_text_delta({"type": "event", "event": "agent", "payload": {"text": " there"}}) == " there"
assert (
extract_text_delta(
{
"type": "event",
"event": "session.message.delta",
"payload": {"message": {"content": [{"type": "text", "text": "!"}]}},
}
)
== "!"
)
def test_per_room_session_state_reuses_stable_session_key() -> None:
state = GatewaySessionState(room_name="kitchen-room", agent_id="helper")
assert state.session_key == "livekit:kitchen-room:helper"
state.session_key = "gateway-session-123"
assert state.session_key == "gateway-session-123"
def test_build_connect_params_uses_backend_operator_defaults() -> None:
params = build_connect_params(token="secret-token")
assert params["client"] == {
"id": "gateway-client",
"version": "livekit-custom-agent",
"platform": "python",
"mode": "backend",
}
assert params["role"] == "operator"
assert params["scopes"] == ["operator.read", "operator.write"]
assert params["auth"] == {"token": "secret-token"}
assert "device" not in params
def test_gateway_response_helpers_match_only_current_send_request() -> None:
assert is_terminal_event({"type": "res", "id": "send-1", "ok": True}, request_id="send-1")
assert is_error_response({"type": "res", "id": "send-1", "ok": False}, request_id="send-1")
assert not is_terminal_event({"type": "res", "id": "connect-1", "ok": True}, request_id="send-1")
assert not is_error_response({"type": "res", "id": "connect-1", "ok": False}, request_id="send-1")
def test_hermes_llm_reports_provider_and_model() -> None:
state = GatewaySessionState(room_name="kitchen", agent_id="helper")
gateway_llm = HermesGatewayLLM(
url="ws://gateway.test/ws",
token="token",
state=state,
agent_id="helper",
model_name="hermes-agent",
)
assert gateway_llm.provider == "hermes-gateway"
assert gateway_llm.model == "hermes-agent"
def test_gateway_session_state_rejects_non_per_room_mode() -> None:
with pytest.raises(ValueError, match="per_room"):
GatewaySessionState(room_name="kitchen", agent_id="helper", session_mode="per_turn")
async def test_llm_stream_sends_gateway_rpcs_and_yields_text(unused_tcp_port: int) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
payload = json.loads(message.data)
received.append(payload)
method = payload.get("method")
request_id = payload.get("id")
if method == "connect":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.create":
await ws.send_json(
{
"type": "res",
"id": request_id,
"ok": True,
"result": {"sessionKey": "livekit:kitchen:helper"},
}
)
elif method == "sessions.send":
await ws.send_json(
{
"type": "event",
"event": "agent",
"payload": {"delta": {"content": "你好"}},
}
)
await ws.send_json(
{
"type": "res",
"id": request_id,
"ok": True,
"result": {"usage": {"prompt_tokens": 3, "completion_tokens": 1}},
}
)
await ws.close()
return ws
app = web.Application()
app.router.add_get("/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
gateway_llm = HermesGatewayLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
token="secret-token",
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
agent_id="helper",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="杯子在哪里")
try:
collected = await gateway_llm.chat(chat_ctx=ctx).collect()
finally:
await gateway_llm.aclose()
await runner.cleanup()
assert collected.text == "你好"
assert [item["method"] for item in received] == ["connect", "sessions.create", "sessions.send"]
send_request = received[2]
assert send_request["params"]["sessionKey"] == "livekit:kitchen:helper"
assert send_request["params"]["messages"] == [{"role": "user", "content": "杯子在哪里"}]
def test_extract_text_delta_reads_final_message_content() -> None:
assert (
extract_text_delta(
{
"type": "event",
"event": "session.message.completed",
"payload": {
"message": {
"content": [
{"type": "text", "text": "完整回复"},
]
}
},
}
)
== "完整回复"
)
def test_is_error_response_accepts_error_events() -> None:
assert is_error_response(
{"type": "event", "event": "agent.error", "payload": {"error": "boom"}},
request_id="send-1",
)
assert is_error_response(
{"type": "event", "event": "run.failed", "payload": {"message": "boom"}},
request_id="send-1",
)
async def test_llm_stream_maps_gateway_error_events_to_api_connection_error(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
await ws.send_json({"type": "event", "event": "connect.challenge", "payload": {}})
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
payload = json.loads(message.data)
method = payload.get("method")
request_id = payload.get("id")
if method == "connect":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.create":
await ws.send_json({"type": "res", "id": request_id, "ok": True})
elif method == "sessions.send":
await ws.send_json(
{
"type": "event",
"event": "run.failed",
"payload": {"message": "gateway exploded"},
}
)
await ws.close()
return ws
app = web.Application()
app.router.add_get("/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
gateway_llm = HermesGatewayLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/ws",
token=None,
state=GatewaySessionState(room_name="kitchen", agent_id="helper"),
agent_id="helper",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello")
try:
with pytest.raises(APIConnectionError, match="gateway exploded"):
await gateway_llm.chat(chat_ctx=ctx).collect()
finally:
await gateway_llm.aclose()
await runner.cleanup()

View File

@ -1,130 +0,0 @@
import asyncio
import requests
from livekit import rtc
import wave
import numpy as np
from livekit.rtc import AudioSource, AudioFrame, LocalAudioTrack
TOKEN_URL = "http://localhost:8000/getToken"
WS_URL = "wss://esp32-vt80c4y6.livekit.cloud" # 你的 LiveKit Server 地址
ROOM_NAME = "test-room20"
import uuid
IDENTITY = f"uv-{uuid.uuid4().hex[:6]}"
# IDENTITY = "test-user0"
def get_token():
resp = requests.get(
TOKEN_URL,
params={
"room": ROOM_NAME,
"identity": IDENTITY,
"agent_name": "my-agent", # 关键!!!
},
)
data = resp.json()
return data["token"]
async def main():
token = get_token()
room = rtc.Room()
@room.on("participant_connected")
def on_participant_connected(participant):
print(f"✅ 有人加入房间: {participant.identity}")
@room.on("participant_disconnected")
def on_participant_disconnected(participant):
print(f"❌ 有人离开房间: {participant.identity}")
print("🔌 正在连接房间...")
await room.connect(WS_URL, token)
print("✅ 已连接房间:", ROOM_NAME)
print("当前房间成员:")
for p in room.remote_participants.values():
print(" -", p.identity)
@room.on("data_received")
def on_data_received(data, participant, kind, topic):
try:
msg = data.decode()
print(f"📩 来自 {participant.identity}: {msg}")
except:
print("📩 收到二进制数据")
@room.on("track_subscribed")
def on_track_subscribed(track, publication, participant):
print(f"🎧 订阅轨道: {participant.identity}")
if track.kind == rtc.TrackKind.KIND_AUDIO:
print("👉 TTS 音频来了")
# 等一下确保连接稳定
await asyncio.sleep(1)
await room.local_participant.publish_data(
b"hello",
reliable=True,
topic="chat"
)
# 上传 wav
await publish_wav(room, "2food.wav")
await room.disconnect()
async def publish_wav(room, wav_path):
print("🎵 开始上传本地 wav:", wav_path)
wf = wave.open(wav_path, "rb")
sample_rate = wf.getframerate()
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
print(f"📊 WAV信息: {sample_rate}Hz, {num_channels}ch, {sample_width*8}bit")
# 创建音频源
source = AudioSource(sample_rate, num_channels)
# 创建本地音轨
track = LocalAudioTrack.create_audio_track("mic", source)
# 发布轨道
await room.local_participant.publish_track(track)
print("📡 已发布音轨")
frame_duration = 0.02 # 20ms
samples_per_frame = int(sample_rate * frame_duration)
while True:
data = wf.readframes(samples_per_frame)
if not data:
break
# 用于计算长度
audio = np.frombuffer(data, dtype=np.int16)
if len(audio) == 0:
continue
samples_per_channel = len(audio) // num_channels
frame = AudioFrame(
data=data, # ✅ 关键:用 bytes
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)
await source.capture_frame(frame)
await asyncio.sleep(frame_duration)
print("✅ wav 推流结束")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,71 +0,0 @@
import asyncio
import os
import logging
from dotenv import load_dotenv
from livekit.agents.llm import ChatContext
from livekit.plugins import openai
# Configure logging to see what's happening
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-minimax")
async def test_minimax():
print("Loading .env...")
load_dotenv()
# Configuration from environment or defaults from custom_agent.py
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "MiniMaxAI")
# Using the hardcoded key from custom_agent.py as a fallback if not in .env
API_KEY = os.getenv("MINIMAX_API_KEY", "sk-orez64WkG1NkfksB5j_hGA")
import httpx
from openai import AsyncClient as OpenAIAsyncClient
print(f"Connecting to Minimax at {MINIMAX_BASE_URL} using model {MINIMAX_MODEL}")
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key=API_KEY,
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
llm = openai.LLM(
model=MINIMAX_MODEL,
client=openai_client,
)
print("Creating ChatContext...")
chat_ctx = ChatContext()
chat_ctx.add_message(
content="Hello! Can you introduce yourself? Please reply in Chinese.",
role="user",
)
print(f"\n--- Testing Streaming Chat ---")
print(f"Request: {chat_ctx.items[-1].content}")
print("Response: ", end="", flush=True)
try:
print("\nCalling llm.chat()...")
stream = llm.chat(chat_ctx=chat_ctx)
print("Iterating over stream...")
async for chunk in stream:
if chunk.delta and chunk.delta.content:
print(chunk.delta.content, end="", flush=True)
print("\n--- Test Completed Successfully ---")
except Exception as e:
logger.error(f"\nTest failed with error: {e}")
if __name__ == "__main__":
print("Starting...")
try:
asyncio.run(asyncio.wait_for(test_minimax(), timeout=30))
except asyncio.TimeoutError:
print("\nTest timed out after 30 seconds.")
except Exception as e:
print(f"\nAn error occurred: {e}")

View File

@ -1,66 +0,0 @@
import asyncio
import logging
import os
from tts import BlackboxTTS
logging.basicConfig(level=logging.INFO)
async def test_tts():
# Use the URL from the user's curl command
url = "http://10.6.80.21:5002/tts-blackbox"
# Check if we have a real wav file to test with
# In the earlier find_by_name, we found tests/change-sophie.wav
prompt_wav = "/home/verachen/Music/voice/2food.wav"
if not os.path.exists(prompt_wav):
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
print(f"Testing BlackboxTTS with URL: {url}")
print(f"Using prompt wav: {prompt_wav}")
blackbox_tts = BlackboxTTS(
url=url,
model_name="voxcpmtts",
prompt_wav_path=prompt_wav,
params={
"streaming": "false",
"prompt_text": "澳门有乜嘢好食嘅",
"cfg_value": "2.0",
"inference_timesteps": "10",
"do_normalize": "true",
"denoise": "true",
"retry_badcase": "true",
"retry_badcase_max_times": "3",
"retry_badcase_ratio_threshold": "6.0",
},
)
text = "你好,这是一段测试文本"
print(f"Synthesizing text: {text}")
try:
stream = blackbox_tts.synthesize(text)
audio_frame = await stream.collect()
print("Successfully synthesized audio!")
print(
f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?"
)
# Actually AudioFrame has duration or samples
print(f"Samples: {len(audio_frame.data) // 2}")
# Save to file for manual check if possible
with open("test_output.wav", "wb") as f:
# This won't be a valid WAV yet if it's just raw PCM,
# but if collect() returns combined frames, we can use to_wav_bytes()
f.write(audio_frame.to_wav_bytes())
print("Saved output to test_output.wav")
except Exception as e:
print(f"TTS test failed: {e}")
if __name__ == "__main__":
asyncio.run(test_tts())

24
tts.py
View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import logging
import os
import time
import wave
from collections.abc import Mapping
from io import BytesIO
@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
started_at = time.perf_counter()
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name)
@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False
wav_header_probe = bytearray()
first_audio_at: float | None = None
chunk_count = 0
total_bytes = 0
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate,
@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
async for data, _ in resp.content.iter_chunks():
if data:
chunk_count += 1
total_bytes += len(data)
if first_audio_at is None:
first_audio_at = time.perf_counter()
logger.info(
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
first_audio_at - started_at,
len(self.input_text),
len(data),
)
if not logged_wav_format:
wav_header_probe.extend(data)
logged_wav_format = _log_wav_format(
@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
logged_wav_format = True
output_emitter.push(data)
output_emitter.flush()
finished_at = time.perf_counter()
logger.info(
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
finished_at - started_at,
(first_audio_at - started_at) if first_audio_at else -1.0,
chunk_count,
total_bytes,
len(self.input_text),
)
except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e: