fix: reconnect beaver terminal websocket before retrying turn
This commit is contained in:
@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageIdGenerator:
|
||||
peer_id: str
|
||||
@ -71,6 +75,7 @@ class BeaverTerminalClient:
|
||||
self.session_id: str | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._close_websocket()
|
||||
session = self._ensure_http_session()
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
@ -84,12 +89,29 @@ class BeaverTerminalClient:
|
||||
|
||||
async def send_text(self, text: str) -> str:
|
||||
message_id = self._message_ids.next_id()
|
||||
await self._send_json(build_message_frame(message_id=message_id, text=text))
|
||||
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||
|
||||
for attempt in range(2):
|
||||
if not self._websocket_is_open():
|
||||
await self.connect()
|
||||
|
||||
try:
|
||||
await self._send_json(message_frame)
|
||||
return await self._wait_for_reply(message_id)
|
||||
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||
if attempt == 1:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
"Beaver websocket closed before assistant reply"
|
||||
) from exc
|
||||
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||
await self.connect()
|
||||
|
||||
raise BeaverTerminalError("unreachable Beaver send state")
|
||||
|
||||
async def _wait_for_reply(self, message_id: str) -> str:
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||
reply = frame.get("reply")
|
||||
if isinstance(reply, str):
|
||||
@ -121,13 +143,19 @@ class BeaverTerminalClient:
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
await self._close_websocket()
|
||||
if self._owned_session and self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
async def _close_websocket(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
def _websocket_is_open(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
@ -143,6 +171,12 @@ class BeaverTerminalClient:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
|
||||
message = await self._ws.receive()
|
||||
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||
if message.type == aiohttp.WSMsgType.ERROR:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||
)
|
||||
if message.type != aiohttp.WSMsgType.TEXT:
|
||||
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||
data = message.json()
|
||||
@ -173,7 +207,11 @@ async def run_console() -> None:
|
||||
continue
|
||||
if text in {"quit", "exit"}:
|
||||
return
|
||||
try:
|
||||
reply = await client.send_text(text)
|
||||
except BeaverTerminalError as exc:
|
||||
logger.error("Beaver turn failed: %s", exc)
|
||||
continue
|
||||
print(reply)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
@ -1,16 +1,32 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from custom.beaver_terminal_client import (
|
||||
if __name__ == "__main__":
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
raise SystemExit(pytest.main([__file__]))
|
||||
|
||||
try:
|
||||
from custom.beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
build_message_frame,
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
from beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
build_message_frame,
|
||||
)
|
||||
|
||||
|
||||
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
||||
@ -318,3 +334,76 @@ async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
connect_peer_ids: list[str] = []
|
||||
connection_count = 0
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
nonlocal connection_count
|
||||
connection_count += 1
|
||||
current_connection = connection_count
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
connect_peer_ids.append(frame["peer_id"])
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
if current_connection == 1:
|
||||
await ws.close()
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-2",
|
||||
"text": "reply after reconnect",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await asyncio.sleep(0.01)
|
||||
reply = await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
assert reply == "reply after reconnect"
|
||||
assert connect_peer_ids == ["device-001", "device-001"]
|
||||
|
||||
Reference in New Issue
Block a user