beaver test

This commit is contained in:
0Xiao0
2026-06-03 17:26:46 +08:00
parent 409c7c9de0
commit f368e156f0
3 changed files with 907 additions and 0 deletions

228
beaver_terminal_client.py Normal file
View File

@ -0,0 +1,228 @@
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from uuid import uuid4
import aiohttp
from dotenv import load_dotenv
logger = logging.getLogger("beaver-terminal-client")
DEFAULT_BEAVER_WS_URL = "ws://127.0.0.1:8080/api/channels/terminal-dev/ws"
DEFAULT_TERMINAL_PEER_ID = "device-001"
DEFAULT_TERMINAL_DEVICE_NAME = "desk-terminal"
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
class BeaverTerminalError(RuntimeError):
pass
class BeaverTerminalConnectionClosed(BeaverTerminalError):
pass
@dataclass
class MessageIdGenerator:
peer_id: str
initial_counter: int = 0
instance_id: str | None = None
def __post_init__(self) -> None:
self.counter = self.initial_counter
def next_id(self) -> str:
self.counter += 1
if self.instance_id:
return f"{self.peer_id}-{self.instance_id}-{self.counter:06d}"
return f"{self.peer_id}-{self.counter:06d}"
def build_connect_frame(*, peer_id: str, device_name: str) -> dict[str, Any]:
return {
"type": "connect",
"peer_id": peer_id,
"device_name": device_name,
"capabilities": ["text"],
}
def build_message_frame(*, message_id: str, text: str) -> dict[str, Any]:
return {
"type": "message",
"message_id": message_id,
"text": text,
}
class BeaverTerminalClient:
def __init__(
self,
*,
url: str,
peer_id: str,
device_name: str,
http_session: aiohttp.ClientSession | None = None,
message_ids: MessageIdGenerator | None = None,
) -> None:
self._url = url
self._peer_id = peer_id
self._device_name = device_name
self._owned_session = http_session is None
self._http_session = http_session
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._message_ids = message_ids or MessageIdGenerator(
peer_id=peer_id,
instance_id=uuid4().hex[:8],
)
self.session_id: str | None = None
async def connect(self) -> None:
await self._close_websocket()
session = self._ensure_http_session()
self._ws = await session.ws_connect(self._url)
await self._send_json(
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
)
frame = await self._receive_json()
if frame.get("type") != "connected":
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
session_id = frame.get("session_id")
self.session_id = session_id if isinstance(session_id, str) else None
async def send_text(self, text: str) -> str:
for attempt in range(2):
if not self._websocket_is_open():
await self.connect()
message_id = self._message_ids.next_id()
message_frame = build_message_frame(message_id=message_id, text=text)
try:
await self._send_json(message_frame)
return await self._wait_for_reply(message_id)
except (aiohttp.ClientConnectionError, BeaverTerminalConnectionClosed) as exc:
if attempt == 1:
raise BeaverTerminalConnectionClosed(
"Beaver websocket closed before assistant reply"
) from exc
logger.info("Beaver websocket closed mid-turn; reconnecting with same peer_id")
await self.connect()
raise BeaverTerminalError("unreachable Beaver send state")
async def _wait_for_reply(self, message_id: str) -> str:
while True:
frame = await self._receive_json()
frame_type = frame.get("type")
if frame_type == "ack" and frame.get("message_id") == message_id:
reply = frame.get("reply")
if isinstance(reply, str):
return reply
continue
if (
frame_type == "message"
and frame.get("role") == "assistant"
and frame.get("message_id") == message_id
):
text = frame.get("text")
if frame.get("finish_reason") == "error":
raise BeaverTerminalError(text if isinstance(text, str) else "assistant turn failed")
return text if isinstance(text, str) else ""
if frame_type == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def ping(self) -> bool:
await self._send_json({"type": "ping"})
while True:
frame = await self._receive_json()
if frame.get("type") == "pong":
return True
if frame.get("type") == "error":
error = frame.get("error")
raise BeaverTerminalError(error if isinstance(error, str) else "unknown error")
async def close(self) -> None:
await self._close_websocket()
if self._owned_session and self._http_session is not None:
await self._http_session.close()
self._http_session = None
async def _close_websocket(self) -> None:
if self._ws is not None:
await self._ws.close()
self._ws = None
def _websocket_is_open(self) -> bool:
return self._ws is not None and not self._ws.closed
def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
return self._http_session
async def _send_json(self, frame: dict[str, Any]) -> None:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
await self._ws.send_json(frame)
async def _receive_json(self) -> dict[str, Any]:
if self._ws is None:
raise BeaverTerminalError("Beaver websocket is not connected")
message = await self._ws.receive()
if message.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
raise BeaverTerminalConnectionClosed("Beaver websocket closed")
if message.type == aiohttp.WSMsgType.ERROR:
raise BeaverTerminalConnectionClosed(
f"Beaver websocket error: {self._ws.exception()!r}"
)
if message.type != aiohttp.WSMsgType.TEXT:
raise BeaverTerminalError(f"expected Beaver text frame, received {message.type!r}")
data = message.json()
if not isinstance(data, dict):
raise BeaverTerminalError(f"expected Beaver JSON object, received {data!r}")
return data
def client_from_env() -> BeaverTerminalClient:
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
return BeaverTerminalClient(
url=os.getenv("BEAVER_WS_URL", DEFAULT_BEAVER_WS_URL),
peer_id=os.getenv("TERMINAL_PEER_ID", DEFAULT_TERMINAL_PEER_ID),
device_name=os.getenv("TERMINAL_DEVICE_NAME", DEFAULT_TERMINAL_DEVICE_NAME),
)
async def run_console() -> None:
logging.basicConfig(level=logging.INFO)
client = client_from_env()
try:
await client.connect()
logger.info("Connected to Beaver session_id=%s", client.session_id)
while True:
text = await asyncio.to_thread(input, "> ")
text = text.strip()
if not text:
continue
if text in {"quit", "exit"}:
return
try:
reply = await client.send_text(text)
except BeaverTerminalError as exc:
logger.error("Beaver turn failed: %s", exc)
continue
print(reply)
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(run_console())

