129 lines
3.7 KiB
Python
129 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
try:
|
|
from beaver_terminal_client import BeaverTerminalClient
|
|
except ModuleNotFoundError:
|
|
from custom.beaver_terminal_client import BeaverTerminalClient
|
|
from livekit.agents import llm
|
|
from livekit.agents.types import (
|
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
NOT_GIVEN,
|
|
APIConnectOptions,
|
|
NotGivenOr,
|
|
)
|
|
from livekit.agents.utils import shortuuid
|
|
|
|
logger = logging.getLogger("beaver-llm")
|
|
|
|
|
|
def latest_user_text(chat_ctx: llm.ChatContext) -> str:
|
|
for message in reversed(chat_ctx.messages()):
|
|
if message.role != "user":
|
|
continue
|
|
return _content_to_text(message.content)
|
|
return ""
|
|
|
|
|
|
def _content_to_text(content: Sequence[llm.ChatContent]) -> str:
|
|
text_parts = [item for item in content if isinstance(item, str)]
|
|
return "\n".join(text_parts)
|
|
|
|
|
|
class BeaverLLM(llm.LLM):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
url: str,
|
|
peer_id: str,
|
|
device_name: str,
|
|
model_name: str = "beaver-terminal",
|
|
) -> None:
|
|
super().__init__()
|
|
self._client = BeaverTerminalClient(url=url, peer_id=peer_id, device_name=device_name)
|
|
self._model_name = model_name
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model_name
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return "beaver"
|
|
|
|
@property
|
|
def session_id(self) -> str | None:
|
|
return self._client.session_id
|
|
|
|
async def connect(self, *, warmup_text: str | None = None) -> str | None:
|
|
warmup_reply: str | None = None
|
|
async with self._lock:
|
|
await self._client.connect()
|
|
if warmup_text and warmup_text.strip():
|
|
warmup_reply = await self._client.send_text(warmup_text.strip())
|
|
|
|
if warmup_reply is None:
|
|
logger.info("Beaver handshake completed session_id=%s", self.session_id)
|
|
else:
|
|
logger.info(
|
|
"Beaver handshake warmup completed session_id=%s reply_len=%s",
|
|
self.session_id,
|
|
len(warmup_reply),
|
|
)
|
|
return warmup_reply
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
chat_ctx: llm.ChatContext,
|
|
tools: list[llm.Tool] | None = None,
|
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
|
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
) -> llm.LLMStream:
|
|
return BeaverLLMStream(
|
|
self,
|
|
chat_ctx=chat_ctx,
|
|
tools=tools or [],
|
|
conn_options=conn_options,
|
|
)
|
|
|
|
async def aclose(self) -> None:
|
|
await self._client.close()
|
|
|
|
|
|
class BeaverLLMStream(llm.LLMStream):
|
|
def __init__(
|
|
self,
|
|
beaver_llm: BeaverLLM,
|
|
*,
|
|
chat_ctx: llm.ChatContext,
|
|
tools: list[llm.Tool],
|
|
conn_options: APIConnectOptions,
|
|
) -> None:
|
|
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
self._beaver_llm = beaver_llm
|
|
self._request_id = shortuuid("beaver_")
|
|
|
|
async def _run(self) -> None:
|
|
user_text = latest_user_text(self.chat_ctx)
|
|
async with self._beaver_llm._lock:
|
|
reply = await self._beaver_llm._client.send_text(user_text)
|
|
|
|
if reply:
|
|
self._send_text_chunk(reply)
|
|
|
|
def _send_text_chunk(self, text: str) -> None:
|
|
self._event_ch.send_nowait(
|
|
llm.ChatChunk(
|
|
id=self._request_id,
|
|
delta=llm.ChoiceDelta(role="assistant", content=text),
|
|
)
|
|
)
|