diff --git a/.env.example b/.env.example index 88add51..5016d4c 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,13 @@ BEAVER_WS_URL=ws://127.0.0.1:8080/api/channels/terminal-dev/ws TERMINAL_PEER_ID=device-001 TERMINAL_DEVICE_NAME=desk-terminal +# Beaver as a LiveKit LLM backend: asr -> beaver -> tts. +# CUSTOM_LLM_PROVIDER=beaver +# CUSTOM_BEAVER_WS_URL=ws://127.0.0.1:8080/api/channels/terminal-dev/ws +# CUSTOM_BEAVER_PEER_ID=livekit-agent-001 +# CUSTOM_BEAVER_DEVICE_NAME=livekit-custom-agent +# CUSTOM_BEAVER_MODEL=beaver-terminal + # ASR blackbox CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox CUSTOM_ASR_MODEL=qwen diff --git a/beaver_llm.py b/beaver_llm.py new file mode 100644 index 0000000..e516951 --- /dev/null +++ b/beaver_llm.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from typing import Any + +from custom.beaver_terminal_client import BeaverTerminalClient +from livekit.agents import llm +from livekit.agents.types import ( + DEFAULT_API_CONNECT_OPTIONS, + NOT_GIVEN, + APIConnectOptions, + NotGivenOr, +) +from livekit.agents.utils import shortuuid + + +def latest_user_text(chat_ctx: llm.ChatContext) -> str: + for message in reversed(chat_ctx.messages()): + if message.role != "user": + continue + return _content_to_text(message.content) + return "" + + +def _content_to_text(content: Sequence[llm.ChatContent]) -> str: + text_parts = [item for item in content if isinstance(item, str)] + return "\n".join(text_parts) + + +class BeaverLLM(llm.LLM): + def __init__( + self, + *, + url: str, + peer_id: str, + device_name: str, + model_name: str = "beaver-terminal", + ) -> None: + super().__init__() + self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name) + self._model_name = model_name + self._lock = asyncio.Lock() + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "beaver" + + def chat( + self, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool] | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, + tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN, + extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + ) -> llm.LLMStream: + return BeaverLLMStream( + self, + chat_ctx=chat_ctx, + tools=tools or [], + conn_options=conn_options, + ) + + async def aclose(self) -> None: + await self._client.close() + + +class BeaverLLMStream(llm.LLMStream): + def __init__( + self, + beaver_llm: BeaverLLM, + *, + chat_ctx: llm.ChatContext, + tools: list[llm.Tool], + conn_options: APIConnectOptions, + ) -> None: + super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) + self._beaver_llm = beaver_llm + + async def _run(self) -> None: + user_text = latest_user_text(self.chat_ctx) + async with self._beaver_llm._lock: + reply = await self._beaver_llm._client.send_text(user_text) + + if reply: + self._event_ch.send_nowait( + llm.ChatChunk( + id=shortuuid("beaver_"), + delta=llm.ChoiceDelta(role="assistant", content=reply), + ) + ) diff --git a/custom_agent.py b/custom_agent.py index d6e9ecd..5f85fad 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -8,6 +8,7 @@ from collections.abc import AsyncIterable from dataclasses import dataclass from pathlib import Path +from beaver_llm import BeaverLLM from dotenv import load_dotenv from hermes_gateway import GatewaySessionState, HermesGatewayLLM from memory import MemoryRecallClient @@ -643,7 +644,14 @@ async def entrypoint(ctx: JobContext) -> None: TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) - if LLM_PROVIDER not in {"openai", "openai-compatible", "hermes", "hermes_gateway", "openclaw"}: + if LLM_PROVIDER not in { + "openai", + "openai-compatible", + "hermes", + "hermes_gateway", + "openclaw", + "beaver", + }: raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}") if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") @@ -677,7 +685,36 @@ async def entrypoint(ctx: JobContext) -> None: ) stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) - if LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}: + if LLM_PROVIDER == "beaver": + beaver_url = os.getenv("CUSTOM_BEAVER_WS_URL") or os.getenv("BEAVER_WS_URL", "").strip() + if not beaver_url: + raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}") + + beaver_peer_id = ( + os.getenv("CUSTOM_BEAVER_PEER_ID") + or os.getenv("BEAVER_PEER_ID") + or f"livekit-{ctx.room.name}" + ) + beaver_device_name = ( + os.getenv("CUSTOM_BEAVER_DEVICE_NAME") + or os.getenv("BEAVER_DEVICE_NAME") + or "livekit-custom-agent" + ) + base_llm = BeaverLLM( + url=beaver_url, + peer_id=beaver_peer_id, + device_name=beaver_device_name, + model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"), + ) + text_llm = base_llm + vision_llm = base_llm + logger.info( + "Using Beaver gateway url=%s peer_id=%s device_name=%s", + beaver_url, + beaver_peer_id, + beaver_device_name, + ) + elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}: gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip() if not gateway_url: raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}") diff --git a/test_beaver_llm.py b/test_beaver_llm.py new file mode 100644 index 0000000..868e461 --- /dev/null +++ b/test_beaver_llm.py @@ -0,0 +1,98 @@ +import json + +import aiohttp +from aiohttp import web + +from custom.beaver_llm import BeaverLLM, latest_user_text +from livekit.agents import ChatContext + + +def test_latest_user_text_uses_most_recent_user_message() -> None: + ctx = ChatContext.empty() + ctx.add_message(role="user", content="first") + ctx.add_message(role="assistant", content="ignored") + ctx.add_message(role="user", content=["second", "line"]) + + assert latest_user_text(ctx) == "second\nline" + + +async def test_beaver_llm_sends_latest_user_text_and_returns_reply( + unused_tcp_port: int, +) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + received.append(frame) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:livekit-room", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:livekit-room", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "beaver reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + beaver_llm = BeaverLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="livekit-room", + device_name="livekit-custom-agent", + ) + ctx = ChatContext.empty() + ctx.add_message(role="system", content="ignored instructions") + ctx.add_message(role="user", content="hello beaver") + + try: + collected = await beaver_llm.chat(chat_ctx=ctx).collect() + finally: + await beaver_llm.aclose() + await runner.cleanup() + + assert collected.text == "beaver reply" + assert received == [ + { + "type": "connect", + "peer_id": "livekit-room", + "device_name": "livekit-custom-agent", + "capabilities": ["text"], + }, + { + "type": "message", + "message_id": "livekit-room-000001", + "text": "hello beaver", + }, + ]