347 lines
11 KiB
Python
347 lines
11 KiB
Python
import asyncio
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
|
from beaver.interfaces.channels import ChannelManager, MemoryChannelAdapter
|
|
from beaver.interfaces.gateway.main import run_gateway
|
|
from beaver.interfaces.channels.runtime import ChannelRuntime
|
|
from beaver.services.agent_service import AgentService
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class FakeResult:
|
|
session_id: str
|
|
run_id: str = "run-1"
|
|
output_text: str = ""
|
|
finish_reason: str = "stop"
|
|
provider_name: str | None = "fake"
|
|
model: str | None = "fake-model"
|
|
usage: dict[str, Any] = field(default_factory=dict)
|
|
task_id: str | None = "task-1"
|
|
task_status: str | None = "awaiting_acceptance"
|
|
validation_result: dict[str, Any] | None = None
|
|
|
|
|
|
class FakeService:
|
|
is_running = True
|
|
|
|
async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult:
|
|
return FakeResult(
|
|
session_id=kwargs.get("session_id") or "s1",
|
|
output_text=f"echo:{message}",
|
|
)
|
|
|
|
async def handle_inbound_message(self, inbound: InboundMessage):
|
|
result = await self.submit_direct(inbound.content, session_id=inbound.session_id)
|
|
return AgentService.build_outbound_message(inbound, result)
|
|
|
|
|
|
class SlowService:
|
|
is_running = True
|
|
|
|
async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult:
|
|
await asyncio.sleep(10)
|
|
return FakeResult(session_id=kwargs.get("session_id") or "s1")
|
|
|
|
async def handle_inbound_message(self, inbound: InboundMessage):
|
|
result = await self.submit_direct(inbound.content, session_id=inbound.session_id)
|
|
return AgentService.build_outbound_message(inbound, result)
|
|
|
|
|
|
class InvalidService:
|
|
is_running = True
|
|
|
|
|
|
def test_gateway_routes_memory_channel_roundtrip(tmp_path) -> None:
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
|
|
channel = MemoryChannelAdapter(runtime)
|
|
runtime.manager.register(channel)
|
|
await runtime.start()
|
|
|
|
await channel.publish_text("hello", peer_id="s1", message_id="m1")
|
|
for _ in range(40):
|
|
if channel.sent_messages:
|
|
break
|
|
await asyncio.sleep(0.05)
|
|
|
|
assert channel.sent_messages
|
|
message = channel.sent_messages[0]
|
|
assert message.content == "echo:hello"
|
|
assert message.session_id == "memory-dev:memory:s1"
|
|
assert message.finish_reason == "stop"
|
|
assert message.metadata["task_id"] == "task-1"
|
|
assert message.metadata["task_status"] == "awaiting_acceptance"
|
|
assert message.metadata["evidence_status"] == "recorded"
|
|
assert message.metadata["validation_result"] is None
|
|
|
|
await runtime.stop()
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_channel_manager_dispatches_by_channel_id() -> None:
|
|
class CaptureChannel:
|
|
channel_id = "webhook-dev"
|
|
kind = "webhook"
|
|
mode = "webhook"
|
|
|
|
def __init__(self) -> None:
|
|
self.sent = []
|
|
|
|
async def start(self) -> None:
|
|
pass
|
|
|
|
async def stop(self) -> None:
|
|
pass
|
|
|
|
async def send(self, message: Any) -> None:
|
|
self.sent.append(message)
|
|
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
channel = CaptureChannel()
|
|
manager = ChannelManager(bus)
|
|
manager.register(channel)
|
|
await bus.publish_outbound(
|
|
OutboundMessage(
|
|
channel="webhook-dev",
|
|
content="ok",
|
|
session_id="webhook-dev:local:demo",
|
|
finish_reason="stop",
|
|
)
|
|
)
|
|
stop_event = asyncio.Event()
|
|
stop_event.set()
|
|
|
|
await manager.dispatch_outbound(stop_event)
|
|
|
|
assert channel.sent[0].content == "ok"
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_gateway_delivers_cancelled_outbound_to_channel(tmp_path) -> None:
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
runtime = ChannelRuntime(service=SlowService(), bus=bus, channels={}, workspace=tmp_path)
|
|
channel = MemoryChannelAdapter(runtime)
|
|
runtime.manager.register(channel)
|
|
await runtime.start()
|
|
|
|
await channel.publish_text("slow", peer_id="s1", message_id="m1")
|
|
for _ in range(40):
|
|
if any(event["kind"] == "direct_run_started" for event in runtime.events.recent(limit=20)):
|
|
break
|
|
await asyncio.sleep(0.05)
|
|
await runtime.stop()
|
|
|
|
assert channel.sent_messages
|
|
assert channel.sent_messages[0].finish_reason == "cancelled"
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_gateway_rejects_channel_manager_and_channels_together() -> None:
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
class CaptureChannel:
|
|
channel_id = "memory-dev"
|
|
kind = "memory"
|
|
mode = "webhook"
|
|
|
|
async def start(self) -> None:
|
|
pass
|
|
|
|
async def stop(self) -> None:
|
|
pass
|
|
|
|
async def send(self, message: Any) -> None:
|
|
pass
|
|
|
|
try:
|
|
await run_gateway(
|
|
service=FakeService(),
|
|
manage_service_lifecycle=False,
|
|
bus=bus,
|
|
channel_manager=ChannelManager(bus),
|
|
channels=[CaptureChannel()],
|
|
stop_event=asyncio.Event(),
|
|
)
|
|
except ValueError as exc:
|
|
assert "either channel_manager or channels" in str(exc)
|
|
else:
|
|
raise AssertionError("expected ValueError")
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_gateway_fails_fast_for_service_without_handle_inbound_message() -> None:
|
|
async def run() -> None:
|
|
try:
|
|
await run_gateway(
|
|
service=InvalidService(),
|
|
manage_service_lifecycle=False,
|
|
bus=MessageBus(),
|
|
stop_event=asyncio.Event(),
|
|
)
|
|
except TypeError as exc:
|
|
assert "handle_inbound_message" in str(exc)
|
|
else:
|
|
raise AssertionError("expected TypeError")
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_agent_service_maps_inbound_error_to_structured_outbound() -> None:
|
|
async def run() -> None:
|
|
service = AgentService()
|
|
|
|
async def failing_submit_direct(message: str, **kwargs: Any) -> FakeResult:
|
|
raise RuntimeError("boom")
|
|
|
|
service.submit_direct = failing_submit_direct # type: ignore[method-assign]
|
|
outbound = await service.handle_inbound_message(
|
|
InboundMessage(channel="memory", content="hello", session_id="s1", metadata={"source": "test"})
|
|
)
|
|
|
|
assert outbound.finish_reason == "error"
|
|
assert outbound.session_id == "s1"
|
|
assert outbound.metadata["error"] == "boom"
|
|
assert outbound.metadata["inbound_metadata"] == {"source": "test"}
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_agent_service_maps_stopped_runtime_to_stopped_outbound() -> None:
|
|
async def run() -> None:
|
|
service = AgentService()
|
|
|
|
async def stopped_submit_direct(message: str, **kwargs: Any) -> FakeResult:
|
|
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
|
|
|
|
service.submit_direct = stopped_submit_direct # type: ignore[method-assign]
|
|
outbound = await service.handle_inbound_message(
|
|
InboundMessage(channel="memory", content="hello", session_id="s1")
|
|
)
|
|
|
|
assert outbound.finish_reason == "stopped"
|
|
assert "not accepting new tasks" in outbound.metadata["error"]
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_channel_manager_keeps_unknown_channel_outbound_undeliverable() -> None:
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
manager = ChannelManager(bus)
|
|
stop_event = asyncio.Event()
|
|
await bus.publish_outbound(
|
|
AgentService.build_outbound_message(
|
|
InboundMessage(channel="missing", content="hello", session_id="missing:1"),
|
|
FakeResult(session_id="missing:1", output_text="ok"),
|
|
)
|
|
)
|
|
stop_event.set()
|
|
|
|
await manager.dispatch_outbound(stop_event)
|
|
|
|
assert len(manager.undeliverable) == 1
|
|
assert manager.undeliverable[0].channel == "missing"
|
|
assert manager.undeliverable[0].session_id == "missing:1"
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_memory_channel_adapts_payload_to_channel_identity_session_id(tmp_path) -> None:
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
|
|
channel = MemoryChannelAdapter(
|
|
runtime,
|
|
channel_id="telegram-main",
|
|
kind="telegram",
|
|
account_id="bot-main",
|
|
)
|
|
inbound = await channel.publish_external_text(
|
|
"hello",
|
|
chat_id="chat-1",
|
|
message_id="message-1",
|
|
raw_payload={"platform": "telegram", "text": "hello"},
|
|
)
|
|
|
|
queued = await bus.consume_inbound()
|
|
assert queued is inbound
|
|
assert queued.channel == "telegram-main"
|
|
assert queued.session_id == "telegram-main:bot-main:chat-1"
|
|
assert queued.channel_identity is not None
|
|
assert queued.channel_identity.kind == "telegram"
|
|
assert queued.metadata["chat_id"] == "chat-1"
|
|
assert queued.metadata["message_id"] == "message-1"
|
|
assert queued.metadata["raw_channel_payload"] == {"platform": "telegram", "text": "hello"}
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_channel_manager_start_cancellation_rolls_back_started_channels() -> None:
|
|
class StartedChannel:
|
|
channel_id = "started"
|
|
kind = "memory"
|
|
mode = "webhook"
|
|
|
|
def __init__(self, bus: MessageBus) -> None:
|
|
self.bus = bus
|
|
self.stopped = False
|
|
|
|
async def start(self) -> None:
|
|
pass
|
|
|
|
async def stop(self) -> None:
|
|
self.stopped = True
|
|
|
|
async def send(self, message: Any) -> None:
|
|
pass
|
|
|
|
class BlockingChannel:
|
|
channel_id = "blocking"
|
|
kind = "memory"
|
|
mode = "webhook"
|
|
|
|
def __init__(self, bus: MessageBus) -> None:
|
|
self.bus = bus
|
|
self.entered = asyncio.Event()
|
|
|
|
async def start(self) -> None:
|
|
self.entered.set()
|
|
await asyncio.sleep(10)
|
|
|
|
async def stop(self) -> None:
|
|
pass
|
|
|
|
async def send(self, message: Any) -> None:
|
|
pass
|
|
|
|
async def run() -> None:
|
|
bus = MessageBus()
|
|
started = StartedChannel(bus)
|
|
blocking = BlockingChannel(bus)
|
|
manager = ChannelManager(bus)
|
|
manager.register(started)
|
|
manager.register(blocking)
|
|
|
|
task = asyncio.create_task(manager.start())
|
|
await blocking.entered.wait()
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
else:
|
|
raise AssertionError("expected cancellation")
|
|
|
|
assert started.stopped
|
|
|
|
asyncio.run(run())
|