Files
livekit_agents/beaver_terminal_client.py
2026-06-12 11:17:12 +08:00

239 lines
8.4 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import aiohttp
from dotenv import load_dotenv
logger = logging.getLogger("beaver-terminal-client")
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
DEFAULT_TERMINAL_PEER_ID = "device-001"
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
nonce: str | None = None
initial_counter: int = 0
def __post_init__(self) -> None:
self.counter = self.initial_counter
def next_id(self) -> str:
self.counter += 1
if self.nonce:
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
return {
"type": "connect",
"peer_id": peer_id,
"device_name": device_name,
"capabilities": ["text"],
}
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
return {
"type": "message",
"message_id": message_id,
"text": text,
}
class BeaverTerminalClient:
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
http_session: aiohttp.ClientSession | None = None,
message_ids: MessageIdGenerator | None = None,
) -> None:
self._url = url
self._peer_id = peer_id
self._device_name = device_name
self._owned_session = http_session is None
self._http_session = http_session
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
try:
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
await self._cleanup_failed_connection()
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
except Exception:
await self._cleanup_failed_connection()
raise
async def send_text(self, text: str) -> str:
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
message_id = self._message_ids.next_id()
message_frame = build_message_frame(message_id=message_id, text=text)
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
return reply
continue
if (
frame_type == "message"
and frame.get("role") == "assistant"
and frame.get("message_id") == message_id
):
text = frame.get("text")
if frame.get("finish_reason") == "error":
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
return text if isinstance(text, str) else ""
if frame_type == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def ping(self) -> bool:
await self._send_json({"type": "ping"})
while True:
frame = await self._receive_json()
if frame.get("type") == "pong":
return True
if frame.get("type") == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
async def _cleanup_failed_connection(self) -> None:
await self._close_websocket()
self.session_id = None
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
return self._http_session
async def _send_json(self, frame: dict[str, Any]) -> None:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
await self._ws.send_json(frame)
async def _receive_json(self) -> dict[str, Any]:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
if not isinstance(data, dict):
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
return data
def client_from_env() -> BeaverTerminalClient:
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
return BeaverTerminalClient(
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
)
async def run_console() -> None:
logging.basicConfig(level=logging.INFO)
client = client_from_env()
try:
await client.connect()
logger.info("Connected to Beaver session_id=%s", client.session_id)
while True:
text = await asyncio.to_thread(input, "> ")
text = text.strip()
if not text:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(run_console())