feat: add beaver terminal websocket client
This commit is contained in:
183
beaver_terminal_client.py
Normal file
183
beaver_terminal_client.py
Normal file
@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger("beaver-terminal-client")
|
||||
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
|
||||
DEFAULT_TERMINAL_PEER_ID = "device-001"
|
||||
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
|
||||
|
||||
class BeaverTerminalError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageIdGenerator:
|
||||
peer_id: str
|
||||
initial_counter: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.counter = self.initial_counter
|
||||
|
||||
def next_id(self) -> str:
|
||||
self.counter += 1
|
||||
return f"{self.peer_id}-{self.counter:06d}"
|
||||
|
||||
|
||||
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "connect",
|
||||
"peer_id": peer_id,
|
||||
"device_name": device_name,
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
|
||||
|
||||
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "message",
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
|
||||
class BeaverTerminalClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
peer_id: str,
|
||||
device_name: str,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
message_ids: MessageIdGenerator | None = None,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._peer_id = peer_id
|
||||
self._device_name = device_name
|
||||
self._owned_session = http_session is None
|
||||
self._http_session = http_session
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id)
|
||||
self.session_id: str | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
session = self._ensure_http_session()
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||
)
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") != "connected":
|
||||
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||
session_id = frame.get("session_id")
|
||||
self.session_id = session_id if isinstance(session_id, str) else None
|
||||
|
||||
async def send_text(self, text: str) -> str:
|
||||
message_id = self._message_ids.next_id()
|
||||
await self._send_json(build_message_frame(message_id=message_id, text=text))
|
||||
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||
reply = frame.get("reply")
|
||||
if isinstance(reply, str):
|
||||
return reply
|
||||
continue
|
||||
|
||||
if (
|
||||
frame_type == "message"
|
||||
and frame.get("role") == "assistant"
|
||||
and frame.get("message_id") == message_id
|
||||
):
|
||||
text = frame.get("text")
|
||||
if frame.get("finish_reason") == "error":
|
||||
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
|
||||
return text if isinstance(text, str) else ""
|
||||
|
||||
if frame_type == "error":
|
||||
error = frame.get("error")
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def ping(self) -> bool:
|
||||
await self._send_json({"type": "ping"})
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") == "pong":
|
||||
return True
|
||||
if frame.get("type") == "error":
|
||||
error = frame.get("error")
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._owned_session and self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
return self._http_session
|
||||
|
||||
async def _send_json(self, frame: dict[str, Any]) -> None:
|
||||
if self._ws is None:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
await self._ws.send_json(frame)
|
||||
|
||||
async def _receive_json(self) -> dict[str, Any]:
|
||||
if self._ws is None:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
|
||||
message = await self._ws.receive()
|
||||
if message.type != aiohttp.WSMsgType.TEXT:
|
||||
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||
data = message.json()
|
||||
if not isinstance(data, dict):
|
||||
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
|
||||
return data
|
||||
|
||||
|
||||
def client_from_env() -> BeaverTerminalClient:
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
return BeaverTerminalClient(
|
||||
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
|
||||
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
|
||||
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
|
||||
)
|
||||
|
||||
|
||||
async def run_console() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
client = client_from_env()
|
||||
try:
|
||||
await client.connect()
|
||||
logger.info("Connected to Beaver session_id=%s", client.session_id)
|
||||
while True:
|
||||
text = await asyncio.to_thread(input, "> ")
|
||||
text = text.strip()
|
||||
if not text:
|
||||
continue
|
||||
if text in {"quit", "exit"}:
|
||||
return
|
||||
reply = await client.send_text(text)
|
||||
print(reply)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_console())
|
||||
Reference in New Issue
Block a user