From 879c73bfee6d09dbe6ea01a24ff666a4052cba11 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Tue, 2 Jun 2026 10:18:19 +0800 Subject: [PATCH] fix: reconnect beaver terminal websocket before retrying turn --- beaver_terminal_client.py | 50 ++++++++++++++-- test_beaver_terminal_client.py | 103 ++++++++++++++++++++++++++++++--- 2 files changed, 140 insertions(+), 13 deletions(-) diff --git a/beaver_terminal_client.py b/beaver_terminal_client.py index 45d5598..039b967 100644 --- a/beaver_terminal_client.py +++ b/beaver_terminal_client.py @@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError): pass +class BeaverTerminalConnectionClosed(BeaverTerminalError): + pass + + @dataclass class MessageIdGenerator: peer_id: str @@ -71,6 +75,7 @@ class BeaverTerminalClient: self.session_id: str | None = None async def connect(self) -> None: + await self._close_websocket() session = self._ensure_http_session() self._ws = await session.ws_connect(self._url) await self._send_json( @@ -84,12 +89,29 @@ class BeaverTerminalClient: async def send_text(self, text: str) -> str: message_id = self._message_ids.next_id() - await self._send_json(build_message_frame(message_id=message_id, text=text)) + message_frame = build_message_frame(message_id=message_id, text=text) + for attempt in range(2): + if not self._websocket_is_open(): + await self.connect() + + try: + await self._send_json(message_frame) + return await self._wait_for_reply(message_id) + except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc: + if attempt == 1: + raise BeaverTerminalConnectionClosed( + "Beaver websocket closed before assistant reply" + ) from exc + logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id") + await self.connect() + + raise BeaverTerminalError("unreachable Beaver send state") + + async def _wait_for_reply(self, message_id: str) -> str: while True: frame = await self._receive_json() frame_type = frame.get("type") - if frame_type == "ack" and frame.get("message_id") == message_id: reply = frame.get("reply") if isinstance(reply, str): @@ -121,13 +143,19 @@ class BeaverTerminalClient: raise BeaverTerminalError(error if isinstance(error, str) else "unknown error") async def close(self) -> None: - if self._ws is not None: - await self._ws.close() - self._ws = None + await self._close_websocket() if self._owned_session and self._http_session is not None: await self._http_session.close() self._http_session = None + async def _close_websocket(self) -> None: + if self._ws is not None: + await self._ws.close() + self._ws = None + + def _websocket_is_open(self) -> bool: + return self._ws is not None and not self._ws.closed + def _ensure_http_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = aiohttp.ClientSession() @@ -143,6 +171,12 @@ class BeaverTerminalClient: raise BeaverTerminalError("Beaver websocket is not connected") message = await self._ws.receive() + if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): + raise BeaverTerminalConnectionClosed("Beaver websocket closed") + if message.type == aiohttp.WSMsgType.ERROR: + raise BeaverTerminalConnectionClosed( + f"Beaver websocket error: {self._ws.exception()!r}" + ) if message.type != aiohttp.WSMsgType.TEXT: raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}") data = message.json() @@ -173,7 +207,11 @@ async def run_console() -> None: continue if text in {"quit", "exit"}: return - reply = await client.send_text(text) + try: + reply = await client.send_text(text) + except BeaverTerminalError as exc: + logger.error("Beaver turn failed: %s", exc) + continue print(reply) finally: await client.close() diff --git a/test_beaver_terminal_client.py b/test_beaver_terminal_client.py index d5bbcef..dc6ee8a 100644 --- a/test_beaver_terminal_client.py +++ b/test_beaver_terminal_client.py @@ -1,16 +1,32 @@ +import asyncio import json +import sys +from pathlib import Path import aiohttp import pytest from aiohttp import web -from custom.beaver_terminal_client import ( - BeaverTerminalClient, - BeaverTerminalError, - MessageIdGenerator, - build_connect_frame, - build_message_frame, -) +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + raise SystemExit(pytest.main([__file__])) + +try: + from custom.beaver_terminal_client import ( + BeaverTerminalClient, + BeaverTerminalError, + MessageIdGenerator, + build_connect_frame, + build_message_frame, + ) +except ModuleNotFoundError: + from beaver_terminal_client import ( + BeaverTerminalClient, + BeaverTerminalError, + MessageIdGenerator, + build_connect_frame, + build_message_frame, + ) def test_build_connect_frame_uses_stable_peer_id() -> None: @@ -318,3 +334,76 @@ async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) - finally: await client.close() await runner.cleanup() + + +async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send( + unused_tcp_port: int, +) -> None: + connect_peer_ids: list[str] = [] + connection_count = 0 + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + nonlocal connection_count + connection_count += 1 + current_connection = connection_count + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + frame = json.loads(message.data) + if frame["type"] == "connect": + connect_peer_ids.append(frame["peer_id"]) + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:device-001", + } + ) + if current_connection == 1: + await ws.close() + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:device-001", + "accepted": True, + } + ) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-2", + "text": "reply after reconnect", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + ) + + try: + await client.connect() + await asyncio.sleep(0.01) + reply = await client.send_text("hello") + finally: + await client.close() + await runner.cleanup() + + assert reply == "reply after reconnect" + assert connect_peer_ids == ["device-001", "device-001"]