fix: reconnect beaver terminal websocket before retrying turn

This commit is contained in:
0Xiao0
2026-06-02 10:18:19 +08:00
parent b1cad592e2
commit 879c73bfee
2 changed files with 140 additions and 13 deletions

View File

@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
@ -71,6 +75,7 @@ class BeaverTerminalClient:
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
@ -84,12 +89,29 @@ class BeaverTerminalClient:
async def send_text(self, text: str) -> str:
message_id = self._message_ids.next_id()
await self._send_json(build_message_frame(message_id=message_id, text=text))
message_frame = build_message_frame(message_id=message_id, text=text)
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
@ -121,13 +143,19 @@ class BeaverTerminalClient:
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
@ -143,6 +171,12 @@ class BeaverTerminalClient:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
@ -173,7 +207,11 @@ async def run_console() -> None:
continue
if text in {"quit", "exit"}:
return
reply = await client.send_text(text)
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()