Files
beaver_project/app-instance/backend/beaver/interfaces/channels/runtime.py

527 lines
21 KiB
Python

"""Channel runtime host for adapter lifecycle and bus-first routing."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from beaver.foundation.config.schema import ChannelConfig
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
from beaver.interfaces.channels.base import ChannelAdapter
from beaver.interfaces.channels.manager import ChannelManager
from beaver.interfaces.channels.state import ChannelDedupeStore, ChannelEventLog
from beaver.services.agent_service import AgentService
def _iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _channel_capabilities(kind: str, mode: str) -> list[str]:
if kind == "webhook":
return ["receive_text", "send_text", "sync_webhook_response"]
if kind == "terminal" and mode == "websocket":
return ["receive_text", "send_text", "persistent_connection"]
if kind in {"feishu", "qqbot", "telegram"}:
return ["receive_text", "send_text", "receive_media", "groups"]
if kind == "weixin":
return ["receive_text", "send_text", "receive_media", "direct_messages"]
return []
@dataclass(slots=True)
class ChannelAcceptResult:
accepted: bool
duplicate: bool = False
pending: bool = False
rejected: bool = False
session_id: str | None = None
dedupe_key: str | None = None
record: dict[str, Any] | None = None
error: str | None = None
class ChannelRuntime:
"""Own channel adapters, state, and the inbound/outbound bus bridge."""
def __init__(
self,
*,
service: AgentService,
workspace: Path,
channels: dict[str, ChannelConfig],
bus: MessageBus | None = None,
) -> None:
self.service = service
self.workspace = Path(workspace)
self.bus = bus or MessageBus()
self.manager = ChannelManager(self.bus)
self.channel_configs = dict(channels)
self.adapters: dict[str, ChannelAdapter] = {}
self.states: dict[str, dict[str, Any]] = {}
state_dir = self.workspace / "state" / "channels"
retention = self._default_dedupe_retention_hours()
self.dedupe = ChannelDedupeStore(state_dir / "dedupe.json", retention_hours=retention)
self.events = ChannelEventLog(state_dir / "events.jsonl")
self._bridge_task: asyncio.Task[None] | None = None
self._dispatch_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
self._dispatch_stop_event = asyncio.Event()
self._lifecycle_lock = asyncio.Lock()
async def start(self) -> None:
self._stop_event.clear()
self._dispatch_stop_event.clear()
for channel_id, cfg in self.channel_configs.items():
if not cfg.enabled:
self.states[channel_id] = {"state": "disabled", "last_error": None}
continue
try:
adapter = self._build_adapter(channel_id, cfg)
self.adapters[channel_id] = adapter
self.manager.register(adapter)
await adapter.start()
self.states[channel_id] = {
"state": "running",
"last_error": None,
"started_at": _iso_now(),
}
self.events.record(channel_id=channel_id, kind="adapter_started")
except Exception as exc: # pragma: no cover - defensive startup isolation
self.states[channel_id] = {"state": "error", "last_error": str(exc)}
self.events.record(
channel_id=channel_id,
kind="adapter_error",
status="error",
error=str(exc),
)
self._bridge_task = asyncio.create_task(self._bridge_inbound_to_agent())
self._dispatch_task = asyncio.create_task(
self.manager.dispatch_outbound(
self._dispatch_stop_event,
on_delivered=self._record_outbound_delivered,
on_failed=self._record_outbound_failed,
)
)
async def stop(self) -> None:
self._stop_event.set()
if self._bridge_task is not None:
self._bridge_task.cancel()
try:
await self._bridge_task
except asyncio.CancelledError:
pass
self._dispatch_stop_event.set()
if self._dispatch_task is not None:
try:
await asyncio.wait_for(self._dispatch_task, timeout=1.0)
except asyncio.TimeoutError:
self._dispatch_task.cancel()
try:
await self._dispatch_task
except asyncio.CancelledError:
pass
await self.manager.stop()
for channel_id in self.adapters:
self.events.record(channel_id=channel_id, kind="adapter_stopped")
async def add_channel(self, channel_id: str, config: ChannelConfig) -> None:
async with self._lifecycle_lock:
current = self.channel_configs.get(channel_id)
if current == config and channel_id in self.adapters:
return
if not config.enabled:
await self._remove_channel_locked(channel_id)
self.channel_configs[channel_id] = config
self.states[channel_id] = {"state": "disabled", "last_error": None}
return
adapter = self._build_adapter(channel_id, config)
await adapter.start()
old_adapter = self.adapters.get(channel_id)
self.manager.replace_registered(adapter)
self.adapters[channel_id] = adapter
self.channel_configs[channel_id] = config
self.states[channel_id] = {"state": "running", "last_error": None, "started_at": _iso_now()}
self.events.record(channel_id=channel_id, kind="adapter_started")
if old_adapter is not None and old_adapter is not adapter:
await old_adapter.stop()
async def remove_channel(self, channel_id: str) -> None:
async with self._lifecycle_lock:
await self._remove_channel_locked(channel_id)
async def _remove_channel_locked(self, channel_id: str) -> None:
adapter = self.adapters.pop(channel_id, None)
self.manager.unregister(channel_id)
self.channel_configs.pop(channel_id, None)
if adapter is not None:
await adapter.stop()
self.events.record(channel_id=channel_id, kind="adapter_stopped")
self.states[channel_id] = {"state": "removed", "last_error": None}
async def accept_inbound(self, message: InboundMessage) -> ChannelAcceptResult:
identity = message.channel_identity
if identity is None:
self.events.record(
channel_id=message.channel,
kind="inbound_rejected",
status="error",
error="channel_identity is required",
)
return ChannelAcceptResult(
accepted=False,
rejected=True,
error="channel_identity is required",
)
validation_error = identity.validation_error()
if validation_error:
self.events.record(
channel_id=identity.channel_id,
kind="inbound_rejected",
status="error",
error=validation_error,
)
return ChannelAcceptResult(accepted=False, rejected=True, error=validation_error)
expected_session_id = identity.session_id()
if message.session_id != expected_session_id:
self.events.record(
channel_id=identity.channel_id,
kind="session_id_normalized",
session_id=expected_session_id,
message_id=identity.message_id,
)
message.session_id = expected_session_id
message.channel = identity.channel_id
dedupe_key = identity.dedupe_key()
if dedupe_key:
write = self.dedupe.mark_processing(
dedupe_key=dedupe_key,
session_id=expected_session_id,
message_id=identity.message_id or "",
)
if not write.created:
record = write.record or {}
self.events.record(
channel_id=identity.channel_id,
kind="inbound_duplicate",
session_id=expected_session_id,
message_id=identity.message_id,
status=str(record.get("status") or "processing"),
)
return ChannelAcceptResult(
accepted=False,
duplicate=True,
pending=record.get("status") == "processing",
session_id=expected_session_id,
dedupe_key=dedupe_key,
record=record,
)
self.events.record(
channel_id=identity.channel_id,
kind="inbound_accepted",
session_id=expected_session_id,
message_id=identity.message_id,
text=message.content,
)
await self.bus.publish_inbound(message)
return ChannelAcceptResult(
accepted=True,
session_id=expected_session_id,
dedupe_key=dedupe_key,
)
def statuses(self) -> list[dict[str, Any]]:
items: list[dict[str, Any]] = []
recent = self.events.recent(limit=500)
last_by_channel = {event["channel_id"]: event for event in recent if event.get("channel_id")}
for channel_id, cfg in self.channel_configs.items():
state = self.states.get(channel_id, {"state": "configured", "last_error": None})
capabilities = _channel_capabilities(cfg.kind, cfg.mode)
webhook_url = None
websocket_url = None
connected_peers = 0
if cfg.kind == "webhook":
webhook_url = f"/api/channels/{channel_id}/webhook"
elif cfg.kind == "terminal" and cfg.mode == "websocket":
websocket_url = f"/api/channels/{channel_id}/ws"
adapter = self.adapters.get(channel_id)
if adapter is not None and hasattr(adapter, "status_extra"):
extra = adapter.status_extra() # type: ignore[attr-defined]
connected_peers = int(extra.get("connected_peers") or 0)
items.append(
{
"channel_id": channel_id,
"name": channel_id,
"kind": cfg.kind,
"mode": cfg.mode,
"display_name": cfg.display_name or channel_id,
"enabled": cfg.enabled,
"state": state.get("state", "configured"),
"account_id": cfg.account_id,
"last_error": state.get("last_error"),
"started_at": state.get("started_at"),
"last_event_at": last_by_channel.get(channel_id, {}).get("created_at"),
"capabilities": capabilities,
"webhook_url": webhook_url,
"websocket_url": websocket_url,
"connected_peers": connected_peers,
}
)
return items
def recent_events(self, channel_id: str, *, limit: int = 100) -> list[dict[str, Any]]:
return self.events.recent(channel_id=channel_id, limit=limit)
def record_event(
self,
*,
channel_id: str,
kind: str,
session_id: str | None = None,
message_id: str | None = None,
run_id: str | None = None,
status: str = "ok",
error: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
self.events.record(
channel_id=channel_id,
kind=kind,
session_id=session_id,
message_id=message_id,
run_id=run_id,
status=status,
error=error,
metadata=metadata,
)
def _build_adapter(self, channel_id: str, cfg: ChannelConfig) -> ChannelAdapter:
if cfg.kind == "webhook" and cfg.mode == "webhook":
from beaver.interfaces.channels.generic_webhook import GenericWebhookAdapter
return GenericWebhookAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
response_timeout_seconds=float(cfg.config.get("response_timeout_seconds") or 1800),
)
if cfg.kind == "terminal" and cfg.mode == "websocket":
from beaver.interfaces.channels.terminal_websocket import TerminalWebSocketAdapter
return TerminalWebSocketAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
event_recorder=self.record_event,
heartbeat_seconds=float(cfg.config.get("heartbeat_seconds") or 30),
max_message_chars=int(cfg.config.get("max_message_chars") or 20000),
)
if cfg.kind == "telegram" and cfg.mode in {"polling", "webhook"}:
from beaver.interfaces.channels.platforms.telegram import TelegramAdapter
return TelegramAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
secrets=cfg.secrets,
config=cfg.config,
event_recorder=self.record_event,
)
if cfg.kind == "feishu" and cfg.mode in {"websocket", "webhook"}:
from beaver.interfaces.channels.platforms.feishu import FeishuAdapter
return FeishuAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
secrets=cfg.secrets,
config=cfg.config,
event_recorder=self.record_event,
)
if cfg.kind == "qqbot" and cfg.mode == "websocket":
from beaver.interfaces.channels.platforms.qqbot import QQBotAdapter
return QQBotAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
secrets=cfg.secrets,
config=cfg.config,
event_recorder=self.record_event,
)
if cfg.kind == "weixin" and cfg.mode == "polling":
from beaver.interfaces.channels.platforms.weixin import WeixinAdapter
return WeixinAdapter(
channel_id=channel_id,
kind=cfg.kind,
mode=cfg.mode,
account_id=cfg.account_id,
display_name=cfg.display_name,
inbound_sink=self,
secrets=cfg.secrets,
config=cfg.config,
event_recorder=self.record_event,
)
if cfg.kind == "external_connector" and cfg.mode == "http":
import os
from beaver.interfaces.channels.connections.sidecar_client import ConnectorSidecarClient
from beaver.interfaces.channels.external_connector import ExternalConnectorChannel
base_url = str(cfg.config.get("sidecarBaseUrl") or os.getenv("EXTERNAL_CONNECTOR_BASE_URL") or "").strip()
token = os.getenv("EXTERNAL_CONNECTOR_TOKEN", "")
platform_kind = str(cfg.config.get("platformKind") or "").strip()
connection_id = str(cfg.config.get("connectionId") or "").strip()
if not base_url:
raise ValueError("external connector sidecarBaseUrl is required")
if not platform_kind:
raise ValueError("external connector platformKind is required")
if not connection_id:
raise ValueError("external connector connectionId is required")
return ExternalConnectorChannel(
channel_id=channel_id,
platform_kind=platform_kind,
connection_id=connection_id,
account_id=cfg.account_id,
display_name=cfg.display_name,
sidecar_client=ConnectorSidecarClient(base_url=base_url, token=token),
)
raise ValueError(f"Unsupported channel kind/mode: {cfg.kind}/{cfg.mode}")
async def _bridge_inbound_to_agent(self) -> None:
current_inbound: InboundMessage | None = None
while not self._stop_event.is_set():
try:
current_inbound = await asyncio.wait_for(self.bus.consume_inbound(), timeout=0.25)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
raise
inbound = current_inbound
identity = inbound.channel_identity
try:
self.events.record(
channel_id=inbound.channel,
kind="direct_run_started",
session_id=inbound.session_id,
message_id=identity.message_id if identity else inbound.message_id,
)
outbound = await self.service.handle_inbound_message(inbound)
except asyncio.CancelledError:
outbound = AgentService.build_outbound_error(
inbound,
detail="Channel runtime stopped before completing the inbound message",
finish_reason="cancelled",
)
self._mark_dedupe_result(inbound, outbound)
await self.bus.publish_outbound(outbound)
current_inbound = None
raise
except Exception as exc:
self.events.record(
channel_id=inbound.channel,
kind="direct_run_failed",
session_id=inbound.session_id,
message_id=identity.message_id if identity else inbound.message_id,
status="error",
error=str(exc),
)
outbound = AgentService.build_outbound_error(
inbound,
detail=str(exc),
finish_reason="error",
)
else:
self.events.record(
channel_id=outbound.channel,
kind="direct_run_finished",
session_id=outbound.session_id,
message_id=identity.message_id if identity else inbound.message_id,
run_id=outbound.run_id,
)
self._mark_dedupe_result(inbound, outbound)
await self.bus.publish_outbound(outbound)
current_inbound = None
def _mark_dedupe_result(self, inbound: InboundMessage, outbound: OutboundMessage) -> None:
identity = inbound.channel_identity
dedupe_key = identity.dedupe_key() if identity else None
if not dedupe_key:
return
cfg = self.channel_configs.get(identity.channel_id)
max_reply_chars = int((cfg.config if cfg else {}).get("max_cached_reply_chars") or 20000)
max_error_chars = int((cfg.config if cfg else {}).get("max_cached_error_chars") or 4000)
if outbound.finish_reason == "error":
self.dedupe.mark_error(
dedupe_key=dedupe_key,
error=outbound.content,
max_error_chars=max_error_chars,
)
else:
self.dedupe.mark_done(
dedupe_key=dedupe_key,
run_id=outbound.run_id,
reply=outbound.content,
max_reply_chars=max_reply_chars,
)
async def _record_outbound_delivered(self, message: OutboundMessage) -> None:
kind = "outbound_unclaimed" if message.metadata.get("delivery_status") == "unclaimed" else "outbound_delivered"
self.events.record(
channel_id=message.channel,
kind=kind,
session_id=message.session_id,
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
run_id=message.run_id,
)
async def _record_outbound_failed(self, message: OutboundMessage, exc: Exception | None) -> None:
self.events.record(
channel_id=message.channel,
kind="outbound_delivery_failed",
session_id=message.session_id,
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
run_id=message.run_id,
status="error",
error=str(exc) if exc else "channel not registered",
)
def _default_dedupe_retention_hours(self) -> int:
for cfg in self.channel_configs.values():
value = cfg.config.get("dedupe_retention_hours")
if value is not None:
return int(value)
return 48