Compare commits

...

2 Commits

Author SHA1 Message Date
820dc44053 fix: fix voice interupt 2026-06-12 11:17:12 +08:00
78b9138c17 feat: four mode agent 2026-06-04 15:54:09 +08:00
9 changed files with 921 additions and 55 deletions

View File

@ -2,7 +2,10 @@
LIVEKIT_URL=ws://localhost:7880
LIVEKIT_API_KEY=
LIVEKIT_API_SECRET=
CUSTOM_AGENT_NAME=my-agent
CUSTOM_AGENT_PROFILE=normal
# CUSTOM_AGENT_NAME=normal-agent
CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver
# Beaver terminal text WebSocket
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
@ -17,9 +20,14 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
CUSTOM_ASR_MAX_SPEECH_DURATION=12
# Keep false if forced ASR turns should reply even while input audio continues.
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
# LLM backend: openai/openai-compatible or hermes_gateway/openclaw.
CUSTOM_LLM_PROVIDER=beaver
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
# CUSTOM_LLM_PROVIDER=beaver
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
# OpenAI-compatible LLM

260
asr.py
View File

@ -1,11 +1,14 @@
import asyncio
import logging
from typing import Any, Optional, Union
from collections import deque
from collections.abc import AsyncIterable
from typing import Any
import aiohttp
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
@ -17,9 +20,15 @@ from livekit.agents import (
utils,
)
from livekit.agents.utils import is_given
from livekit.agents.vad import VAD, VADEventType
logger = logging.getLogger("blackbox-asr")
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
STT_CAPABILITIES = stt.STTCapabilities
class BlackboxSTT(stt.STT):
def __init__(
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
language: str | None = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
hotwords: str | None = None,
itn: bool | str | None = None,
chunk_mode: bool | str | None = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
def _form_value(value: bool | str) -> str:
if isinstance(value, bool):
return str(value).lower()
return value
class BoundedStreamAdapter(stt.STT):
def __init__(
self,
*,
stt: stt.STT,
vad: VAD,
max_speech_duration: float | None = 12.0,
pre_speech_duration: float = 0.5,
) -> None:
super().__init__(
capabilities=STT_CAPABILITIES(
streaming=True,
interim_results=False,
diarization=False,
)
)
self._vad = vad
self._stt = stt
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
self._stt.on("metrics_collected", self._on_metrics_collected)
@property
def wrapped_stt(self) -> stt.STT:
return self._stt
@property
def model(self) -> str:
return self._stt.model
@property
def provider(self) -> str:
return self._stt.provider
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
return _BoundedStreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
max_speech_duration=self._max_speech_duration,
pre_speech_duration=self._pre_speech_duration,
)
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
self.emit("metrics_collected", *args, **kwargs)
async def aclose(self) -> None:
self._stt.off("metrics_collected", self._on_metrics_collected)
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
def __init__(
self,
adapter: BoundedStreamAdapter,
*,
vad: VAD,
wrapped_stt: stt.STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
max_speech_duration: float | None,
pre_speech_duration: float,
) -> None:
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._max_speech_duration = max_speech_duration
self._pre_speech_duration = pre_speech_duration
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
async for _ in event_aiter:
pass
async def _run(self) -> None:
vad_stream = self._vad.stream()
lock = asyncio.Lock()
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
speech_active = False
segment_frames: list[rtc.AudioFrame] = []
segment_duration = 0.0
pre_roll_frames: deque[rtc.AudioFrame] = deque()
pre_roll_duration = 0.0
def _frame_duration(frame: rtc.AudioFrame) -> float:
return frame.samples_per_channel / frame.sample_rate
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
nonlocal pre_roll_duration
pre_roll_frames.append(frame)
pre_roll_duration += _frame_duration(frame)
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
if not frames:
return
if forced:
logger.info(
"Forcing ASR segment after %.2fs of continuous speech",
self._max_speech_duration,
)
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
await recognize_queue.put(frames)
async def _recognize_worker() -> None:
while True:
frames = await recognize_queue.get()
if frames is None:
return
merged_frames = utils.merge_frames(frames)
try:
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)
except Exception:
logger.exception("ASR segment recognition failed")
continue
if not t_event.alternatives or not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
async def _forward_input() -> None:
nonlocal segment_duration, segment_frames, speech_active
async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
vad_stream.flush()
continue
vad_stream.push_frame(input_frame)
forced_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active:
segment_frames.append(input_frame)
segment_duration += _frame_duration(input_frame)
if (
self._max_speech_duration is not None
and segment_duration >= self._max_speech_duration
):
forced_frames = segment_frames
segment_frames = []
segment_duration = 0.0
else:
_append_pre_roll(input_frame)
if forced_frames:
await _enqueue_segment(forced_frames, forced=True)
vad_stream.end_input()
final_frames: list[rtc.AudioFrame] = []
async with lock:
if speech_active and segment_frames:
final_frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
if final_frames:
await _enqueue_segment(final_frames)
async def _recognize_from_vad() -> None:
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
)
async with lock:
if not speech_active:
speech_active = True
segment_frames = list(pre_roll_frames)
segment_duration = sum(_frame_duration(f) for f in segment_frames)
pre_roll_frames.clear()
pre_roll_duration = 0.0
continue
if event.type != VADEventType.END_OF_SPEECH:
continue
async with lock:
frames = segment_frames
segment_frames = []
segment_duration = 0.0
speech_active = False
await _enqueue_segment(frames)
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
tasks = [
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
]
try:
await asyncio.gather(*tasks)
await recognize_queue.put(None)
await worker_task
finally:
await utils.aio.cancel_and_wait(*tasks, worker_task)
await vad_stream.aclose()

