From b1cad592e203b180c3d006124f5fe665863008f2 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Mon, 1 Jun 2026 17:06:23 +0800 Subject: [PATCH] feat: add beaver terminal websocket client --- .env.example | 5 + beaver_terminal_client.py | 183 +++++++++++++++++++ test_beaver_terminal_client.py | 320 +++++++++++++++++++++++++++++++++ 3 files changed, 508 insertions(+) create mode 100644 beaver_terminal_client.py create mode 100644 test_beaver_terminal_client.py diff --git a/.env.example b/.env.example index a159380..88add51 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,11 @@ LIVEKIT_API_KEY= LIVEKIT_API_SECRET= CUSTOM_AGENT_NAME=my-agent +# Beaver terminal text WebSocket +BEAVER_WS_URL=ws://127.0.0.1:8080/api/channels/terminal-dev/ws +TERMINAL_PEER_ID=device-001 +TERMINAL_DEVICE_NAME=desk-terminal + # ASR blackbox CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox CUSTOM_ASR_MODEL=qwen diff --git a/beaver_terminal_client.py b/beaver_terminal_client.py new file mode 100644 index 0000000..45d5598 --- /dev/null +++ b/beaver_terminal_client.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import aiohttp +from dotenv import load_dotenv + +logger = logging.getLogger("beaver-terminal-client") +DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws" +DEFAULT_TERMINAL_PEER_ID = "device-001" +DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal" +CUSTOM_ENV_PATH = Path(__file__).with_name(".env") + + +class BeaverTerminalError(RuntimeError): + pass + + +@dataclass +class MessageIdGenerator: + peer_id: str + initial_counter: int = 0 + + def __post_init__(self) -> None: + self.counter = self.initial_counter + + def next_id(self) -> str: + self.counter += 1 + return f"{self.peer_id}-{self.counter:06d}" + + +def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]: + return { + "type": "connect", + "peer_id": peer_id, + "device_name": device_name, + "capabilities": ["text"], + } + + +def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]: + return { + "type": "message", + "message_id": message_id, + "text": text, + } + + +class BeaverTerminalClient: + def __init__( + self, + *, + url: str, + peer_id: str, + device_name: str, + http_session: aiohttp.ClientSession | None = None, + message_ids: MessageIdGenerator | None = None, + ) -> None: + self._url = url + self._peer_id = peer_id + self._device_name = device_name + self._owned_session = http_session is None + self._http_session = http_session + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._message_ids = message_ids or MessageIdGenerator(peer_id=peer_id) + self.session_id: str | None = None + + async def connect(self) -> None: + session = self._ensure_http_session() + self._ws = await session.ws_connect(self._url) + await self._send_json( + build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) + ) + frame = await self._receive_json() + if frame.get("type") != "connected": + raise BeaverTerminalError(f"expected connected frame, received {frame!r}") + session_id = frame.get("session_id") + self.session_id = session_id if isinstance(session_id, str) else None + + async def send_text(self, text: str) -> str: + message_id = self._message_ids.next_id() + await self._send_json(build_message_frame(message_id=message_id, text=text)) + + while True: + frame = await self._receive_json() + frame_type = frame.get("type") + + if frame_type == "ack" and frame.get("message_id") == message_id: + reply = frame.get("reply") + if isinstance(reply, str): + return reply + continue + + if ( + frame_type == "message" + and frame.get("role") == "assistant" + and frame.get("message_id") == message_id + ): + text = frame.get("text") + if frame.get("finish_reason") == "error": + raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed") + return text if isinstance(text, str) else "" + + if frame_type == "error": + error = frame.get("error") + raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") + + async def ping(self) -> bool: + await self._send_json({"type": "ping"}) + while True: + frame = await self._receive_json() + if frame.get("type") == "pong": + return True + if frame.get("type") == "error": + error = frame.get("error") + raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") + + async def close(self) -> None: + if self._ws is not None: + await self._ws.close() + self._ws = None + if self._owned_session and self._http_session is not None: + await self._http_session.close() + self._http_session = None + + def _ensure_http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def _send_json(self, frame: dict[str, Any]) -> None: + if self._ws is None: + raise BeaverTerminalError("Beaver websocket is not connected") + await self._ws.send_json(frame) + + async def _receive_json(self) -> dict[str, Any]: + if self._ws is None: + raise BeaverTerminalError("Beaver websocket is not connected") + + message = await self._ws.receive() + if message.type != aiohttp.WSMsgType.TEXT: + raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}") + data = message.json() + if not isinstance(data, dict): + raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}") + return data + + +def client_from_env() -> BeaverTerminalClient: + load_dotenv(dotenv_path=CUSTOM_ENV_PATH) + return BeaverTerminalClient( + url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL), + peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID), + device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME), + ) + + +async def run_console() -> None: + logging.basicConfig(level=logging.INFO) + client = client_from_env() + try: + await client.connect() + logger.info("Connected to Beaver session_id=%s", client.session_id) + while True: + text = await asyncio.to_thread(input, "> ") + text = text.strip() + if not text: + continue + if text in {"quit", "exit"}: + return + reply = await client.send_text(text) + print(reply) + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(run_console()) diff --git a/test_beaver_terminal_client.py b/test_beaver_terminal_client.py new file mode 100644 index 0000000..d5bbcef --- /dev/null +++ b/test_beaver_terminal_client.py @@ -0,0 +1,320 @@ +import json + +import aiohttp +import pytest +from aiohttp import web + +from custom.beaver_terminal_client import ( + BeaverTerminalClient, + BeaverTerminalError, + MessageIdGenerator, + build_connect_frame, + build_message_frame, +) + + +def test_build_connect_frame_uses_stable_peer_id() -> None: + frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal") + + assert frame == { + "type": "connect", + "peer_id": "device-001", + "device_name": "desk-terminal", + "capabilities": ["text"], + } + + +def test_build_message_frame_uses_message_id_and_text() -> None: + frame = build_message_frame(message_id="device-001-000001", text="hello") + + assert frame == { + "type": "message", + "message_id": "device-001-000001", + "text": "hello", + } + + +def test_message_id_generator_uses_monotonic_peer_counter() -> None: + generator = MessageIdGenerator(peer_id="device-001", initial_counter=7) + + assert generator.next_id() == "device-001-000008" + assert generator.next_id() == "device-001-000009" + assert generator.counter == 9 + + +async def test_client_connects_sends_text_and_returns_assistant_reply( + unused_tcp_port: int, +) -> None: + received: list[dict[str, object]] = [] + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + received.append(frame) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "assistant reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert client.session_id == "terminal-dev:local:device-001" + assert reply == "assistant reply" + assert received == [ + { + "type": "connect", + "peer_id": "device-001", + "device_name": "desk-terminal", + "capabilities": ["text"], + }, + { + "type": "message", + "message_id": "device-001-000001", + "text": "hello", + }, + ] + + +async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": False, + "duplicate": True, + "pending": False, + "reply": "cached assistant reply", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert reply == "cached assistant reply" + + +async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json({"type": "error", "error": "text is required"}) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + with pytest.raises(BeaverTerminalError, match="text is required"): + await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + +async def test_client_treats_assistant_finish_reason_error_as_failed_turn( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "failed turn", + "finish_reason": "error", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + with pytest.raises(BeaverTerminalError, match="failed turn"): + await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + +async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + elif frame["type"] == "ping": + await ws.send_json({"type": "pong"}) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + assert await client.ping() + finally: + await client.close() + await runner.cleanup()