254 lines
8.2 KiB
Python
254 lines
8.2 KiB
Python
import json
|
|
|
|
import aiohttp
|
|
from aiohttp import web
|
|
|
|
try:
|
|
from custom.beaver_llm import BeaverLLM, latest_user_text
|
|
except ModuleNotFoundError:
|
|
from beaver_llm import BeaverLLM, latest_user_text
|
|
from livekit.agents import ChatContext
|
|
|
|
|
|
def test_latest_user_text_uses_most_recent_user_message() -> None:
|
|
ctx = ChatContext.empty()
|
|
ctx.add_message(role="user", content="first")
|
|
ctx.add_message(role="assistant", content="ignored")
|
|
ctx.add_message(role="user", content=["second", "line"])
|
|
|
|
assert latest_user_text(ctx) == "second\nline"
|
|
|
|
|
|
async def test_beaver_llm_can_connect_before_first_message(
|
|
unused_tcp_port: int,
|
|
) -> None:
|
|
received: list[dict[str, object]] = []
|
|
|
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
async for message in ws:
|
|
assert message.type == aiohttp.WSMsgType.TEXT
|
|
frame = json.loads(message.data)
|
|
received.append(frame)
|
|
|
|
if frame["type"] == "connect":
|
|
await ws.send_json(
|
|
{
|
|
"type": "connected",
|
|
"channel_id": "terminal-dev",
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
}
|
|
)
|
|
elif frame["type"] == "message":
|
|
await ws.send_json(
|
|
{
|
|
"type": "ack",
|
|
"message_id": frame["message_id"],
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
"accepted": True,
|
|
}
|
|
)
|
|
await ws.send_json(
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"message_id": frame["message_id"],
|
|
"run_id": "run-1",
|
|
"text": "beaver reply",
|
|
"finish_reason": "stop",
|
|
}
|
|
)
|
|
|
|
return ws
|
|
|
|
app = web.Application()
|
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
|
await site.start()
|
|
|
|
beaver_llm = BeaverLLM(
|
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
|
peer_id="livekit-room",
|
|
device_name="livekit-custom-agent",
|
|
)
|
|
ctx = ChatContext.empty()
|
|
ctx.add_message(role="user", content="hello beaver")
|
|
|
|
try:
|
|
await beaver_llm.connect()
|
|
assert beaver_llm.session_id == "terminal-dev:local:livekit-room"
|
|
assert received == [
|
|
{
|
|
"type": "connect",
|
|
"peer_id": "livekit-room",
|
|
"device_name": "livekit-custom-agent",
|
|
"capabilities": ["text"],
|
|
}
|
|
]
|
|
|
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
|
finally:
|
|
await beaver_llm.aclose()
|
|
await runner.cleanup()
|
|
|
|
assert collected.text == "beaver reply"
|
|
assert received[1]["type"] == "message"
|
|
assert received[1]["text"] == "hello beaver"
|
|
|
|
|
|
async def test_beaver_llm_connect_can_send_warmup_message(
|
|
unused_tcp_port: int,
|
|
) -> None:
|
|
received: list[dict[str, object]] = []
|
|
|
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
async for message in ws:
|
|
assert message.type == aiohttp.WSMsgType.TEXT
|
|
frame = json.loads(message.data)
|
|
received.append(frame)
|
|
|
|
if frame["type"] == "connect":
|
|
await ws.send_json(
|
|
{
|
|
"type": "connected",
|
|
"channel_id": "terminal-dev",
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
}
|
|
)
|
|
elif frame["type"] == "message":
|
|
await ws.send_json(
|
|
{
|
|
"type": "ack",
|
|
"message_id": frame["message_id"],
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
"accepted": True,
|
|
}
|
|
)
|
|
await ws.send_json(
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"message_id": frame["message_id"],
|
|
"run_id": "run-warmup",
|
|
"text": "ready",
|
|
"finish_reason": "stop",
|
|
}
|
|
)
|
|
|
|
return ws
|
|
|
|
app = web.Application()
|
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
|
await site.start()
|
|
|
|
beaver_llm = BeaverLLM(
|
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
|
peer_id="livekit-room",
|
|
device_name="livekit-custom-agent",
|
|
)
|
|
|
|
try:
|
|
warmup_reply = await beaver_llm.connect(warmup_text="初始化连接")
|
|
finally:
|
|
await beaver_llm.aclose()
|
|
await runner.cleanup()
|
|
|
|
assert warmup_reply == "ready"
|
|
assert received[0] == {
|
|
"type": "connect",
|
|
"peer_id": "livekit-room",
|
|
"device_name": "livekit-custom-agent",
|
|
"capabilities": ["text"],
|
|
}
|
|
assert received[1]["type"] == "message"
|
|
assert received[1]["text"] == "初始化连接"
|
|
|
|
|
|
async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
|
unused_tcp_port: int,
|
|
) -> None:
|
|
received: list[dict[str, object]] = []
|
|
|
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
async for message in ws:
|
|
assert message.type == aiohttp.WSMsgType.TEXT
|
|
frame = json.loads(message.data)
|
|
received.append(frame)
|
|
|
|
if frame["type"] == "connect":
|
|
await ws.send_json(
|
|
{
|
|
"type": "connected",
|
|
"channel_id": "terminal-dev",
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
}
|
|
)
|
|
elif frame["type"] == "message":
|
|
await ws.send_json(
|
|
{
|
|
"type": "ack",
|
|
"message_id": frame["message_id"],
|
|
"session_id": "terminal-dev:local:livekit-room",
|
|
"accepted": True,
|
|
}
|
|
)
|
|
await ws.send_json(
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"message_id": frame["message_id"],
|
|
"run_id": "run-1",
|
|
"text": "beaver reply",
|
|
"finish_reason": "stop",
|
|
}
|
|
)
|
|
|
|
return ws
|
|
|
|
app = web.Application()
|
|
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
|
await site.start()
|
|
|
|
beaver_llm = BeaverLLM(
|
|
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
|
peer_id="livekit-room",
|
|
device_name="livekit-custom-agent",
|
|
)
|
|
ctx = ChatContext.empty()
|
|
ctx.add_message(role="system", content="ignored instructions")
|
|
ctx.add_message(role="user", content="hello beaver")
|
|
|
|
try:
|
|
collected = await beaver_llm.chat(chat_ctx=ctx).collect()
|
|
finally:
|
|
await beaver_llm.aclose()
|
|
await runner.cleanup()
|
|
|
|
assert collected.text == "beaver reply"
|
|
assert received[0] == {
|
|
"type": "connect",
|
|
"peer_id": "livekit-room",
|
|
"device_name": "livekit-custom-agent",
|
|
"capabilities": ["text"],
|
|
}
|
|
assert received[1]["type"] == "message"
|
|
assert received[1]["message_id"].startswith("livekit-room-")
|
|
assert received[1]["message_id"].endswith("-000001")
|
|
assert received[1]["text"] == "hello beaver"
|