253
test_beaver_llm.py Normal file
View File

@ -0,0 +1,253 @@
import json
import aiohttp
from aiohttp import web
try:
from custom.beaver_llm import BeaverLLM, latest_user_text
except ModuleNotFoundError:
from beaver_llm import BeaverLLM, latest_user_text
from livekit.agents import ChatContext
def test_latest_user_text_uses_most_recent_user_message() -> None:
ctx = ChatContext.empty()
ctx.add_message(role="user", content="first")
ctx.add_message(role="assistant", content="ignored")
ctx.add_message(role="user", content=["second", "line"])
assert latest_user_text(ctx) == "second\nline"
async def test_beaver_llm_can_connect_before_first_message(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="user", content="hello beaver")
try:
await beaver_llm.connect()
assert beaver_llm.session_id == "terminal-dev:local:livekit-room"
assert received == [
{
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
]
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[1]["type"] == "message"
assert received[1]["text"] == "hello beaver"
async def test_beaver_llm_connect_can_send_warmup_message(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-warmup",
"text": "ready",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
try:
warmup_reply = await beaver_llm.connect(warmup_text="初始化连接")
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert warmup_reply == "ready"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["text"] == "初始化连接"
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:livekit-room",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:livekit-room",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "beaver reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
beaver_llm = BeaverLLM(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="livekit-room",
device_name="livekit-custom-agent",
)
ctx = ChatContext.empty()
ctx.add_message(role="system", content="ignored instructions")
ctx.add_message(role="user", content="hello beaver")
try:
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
finally:
await beaver_llm.aclose()
await runner.cleanup()
assert collected.text == "beaver reply"
assert received[0] == {
"type": "connect",
"peer_id": "livekit-room",
"device_name": "livekit-custom-agent",
"capabilities": ["text"],
}
assert received[1]["type"] == "message"
assert received[1]["message_id"].startswith("livekit-room-")
assert received[1]["message_id"].endswith("-000001")
assert received[1]["text"] == "hello beaver"

View File

@ -0,0 +1,426 @@
import asyncio
import json
import sys
from pathlib import Path
import aiohttp
import pytest
from aiohttp import web
if __name__ == "__main__":
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
raise SystemExit(pytest.main([__file__]))
try:
from custom.beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
except ModuleNotFoundError:
from beaver_terminal_client import (
BeaverTerminalClient,
BeaverTerminalError,
MessageIdGenerator,
build_connect_frame,
build_message_frame,
)
def test_build_connect_frame_uses_stable_peer_id() -> None:
frame = build_connect_frame(peer_id="device-001", device_name="desk-terminal")
assert frame == {
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
}
def test_build_message_frame_uses_message_id_and_text() -> None:
frame = build_message_frame(message_id="device-001-000001", text="hello")
assert frame == {
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
}
def test_message_id_generator_uses_monotonic_peer_counter() -> None:
generator = MessageIdGenerator(peer_id="device-001", initial_counter=7)
assert generator.next_id() == "device-001-000008"
assert generator.next_id() == "device-001-000009"
assert generator.counter == 9
def test_message_id_generator_can_include_instance_id() -> None:
generator = MessageIdGenerator(peer_id="device-001", instance_id="abc123ef")
assert generator.next_id() == "device-001-abc123ef-000001"
assert generator.next_id() == "device-001-abc123ef-000002"
async def test_client_connects_sends_text_and_returns_assistant_reply(
unused_tcp_port: int,
) -> None:
received: list[dict[str, object]] = []
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
assert message.type == aiohttp.WSMsgType.TEXT
frame = json.loads(message.data)
received.append(frame)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "assistant reply",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert client.session_id == "terminal-dev:local:device-001"
assert reply == "assistant reply"
assert received == [
{
"type": "connect",
"peer_id": "device-001",
"device_name": "desk-terminal",
"capabilities": ["text"],
},
{
"type": "message",
"message_id": "device-001-000001",
"text": "hello",
},
]
async def test_client_returns_cached_duplicate_reply(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": False,
"duplicate": True,
"pending": False,
"reply": "cached assistant reply",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "cached assistant reply"
async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json({"type": "error", "error": "text is required"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="text is required"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
unused_tcp_port: int,
) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-1",
"text": "failed turn",
"finish_reason": "error",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
with pytest.raises(BeaverTerminalError, match="failed turn"):
await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
async def test_client_ping_sends_ping_and_waits_for_pong(unused_tcp_port: int) -> None:
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "ping":
await ws.send_json({"type": "pong"})
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
assert await client.ping()
finally:
await client.close()
await runner.cleanup()
async def test_client_reconnects_with_same_peer_id_when_socket_closes_before_send(
unused_tcp_port: int,
) -> None:
connect_peer_ids: list[str] = []
message_ids: list[str] = []
connection_count = 0
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
nonlocal connection_count
connection_count += 1
current_connection = connection_count
ws = web.WebSocketResponse()
await ws.prepare(request)
async for message in ws:
frame = json.loads(message.data)
if frame["type"] == "connect":
connect_peer_ids.append(frame["peer_id"])
await ws.send_json(
{
"type": "connected",
"channel_id": "terminal-dev",
"session_id": "terminal-dev:local:device-001",
}
)
elif frame["type"] == "message":
message_ids.append(frame["message_id"])
if current_connection == 1:
await ws.close()
continue
await ws.send_json(
{
"type": "ack",
"message_id": frame["message_id"],
"session_id": "terminal-dev:local:device-001",
"accepted": True,
}
)
await ws.send_json(
{
"type": "message",
"role": "assistant",
"message_id": frame["message_id"],
"run_id": "run-2",
"text": "reply after reconnect",
"finish_reason": "stop",
}
)
return ws
app = web.Application()
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
await site.start()
client = BeaverTerminalClient(
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
peer_id="device-001",
device_name="desk-terminal",
message_ids=MessageIdGenerator(peer_id="device-001"),
)
try:
await client.connect()
await asyncio.sleep(0.01)
reply = await client.send_text("hello")
finally:
await client.close()
await runner.cleanup()
assert reply == "reply after reconnect"
assert connect_peer_ids == ["device-001", "device-001"]
assert message_ids == ["device-001-000001", "device-001-000002"]