feat: support beaver llm provider mode
This commit is contained in:
@ -9,6 +9,13 @@ BEAVER_WS_URL=ws://127.0.0.1:8080/api/channels/terminal-dev/ws
|
|||||||
TERMINAL_PEER_ID=device-001
|
TERMINAL_PEER_ID=device-001
|
||||||
TERMINAL_DEVICE_NAME=desk-terminal
|
TERMINAL_DEVICE_NAME=desk-terminal
|
||||||
|
|
||||||
|
# Beaver as a LiveKit LLM backend: asr -> beaver -> tts.
|
||||||
|
# CUSTOM_LLM_PROVIDER=beaver
|
||||||
|
# CUSTOM_BEAVER_WS_URL=ws://127.0.0.1:8080/api/channels/terminal-dev/ws
|
||||||
|
# CUSTOM_BEAVER_PEER_ID=livekit-agent-001
|
||||||
|
# CUSTOM_BEAVER_DEVICE_NAME=livekit-custom-agent
|
||||||
|
# CUSTOM_BEAVER_MODEL=beaver-terminal
|
||||||
|
|
||||||
# ASR blackbox
|
# ASR blackbox
|
||||||
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
|
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
|
||||||
CUSTOM_ASR_MODEL=qwen
|
CUSTOM_ASR_MODEL=qwen
|
||||||
|
|||||||
97
beaver_llm.py
Normal file
97
beaver_llm.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from custom.beaver_terminal_client import BeaverTerminalClient
|
||||||
|
from livekit.agents import llm
|
||||||
|
from livekit.agents.types import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
NOT_GIVEN,
|
||||||
|
APIConnectOptions,
|
||||||
|
NotGivenOr,
|
||||||
|
)
|
||||||
|
from livekit.agents.utils import shortuuid
|
||||||
|
|
||||||
|
|
||||||
|
def latest_user_text(chat_ctx: llm.ChatContext) -> str:
|
||||||
|
for message in reversed(chat_ctx.messages()):
|
||||||
|
if message.role != "user":
|
||||||
|
continue
|
||||||
|
return _content_to_text(message.content)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_text(content: Sequence[llm.ChatContent]) -> str:
|
||||||
|
text_parts = [item for item in content if isinstance(item, str)]
|
||||||
|
return "\n".join(text_parts)
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverLLM(llm.LLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
peer_id: str,
|
||||||
|
device_name: str,
|
||||||
|
model_name: str = "beaver-terminal",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name)
|
||||||
|
self._model_name = model_name
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return "beaver"
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool] | None = None,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
||||||
|
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
||||||
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
||||||
|
) -> llm.LLMStream:
|
||||||
|
return BeaverLLMStream(
|
||||||
|
self,
|
||||||
|
chat_ctx=chat_ctx,
|
||||||
|
tools=tools or [],
|
||||||
|
conn_options=conn_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._client.close()
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverLLMStream(llm.LLMStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beaver_llm: BeaverLLM,
|
||||||
|
*,
|
||||||
|
chat_ctx: llm.ChatContext,
|
||||||
|
tools: list[llm.Tool],
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||||
|
self._beaver_llm = beaver_llm
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
user_text = latest_user_text(self.chat_ctx)
|
||||||
|
async with self._beaver_llm._lock:
|
||||||
|
reply = await self._beaver_llm._client.send_text(user_text)
|
||||||
|
|
||||||
|
if reply:
|
||||||
|
self._event_ch.send_nowait(
|
||||||
|
llm.ChatChunk(
|
||||||
|
id=shortuuid("beaver_"),
|
||||||
|
delta=llm.ChoiceDelta(role="assistant", content=reply),
|
||||||
|
)
|
||||||
|
)
|
||||||
@ -8,6 +8,7 @@ from collections.abc import AsyncIterable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from beaver_llm import BeaverLLM
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||||||
from memory import MemoryRecallClient
|
from memory import MemoryRecallClient
|
||||||
@ -643,7 +644,14 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
||||||
if LLM_PROVIDER not in {"openai", "openai-compatible", "hermes", "hermes_gateway", "openclaw"}:
|
if LLM_PROVIDER not in {
|
||||||
|
"openai",
|
||||||
|
"openai-compatible",
|
||||||
|
"hermes",
|
||||||
|
"hermes_gateway",
|
||||||
|
"openclaw",
|
||||||
|
"beaver",
|
||||||
|
}:
|
||||||
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
|
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
|
||||||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||||
@ -677,7 +685,36 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
)
|
)
|
||||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||||
|
|
||||||
if LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
if LLM_PROVIDER == "beaver":
|
||||||
|
beaver_url = os.getenv("CUSTOM_BEAVER_WS_URL") or os.getenv("BEAVER_WS_URL", "").strip()
|
||||||
|
if not beaver_url:
|
||||||
|
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
|
||||||
|
|
||||||
|
beaver_peer_id = (
|
||||||
|
os.getenv("CUSTOM_BEAVER_PEER_ID")
|
||||||
|
or os.getenv("BEAVER_PEER_ID")
|
||||||
|
or f"livekit-{ctx.room.name}"
|
||||||
|
)
|
||||||
|
beaver_device_name = (
|
||||||
|
os.getenv("CUSTOM_BEAVER_DEVICE_NAME")
|
||||||
|
or os.getenv("BEAVER_DEVICE_NAME")
|
||||||
|
or "livekit-custom-agent"
|
||||||
|
)
|
||||||
|
base_llm = BeaverLLM(
|
||||||
|
url=beaver_url,
|
||||||
|
peer_id=beaver_peer_id,
|
||||||
|
device_name=beaver_device_name,
|
||||||
|
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
||||||
|
)
|
||||||
|
text_llm = base_llm
|
||||||
|
vision_llm = base_llm
|
||||||
|
logger.info(
|
||||||
|
"Using Beaver gateway url=%s peer_id=%s device_name=%s",
|
||||||
|
beaver_url,
|
||||||
|
beaver_peer_id,
|
||||||
|
beaver_device_name,
|
||||||
|
)
|
||||||
|
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||||||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||||||
if not gateway_url:
|
if not gateway_url:
|
||||||
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
|
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
|
||||||
|
|||||||
98
test_beaver_llm.py
Normal file
98
test_beaver_llm.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from custom.beaver_llm import BeaverLLM, latest_user_text
|
||||||
|
from livekit.agents import ChatContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_latest_user_text_uses_most_recent_user_message() -> None:
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="user", content="first")
|
||||||
|
ctx.add_message(role="assistant", content="ignored")
|
||||||
|
ctx.add_message(role="user", content=["second", "line"])
|
||||||
|
|
||||||
|
assert latest_user_text(ctx) == "second\nline"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
received: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
assert message.type == aiohttp.WSMsgType.TEXT
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
received.append(frame)
|
||||||
|
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:livekit-room",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-1",
|
||||||
|
"text": "beaver reply",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
beaver_llm = BeaverLLM(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="livekit-room",
|
||||||
|
device_name="livekit-custom-agent",
|
||||||
|
)
|
||||||
|
ctx = ChatContext.empty()
|
||||||
|
ctx.add_message(role="system", content="ignored instructions")
|
||||||
|
ctx.add_message(role="user", content="hello beaver")
|
||||||
|
|
||||||
|
try:
|
||||||
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||||
|
finally:
|
||||||
|
await beaver_llm.aclose()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert collected.text == "beaver reply"
|
||||||
|
assert received == [
|
||||||
|
{
|
||||||
|
"type": "connect",
|
||||||
|
"peer_id": "livekit-room",
|
||||||
|
"device_name": "livekit-custom-agent",
|
||||||
|
"capabilities": ["text"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"message_id": "livekit-room-000001",
|
||||||
|
"text": "hello beaver",
|
||||||
|
},
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user