222 lines
7.6 KiB
Python
222 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
from dotenv import load_dotenv
|
|
|
|
logger = logging.getLogger("beaver-terminal-client")
|
|
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
|
|
DEFAULT_TERMINAL_PEER_ID = "device-001"
|
|
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
|
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
|
|
|
|
|
class BeaverTerminalError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class MessageIdGenerator:
|
|
peer_id: str
|
|
initial_counter: int = 0
|
|
|
|
def __post_init__(self) -> None:
|
|
self.counter = self.initial_counter
|
|
|
|
def next_id(self) -> str:
|
|
self.counter += 1
|
|
return f"{self.peer_id}-{self.counter:06d}"
|
|
|
|
|
|
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
|
|
return {
|
|
"type": "connect",
|
|
"peer_id": peer_id,
|
|
"device_name": device_name,
|
|
"capabilities": ["text"],
|
|
}
|
|
|
|
|
|
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
|
|
return {
|
|
"type": "message",
|
|
"message_id": message_id,
|
|
"text": text,
|
|
}
|
|
|
|
|
|
class BeaverTerminalClient:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
url: str,
|
|
peer_id: str,
|
|
device_name: str,
|
|
http_session: aiohttp.ClientSession | None = None,
|
|
message_ids: MessageIdGenerator | None = None,
|
|
) -> None:
|
|
self._url = url
|
|
self._peer_id = peer_id
|
|
self._device_name = device_name
|
|
self._owned_session = http_session is None
|
|
self._http_session = http_session
|
|
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
|
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
|
|
self.session_id: str | None = None
|
|
|
|
async def connect(self) -> None:
|
|
await self._close_websocket()
|
|
session = self._ensure_http_session()
|
|
self._ws = await session.ws_connect(self._url)
|
|
await self._send_json(
|
|
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
|
)
|
|
frame = await self._receive_json()
|
|
if frame.get("type") != "connected":
|
|
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
|
session_id = frame.get("session_id")
|
|
self.session_id = session_id if isinstance(session_id, str) else None
|
|
|
|
async def send_text(self, text: str) -> str:
|
|
message_id = self._message_ids.next_id()
|
|
message_frame = build_message_frame(message_id=message_id, text=text)
|
|
|
|
for attempt in range(2):
|
|
if not self._websocket_is_open():
|
|
await self.connect()
|
|
|
|
try:
|
|
await self._send_json(message_frame)
|
|
return await self._wait_for_reply(message_id)
|
|
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
|
if attempt == 1:
|
|
raise BeaverTerminalConnectionClosed(
|
|
"Beaver websocket closed before assistant reply"
|
|
) from exc
|
|
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
|
await self.connect()
|
|
|
|
raise BeaverTerminalError("unreachable Beaver send state")
|
|
|
|
async def _wait_for_reply(self, message_id: str) -> str:
|
|
while True:
|
|
frame = await self._receive_json()
|
|
frame_type = frame.get("type")
|
|
if frame_type == "ack" and frame.get("message_id") == message_id:
|
|
reply = frame.get("reply")
|
|
if isinstance(reply, str):
|
|
return reply
|
|
continue
|
|
|
|
if (
|
|
frame_type == "message"
|
|
and frame.get("role") == "assistant"
|
|
and frame.get("message_id") == message_id
|
|
):
|
|
text = frame.get("text")
|
|
if frame.get("finish_reason") == "error":
|
|
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
|
|
return text if isinstance(text, str) else ""
|
|
|
|
if frame_type == "error":
|
|
error = frame.get("error")
|
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
|
|
|
async def ping(self) -> bool:
|
|
await self._send_json({"type": "ping"})
|
|
while True:
|
|
frame = await self._receive_json()
|
|
if frame.get("type") == "pong":
|
|
return True
|
|
if frame.get("type") == "error":
|
|
error = frame.get("error")
|
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
|
|
|
async def close(self) -> None:
|
|
await self._close_websocket()
|
|
if self._owned_session and self._http_session is not None:
|
|
await self._http_session.close()
|
|
self._http_session = None
|
|
|
|
async def _close_websocket(self) -> None:
|
|
if self._ws is not None:
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
def _websocket_is_open(self) -> bool:
|
|
return self._ws is not None and not self._ws.closed
|
|
|
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
|
if self._http_session is None:
|
|
self._http_session = aiohttp.ClientSession()
|
|
return self._http_session
|
|
|
|
async def _send_json(self, frame: dict[str, Any]) -> None:
|
|
if self._ws is None:
|
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
|
await self._ws.send_json(frame)
|
|
|
|
async def _receive_json(self) -> dict[str, Any]:
|
|
if self._ws is None:
|
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
|
|
|
message = await self._ws.receive()
|
|
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
|
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
|
if message.type == aiohttp.WSMsgType.ERROR:
|
|
raise BeaverTerminalConnectionClosed(
|
|
f"Beaver websocket error: {self._ws.exception()!r}"
|
|
)
|
|
if message.type != aiohttp.WSMsgType.TEXT:
|
|
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
|
data = message.json()
|
|
if not isinstance(data, dict):
|
|
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
|
|
return data
|
|
|
|
|
|
def client_from_env() -> BeaverTerminalClient:
|
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
|
return BeaverTerminalClient(
|
|
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
|
|
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
|
|
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
|
|
)
|
|
|
|
|
|
async def run_console() -> None:
|
|
logging.basicConfig(level=logging.INFO)
|
|
client = client_from_env()
|
|
try:
|
|
await client.connect()
|
|
logger.info("Connected to Beaver session_id=%s", client.session_id)
|
|
while True:
|
|
text = await asyncio.to_thread(input, "> ")
|
|
text = text.strip()
|
|
if not text:
|
|
continue
|
|
if text in {"quit", "exit"}:
|
|
return
|
|
try:
|
|
reply = await client.send_text(text)
|
|
except BeaverTerminalError as exc:
|
|
logger.error("Beaver turn failed: %s", exc)
|
|
continue
|
|
print(reply)
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run_console())
|