View File

@ -109,6 +109,7 @@ class BeaverLLMStream(llm.LLMStream):
) -> None:
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._beaver_llm = beaver_llm
self._request_id = shortuuid("beaver_")
async def _run(self) -> None:
user_text = latest_user_text(self.chat_ctx)
@ -116,9 +117,12 @@ class BeaverLLMStream(llm.LLMStream):
reply = await self._beaver_llm._client.send_text(user_text)
if reply:
self._event_ch.send_nowait(
llm.ChatChunk(
id=shortuuid("beaver_"),
delta=llm.ChoiceDelta(role="assistant", content=reply),
)
self._send_text_chunk(reply)
def _send_text_chunk(self, text: str) -> None:
self._event_ch.send_nowait(
llm.ChatChunk(
id=self._request_id,
delta=llm.ChoiceDelta(role="assistant", content=text),
)
)

View File

@ -28,6 +28,7 @@ class BeaverTerminalConnectionClosed(BeaverTerminalError):
@dataclass
class MessageIdGenerator:
peer_id: str
nonce: str | None = None
initial_counter: int = 0
def __post_init__(self) -> None:
@ -35,6 +36,8 @@ class MessageIdGenerator:
def next_id(self) -> str:
self.counter += 1
if self.nonce:
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
@ -77,15 +80,22 @@ class BeaverTerminalClient:
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
try:
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
await self._cleanup_failed_connection()
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
except Exception:
await self._cleanup_failed_connection()
raise
async def send_text(self, text: str) -> str:
for attempt in range(2):
@ -153,6 +163,13 @@ class BeaverTerminalClient:
await self._ws.close()
self._ws = None
async def _cleanup_failed_connection(self) -> None:
await self._close_websocket()
self.session_id = None
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import json
import logging
@ -9,12 +10,13 @@ from dataclasses import dataclass
from pathlib import Path
from beaver_llm import BeaverLLM
from beaver_terminal_client import BeaverTerminalError
from dotenv import load_dotenv
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from asr import BlackboxSTT, BoundedStreamAdapter
from livekit.agents import (
Agent,
AgentServer,
@ -32,7 +34,6 @@ from livekit.agents import (
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
@ -42,7 +43,6 @@ logger = logging.getLogger("custom-agent")
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
ROOM_LOCATOR_INSTRUCTIONS = """
你是一个房间物品定位助手。
@ -62,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
@ -71,18 +71,121 @@ GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_AGENT_PROFILE = "normal"
@dataclass(frozen=True)
class AgentProfile:
agent_name: str
llm_provider: str
input_mode: str
AGENT_PROFILES = {
"normal": AgentProfile(
agent_name="normal-agent",
llm_provider="openai-compatible",
input_mode=VOICE_INPUT_MODE,
),
"beaver": AgentProfile(
agent_name="beaver-agent",
llm_provider="beaver",
input_mode=VOICE_INPUT_MODE,
),
"vision-normal": AgentProfile(
agent_name="vision-normal-agent",
llm_provider="openai-compatible",
input_mode=VISION_VOICE_INPUT_MODE,
),
"vision-beaver": AgentProfile(
agent_name="vision-beaver-agent",
llm_provider="beaver",
input_mode=VISION_VOICE_INPUT_MODE,
),
}
AGENT_PROFILE_ALIASES = {
"default": "normal",
"openai": "normal",
"openai-compatible": "normal",
"llm": "normal",
"text": "normal",
"voice": "normal",
"vision": "vision-normal",
"vision-llm": "vision-normal",
"vision-openai": "vision-normal",
"vision-openai-compatible": "vision-normal",
}
def _normalize_agent_profile(value: str | None) -> str:
if not value or not value.strip():
return DEFAULT_AGENT_PROFILE
normalized = value.strip().lower().replace("_", "-")
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
if profile in AGENT_PROFILES:
return profile
logger.warning(
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
value,
DEFAULT_AGENT_PROFILE,
)
return DEFAULT_AGENT_PROFILE
def _agent_profile_from_name(agent_name: str | None) -> str | None:
if not agent_name or not agent_name.strip():
return None
normalized = agent_name.strip().lower().replace("_", "-")
for profile_name, profile in AGENT_PROFILES.items():
if normalized == profile.agent_name:
return profile_name
return None
def _selected_agent_profile_name() -> str:
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
if configured_profile and configured_profile.strip():
return _normalize_agent_profile(configured_profile)
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
if inferred_profile is not None:
return inferred_profile
return DEFAULT_AGENT_PROFILE
AGENT_PROFILE_NAME = _selected_agent_profile_name()
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"confident",
"confused",
"cool",
"crying",
"delicious",
"embarrassed",
"funny",
"happy",
"kissy",
"laughing",
"loving",
"neutral",
"relaxed",
"sad",
"shocked",
"silly",
"sleepy",
"surprised",
"fearful",
"calm",
"concerned",
"thinking",
"winking",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
@ -292,6 +395,7 @@ class CustomAgent(Agent):
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
@ -547,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
def _with_vision_as_latest_user_message(
chat_ctx: ChatContext, vision_frame: VisionFrame
) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
@ -614,7 +720,25 @@ def _model_image_save_dir_from_env() -> Path | None:
return Path(__file__).with_name("model_images")
server = AgentServer()
def _agent_server_from_env() -> AgentServer:
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
if configured_port is None:
return AgentServer()
try:
port = int(configured_port)
except ValueError:
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
port = 0
if port < 0 or port > 65535:
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
port = 0
return AgentServer(port=port)
server = _agent_server_from_env()
def prewarm(proc: JobProcess) -> None:
@ -636,14 +760,17 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
INPUT_MODE = _normalize_input_mode(
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
)
if LLM_PROVIDER not in {
"openai",
"openai-compatible",
@ -656,7 +783,10 @@ async def entrypoint(ctx: JobContext) -> None:
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
logger.info(
"Using LLM provider=%s model=%s base_url=%s",
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
AGENT_PROFILE_NAME,
AGENT_NAME or "<automatic>",
INPUT_MODE,
LLM_PROVIDER,
LLM_MODEL,
LLM_BASE_URL or "OpenAI default",
@ -673,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
beaver_warmup_text: str | None = None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -683,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
stt_stream = BoundedStreamAdapter(
stt=blackbox_stt,
vad=ctx.proc.userdata["vad"],
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
)
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
allow_interruptions = _env_bool(
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
ASR_MAX_SPEECH_DURATION <= 0,
)
if LLM_PROVIDER == "beaver":
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
if not beaver_url:
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
raise RuntimeError(
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
)
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
beaver_peer_id = (
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
)
beaver_device_name = (
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
or "livekit-custom-agent"
@ -702,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
)
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
beaver_url,
beaver_peer_id,
beaver_device_name,
ctx.room.name,
base_llm.session_id,
bool(beaver_warmup_text and beaver_warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
@ -805,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
turn_detection=turn_detection,
interruption={
"enabled": allow_interruptions,
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
@ -837,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
def _interrupt_done(fut: asyncio.Future[None]) -> None:
try:
fut.result()
except Exception:
logger.exception("Bridge interrupt failed")
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
reason = None if payload is None else payload.get("reason")
logger.info(
"Received bridge interrupt: topic=%s reason=%s",
packet_topic or "<default>",
reason if isinstance(reason, str) and reason else "<unspecified>",
)
try:
interrupt_fut = session.interrupt(force=True)
except RuntimeError:
logger.exception("Bridge interrupt received before AgentSession was running")
return
interrupt_fut.add_done_callback(_interrupt_done)
if packet_topic == INTERRUPT_TOPIC:
payload: dict[str, object] | None = None
try:
decoded = json.loads(data_packet.data.decode("utf-8"))
if isinstance(decoded, dict):
payload = decoded
except Exception:
logger.exception("Failed to decode interrupt payload")
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
logger.exception("Failed to decode data payload")
return
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
@ -904,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
),
record=_recording_options_from_env(),
)
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
_start_beaver_background_warmup(
ctx=ctx,
beaver_llm=base_llm,
warmup_text=beaver_warmup_text,
)
def _start_beaver_background_warmup(
*,
ctx: JobContext,
beaver_llm: BeaverLLM,
warmup_text: str | None,
) -> None:
async def _warmup() -> None:
try:
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
except BeaverTerminalError:
logger.warning(
"Beaver background handshake failed; will retry on first user turn",
exc_info=True,
)
return
except Exception:
logger.exception("Unexpected Beaver background handshake failure")
return
logger.info(
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
ctx.room.name,
beaver_llm.session_id,
bool(warmup_text and warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
async def _cancel_warmup() -> None:
if warmup_task.done():
return
warmup_task.cancel()
try:
await warmup_task
except asyncio.CancelledError:
pass
ctx.add_shutdown_callback(_cancel_warmup)
def _tts_params_from_env(model_name: str) -> dict[str, str]:

264
start_agent_profiles.py Normal file
View File

@ -0,0 +1,264 @@
from __future__ import annotations
import argparse
import asyncio
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from dotenv import load_dotenv
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
@dataclass(frozen=True)
class AgentProfile:
agent_name: str
llm_provider: str
input_mode: str
AGENT_PROFILES = {
"normal": AgentProfile(
agent_name="normal-agent",
llm_provider="openai-compatible",
input_mode="voice",
),
"beaver": AgentProfile(
agent_name="beaver-agent",
llm_provider="beaver",
input_mode="voice",
),
"vision-normal": AgentProfile(
agent_name="vision-normal-agent",
llm_provider="openai-compatible",
input_mode="vision_voice",
),
"vision-beaver": AgentProfile(
agent_name="vision-beaver-agent",
llm_provider="beaver",
input_mode="vision_voice",
),
}
DEFAULT_PROFILES = ("normal", "beaver")
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def _parse_profiles(value: str) -> list[str]:
if value.strip().lower() == "all":
return list(AGENT_PROFILES)
profiles: list[str] = []
for raw_profile in value.split(","):
profile = raw_profile.strip().lower().replace("_", "-")
if not profile:
continue
if profile not in AGENT_PROFILES:
valid = ", ".join([*AGENT_PROFILES, "all"])
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
profiles.append(profile)
if not profiles:
raise ValueError("at least one profile is required")
return list(dict.fromkeys(profiles))
def _parse_http_port_base(value: str | None) -> int:
if value is None or not value.strip():
return 0
try:
port = int(value)
except ValueError as exc:
raise ValueError(f"invalid HTTP port base {value!r}") from exc
if port < 0 or port > 65535:
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
return port
def _profile_http_port(http_port_base: int, index: int) -> int:
if http_port_base == 0:
return 0
port = http_port_base + index
if port > 65535:
raise ValueError("HTTP port range exceeds 65535")
return port
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
profile = AGENT_PROFILES[profile_name]
env = os.environ.copy()
env.update(
{
"CUSTOM_AGENT_PROFILE": profile_name,
"CUSTOM_AGENT_NAME": profile.agent_name,
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
}
)
return env
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
while line := await stream.readline():
text = line.decode("utf-8", errors="replace").rstrip()
print(f"[{prefix}] {text}", flush=True)
async def _start_profile(
profile_name: str,
*,
mode: str,
http_port: int,
reload: bool,
) -> asyncio.subprocess.Process:
profile = AGENT_PROFILES[profile_name]
script_path = Path(__file__).with_name("custom_agent.py")
args = [sys.executable, str(script_path), mode]
if mode == "dev" and not reload:
args.append("--no-reload")
process = await asyncio.create_subprocess_exec(
*args,
env=_child_env(profile_name, http_port=http_port),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
print(
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
flush=True,
)
if process.stdout is not None:
asyncio.create_task(_pipe_output(profile_name, process.stdout))
if process.stderr is not None:
asyncio.create_task(_pipe_output(profile_name, process.stderr))
return process
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
for process in processes:
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(
asyncio.gather(*(process.wait() for process in processes)),
timeout=10.0,
)
except asyncio.TimeoutError:
for process in processes:
if process.returncode is None:
process.kill()
await asyncio.gather(*(process.wait() for process in processes))
async def _run(
profiles: list[str],
*,
mode: str,
http_port_base: int,
reload: bool,
) -> int:
processes = [
await _start_profile(
profile,
mode=mode,
http_port=_profile_http_port(http_port_base, index),
reload=reload,
)
for index, profile in enumerate(profiles)
]
stop_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, stop_event.set)
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
stop_task = asyncio.create_task(stop_event.wait())
done, _ = await asyncio.wait(
[*wait_tasks, stop_task],
return_when=asyncio.FIRST_COMPLETED,
)
exit_code = 0
if stop_task not in done:
for task in done:
process = wait_tasks[task]
exit_code = task.result() or 0
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
await _terminate(processes)
return exit_code
def main() -> None:
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
parser.add_argument(
"mode",
nargs="?",
default="dev",
choices=("console", "dev", "start", "connect"),
help="custom_agent.py CLI mode to run for each profile",
)
parser.add_argument(
"--profiles",
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
help="comma-separated profiles to start, or 'all'",
)
parser.add_argument(
"--http-port-base",
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
help=(
"base HTTP health-check port for profile workers; "
"0 lets the OS assign free ports"
),
)
parser.add_argument(
"--reload",
action="store_true",
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
help="enable auto-reload in dev mode",
)
args = parser.parse_args()
try:
profiles = _parse_profiles(args.profiles)
http_port_base = _parse_http_port_base(args.http_port_base)
except ValueError as exc:
parser.error(str(exc))
raise SystemExit(
asyncio.run(
_run(
profiles,
mode=args.mode,
http_port_base=http_port_base,
reload=args.reload,
)
)
)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
import asyncio
import json
import aiohttp
@ -96,3 +97,73 @@ async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await asyncio.sleep(0.05)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello beaver")
try:
chunks: list[str] = []
async with beaver_llm.chat(chat_ctx=ctx) as stream:
async for chunk in stream:
if chunk.delta and chunk.delta.content:
chunks.append(chunk.delta.content)
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert chunks == ["beaver reply"]

View File

@ -14,6 +14,7 @@ if __name__ == "__main__":
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
@ -22,6 +23,7 @@ try:
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalConnectionClosed,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
@ -244,6 +246,36 @@ async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
await runner.cleanup()
async def test_client_cleans_up_owned_session_when_connect_fails(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.Response:
return web.Response(status=200, text="not a websocket")
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
await client.connect()
assert client._http_session is None
assert client._ws is None
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:

7
tts.py
View File

@ -7,7 +7,6 @@ import time
import wave
from collections.abc import Mapping
from io import BytesIO
from typing import Optional
import aiohttp
@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS):
*,
url: str,
model_name: str = "voxcpmtts",
params: Optional[Mapping[str, object]] = None,
prompt_wav_path: Optional[str] = None,
params: Mapping[str, object] | None = None,
prompt_wav_path: str | None = None,
prompt_wav_field: str = "prompt_wav",
sample_rate: int = 16000,
num_channels: int = 1,
timeout: float = 60.0,
http_session: Optional[aiohttp.ClientSession] = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),