from __future__ import annotations import asyncio import logging import os from dataclasses import dataclass from pathlib import Path from typing import Any import aiohttp from dotenv import load_dotenv logger = logging.getLogger("beaver-terminal-client") DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws" DEFAULT_TERMINAL_PEER_ID = "device-001" DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal" CUSTOM_ENV_PATH = Path(__file__).with_name(".env") class BeaverTerminalError(RuntimeError): pass class BeaverTerminalConnectionClosed(BeaverTerminalError): pass @dataclass class MessageIdGenerator: peer_id: str initial_counter: int = 0 def __post_init__(self) -> None: self.counter = self.initial_counter def next_id(self) -> str: self.counter += 1 return f"{self.peer_id}-{self.counter:06d}" def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]: return { "type": "connect", "peer_id": peer_id, "device_name": device_name, "capabilities": ["text"], } def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]: return { "type": "message", "message_id": message_id, "text": text, } class BeaverTerminalClient: def __init__( self, *, url: str, peer_id: str, device_name: str, http_session: aiohttp.ClientSession | None = None, message_ids: MessageIdGenerator | None = None, ) -> None: self._url = url self._peer_id = peer_id self._device_name = device_name self._owned_session = http_session is None self._http_session = http_session self._ws: aiohttp.ClientWebSocketResponse | None = None self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id) self.session_id: str | None = None async def connect(self) -> None: await self._close_websocket() session = self._ensure_http_session() self._ws = await session.ws_connect(self._url) await self._send_json( build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) ) frame = await self._receive_json() if frame.get("type") != "connected": raise BeaverTerminalError(f"expected connected frame, received {frame!r}") session_id = frame.get("session_id") self.session_id = session_id if isinstance(session_id, str) else None async def send_text(self, text: str) -> str: for attempt in range(2): if not self._websocket_is_open(): await self.connect() message_id = self._message_ids.next_id() message_frame = build_message_frame(message_id=message_id, text=text) try: await self._send_json(message_frame) return await self._wait_for_reply(message_id) except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc: if attempt == 1: raise BeaverTerminalConnectionClosed( "Beaver websocket closed before assistant reply" ) from exc logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id") await self.connect() raise BeaverTerminalError("unreachable Beaver send state") async def _wait_for_reply(self, message_id: str) -> str: while True: frame = await self._receive_json() frame_type = frame.get("type") if frame_type == "ack" and frame.get("message_id") == message_id: reply = frame.get("reply") if isinstance(reply, str): return reply continue if ( frame_type == "message" and frame.get("role") == "assistant" and frame.get("message_id") == message_id ): text = frame.get("text") if frame.get("finish_reason") == "error": raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed") return text if isinstance(text, str) else "" if frame_type == "error": error = frame.get("error") raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") async def ping(self) -> bool: await self._send_json({"type": "ping"}) while True: frame = await self._receive_json() if frame.get("type") == "pong": return True if frame.get("type") == "error": error = frame.get("error") raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") async def close(self) -> None: await self._close_websocket() if self._owned_session and self._http_session is not None: await self._http_session.close() self._http_session = None async def _close_websocket(self) -> None: if self._ws is not None: await self._ws.close() self._ws = None def _websocket_is_open(self) -> bool: return self._ws is not None and not self._ws.closed def _ensure_http_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = aiohttp.ClientSession() return self._http_session async def _send_json(self, frame: dict[str, Any]) -> None: if self._ws is None: raise BeaverTerminalError("Beaver websocket is not connected") await self._ws.send_json(frame) async def _receive_json(self) -> dict[str, Any]: if self._ws is None: raise BeaverTerminalError("Beaver websocket is not connected") message = await self._ws.receive() if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): raise BeaverTerminalConnectionClosed("Beaver websocket closed") if message.type == aiohttp.WSMsgType.ERROR: raise BeaverTerminalConnectionClosed( f"Beaver websocket error: {self._ws.exception()!r}" ) if message.type != aiohttp.WSMsgType.TEXT: raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}") data = message.json() if not isinstance(data, dict): raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}") return data def client_from_env() -> BeaverTerminalClient: load_dotenv(dotenv_path=CUSTOM_ENV_PATH) return BeaverTerminalClient( url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL), peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID), device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME), ) async def run_console() -> None: logging.basicConfig(level=logging.INFO) client = client_from_env() try: await client.connect() logger.info("Connected to Beaver session_id=%s", client.session_id) while True: text = await asyncio.to_thread(input, "> ") text = text.strip() if not text: continue if text in {"quit", "exit"}: return try: reply = await client.send_text(text) except BeaverTerminalError as exc: logger.error("Beaver turn failed: %s", exc) continue print(reply) finally: await client.close() if __name__ == "__main__": asyncio.run(run_console())