Compare commits
2 Commits
52e6d3cd9c
...
beaver
| Author | SHA1 | Date | |
|---|---|---|---|
| 820dc44053 | |||
| 78b9138c17 |
14
.env.example
14
.env.example
@ -2,7 +2,10 @@
|
|||||||
LIVEKIT_URL=ws://localhost:7880
|
LIVEKIT_URL=ws://localhost:7880
|
||||||
LIVEKIT_API_KEY=
|
LIVEKIT_API_KEY=
|
||||||
LIVEKIT_API_SECRET=
|
LIVEKIT_API_SECRET=
|
||||||
CUSTOM_AGENT_NAME=my-agent
|
|
||||||
|
CUSTOM_AGENT_PROFILE=normal
|
||||||
|
# CUSTOM_AGENT_NAME=normal-agent
|
||||||
|
CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver
|
||||||
|
|
||||||
# Beaver terminal text WebSocket
|
# Beaver terminal text WebSocket
|
||||||
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
|
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
|
||||||
@ -17,9 +20,14 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh
|
|||||||
CUSTOM_ASR_HOTWORDS=
|
CUSTOM_ASR_HOTWORDS=
|
||||||
CUSTOM_ASR_ITN=
|
CUSTOM_ASR_ITN=
|
||||||
CUSTOM_ASR_CHUNK_MODE=
|
CUSTOM_ASR_CHUNK_MODE=
|
||||||
|
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
|
||||||
|
CUSTOM_ASR_MAX_SPEECH_DURATION=12
|
||||||
|
# Keep false if forced ASR turns should reply even while input audio continues.
|
||||||
|
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
|
||||||
|
|
||||||
# LLM backend: openai/openai-compatible or hermes_gateway/openclaw.
|
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
|
||||||
CUSTOM_LLM_PROVIDER=beaver
|
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
|
||||||
|
# CUSTOM_LLM_PROVIDER=beaver
|
||||||
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
|
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
|
||||||
|
|
||||||
# OpenAI-compatible LLM
|
# OpenAI-compatible LLM
|
||||||
|
|||||||
260
asr.py
260
asr.py
@ -1,11 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional, Union
|
from collections import deque
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from livekit import rtc
|
from livekit import rtc
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
NOT_GIVEN,
|
NOT_GIVEN,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIConnectOptions,
|
APIConnectOptions,
|
||||||
@ -17,9 +20,15 @@ from livekit.agents import (
|
|||||||
utils,
|
utils,
|
||||||
)
|
)
|
||||||
from livekit.agents.utils import is_given
|
from livekit.agents.utils import is_given
|
||||||
|
from livekit.agents.vad import VAD, VADEventType
|
||||||
|
|
||||||
logger = logging.getLogger("blackbox-asr")
|
logger = logging.getLogger("blackbox-asr")
|
||||||
|
|
||||||
|
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
|
||||||
|
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
|
||||||
|
)
|
||||||
|
STT_CAPABILITIES = stt.STTCapabilities
|
||||||
|
|
||||||
|
|
||||||
class BlackboxSTT(stt.STT):
|
class BlackboxSTT(stt.STT):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
|
|||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
model_name: str = "sensevoice",
|
model_name: str = "sensevoice",
|
||||||
language: Optional[str] = "auto",
|
language: str | None = "auto",
|
||||||
output_language: str = "zh",
|
output_language: str = "zh",
|
||||||
hotwords: Optional[str] = None,
|
hotwords: str | None = None,
|
||||||
itn: Optional[Union[bool, str]] = None,
|
itn: bool | str | None = None,
|
||||||
chunk_mode: Optional[Union[bool, str]] = None,
|
chunk_mode: bool | str | None = None,
|
||||||
timeout: float = 30.0,
|
timeout: float = 30.0,
|
||||||
http_session: Optional[aiohttp.ClientSession] = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
capabilities=stt.STTCapabilities(
|
capabilities=stt.STTCapabilities(
|
||||||
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
|
|||||||
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||||
|
|
||||||
|
|
||||||
def _form_value(value: Union[bool, str]) -> str:
|
def _form_value(value: bool | str) -> str:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return str(value).lower()
|
return str(value).lower()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BoundedStreamAdapter(stt.STT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
stt: stt.STT,
|
||||||
|
vad: VAD,
|
||||||
|
max_speech_duration: float | None = 12.0,
|
||||||
|
pre_speech_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
capabilities=STT_CAPABILITIES(
|
||||||
|
streaming=True,
|
||||||
|
interim_results=False,
|
||||||
|
diarization=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._vad = vad
|
||||||
|
self._stt = stt
|
||||||
|
self._max_speech_duration = max_speech_duration
|
||||||
|
self._pre_speech_duration = pre_speech_duration
|
||||||
|
self._stt.on("metrics_collected", self._on_metrics_collected)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wrapped_stt(self) -> stt.STT:
|
||||||
|
return self._stt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._stt.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return self._stt.provider
|
||||||
|
|
||||||
|
async def _recognize_impl(
|
||||||
|
self,
|
||||||
|
buffer: utils.AudioBuffer,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> stt.SpeechEvent:
|
||||||
|
return await self._stt.recognize(
|
||||||
|
buffer=buffer, language=language, conn_options=conn_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> stt.RecognizeStream:
|
||||||
|
return _BoundedStreamAdapterWrapper(
|
||||||
|
self,
|
||||||
|
vad=self._vad,
|
||||||
|
wrapped_stt=self._stt,
|
||||||
|
language=language,
|
||||||
|
conn_options=conn_options,
|
||||||
|
max_speech_duration=self._max_speech_duration,
|
||||||
|
pre_speech_duration=self._pre_speech_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
self.emit("metrics_collected", *args, **kwargs)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
self._stt.off("metrics_collected", self._on_metrics_collected)
|
||||||
|
|
||||||
|
|
||||||
|
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
adapter: BoundedStreamAdapter,
|
||||||
|
*,
|
||||||
|
vad: VAD,
|
||||||
|
wrapped_stt: stt.STT,
|
||||||
|
language: NotGivenOr[str],
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
max_speech_duration: float | None,
|
||||||
|
pre_speech_duration: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
|
||||||
|
self._vad = vad
|
||||||
|
self._wrapped_stt = wrapped_stt
|
||||||
|
self._wrapped_stt_conn_options = conn_options
|
||||||
|
self._language = language
|
||||||
|
self._max_speech_duration = max_speech_duration
|
||||||
|
self._pre_speech_duration = pre_speech_duration
|
||||||
|
|
||||||
|
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
|
||||||
|
async for _ in event_aiter:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
vad_stream = self._vad.stream()
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
|
||||||
|
|
||||||
|
speech_active = False
|
||||||
|
segment_frames: list[rtc.AudioFrame] = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
pre_roll_frames: deque[rtc.AudioFrame] = deque()
|
||||||
|
pre_roll_duration = 0.0
|
||||||
|
|
||||||
|
def _frame_duration(frame: rtc.AudioFrame) -> float:
|
||||||
|
return frame.samples_per_channel / frame.sample_rate
|
||||||
|
|
||||||
|
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
|
||||||
|
nonlocal pre_roll_duration
|
||||||
|
pre_roll_frames.append(frame)
|
||||||
|
pre_roll_duration += _frame_duration(frame)
|
||||||
|
|
||||||
|
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
|
||||||
|
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
|
||||||
|
|
||||||
|
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
|
||||||
|
if not frames:
|
||||||
|
return
|
||||||
|
|
||||||
|
if forced:
|
||||||
|
logger.info(
|
||||||
|
"Forcing ASR segment after %.2fs of continuous speech",
|
||||||
|
self._max_speech_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
||||||
|
await recognize_queue.put(frames)
|
||||||
|
|
||||||
|
async def _recognize_worker() -> None:
|
||||||
|
while True:
|
||||||
|
frames = await recognize_queue.get()
|
||||||
|
if frames is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
merged_frames = utils.merge_frames(frames)
|
||||||
|
try:
|
||||||
|
t_event = await self._wrapped_stt.recognize(
|
||||||
|
buffer=merged_frames,
|
||||||
|
language=self._language,
|
||||||
|
conn_options=self._wrapped_stt_conn_options,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ASR segment recognition failed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not t_event.alternatives or not t_event.alternatives[0].text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
stt.SpeechEvent(
|
||||||
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||||
|
alternatives=[t_event.alternatives[0]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _forward_input() -> None:
|
||||||
|
nonlocal segment_duration, segment_frames, speech_active
|
||||||
|
|
||||||
|
async for input_frame in self._input_ch:
|
||||||
|
if isinstance(input_frame, self._FlushSentinel):
|
||||||
|
vad_stream.flush()
|
||||||
|
continue
|
||||||
|
|
||||||
|
vad_stream.push_frame(input_frame)
|
||||||
|
forced_frames: list[rtc.AudioFrame] = []
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
if speech_active:
|
||||||
|
segment_frames.append(input_frame)
|
||||||
|
segment_duration += _frame_duration(input_frame)
|
||||||
|
if (
|
||||||
|
self._max_speech_duration is not None
|
||||||
|
and segment_duration >= self._max_speech_duration
|
||||||
|
):
|
||||||
|
forced_frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
else:
|
||||||
|
_append_pre_roll(input_frame)
|
||||||
|
|
||||||
|
if forced_frames:
|
||||||
|
await _enqueue_segment(forced_frames, forced=True)
|
||||||
|
|
||||||
|
vad_stream.end_input()
|
||||||
|
|
||||||
|
final_frames: list[rtc.AudioFrame] = []
|
||||||
|
async with lock:
|
||||||
|
if speech_active and segment_frames:
|
||||||
|
final_frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
speech_active = False
|
||||||
|
|
||||||
|
if final_frames:
|
||||||
|
await _enqueue_segment(final_frames)
|
||||||
|
|
||||||
|
async def _recognize_from_vad() -> None:
|
||||||
|
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
|
||||||
|
|
||||||
|
async for event in vad_stream:
|
||||||
|
if event.type == VADEventType.START_OF_SPEECH:
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
if not speech_active:
|
||||||
|
speech_active = True
|
||||||
|
segment_frames = list(pre_roll_frames)
|
||||||
|
segment_duration = sum(_frame_duration(f) for f in segment_frames)
|
||||||
|
pre_roll_frames.clear()
|
||||||
|
pre_roll_duration = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.type != VADEventType.END_OF_SPEECH:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
frames = segment_frames
|
||||||
|
segment_frames = []
|
||||||
|
segment_duration = 0.0
|
||||||
|
speech_active = False
|
||||||
|
|
||||||
|
await _enqueue_segment(frames)
|
||||||
|
|
||||||
|
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
|
||||||
|
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
await recognize_queue.put(None)
|
||||||
|
await worker_task
|
||||||
|
finally:
|
||||||
|
await utils.aio.cancel_and_wait(*tasks, worker_task)
|
||||||
|
await vad_stream.aclose()
|
||||||
|
|||||||
@ -109,6 +109,7 @@ class BeaverLLMStream(llm.LLMStream):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||||
self._beaver_llm = beaver_llm
|
self._beaver_llm = beaver_llm
|
||||||
|
self._request_id = shortuuid("beaver_")
|
||||||
|
|
||||||
async def _run(self) -> None:
|
async def _run(self) -> None:
|
||||||
user_text = latest_user_text(self.chat_ctx)
|
user_text = latest_user_text(self.chat_ctx)
|
||||||
@ -116,9 +117,12 @@ class BeaverLLMStream(llm.LLMStream):
|
|||||||
reply = await self._beaver_llm._client.send_text(user_text)
|
reply = await self._beaver_llm._client.send_text(user_text)
|
||||||
|
|
||||||
if reply:
|
if reply:
|
||||||
|
self._send_text_chunk(reply)
|
||||||
|
|
||||||
|
def _send_text_chunk(self, text: str) -> None:
|
||||||
self._event_ch.send_nowait(
|
self._event_ch.send_nowait(
|
||||||
llm.ChatChunk(
|
llm.ChatChunk(
|
||||||
id=shortuuid("beaver_"),
|
id=self._request_id,
|
||||||
delta=llm.ChoiceDelta(role="assistant", content=reply),
|
delta=llm.ChoiceDelta(role="assistant", content=text),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MessageIdGenerator:
|
class MessageIdGenerator:
|
||||||
peer_id: str
|
peer_id: str
|
||||||
|
nonce: str | None = None
|
||||||
initial_counter: int = 0
|
initial_counter: int = 0
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -35,6 +36,8 @@ class MessageIdGenerator:
|
|||||||
|
|
||||||
def next_id(self) -> str:
|
def next_id(self) -> str:
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
|
if self.nonce:
|
||||||
|
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
|
||||||
return f"{self.peer_id}-{self.counter:06d}"
|
return f"{self.peer_id}-{self.counter:06d}"
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +80,7 @@ class BeaverTerminalClient:
|
|||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
await self._close_websocket()
|
await self._close_websocket()
|
||||||
session = self._ensure_http_session()
|
session = self._ensure_http_session()
|
||||||
|
try:
|
||||||
self._ws = await session.ws_connect(self._url)
|
self._ws = await session.ws_connect(self._url)
|
||||||
await self._send_json(
|
await self._send_json(
|
||||||
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||||
@ -86,6 +90,12 @@ class BeaverTerminalClient:
|
|||||||
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||||
session_id = frame.get("session_id")
|
session_id = frame.get("session_id")
|
||||||
self.session_id = session_id if isinstance(session_id, str) else None
|
self.session_id = session_id if isinstance(session_id, str) else None
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
|
||||||
|
await self._cleanup_failed_connection()
|
||||||
|
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
|
||||||
|
except Exception:
|
||||||
|
await self._cleanup_failed_connection()
|
||||||
|
raise
|
||||||
|
|
||||||
async def send_text(self, text: str) -> str:
|
async def send_text(self, text: str) -> str:
|
||||||
for attempt in range(2):
|
for attempt in range(2):
|
||||||
@ -153,6 +163,13 @@ class BeaverTerminalClient:
|
|||||||
await self._ws.close()
|
await self._ws.close()
|
||||||
self._ws = None
|
self._ws = None
|
||||||
|
|
||||||
|
async def _cleanup_failed_connection(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
|
self.session_id = None
|
||||||
|
if self._owned_session and self._http_session is not None:
|
||||||
|
await self._http_session.close()
|
||||||
|
self._http_session = None
|
||||||
|
|
||||||
def _websocket_is_open(self) -> bool:
|
def _websocket_is_open(self) -> bool:
|
||||||
return self._ws is not None and not self._ws.closed
|
return self._ws is not None and not self._ws.closed
|
||||||
|
|
||||||
|
|||||||
279
custom_agent.py
279
custom_agent.py
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -9,12 +10,13 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from beaver_llm import BeaverLLM
|
from beaver_llm import BeaverLLM
|
||||||
|
from beaver_terminal_client import BeaverTerminalError
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||||||
from memory import MemoryRecallClient
|
from memory import MemoryRecallClient
|
||||||
from tts import BlackboxTTS
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
from asr import BlackboxSTT
|
from asr import BlackboxSTT, BoundedStreamAdapter
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
AgentServer,
|
AgentServer,
|
||||||
@ -32,7 +34,6 @@ from livekit.agents import (
|
|||||||
llm,
|
llm,
|
||||||
metrics,
|
metrics,
|
||||||
room_io,
|
room_io,
|
||||||
stt,
|
|
||||||
)
|
)
|
||||||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||||||
from livekit.plugins import openai, silero
|
from livekit.plugins import openai, silero
|
||||||
@ -42,7 +43,6 @@ logger = logging.getLogger("custom-agent")
|
|||||||
|
|
||||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
|
||||||
|
|
||||||
ROOM_LOCATOR_INSTRUCTIONS = """
|
ROOM_LOCATOR_INSTRUCTIONS = """
|
||||||
你是一个房间物品定位助手。
|
你是一个房间物品定位助手。
|
||||||
@ -62,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
|
|||||||
|
|
||||||
EMOTION_INSTRUCTIONS = """
|
EMOTION_INSTRUCTIONS = """
|
||||||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
||||||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
|
||||||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
@ -71,18 +71,121 @@ GENERAL_MODE = "general"
|
|||||||
VOICE_INPUT_MODE = "voice"
|
VOICE_INPUT_MODE = "voice"
|
||||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||||
AUTO_INPUT_MODE = "auto"
|
AUTO_INPUT_MODE = "auto"
|
||||||
|
INTERRUPT_TOPIC = "lk.interrupt"
|
||||||
VISION_FRAME_TOPIC = "vision.frame"
|
VISION_FRAME_TOPIC = "vision.frame"
|
||||||
|
DEFAULT_AGENT_PROFILE = "normal"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentProfile:
|
||||||
|
agent_name: str
|
||||||
|
llm_provider: str
|
||||||
|
input_mode: str
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_PROFILES = {
|
||||||
|
"normal": AgentProfile(
|
||||||
|
agent_name="normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode=VOICE_INPUT_MODE,
|
||||||
|
),
|
||||||
|
"beaver": AgentProfile(
|
||||||
|
agent_name="beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode=VOICE_INPUT_MODE,
|
||||||
|
),
|
||||||
|
"vision-normal": AgentProfile(
|
||||||
|
agent_name="vision-normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode=VISION_VOICE_INPUT_MODE,
|
||||||
|
),
|
||||||
|
"vision-beaver": AgentProfile(
|
||||||
|
agent_name="vision-beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode=VISION_VOICE_INPUT_MODE,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
AGENT_PROFILE_ALIASES = {
|
||||||
|
"default": "normal",
|
||||||
|
"openai": "normal",
|
||||||
|
"openai-compatible": "normal",
|
||||||
|
"llm": "normal",
|
||||||
|
"text": "normal",
|
||||||
|
"voice": "normal",
|
||||||
|
"vision": "vision-normal",
|
||||||
|
"vision-llm": "vision-normal",
|
||||||
|
"vision-openai": "vision-normal",
|
||||||
|
"vision-openai-compatible": "vision-normal",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_agent_profile(value: str | None) -> str:
|
||||||
|
if not value or not value.strip():
|
||||||
|
return DEFAULT_AGENT_PROFILE
|
||||||
|
|
||||||
|
normalized = value.strip().lower().replace("_", "-")
|
||||||
|
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
|
||||||
|
if profile in AGENT_PROFILES:
|
||||||
|
return profile
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
|
||||||
|
value,
|
||||||
|
DEFAULT_AGENT_PROFILE,
|
||||||
|
)
|
||||||
|
return DEFAULT_AGENT_PROFILE
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_profile_from_name(agent_name: str | None) -> str | None:
|
||||||
|
if not agent_name or not agent_name.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized = agent_name.strip().lower().replace("_", "-")
|
||||||
|
for profile_name, profile in AGENT_PROFILES.items():
|
||||||
|
if normalized == profile.agent_name:
|
||||||
|
return profile_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _selected_agent_profile_name() -> str:
|
||||||
|
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
|
||||||
|
if configured_profile and configured_profile.strip():
|
||||||
|
return _normalize_agent_profile(configured_profile)
|
||||||
|
|
||||||
|
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
|
||||||
|
if inferred_profile is not None:
|
||||||
|
return inferred_profile
|
||||||
|
|
||||||
|
return DEFAULT_AGENT_PROFILE
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_PROFILE_NAME = _selected_agent_profile_name()
|
||||||
|
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
|
||||||
|
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||||||
|
|
||||||
DEFAULT_EMOTION = "neutral"
|
DEFAULT_EMOTION = "neutral"
|
||||||
EMOTION_LABELS = {
|
EMOTION_LABELS = {
|
||||||
"neutral",
|
|
||||||
"happy",
|
|
||||||
"sad",
|
|
||||||
"angry",
|
"angry",
|
||||||
|
"confident",
|
||||||
|
"confused",
|
||||||
|
"cool",
|
||||||
|
"crying",
|
||||||
|
"delicious",
|
||||||
|
"embarrassed",
|
||||||
|
"funny",
|
||||||
|
"happy",
|
||||||
|
"kissy",
|
||||||
|
"laughing",
|
||||||
|
"loving",
|
||||||
|
"neutral",
|
||||||
|
"relaxed",
|
||||||
|
"sad",
|
||||||
|
"shocked",
|
||||||
|
"silly",
|
||||||
|
"sleepy",
|
||||||
"surprised",
|
"surprised",
|
||||||
"fearful",
|
"thinking",
|
||||||
"calm",
|
"winking",
|
||||||
"concerned",
|
|
||||||
}
|
}
|
||||||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
||||||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
||||||
@ -292,6 +395,7 @@ class CustomAgent(Agent):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return _stream()
|
return _stream()
|
||||||
|
|
||||||
async def _observe_emotion_prefix(
|
async def _observe_emotion_prefix(
|
||||||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
self, chunk: llm.ChatChunk | str | FlushSentinel
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||||
@ -547,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
|
|||||||
return chat_ctx
|
return chat_ctx
|
||||||
|
|
||||||
|
|
||||||
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
|
def _with_vision_as_latest_user_message(
|
||||||
|
chat_ctx: ChatContext, vision_frame: VisionFrame
|
||||||
|
) -> ChatContext:
|
||||||
chat_ctx = chat_ctx.copy()
|
chat_ctx = chat_ctx.copy()
|
||||||
image_content = llm.ImageContent(
|
image_content = llm.ImageContent(
|
||||||
image=vision_frame.image_data_url,
|
image=vision_frame.image_data_url,
|
||||||
@ -614,7 +720,25 @@ def _model_image_save_dir_from_env() -> Path | None:
|
|||||||
return Path(__file__).with_name("model_images")
|
return Path(__file__).with_name("model_images")
|
||||||
|
|
||||||
|
|
||||||
server = AgentServer()
|
def _agent_server_from_env() -> AgentServer:
|
||||||
|
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
|
||||||
|
if configured_port is None:
|
||||||
|
return AgentServer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
port = int(configured_port)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||||
|
port = 0
|
||||||
|
|
||||||
|
if port < 0 or port > 65535:
|
||||||
|
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||||
|
port = 0
|
||||||
|
|
||||||
|
return AgentServer(port=port)
|
||||||
|
|
||||||
|
|
||||||
|
server = _agent_server_from_env()
|
||||||
|
|
||||||
|
|
||||||
def prewarm(proc: JobProcess) -> None:
|
def prewarm(proc: JobProcess) -> None:
|
||||||
@ -636,14 +760,17 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||||||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||||
|
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
|
||||||
|
|
||||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
||||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
|
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
INPUT_MODE = _normalize_input_mode(
|
||||||
|
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
|
||||||
|
)
|
||||||
if LLM_PROVIDER not in {
|
if LLM_PROVIDER not in {
|
||||||
"openai",
|
"openai",
|
||||||
"openai-compatible",
|
"openai-compatible",
|
||||||
@ -656,7 +783,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using LLM provider=%s model=%s base_url=%s",
|
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
|
||||||
|
AGENT_PROFILE_NAME,
|
||||||
|
AGENT_NAME or "<automatic>",
|
||||||
|
INPUT_MODE,
|
||||||
LLM_PROVIDER,
|
LLM_PROVIDER,
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
LLM_BASE_URL or "OpenAI default",
|
LLM_BASE_URL or "OpenAI default",
|
||||||
@ -673,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||||
|
beaver_warmup_text: str | None = None
|
||||||
|
|
||||||
blackbox_stt = BlackboxSTT(
|
blackbox_stt = BlackboxSTT(
|
||||||
url=ASR_URL,
|
url=ASR_URL,
|
||||||
@ -683,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
itn=os.getenv("CUSTOM_ASR_ITN"),
|
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||||||
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||||||
)
|
)
|
||||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
stt_stream = BoundedStreamAdapter(
|
||||||
|
stt=blackbox_stt,
|
||||||
|
vad=ctx.proc.userdata["vad"],
|
||||||
|
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
|
||||||
|
)
|
||||||
|
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
|
||||||
|
allow_interruptions = _env_bool(
|
||||||
|
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
|
||||||
|
ASR_MAX_SPEECH_DURATION <= 0,
|
||||||
|
)
|
||||||
|
|
||||||
if LLM_PROVIDER == "beaver":
|
if LLM_PROVIDER == "beaver":
|
||||||
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
|
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
|
||||||
if not beaver_url:
|
if not beaver_url:
|
||||||
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
|
raise RuntimeError(
|
||||||
|
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
|
||||||
|
)
|
||||||
|
|
||||||
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
beaver_peer_id = (
|
||||||
|
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||||||
|
)
|
||||||
beaver_device_name = (
|
beaver_device_name = (
|
||||||
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
|
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
|
||||||
or "livekit-custom-agent"
|
or "livekit-custom-agent"
|
||||||
@ -702,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
||||||
)
|
)
|
||||||
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
|
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
|
||||||
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
|
|
||||||
text_llm = base_llm
|
text_llm = base_llm
|
||||||
vision_llm = base_llm
|
vision_llm = base_llm
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
|
||||||
beaver_url,
|
beaver_url,
|
||||||
beaver_peer_id,
|
beaver_peer_id,
|
||||||
beaver_device_name,
|
beaver_device_name,
|
||||||
ctx.room.name,
|
ctx.room.name,
|
||||||
base_llm.session_id,
|
|
||||||
bool(beaver_warmup_text and beaver_warmup_text.strip()),
|
bool(beaver_warmup_text and beaver_warmup_text.strip()),
|
||||||
len(warmup_reply) if warmup_reply is not None else 0,
|
|
||||||
)
|
)
|
||||||
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||||||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||||||
@ -805,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
# 4. Silero VAD
|
# 4. Silero VAD
|
||||||
vad=ctx.proc.userdata["vad"],
|
vad=ctx.proc.userdata["vad"],
|
||||||
turn_handling=TurnHandlingOptions(
|
turn_handling=TurnHandlingOptions(
|
||||||
turn_detection=MultilingualModel(),
|
turn_detection=turn_detection,
|
||||||
interruption={
|
interruption={
|
||||||
|
"enabled": allow_interruptions,
|
||||||
"resume_false_interruption": True,
|
"resume_false_interruption": True,
|
||||||
"false_interruption_timeout": 1.0,
|
"false_interruption_timeout": 1.0,
|
||||||
},
|
},
|
||||||
@ -837,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
@ctx.room.on("data_received")
|
@ctx.room.on("data_received")
|
||||||
def _on_data_received(data_packet) -> None:
|
def _on_data_received(data_packet) -> None:
|
||||||
packet_topic = getattr(data_packet, "topic", None)
|
packet_topic = getattr(data_packet, "topic", None)
|
||||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
|
||||||
|
def _interrupt_done(fut: asyncio.Future[None]) -> None:
|
||||||
|
try:
|
||||||
|
fut.result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Bridge interrupt failed")
|
||||||
|
|
||||||
|
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
|
||||||
|
reason = None if payload is None else payload.get("reason")
|
||||||
|
logger.info(
|
||||||
|
"Received bridge interrupt: topic=%s reason=%s",
|
||||||
|
packet_topic or "<default>",
|
||||||
|
reason if isinstance(reason, str) and reason else "<unspecified>",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
interrupt_fut = session.interrupt(force=True)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.exception("Bridge interrupt received before AgentSession was running")
|
||||||
|
return
|
||||||
|
interrupt_fut.add_done_callback(_interrupt_done)
|
||||||
|
|
||||||
|
if packet_topic == INTERRUPT_TOPIC:
|
||||||
|
payload: dict[str, object] | None = None
|
||||||
|
try:
|
||||||
|
decoded = json.loads(data_packet.data.decode("utf-8"))
|
||||||
|
if isinstance(decoded, dict):
|
||||||
|
payload = decoded
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to decode interrupt payload")
|
||||||
|
_handle_interrupt(payload)
|
||||||
return
|
return
|
||||||
|
|
||||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(data_packet.data.decode("utf-8"))
|
payload = json.loads(data_packet.data.decode("utf-8"))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to decode vision frame payload")
|
logger.exception("Failed to decode data payload")
|
||||||
|
return
|
||||||
|
|
||||||
|
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
|
||||||
|
_handle_interrupt(payload)
|
||||||
|
return
|
||||||
|
|
||||||
|
if INPUT_MODE == VOICE_INPUT_MODE:
|
||||||
|
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||||||
return
|
return
|
||||||
|
|
||||||
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
||||||
@ -904,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
),
|
),
|
||||||
record=_recording_options_from_env(),
|
record=_recording_options_from_env(),
|
||||||
)
|
)
|
||||||
|
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
|
||||||
|
_start_beaver_background_warmup(
|
||||||
|
ctx=ctx,
|
||||||
|
beaver_llm=base_llm,
|
||||||
|
warmup_text=beaver_warmup_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _start_beaver_background_warmup(
|
||||||
|
*,
|
||||||
|
ctx: JobContext,
|
||||||
|
beaver_llm: BeaverLLM,
|
||||||
|
warmup_text: str | None,
|
||||||
|
) -> None:
|
||||||
|
async def _warmup() -> None:
|
||||||
|
try:
|
||||||
|
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
|
||||||
|
except BeaverTerminalError:
|
||||||
|
logger.warning(
|
||||||
|
"Beaver background handshake failed; will retry on first user turn",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected Beaver background handshake failure")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||||||
|
ctx.room.name,
|
||||||
|
beaver_llm.session_id,
|
||||||
|
bool(warmup_text and warmup_text.strip()),
|
||||||
|
len(warmup_reply) if warmup_reply is not None else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
|
||||||
|
|
||||||
|
async def _cancel_warmup() -> None:
|
||||||
|
if warmup_task.done():
|
||||||
|
return
|
||||||
|
warmup_task.cancel()
|
||||||
|
try:
|
||||||
|
await warmup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
ctx.add_shutdown_callback(_cancel_warmup)
|
||||||
|
|
||||||
|
|
||||||
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||||
|
|||||||
264
start_agent_profiles.py
Normal file
264
start_agent_profiles.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentProfile:
|
||||||
|
agent_name: str
|
||||||
|
llm_provider: str
|
||||||
|
input_mode: str
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_PROFILES = {
|
||||||
|
"normal": AgentProfile(
|
||||||
|
agent_name="normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode="voice",
|
||||||
|
),
|
||||||
|
"beaver": AgentProfile(
|
||||||
|
agent_name="beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode="voice",
|
||||||
|
),
|
||||||
|
"vision-normal": AgentProfile(
|
||||||
|
agent_name="vision-normal-agent",
|
||||||
|
llm_provider="openai-compatible",
|
||||||
|
input_mode="vision_voice",
|
||||||
|
),
|
||||||
|
"vision-beaver": AgentProfile(
|
||||||
|
agent_name="vision-beaver-agent",
|
||||||
|
llm_provider="beaver",
|
||||||
|
input_mode="vision_voice",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
DEFAULT_PROFILES = ("normal", "beaver")
|
||||||
|
|
||||||
|
|
||||||
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if normalized in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_profiles(value: str) -> list[str]:
|
||||||
|
if value.strip().lower() == "all":
|
||||||
|
return list(AGENT_PROFILES)
|
||||||
|
|
||||||
|
profiles: list[str] = []
|
||||||
|
for raw_profile in value.split(","):
|
||||||
|
profile = raw_profile.strip().lower().replace("_", "-")
|
||||||
|
if not profile:
|
||||||
|
continue
|
||||||
|
if profile not in AGENT_PROFILES:
|
||||||
|
valid = ", ".join([*AGENT_PROFILES, "all"])
|
||||||
|
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
|
||||||
|
profiles.append(profile)
|
||||||
|
|
||||||
|
if not profiles:
|
||||||
|
raise ValueError("at least one profile is required")
|
||||||
|
return list(dict.fromkeys(profiles))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_http_port_base(value: str | None) -> int:
|
||||||
|
if value is None or not value.strip():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
port = int(value)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"invalid HTTP port base {value!r}") from exc
|
||||||
|
|
||||||
|
if port < 0 or port > 65535:
|
||||||
|
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def _profile_http_port(http_port_base: int, index: int) -> int:
|
||||||
|
if http_port_base == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
port = http_port_base + index
|
||||||
|
if port > 65535:
|
||||||
|
raise ValueError("HTTP port range exceeds 65535")
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
|
||||||
|
profile = AGENT_PROFILES[profile_name]
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.update(
|
||||||
|
{
|
||||||
|
"CUSTOM_AGENT_PROFILE": profile_name,
|
||||||
|
"CUSTOM_AGENT_NAME": profile.agent_name,
|
||||||
|
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
|
||||||
|
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
|
||||||
|
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
|
||||||
|
while line := await stream.readline():
|
||||||
|
text = line.decode("utf-8", errors="replace").rstrip()
|
||||||
|
print(f"[{prefix}] {text}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _start_profile(
|
||||||
|
profile_name: str,
|
||||||
|
*,
|
||||||
|
mode: str,
|
||||||
|
http_port: int,
|
||||||
|
reload: bool,
|
||||||
|
) -> asyncio.subprocess.Process:
|
||||||
|
profile = AGENT_PROFILES[profile_name]
|
||||||
|
script_path = Path(__file__).with_name("custom_agent.py")
|
||||||
|
args = [sys.executable, str(script_path), mode]
|
||||||
|
if mode == "dev" and not reload:
|
||||||
|
args.append("--no-reload")
|
||||||
|
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*args,
|
||||||
|
env=_child_env(profile_name, http_port=http_port),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
|
||||||
|
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if process.stdout is not None:
|
||||||
|
asyncio.create_task(_pipe_output(profile_name, process.stdout))
|
||||||
|
if process.stderr is not None:
|
||||||
|
asyncio.create_task(_pipe_output(profile_name, process.stderr))
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
|
||||||
|
for process in processes:
|
||||||
|
if process.returncode is None:
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(*(process.wait() for process in processes)),
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
for process in processes:
|
||||||
|
if process.returncode is None:
|
||||||
|
process.kill()
|
||||||
|
await asyncio.gather(*(process.wait() for process in processes))
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
profiles: list[str],
|
||||||
|
*,
|
||||||
|
mode: str,
|
||||||
|
http_port_base: int,
|
||||||
|
reload: bool,
|
||||||
|
) -> int:
|
||||||
|
processes = [
|
||||||
|
await _start_profile(
|
||||||
|
profile,
|
||||||
|
mode=mode,
|
||||||
|
http_port=_profile_http_port(http_port_base, index),
|
||||||
|
reload=reload,
|
||||||
|
)
|
||||||
|
for index, profile in enumerate(profiles)
|
||||||
|
]
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(sig, stop_event.set)
|
||||||
|
|
||||||
|
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
|
||||||
|
stop_task = asyncio.create_task(stop_event.wait())
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
[*wait_tasks, stop_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_code = 0
|
||||||
|
if stop_task not in done:
|
||||||
|
for task in done:
|
||||||
|
process = wait_tasks[task]
|
||||||
|
exit_code = task.result() or 0
|
||||||
|
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
|
||||||
|
|
||||||
|
await _terminate(processes)
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
|
||||||
|
parser.add_argument(
|
||||||
|
"mode",
|
||||||
|
nargs="?",
|
||||||
|
default="dev",
|
||||||
|
choices=("console", "dev", "start", "connect"),
|
||||||
|
help="custom_agent.py CLI mode to run for each profile",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profiles",
|
||||||
|
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
|
||||||
|
help="comma-separated profiles to start, or 'all'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--http-port-base",
|
||||||
|
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
|
||||||
|
help=(
|
||||||
|
"base HTTP health-check port for profile workers; "
|
||||||
|
"0 lets the OS assign free ports"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reload",
|
||||||
|
action="store_true",
|
||||||
|
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
|
||||||
|
help="enable auto-reload in dev mode",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
profiles = _parse_profiles(args.profiles)
|
||||||
|
http_port_base = _parse_http_port_base(args.http_port_base)
|
||||||
|
except ValueError as exc:
|
||||||
|
parser.error(str(exc))
|
||||||
|
|
||||||
|
raise SystemExit(
|
||||||
|
asyncio.run(
|
||||||
|
_run(
|
||||||
|
profiles,
|
||||||
|
mode=args.mode,
|
||||||
|
http_port_base=http_port_base,
|
||||||
|
reload=args.reload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -96,3 +97,73 @@ async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
|||||||
assert received[1]["message_id"].startswith("livekit-room-")
|
assert received[1]["message_id"].startswith("livekit-room-")
|
||||||
assert received[1]["message_id"].endswith("-000001")
|
assert received[1]["message_id"].endswith("-000001")
|
||||||
assert received[1]["text"] == "hello beaver"
|
assert received[1]["text"] == "hello beaver"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
async with beaver_llm.chat(chat_ctx=ctx) as stream:
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.delta and chunk.delta.content:
|
||||||
|
chunks.append(chunk.delta.content)
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert chunks == ["beaver reply"]
|
||||||
|
|||||||
@ -14,6 +14,7 @@ if __name__ == "__main__":
|
|||||||
try:
|
try:
|
||||||
from custom.beaver_terminal_client import (
|
from custom.beaver_terminal_client import (
|
||||||
BeaverTerminalClient,
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalConnectionClosed,
|
||||||
BeaverTerminalError,
|
BeaverTerminalError,
|
||||||
MessageIdGenerator,
|
MessageIdGenerator,
|
||||||
build_connect_frame,
|
build_connect_frame,
|
||||||
@ -22,6 +23,7 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
from beaver_terminal_client import (
|
from beaver_terminal_client import (
|
||||||
BeaverTerminalClient,
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalConnectionClosed,
|
||||||
BeaverTerminalError,
|
BeaverTerminalError,
|
||||||
MessageIdGenerator,
|
MessageIdGenerator,
|
||||||
build_connect_frame,
|
build_connect_frame,
|
||||||
@ -244,6 +246,36 @@ async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
|
|||||||
await runner.cleanup()
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_cleans_up_owned_session_when_connect_fails(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
async def websocket_handler(request: web.Request) -> web.Response:
|
||||||
|
return web.Response(status=200, text="not a websocket")
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
|
||||||
|
await client.connect()
|
||||||
|
assert client._http_session is None
|
||||||
|
assert client._ws is None
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
||||||
unused_tcp_port: int,
|
unused_tcp_port: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
7
tts.py
7
tts.py
@ -7,7 +7,6 @@ import time
|
|||||||
import wave
|
import wave
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS):
|
|||||||
*,
|
*,
|
||||||
url: str,
|
url: str,
|
||||||
model_name: str = "voxcpmtts",
|
model_name: str = "voxcpmtts",
|
||||||
params: Optional[Mapping[str, object]] = None,
|
params: Mapping[str, object] | None = None,
|
||||||
prompt_wav_path: Optional[str] = None,
|
prompt_wav_path: str | None = None,
|
||||||
prompt_wav_field: str = "prompt_wav",
|
prompt_wav_field: str = "prompt_wav",
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
http_session: Optional[aiohttp.ClientSession] = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
capabilities=tts.TTSCapabilities(streaming=False),
|
capabilities=tts.TTSCapabilities(streaming=False),
|
||||||
|
|||||||
Reference in New Issue
Block a user