import asyncio from dataclasses import dataclass, field from typing import Any from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage from beaver.interfaces.channels import ChannelManager, MemoryChannelAdapter from beaver.interfaces.gateway.main import run_gateway from beaver.interfaces.channels.runtime import ChannelRuntime from beaver.services.agent_service import AgentService @dataclass(slots=True) class FakeResult: session_id: str run_id: str = "run-1" output_text: str = "" finish_reason: str = "stop" provider_name: str | None = "fake" model: str | None = "fake-model" usage: dict[str, Any] = field(default_factory=dict) task_id: str | None = "task-1" task_status: str | None = "awaiting_acceptance" validation_result: dict[str, Any] | None = None class FakeService: is_running = True async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult: return FakeResult( session_id=kwargs.get("session_id") or "s1", output_text=f"echo:{message}", ) async def handle_inbound_message(self, inbound: InboundMessage): result = await self.submit_direct(inbound.content, session_id=inbound.session_id) return AgentService.build_outbound_message(inbound, result) class SlowService: is_running = True async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult: await asyncio.sleep(10) return FakeResult(session_id=kwargs.get("session_id") or "s1") async def handle_inbound_message(self, inbound: InboundMessage): result = await self.submit_direct(inbound.content, session_id=inbound.session_id) return AgentService.build_outbound_message(inbound, result) class InvalidService: is_running = True def test_gateway_routes_memory_channel_roundtrip(tmp_path) -> None: async def run() -> None: bus = MessageBus() runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path) channel = MemoryChannelAdapter(runtime) runtime.manager.register(channel) await runtime.start() await channel.publish_text("hello", peer_id="s1", message_id="m1") for _ in range(40): if channel.sent_messages: break await asyncio.sleep(0.05) assert channel.sent_messages message = channel.sent_messages[0] assert message.content == "echo:hello" assert message.session_id == "memory-dev:memory:s1" assert message.finish_reason == "stop" assert message.metadata["task_id"] == "task-1" assert message.metadata["task_status"] == "awaiting_acceptance" assert message.metadata["evidence_status"] == "recorded" assert message.metadata["validation_result"] is None await runtime.stop() asyncio.run(run()) def test_channel_manager_dispatches_by_channel_id() -> None: class CaptureChannel: channel_id = "webhook-dev" kind = "webhook" mode = "webhook" def __init__(self) -> None: self.sent = [] async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, message: Any) -> None: self.sent.append(message) async def run() -> None: bus = MessageBus() channel = CaptureChannel() manager = ChannelManager(bus) manager.register(channel) await bus.publish_outbound( OutboundMessage( channel="webhook-dev", content="ok", session_id="webhook-dev:local:demo", finish_reason="stop", ) ) stop_event = asyncio.Event() stop_event.set() await manager.dispatch_outbound(stop_event) assert channel.sent[0].content == "ok" asyncio.run(run()) def test_gateway_delivers_cancelled_outbound_to_channel(tmp_path) -> None: async def run() -> None: bus = MessageBus() runtime = ChannelRuntime(service=SlowService(), bus=bus, channels={}, workspace=tmp_path) channel = MemoryChannelAdapter(runtime) runtime.manager.register(channel) await runtime.start() await channel.publish_text("slow", peer_id="s1", message_id="m1") for _ in range(40): if any(event["kind"] == "direct_run_started" for event in runtime.events.recent(limit=20)): break await asyncio.sleep(0.05) await runtime.stop() assert channel.sent_messages assert channel.sent_messages[0].finish_reason == "cancelled" asyncio.run(run()) def test_gateway_rejects_channel_manager_and_channels_together() -> None: async def run() -> None: bus = MessageBus() class CaptureChannel: channel_id = "memory-dev" kind = "memory" mode = "webhook" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, message: Any) -> None: pass try: await run_gateway( service=FakeService(), manage_service_lifecycle=False, bus=bus, channel_manager=ChannelManager(bus), channels=[CaptureChannel()], stop_event=asyncio.Event(), ) except ValueError as exc: assert "either channel_manager or channels" in str(exc) else: raise AssertionError("expected ValueError") asyncio.run(run()) def test_gateway_fails_fast_for_service_without_handle_inbound_message() -> None: async def run() -> None: try: await run_gateway( service=InvalidService(), manage_service_lifecycle=False, bus=MessageBus(), stop_event=asyncio.Event(), ) except TypeError as exc: assert "handle_inbound_message" in str(exc) else: raise AssertionError("expected TypeError") asyncio.run(run()) def test_agent_service_maps_inbound_error_to_structured_outbound() -> None: async def run() -> None: service = AgentService() async def failing_submit_direct(message: str, **kwargs: Any) -> FakeResult: raise RuntimeError("boom") service.submit_direct = failing_submit_direct # type: ignore[method-assign] outbound = await service.handle_inbound_message( InboundMessage(channel="memory", content="hello", session_id="s1", metadata={"source": "test"}) ) assert outbound.finish_reason == "error" assert outbound.session_id == "s1" assert outbound.metadata["error"] == "boom" assert outbound.metadata["inbound_metadata"] == {"source": "test"} asyncio.run(run()) def test_agent_service_maps_stopped_runtime_to_stopped_outbound() -> None: async def run() -> None: service = AgentService() async def stopped_submit_direct(message: str, **kwargs: Any) -> FakeResult: raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()") service.submit_direct = stopped_submit_direct # type: ignore[method-assign] outbound = await service.handle_inbound_message( InboundMessage(channel="memory", content="hello", session_id="s1") ) assert outbound.finish_reason == "stopped" assert "not accepting new tasks" in outbound.metadata["error"] asyncio.run(run()) def test_channel_manager_keeps_unknown_channel_outbound_undeliverable() -> None: async def run() -> None: bus = MessageBus() manager = ChannelManager(bus) stop_event = asyncio.Event() await bus.publish_outbound( AgentService.build_outbound_message( InboundMessage(channel="missing", content="hello", session_id="missing:1"), FakeResult(session_id="missing:1", output_text="ok"), ) ) stop_event.set() await manager.dispatch_outbound(stop_event) assert len(manager.undeliverable) == 1 assert manager.undeliverable[0].channel == "missing" assert manager.undeliverable[0].session_id == "missing:1" asyncio.run(run()) def test_memory_channel_adapts_payload_to_channel_identity_session_id(tmp_path) -> None: async def run() -> None: bus = MessageBus() runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path) channel = MemoryChannelAdapter( runtime, channel_id="telegram-main", kind="telegram", account_id="bot-main", ) inbound = await channel.publish_external_text( "hello", chat_id="chat-1", message_id="message-1", raw_payload={"platform": "telegram", "text": "hello"}, ) queued = await bus.consume_inbound() assert queued is inbound assert queued.channel == "telegram-main" assert queued.session_id == "telegram-main:bot-main:chat-1" assert queued.channel_identity is not None assert queued.channel_identity.kind == "telegram" assert queued.metadata["chat_id"] == "chat-1" assert queued.metadata["message_id"] == "message-1" assert queued.metadata["raw_channel_payload"] == {"platform": "telegram", "text": "hello"} asyncio.run(run()) def test_channel_manager_start_cancellation_rolls_back_started_channels() -> None: class StartedChannel: channel_id = "started" kind = "memory" mode = "webhook" def __init__(self, bus: MessageBus) -> None: self.bus = bus self.stopped = False async def start(self) -> None: pass async def stop(self) -> None: self.stopped = True async def send(self, message: Any) -> None: pass class BlockingChannel: channel_id = "blocking" kind = "memory" mode = "webhook" def __init__(self, bus: MessageBus) -> None: self.bus = bus self.entered = asyncio.Event() async def start(self) -> None: self.entered.set() await asyncio.sleep(10) async def stop(self) -> None: pass async def send(self, message: Any) -> None: pass async def run() -> None: bus = MessageBus() started = StartedChannel(bus) blocking = BlockingChannel(bus) manager = ChannelManager(bus) manager.register(started) manager.register(blocking) task = asyncio.create_task(manager.start()) await blocking.entered.wait() task.cancel() try: await task except asyncio.CancelledError: pass else: raise AssertionError("expected cancellation") assert started.stopped asyncio.run(run())