Files
beaver_project/app-instance/backend/tests/unit/test_gateway_channels.py
steven_li 7020f2d67f feat(agent-service): 添加直接模式下的消息处理支持
当代理服务处于非运行状态时,现在会使用process_direct方法来处理入站消息,
而不是依赖submit_direct方法。这使得服务能够在两种模式下都能正确处理消息。

添加了新的DirectModeInboundService和RunningInboundService测试类来验证
不同模式下的行为,并增加了相应的集成测试用例。
2026-06-16 11:05:08 +08:00

381 lines
12 KiB
Python

import asyncio
from dataclasses import dataclass, field
from typing import Any
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
from beaver.interfaces.channels import ChannelManager, MemoryChannelAdapter
from beaver.interfaces.gateway.main import run_gateway
from beaver.interfaces.channels.runtime import ChannelRuntime
from beaver.services.agent_service import AgentService
@dataclass(slots=True)
class FakeResult:
session_id: str
run_id: str = "run-1"
output_text: str = ""
finish_reason: str = "stop"
provider_name: str | None = "fake"
model: str | None = "fake-model"
usage: dict[str, Any] = field(default_factory=dict)
task_id: str | None = "task-1"
task_status: str | None = "awaiting_acceptance"
validation_result: dict[str, Any] | None = None
class FakeService:
is_running = True
async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult:
return FakeResult(
session_id=kwargs.get("session_id") or "s1",
output_text=f"echo:{message}",
)
async def handle_inbound_message(self, inbound: InboundMessage):
result = await self.submit_direct(inbound.content, session_id=inbound.session_id)
return AgentService.build_outbound_message(inbound, result)
class SlowService:
is_running = True
async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult:
await asyncio.sleep(10)
return FakeResult(session_id=kwargs.get("session_id") or "s1")
async def handle_inbound_message(self, inbound: InboundMessage):
result = await self.submit_direct(inbound.content, session_id=inbound.session_id)
return AgentService.build_outbound_message(inbound, result)
class InvalidService:
is_running = True
class DirectModeInboundService(AgentService):
@property
def is_running(self) -> bool:
return False
async def submit_direct(self, message: str, **kwargs: Any) -> FakeResult:
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
async def process_direct(self, message: str, **kwargs: Any) -> FakeResult:
return FakeResult(
session_id=kwargs.get("session_id") or "s1",
output_text=f"direct:{message}",
)
class RunningInboundService(AgentService):
@property
def is_running(self) -> bool:
return True
def test_gateway_routes_memory_channel_roundtrip(tmp_path) -> None:
async def run() -> None:
bus = MessageBus()
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
channel = MemoryChannelAdapter(runtime)
runtime.manager.register(channel)
await runtime.start()
await channel.publish_text("hello", peer_id="s1", message_id="m1")
for _ in range(40):
if channel.sent_messages:
break
await asyncio.sleep(0.05)
assert channel.sent_messages
message = channel.sent_messages[0]
assert message.content == "echo:hello"
assert message.session_id == "memory-dev:memory:s1"
assert message.finish_reason == "stop"
assert message.metadata["task_id"] == "task-1"
assert message.metadata["task_status"] == "awaiting_acceptance"
assert message.metadata["evidence_status"] == "recorded"
assert message.metadata["validation_result"] is None
await runtime.stop()
asyncio.run(run())
def test_channel_manager_dispatches_by_channel_id() -> None:
class CaptureChannel:
channel_id = "webhook-dev"
kind = "webhook"
mode = "webhook"
def __init__(self) -> None:
self.sent = []
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, message: Any) -> None:
self.sent.append(message)
async def run() -> None:
bus = MessageBus()
channel = CaptureChannel()
manager = ChannelManager(bus)
manager.register(channel)
await bus.publish_outbound(
OutboundMessage(
channel="webhook-dev",
content="ok",
session_id="webhook-dev:local:demo",
finish_reason="stop",
)
)
stop_event = asyncio.Event()
stop_event.set()
await manager.dispatch_outbound(stop_event)
assert channel.sent[0].content == "ok"
asyncio.run(run())
def test_gateway_delivers_cancelled_outbound_to_channel(tmp_path) -> None:
async def run() -> None:
bus = MessageBus()
runtime = ChannelRuntime(service=SlowService(), bus=bus, channels={}, workspace=tmp_path)
channel = MemoryChannelAdapter(runtime)
runtime.manager.register(channel)
await runtime.start()
await channel.publish_text("slow", peer_id="s1", message_id="m1")
for _ in range(40):
if any(event["kind"] == "direct_run_started" for event in runtime.events.recent(limit=20)):
break
await asyncio.sleep(0.05)
await runtime.stop()
assert channel.sent_messages
assert channel.sent_messages[0].finish_reason == "cancelled"
asyncio.run(run())
def test_gateway_rejects_channel_manager_and_channels_together() -> None:
async def run() -> None:
bus = MessageBus()
class CaptureChannel:
channel_id = "memory-dev"
kind = "memory"
mode = "webhook"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, message: Any) -> None:
pass
try:
await run_gateway(
service=FakeService(),
manage_service_lifecycle=False,
bus=bus,
channel_manager=ChannelManager(bus),
channels=[CaptureChannel()],
stop_event=asyncio.Event(),
)
except ValueError as exc:
assert "either channel_manager or channels" in str(exc)
else:
raise AssertionError("expected ValueError")
asyncio.run(run())
def test_gateway_fails_fast_for_service_without_handle_inbound_message() -> None:
async def run() -> None:
try:
await run_gateway(
service=InvalidService(),
manage_service_lifecycle=False,
bus=MessageBus(),
stop_event=asyncio.Event(),
)
except TypeError as exc:
assert "handle_inbound_message" in str(exc)
else:
raise AssertionError("expected TypeError")
asyncio.run(run())
def test_agent_service_maps_inbound_error_to_structured_outbound() -> None:
async def run() -> None:
service = RunningInboundService()
async def failing_submit_direct(message: str, **kwargs: Any) -> FakeResult:
raise RuntimeError("boom")
service.submit_direct = failing_submit_direct # type: ignore[method-assign]
outbound = await service.handle_inbound_message(
InboundMessage(channel="memory", content="hello", session_id="s1", metadata={"source": "test"})
)
assert outbound.finish_reason == "error"
assert outbound.session_id == "s1"
assert outbound.metadata["error"] == "boom"
assert outbound.metadata["inbound_metadata"] == {"source": "test"}
asyncio.run(run())
def test_agent_service_maps_stopped_runtime_to_stopped_outbound() -> None:
async def run() -> None:
service = RunningInboundService()
async def stopped_submit_direct(message: str, **kwargs: Any) -> FakeResult:
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
service.submit_direct = stopped_submit_direct # type: ignore[method-assign]
outbound = await service.handle_inbound_message(
InboundMessage(channel="memory", content="hello", session_id="s1")
)
assert outbound.finish_reason == "stopped"
assert "not accepting new tasks" in outbound.metadata["error"]
asyncio.run(run())
def test_agent_service_handles_inbound_in_direct_mode() -> None:
async def run() -> None:
service = DirectModeInboundService()
outbound = await service.handle_inbound_message(
InboundMessage(channel="memory", content="hello", session_id="s1")
)
assert outbound.finish_reason == "stop"
assert outbound.content == "direct:hello"
asyncio.run(run())
def test_channel_manager_keeps_unknown_channel_outbound_undeliverable() -> None:
async def run() -> None:
bus = MessageBus()
manager = ChannelManager(bus)
stop_event = asyncio.Event()
await bus.publish_outbound(
AgentService.build_outbound_message(
InboundMessage(channel="missing", content="hello", session_id="missing:1"),
FakeResult(session_id="missing:1", output_text="ok"),
)
)
stop_event.set()
await manager.dispatch_outbound(stop_event)
assert len(manager.undeliverable) == 1
assert manager.undeliverable[0].channel == "missing"
assert manager.undeliverable[0].session_id == "missing:1"
asyncio.run(run())
def test_memory_channel_adapts_payload_to_channel_identity_session_id(tmp_path) -> None:
async def run() -> None:
bus = MessageBus()
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
channel = MemoryChannelAdapter(
runtime,
channel_id="telegram-main",
kind="telegram",
account_id="bot-main",
)
inbound = await channel.publish_external_text(
"hello",
chat_id="chat-1",
message_id="message-1",
raw_payload={"platform": "telegram", "text": "hello"},
)
queued = await bus.consume_inbound()
assert queued is inbound
assert queued.channel == "telegram-main"
assert queued.session_id == "telegram-main:bot-main:chat-1"
assert queued.channel_identity is not None
assert queued.channel_identity.kind == "telegram"
assert queued.metadata["chat_id"] == "chat-1"
assert queued.metadata["message_id"] == "message-1"
assert queued.metadata["raw_channel_payload"] == {"platform": "telegram", "text": "hello"}
asyncio.run(run())
def test_channel_manager_start_cancellation_rolls_back_started_channels() -> None:
class StartedChannel:
channel_id = "started"
kind = "memory"
mode = "webhook"
def __init__(self, bus: MessageBus) -> None:
self.bus = bus
self.stopped = False
async def start(self) -> None:
pass
async def stop(self) -> None:
self.stopped = True
async def send(self, message: Any) -> None:
pass
class BlockingChannel:
channel_id = "blocking"
kind = "memory"
mode = "webhook"
def __init__(self, bus: MessageBus) -> None:
self.bus = bus
self.entered = asyncio.Event()
async def start(self) -> None:
self.entered.set()
await asyncio.sleep(10)
async def stop(self) -> None:
pass
async def send(self, message: Any) -> None:
pass
async def run() -> None:
bus = MessageBus()
started = StartedChannel(bus)
blocking = BlockingChannel(bus)
manager = ChannelManager(bus)
manager.register(started)
manager.register(blocking)
task = asyncio.create_task(manager.start())
await blocking.entered.wait()
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
else:
raise AssertionError("expected cancellation")
assert started.stopped
asyncio.run(run())