95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
"""Channel manager for routing gateway outbound messages."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Awaitable, Callable
|
|
from contextlib import suppress
|
|
|
|
from beaver.foundation.events import MessageBus, OutboundMessage
|
|
|
|
from .base import ChannelAdapter
|
|
|
|
|
|
class ChannelManager:
|
|
"""Start/stop channel adapters and dispatch outbound messages to them."""
|
|
|
|
def __init__(self, bus: MessageBus) -> None:
|
|
self.bus = bus
|
|
self.channels: dict[str, ChannelAdapter] = {}
|
|
self.undeliverable: list[OutboundMessage] = []
|
|
self.started = False
|
|
|
|
def register(self, channel: ChannelAdapter) -> None:
|
|
if channel.channel_id in self.channels:
|
|
raise ValueError(f"Channel already registered: {channel.channel_id}")
|
|
self.channels[channel.channel_id] = channel
|
|
|
|
def unregister(self, channel_id: str) -> ChannelAdapter | None:
|
|
return self.channels.pop(channel_id, None)
|
|
|
|
def replace_registered(self, channel: ChannelAdapter) -> ChannelAdapter | None:
|
|
old = self.channels.get(channel.channel_id)
|
|
self.channels[channel.channel_id] = channel
|
|
return old
|
|
|
|
async def start(self) -> None:
|
|
started: list[ChannelAdapter] = []
|
|
try:
|
|
for channel in self.channels.values():
|
|
await channel.start()
|
|
started.append(channel)
|
|
except BaseException:
|
|
for channel in reversed(started):
|
|
with suppress(BaseException):
|
|
await channel.stop()
|
|
raise
|
|
else:
|
|
self.started = True
|
|
|
|
async def stop(self) -> None:
|
|
errors: list[BaseException] = []
|
|
for channel in reversed(tuple(self.channels.values())):
|
|
try:
|
|
await channel.stop()
|
|
except Exception as exc: # pragma: no cover - defensive cleanup path
|
|
errors.append(exc)
|
|
self.started = False
|
|
if errors:
|
|
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
|
|
|
|
async def dispatch_outbound(
|
|
self,
|
|
stop_event: asyncio.Event,
|
|
*,
|
|
on_delivered: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
|
on_failed: Callable[[OutboundMessage, Exception | None], Awaitable[None]] | None = None,
|
|
) -> None:
|
|
"""Route bus outbound messages until stopped and the queue is drained."""
|
|
|
|
while True:
|
|
if stop_event.is_set() and self.bus.outbound_size == 0:
|
|
break
|
|
|
|
try:
|
|
message = await asyncio.wait_for(self.bus.consume_outbound(), timeout=0.25)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
channel = self.channels.get(message.channel)
|
|
if channel is None:
|
|
self.undeliverable.append(message)
|
|
if on_failed is not None:
|
|
await on_failed(message, None)
|
|
continue
|
|
|
|
try:
|
|
await channel.send(message)
|
|
except Exception as exc: # pragma: no cover - defensive channel isolation
|
|
self.undeliverable.append(message)
|
|
if on_failed is not None:
|
|
await on_failed(message, exc)
|
|
else:
|
|
if on_delivered is not None:
|
|
await on_delivered(message)
|