beaver test
This commit is contained in:
228
beaver_terminal_client.py
Normal file
228
beaver_terminal_client.py
Normal file
@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger("beaver-terminal-client")
|
||||
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
|
||||
DEFAULT_TERMINAL_PEER_ID = "device-001"
|
||||
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
|
||||
|
||||
class BeaverTerminalError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageIdGenerator:
|
||||
peer_id: str
|
||||
initial_counter: int = 0
|
||||
instance_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.counter = self.initial_counter
|
||||
|
||||
def next_id(self) -> str:
|
||||
self.counter += 1
|
||||
if self.instance_id:
|
||||
return f"{self.peer_id}-{self.instance_id}-{self.counter:06d}"
|
||||
return f"{self.peer_id}-{self.counter:06d}"
|
||||
|
||||
|
||||
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "connect",
|
||||
"peer_id": peer_id,
|
||||
"device_name": device_name,
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
|
||||
|
||||
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "message",
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
|
||||
class BeaverTerminalClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
peer_id: str,
|
||||
device_name: str,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
message_ids: MessageIdGenerator | None = None,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._peer_id = peer_id
|
||||
self._device_name = device_name
|
||||
self._owned_session = http_session is None
|
||||
self._http_session = http_session
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._message_ids = message_ids or MessageIdGenerator(
|
||||
peer_id=peer_id,
|
||||
instance_id=uuid4().hex[:8],
|
||||
)
|
||||
self.session_id: str | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._close_websocket()
|
||||
session = self._ensure_http_session()
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||
)
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") != "connected":
|
||||
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||
session_id = frame.get("session_id")
|
||||
self.session_id = session_id if isinstance(session_id, str) else None
|
||||
|
||||
async def send_text(self, text: str) -> str:
|
||||
for attempt in range(2):
|
||||
if not self._websocket_is_open():
|
||||
await self.connect()
|
||||
|
||||
message_id = self._message_ids.next_id()
|
||||
message_frame = build_message_frame(message_id=message_id, text=text)
|
||||
|
||||
try:
|
||||
await self._send_json(message_frame)
|
||||
return await self._wait_for_reply(message_id)
|
||||
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
|
||||
if attempt == 1:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
"Beaver websocket closed before assistant reply"
|
||||
) from exc
|
||||
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
|
||||
await self.connect()
|
||||
|
||||
raise BeaverTerminalError("unreachable Beaver send state")
|
||||
|
||||
async def _wait_for_reply(self, message_id: str) -> str:
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
frame_type = frame.get("type")
|
||||
if frame_type == "ack" and frame.get("message_id") == message_id:
|
||||
reply = frame.get("reply")
|
||||
if isinstance(reply, str):
|
||||
return reply
|
||||
continue
|
||||
|
||||
if (
|
||||
frame_type == "message"
|
||||
and frame.get("role") == "assistant"
|
||||
and frame.get("message_id") == message_id
|
||||
):
|
||||
text = frame.get("text")
|
||||
if frame.get("finish_reason") == "error":
|
||||
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
|
||||
return text if isinstance(text, str) else ""
|
||||
|
||||
if frame_type == "error":
|
||||
error = frame.get("error")
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def ping(self) -> bool:
|
||||
await self._send_json({"type": "ping"})
|
||||
while True:
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") == "pong":
|
||||
return True
|
||||
if frame.get("type") == "error":
|
||||
error = frame.get("error")
|
||||
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._close_websocket()
|
||||
if self._owned_session and self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
async def _close_websocket(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
def _websocket_is_open(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
return self._http_session
|
||||
|
||||
async def _send_json(self, frame: dict[str, Any]) -> None:
|
||||
if self._ws is None:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
await self._ws.send_json(frame)
|
||||
|
||||
async def _receive_json(self) -> dict[str, Any]:
|
||||
if self._ws is None:
|
||||
raise BeaverTerminalError("Beaver websocket is not connected")
|
||||
|
||||
message = await self._ws.receive()
|
||||
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
|
||||
if message.type == aiohttp.WSMsgType.ERROR:
|
||||
raise BeaverTerminalConnectionClosed(
|
||||
f"Beaver websocket error: {self._ws.exception()!r}"
|
||||
)
|
||||
if message.type != aiohttp.WSMsgType.TEXT:
|
||||
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
|
||||
data = message.json()
|
||||
if not isinstance(data, dict):
|
||||
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
|
||||
return data
|
||||
|
||||
|
||||
def client_from_env() -> BeaverTerminalClient:
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
return BeaverTerminalClient(
|
||||
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
|
||||
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
|
||||
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
|
||||
)
|
||||
|
||||
|
||||
async def run_console() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
client = client_from_env()
|
||||
try:
|
||||
await client.connect()
|
||||
logger.info("Connected to Beaver session_id=%s", client.session_id)
|
||||
while True:
|
||||
text = await asyncio.to_thread(input, "> ")
|
||||
text = text.strip()
|
||||
if not text:
|
||||
continue
|
||||
if text in {"quit", "exit"}:
|
||||
return
|
||||
try:
|
||||
reply = await client.send_text(text)
|
||||
except BeaverTerminalError as exc:
|
||||
logger.error("Beaver turn failed: %s", exc)
|
||||
continue
|
||||
print(reply)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_console())
|
||||
253
test_beaver_llm.py
Normal file
253
test_beaver_llm.py
Normal file
@ -0,0 +1,253 @@
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
try:
|
||||
from custom.beaver_llm import BeaverLLM, latest_user_text
|
||||
except ModuleNotFoundError:
|
||||
from beaver_llm import BeaverLLM, latest_user_text
|
||||
from livekit.agents import ChatContext
|
||||
|
||||
|
||||
def test_latest_user_text_uses_most_recent_user_message() -> None:
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="first")
|
||||
ctx.add_message(role="assistant", content="ignored")
|
||||
ctx.add_message(role="user", content=["second", "line"])
|
||||
|
||||
assert latest_user_text(ctx) == "second\nline"
|
||||
|
||||
|
||||
async def test_beaver_llm_can_connect_before_first_message(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
frame = json.loads(message.data)
|
||||
received.append(frame)
|
||||
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-1",
|
||||
"text": "beaver reply",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
beaver_llm = BeaverLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="livekit-room",
|
||||
device_name="livekit-custom-agent",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="hello beaver")
|
||||
|
||||
try:
|
||||
await beaver_llm.connect()
|
||||
assert beaver_llm.session_id == "terminal-dev:local:livekit-room"
|
||||
assert received == [
|
||||
{
|
||||
"type": "connect",
|
||||
"peer_id": "livekit-room",
|
||||
"device_name": "livekit-custom-agent",
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
]
|
||||
|
||||
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await beaver_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert collected.text == "beaver reply"
|
||||
assert received[1]["type"] == "message"
|
||||
assert received[1]["text"] == "hello beaver"
|
||||
|
||||
|
||||
async def test_beaver_llm_connect_can_send_warmup_message(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
frame = json.loads(message.data)
|
||||
received.append(frame)
|
||||
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-warmup",
|
||||
"text": "ready",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
beaver_llm = BeaverLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="livekit-room",
|
||||
device_name="livekit-custom-agent",
|
||||
)
|
||||
|
||||
try:
|
||||
warmup_reply = await beaver_llm.connect(warmup_text="初始化连接")
|
||||
finally:
|
||||
await beaver_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert warmup_reply == "ready"
|
||||
assert received[0] == {
|
||||
"type": "connect",
|
||||
"peer_id": "livekit-room",
|
||||
"device_name": "livekit-custom-agent",
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
assert received[1]["type"] == "message"
|
||||
assert received[1]["text"] == "初始化连接"
|
||||
|
||||
|
||||
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
frame = json.loads(message.data)
|
||||
received.append(frame)
|
||||
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-1",
|
||||
"text": "beaver reply",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
beaver_llm = BeaverLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="livekit-room",
|
||||
device_name="livekit-custom-agent",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="system", content="ignored instructions")
|
||||
ctx.add_message(role="user", content="hello beaver")
|
||||
|
||||
try:
|
||||
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
||||
finally:
|
||||
await beaver_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert collected.text == "beaver reply"
|
||||
assert received[0] == {
|
||||
"type": "connect",
|
||||
"peer_id": "livekit-room",
|
||||
"device_name": "livekit-custom-agent",
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
assert received[1]["type"] == "message"
|
||||
assert received[1]["message_id"].startswith("livekit-room-")
|
||||
assert received[1]["message_id"].endswith("-000001")
|
||||
assert received[1]["text"] == "hello beaver"
|
||||
426
test_beaver_terminal_client.py
Normal file
426
test_beaver_terminal_client.py
Normal file
@ -0,0 +1,426 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
raise SystemExit(pytest.main([__file__]))
|
||||
|
||||
try:
|
||||
from custom.beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
build_message_frame,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
from beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
build_message_frame,
|
||||
)
|
||||
|
||||
|
||||
def test_build_connect_frame_uses_stable_peer_id() -> None:
|
||||
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
|
||||
|
||||
assert frame == {
|
||||
"type": "connect",
|
||||
"peer_id": "device-001",
|
||||
"device_name": "desk-terminal",
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
|
||||
|
||||
def test_build_message_frame_uses_message_id_and_text() -> None:
|
||||
frame = build_message_frame(message_id="device-001-000001", text="hello")
|
||||
|
||||
assert frame == {
|
||||
"type": "message",
|
||||
"message_id": "device-001-000001",
|
||||
"text": "hello",
|
||||
}
|
||||
|
||||
|
||||
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
|
||||
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
|
||||
|
||||
assert generator.next_id() == "device-001-000008"
|
||||
assert generator.next_id() == "device-001-000009"
|
||||
assert generator.counter == 9
|
||||
|
||||
|
||||
def test_message_id_generator_can_include_instance_id() -> None:
|
||||
generator = MessageIdGenerator(peer_id="device-001", instance_id="abc123ef")
|
||||
|
||||
assert generator.next_id() == "device-001-abc123ef-000001"
|
||||
assert generator.next_id() == "device-001-abc123ef-000002"
|
||||
|
||||
|
||||
async def test_client_connects_sends_text_and_returns_assistant_reply(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
received: list[dict[str, object]] = []
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
frame = json.loads(message.data)
|
||||
received.append(frame)
|
||||
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-1",
|
||||
"text": "assistant reply",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
reply = await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
assert client.session_id == "terminal-dev:local:device-001"
|
||||
assert reply == "assistant reply"
|
||||
assert received == [
|
||||
{
|
||||
"type": "connect",
|
||||
"peer_id": "device-001",
|
||||
"device_name": "desk-terminal",
|
||||
"capabilities": ["text"],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"message_id": "device-001-000001",
|
||||
"text": "hello",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
"accepted": False,
|
||||
"duplicate": True,
|
||||
"pending": False,
|
||||
"reply": "cached assistant reply",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
reply = await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
assert reply == "cached assistant reply"
|
||||
|
||||
|
||||
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json({"type": "error", "error": "text is required"})
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
with pytest.raises(BeaverTerminalError, match="text is required"):
|
||||
await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-1",
|
||||
"text": "failed turn",
|
||||
"finish_reason": "error",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
with pytest.raises(BeaverTerminalError, match="failed turn"):
|
||||
await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "ping":
|
||||
await ws.send_json({"type": "pong"})
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
assert await client.ping()
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
connect_peer_ids: list[str] = []
|
||||
message_ids: list[str] = []
|
||||
connection_count = 0
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
nonlocal connection_count
|
||||
connection_count += 1
|
||||
current_connection = connection_count
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
frame = json.loads(message.data)
|
||||
if frame["type"] == "connect":
|
||||
connect_peer_ids.append(frame["peer_id"])
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
message_ids.append(frame["message_id"])
|
||||
if current_connection == 1:
|
||||
await ws.close()
|
||||
continue
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:device-001",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-2",
|
||||
"text": "reply after reconnect",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await asyncio.sleep(0.01)
|
||||
reply = await client.send_text("hello")
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
assert reply == "reply after reconnect"
|
||||
assert connect_peer_ids == ["device-001", "device-001"]
|
||||
assert message_ids == ["device-001-000001", "device-001-000002"]
|
||||
Reference in New Issue
Block a user