fix: reconnect beaver terminal websocket before retrying turn

This commit is contained in:
0Xiao0
2026-06-02 10:18:19 +08:00
parent b1cad592e2
commit 879c73bfee
2 changed files with 140 additions and 13 deletions

View File

@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
@ -71,6 +75,7 @@ class BeaverTerminalClient:
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
@ -84,12 +89,29 @@ class BeaverTerminalClient:
async def send_text(self, text: str) -> str:
message_id = self._message_ids.next_id()
await self._send_json(build_message_frame(message_id=message_id, text=text))
message_frame = build_message_frame(message_id=message_id, text=text)
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
@ -121,13 +143,19 @@ class BeaverTerminalClient:
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
@ -143,6 +171,12 @@ class BeaverTerminalClient:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
@ -173,7 +207,11 @@ async def run_console() -> None:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()

View File

@ -1,16 +1,32 @@
import asyncio
import json
import sys
from pathlib import Path
import aiohttp
import pytest
from aiohttp import web
from custom.beaver_terminal_client import (
if __name__ == "__main__":
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
raise SystemExit(pytest.main([__file__]))
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
)
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
def test_build_connect_frame_uses_stable_peer_id() -> None:
@ -318,3 +334,76 @@ async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -
finally:
await client.close()
await runner.cleanup()
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
unused_tcp_port: int,
) -> None:
connect_peer_ids: list[str] = []
connection_count = 0
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
nonlocal connection_count
connection_count += 1
current_connection = connection_count
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
connect_peer_ids.append(frame["peer_id"])
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
if current_connection == 1:
await ws.close()
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-2",
"text": "reply after reconnect",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
)
try:
await client.connect()
await asyncio.sleep(0.01)
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "reply after reconnect"
assert connect_peer_ids == ["device-001", "device-001"]