fix: reconnect beaver terminal websocket before retrying turn
This commit is contained in:
@ -21,6 +21,10 @@ class BeaverTerminalError(RuntimeError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageIdGenerator:
|
class MessageIdGenerator:
|
||||||
peer_id: str
|
peer_id: str
|
||||||
@ -71,6 +75,7 @@ class BeaverTerminalClient:
|
|||||||
self.session_id: str | None = None
|
self.session_id: str | None = None
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
|
await self._close_websocket()
|
||||||
session = self._ensure_http_session()
|
session = self._ensure_http_session()
|
||||||
self._ws = await session.ws_connect(self._url)
|
self._ws = await session.ws_connect(self._url)
|
||||||
await self._send_json(
|
await self._send_json(
|
||||||
@ -84,12 +89,29 @@ class BeaverTerminalClient:
|
|||||||
|
|
||||||
async def send_text(self, text: str) -> str:
|
async def send_text(self, text: str) -> str:
|
||||||
message_id = self._message_ids.next_id()
|
message_id = self._message_ids.next_id()
|
||||||
await self._send_json(build_message_frame(message_id=message_id, text=text))
|
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||||
|
|
||||||
|
for attempt in range(2):
|
||||||
|
if not self._websocket_is_open():
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_json(message_frame)
|
||||||
|
return await self._wait_for_reply(message_id)
|
||||||
|
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||||
|
if attempt == 1:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
"Beaver websocket closed before assistant reply"
|
||||||
|
) from exc
|
||||||
|
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
raise BeaverTerminalError("unreachable Beaver send state")
|
||||||
|
|
||||||
|
async def _wait_for_reply(self, message_id: str) -> str:
|
||||||
while True:
|
while True:
|
||||||
frame = await self._receive_json()
|
frame = await self._receive_json()
|
||||||
frame_type = frame.get("type")
|
frame_type = frame.get("type")
|
||||||
|
|
||||||
if frame_type == "ack" and frame.get("message_id") == message_id:
|
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||||
reply = frame.get("reply")
|
reply = frame.get("reply")
|
||||||
if isinstance(reply, str):
|
if isinstance(reply, str):
|
||||||
@ -121,13 +143,19 @@ class BeaverTerminalClient:
|
|||||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self._ws is not None:
|
await self._close_websocket()
|
||||||
await self._ws.close()
|
|
||||||
self._ws = None
|
|
||||||
if self._owned_session and self._http_session is not None:
|
if self._owned_session and self._http_session is not None:
|
||||||
await self._http_session.close()
|
await self._http_session.close()
|
||||||
self._http_session = None
|
self._http_session = None
|
||||||
|
|
||||||
|
async def _close_websocket(self) -> None:
|
||||||
|
if self._ws is not None:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
def _websocket_is_open(self) -> bool:
|
||||||
|
return self._ws is not None and not self._ws.closed
|
||||||
|
|
||||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||||
if self._http_session is None:
|
if self._http_session is None:
|
||||||
self._http_session = aiohttp.ClientSession()
|
self._http_session = aiohttp.ClientSession()
|
||||||
@ -143,6 +171,12 @@ class BeaverTerminalClient:
|
|||||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||||
|
|
||||||
message = await self._ws.receive()
|
message = await self._ws.receive()
|
||||||
|
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||||
|
if message.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
raise BeaverTerminalConnectionClosed(
|
||||||
|
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||||
|
)
|
||||||
if message.type != aiohttp.WSMsgType.TEXT:
|
if message.type != aiohttp.WSMsgType.TEXT:
|
||||||
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||||
data = message.json()
|
data = message.json()
|
||||||
@ -173,7 +207,11 @@ async def run_console() -> None:
|
|||||||
continue
|
continue
|
||||||
if text in {"quit", "exit"}:
|
if text in {"quit", "exit"}:
|
||||||
return
|
return
|
||||||
|
try:
|
||||||
reply = await client.send_text(text)
|
reply = await client.send_text(text)
|
||||||
|
except BeaverTerminalError as exc:
|
||||||
|
logger.error("Beaver turn failed: %s", exc)
|
||||||
|
continue
|
||||||
print(reply)
|
print(reply)
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
|
|
||||||
|
try:
|
||||||
from custom.beaver_terminal_client import (
|
from custom.beaver_terminal_client import (
|
||||||
BeaverTerminalClient,
|
BeaverTerminalClient,
|
||||||
BeaverTerminalError,
|
BeaverTerminalError,
|
||||||
@ -11,6 +19,14 @@ from custom.beaver_terminal_client import (
|
|||||||
build_connect_frame,
|
build_connect_frame,
|
||||||
build_message_frame,
|
build_message_frame,
|
||||||
)
|
)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from beaver_terminal_client import (
|
||||||
|
BeaverTerminalClient,
|
||||||
|
BeaverTerminalError,
|
||||||
|
MessageIdGenerator,
|
||||||
|
build_connect_frame,
|
||||||
|
build_message_frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
||||||
@ -318,3 +334,76 @@ async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -
|
|||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
await runner.cleanup()
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> None:
|
||||||
|
connect_peer_ids: list[str] = []
|
||||||
|
connection_count = 0
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
nonlocal connection_count
|
||||||
|
connection_count += 1
|
||||||
|
current_connection = connection_count
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for message in ws:
|
||||||
|
frame = json.loads(message.data)
|
||||||
|
if frame["type"] == "connect":
|
||||||
|
connect_peer_ids.append(frame["peer_id"])
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": "terminal-dev",
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if current_connection == 1:
|
||||||
|
await ws.close()
|
||||||
|
elif frame["type"] == "message":
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"session_id": "terminal-dev:local:device-001",
|
||||||
|
"accepted": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await ws.send_json(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": frame["message_id"],
|
||||||
|
"run_id": "run-2",
|
||||||
|
"text": "reply after reconnect",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
client = BeaverTerminalClient(
|
||||||
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||||
|
peer_id="device-001",
|
||||||
|
device_name="desk-terminal",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.connect()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
reply = await client.send_text("hello")
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
assert reply == "reply after reconnect"
|
||||||
|
assert connect_peer_ids == ["device-001", "device-001"]
|
||||||
|
|||||||
Reference in New Issue
Block a user