from __future__ import annotations import asyncio import logging import os from dataclasses import dataclass from pathlib import Path from typing import Any from uuid import uuid4 import aiohttp from dotenv import load_dotenv logger = logging.getLogger("beaver-terminal-client") DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws" DEFAULT_TERMINAL_PEER_ID = "device-001" DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal" CUSTOM_ENV_PATH = Path(__file__).with_name(".env") class BeaverTerminalError(RuntimeError): pass class BeaverTerminalConnectionClosed(BeaverTerminalError): pass @dataclass class MessageIdGenerator: peer_id: str initial_counter: int = 0 instance_id: str | None = None def __post_init__(self) -> None: self.counter = self.initial_counter def next_id(self) -> str: self.counter += 1 if self.instance_id: return f"{self.peer_id}-{self.instance_id}-{self.counter:06d}" return f"{self.peer_id}-{self.counter:06d}" def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]: return { "type": "connect", "peer_id": peer_id, "device_name": device_name, "capabilities": ["text"], } def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]: return { "type": "message", "message_id": message_id, "text": text, } class BeaverTerminalClient: def __init__( self, *, url: str, peer_id: str, device_name: str, http_session: aiohttp.ClientSession | None = None, message_ids: MessageIdGenerator | None = None, ) -> None: self._url = url self._peer_id = peer_id self._device_name = device_name self._owned_session = http_session is None self._http_session = http_session self._ws: aiohttp.ClientWebSocketResponse | None = None self._message_ids = message_ids or MessageIdGenerator( peer_id=peer_id, instance_id=uuid4().hex[:8], ) self.session_id: str | None = None async def connect(self) -> None: await self._close_websocket() session = self._ensure_http_session() self._ws = await session.ws_connect(self._url) await self._send_json( build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) ) frame = await self._receive_json() if frame.get("type") != "connected": raise BeaverTerminalError(f"expected connected frame, received {frame!r}") session_id = frame.get("session_id") self.session_id = session_id if isinstance(session_id, str) else None async def send_text(self, text: str) -> str: for attempt in range(2): if not self._websocket_is_open(): await self.connect() message_id = self._message_ids.next_id() message_frame = build_message_frame(message_id=message_id, text=text) try: await self._send_json(message_frame) return await self._wait_for_reply(message_id) except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc: if attempt == 1: raise BeaverTerminalConnectionClosed( "Beaver websocket closed before assistant reply" ) from exc logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id") await self.connect() raise BeaverTerminalError("unreachable Beaver send state") async def _wait_for_reply(self, message_id: str) -> str: while True: frame = await self._receive_json() frame_type = frame.get("type") if frame_type == "ack" and frame.get("message_id") == message_id: reply = frame.get("reply") if isinstance(reply, str): return reply continue if ( frame_type == "message" and frame.get("role") == "assistant" and frame.get("message_id") == message_id ): text = frame.get("text") if frame.get("finish_reason") == "error": raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed") return text if isinstance(text, str) else "" if frame_type == "error": error = frame.get("error") raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") async def ping(self) -> bool: await self._send_json({"type": "ping"}) while True: frame = await self._receive_json() if frame.get("type") == "pong": return True if frame.get("type") == "error": error = frame.get("error") raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") async def close(self) -> None: await self._close_websocket() if self._owned_session and self._http_session is not None: await self._http_session.close() self._http_session = None async def _close_websocket(self) -> None: if self._ws is not None: await self._ws.close() self._ws = None def _websocket_is_open(self) -> bool: return self._ws is not None and not self._ws.closed def _ensure_http_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = aiohttp.ClientSession() return self._http_session async def _send_json(self, frame: dict[str, Any]) -> None: if self._ws is None: raise BeaverTerminalError("Beaver websocket is not connected") await self._ws.send_json(frame) async def _receive_json(self) -> dict[str, Any]: if self._ws is None: raise BeaverTerminalError("Beaver websocket is not connected") message = await self._ws.receive() if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): raise BeaverTerminalConnectionClosed("Beaver websocket closed") if message.type == aiohttp.WSMsgType.ERROR: raise BeaverTerminalConnectionClosed( f"Beaver websocket error: {self._ws.exception()!r}" ) if message.type != aiohttp.WSMsgType.TEXT: raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}") data = message.json() if not isinstance(data, dict): raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}") return data def client_from_env() -> BeaverTerminalClient: load_dotenv(dotenv_path=CUSTOM_ENV_PATH) return BeaverTerminalClient( url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL), peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID), device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME), ) async def run_console() -> None: logging.basicConfig(level=logging.INFO) client = client_from_env() try: await client.connect() logger.info("Connected to Beaver session_id=%s", client.session_id) while True: text = await asyncio.to_thread(input, "> ") text = text.strip() if not text: continue if text in {"quit", "exit"}: return try: reply = await client.send_text(text) except BeaverTerminalError as exc: logger.error("Beaver turn failed: %s", exc) continue print(reply) finally: await client.close() if __name__ == "__main__": asyncio.run(run_console())