fix: reconnect beaver terminal websocket before retrying turn
This commit is contained in:
@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageIdGenerator:
|
||||
peer_id: str
|
||||
@ -71,6 +75,7 @@ class BeaverTerminalClient:
|
||||
self.session_id: str | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._close_websocket()
|
||||
session = self._ensure_http_session()
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
@ -84,12 +89,29 @@ class BeaverTerminalClient:
|
||||
|
||||
async def send_text(self, text: str) -> str:
|
||||
message_id = self._message_ids.next_id()
|
||||
await self._send_json(build_message_frame(message_id=message_id, text=text))
|
||||
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||
|
||||
for attempt in range(2):
|
||||
if not self._websocket_is_open():
|
||||
await self.connect()
|
||||
|
||||
try:
|
||||
await self._send_json(message_frame)
|
||||
return await self._wait_for_reply(message_id)
|
||||
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||
if attempt == 1:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
"Beaver websocket closed before assistant reply"
|
||||
) from exc
|
||||
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||
await self.connect()
|
||||
|
||||
raise BeaverTerminalError("unreachable Beaver send state")
|
||||
|
||||
async def _wait_for_reply(self, message_id: str) -> str:
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||
reply = frame.get("reply")
|
||||
if isinstance(reply, str):
|
||||
@ -121,13 +143,19 @@ class BeaverTerminalClient:
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
await self._close_websocket()
|
||||
if self._owned_session and self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
async def _close_websocket(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
def _websocket_is_open(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
@ -143,6 +171,12 @@ class BeaverTerminalClient:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
|
||||
message = await self._ws.receive()
|
||||
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||
if message.type == aiohttp.WSMsgType.ERROR:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||
)
|
||||
if message.type != aiohttp.WSMsgType.TEXT:
|
||||
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||
data = message.json()
|
||||
@ -173,7 +207,11 @@ async def run_console() -> None:
|
||||
continue
|
||||
if text in {"quit", "exit"}:
|
||||
return
|
||||
reply = await client.send_text(text)
|
||||
try:
|
||||
reply = await client.send_text(text)
|
||||
except BeaverTerminalError as exc:
|
||||
logger.error("Beaver turn failed: %s", exc)
|
||||
continue
|
||||
print(reply)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
Reference in New Issue
Block a user