"""Text-only terminal WebSocket channel adapter.""" from __future__ import annotations from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass, field from typing import Any from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage from beaver.interfaces.channels.base import ChannelInboundSink try: from fastapi import WebSocket from starlette.websockets import WebSocketDisconnect except ModuleNotFoundError: # pragma: no cover - import-only fallback class WebSocketDisconnect(Exception): """Fallback disconnect exception for skeleton import environments.""" class WebSocket: # type: ignore[override] """Fallback websocket annotation shim.""" def _clean(value: Any) -> str: return str(value or "").strip() @dataclass(slots=True) class TerminalConnection: websocket: WebSocket peer_id: str session_id: str thread_id: str | None = None user_id: str | None = None device_name: str = "" capabilities: list[str] = field(default_factory=list) class TerminalWebSocketAdapter: """Accept text terminal websocket frames and deliver final assistant replies.""" def __init__( self, *, channel_id: str, kind: str, mode: str, account_id: str, display_name: str = "", inbound_sink: ChannelInboundSink, event_recorder: Callable[..., None] | None = None, heartbeat_seconds: float = 30, max_message_chars: int = 20000, session_peer_from_device_name: bool = False, ) -> None: self.channel_id = channel_id self.kind = kind self.mode = mode self.account_id = account_id self.display_name = display_name or channel_id self.inbound_sink = inbound_sink self.event_recorder = event_recorder self.heartbeat_seconds = max(1.0, float(heartbeat_seconds)) self.max_message_chars = max(1, int(max_message_chars)) self.session_peer_from_device_name = bool(session_peer_from_device_name) self.started = False self._connections_by_session: dict[str, TerminalConnection] = {} self._session_by_peer: dict[str, str] = {} async def start(self) -> None: self.started = True async def stop(self) -> None: self.started = False for connection in list(self._connections_by_session.values()): with suppress(Exception): await connection.websocket.close(code=1001) self._connections_by_session.clear() self._session_by_peer.clear() def status_extra(self) -> dict[str, Any]: return {"connected_peers": len(self._connections_by_session)} async def handle_websocket(self, websocket: WebSocket) -> None: await websocket.accept() connection: TerminalConnection | None = None try: while True: try: payload = await websocket.receive_json() except WebSocketDisconnect: break except ValueError: await websocket.send_json({"type": "error", "error": "Invalid websocket JSON payload"}) continue if not isinstance(payload, dict): await websocket.send_json({"type": "error", "error": "Websocket payload must be a JSON object"}) continue frame_type = _clean(payload.get("type")).lower() if frame_type == "ping": await websocket.send_json({"type": "pong"}) continue if frame_type == "connect": connection = await self._handle_connect(websocket, payload, current=connection) continue if frame_type == "message": if connection is None: await websocket.send_json({"type": "error", "error": "connect is required before message"}) continue await self._handle_message(websocket, connection, payload) continue await websocket.send_json( { "type": "error", "error": f"Unsupported websocket frame type: {frame_type or ''}", } ) finally: if connection is not None: self._remove_connection(connection) self._record( kind="terminal_disconnected", session_id=connection.session_id, metadata={"peer_id": connection.peer_id, "device_name": connection.device_name}, ) async def _handle_connect( self, websocket: WebSocket, payload: dict[str, Any], *, current: TerminalConnection | None, ) -> TerminalConnection | None: raw_peer_id = _clean(payload.get("peer_id")) if not raw_peer_id: await websocket.send_json({"type": "error", "error": "peer_id is required"}) return current thread_id = _clean(payload.get("thread_id")) or None user_id = _clean(payload.get("user_id")) or None device_name = _clean(payload.get("device_name")) peer_id = self._session_peer_id(raw_peer_id, device_name) capabilities = [str(item) for item in payload.get("capabilities") or [] if item is not None] identity = ChannelIdentity( channel_id=self.channel_id, kind=self.kind, account_id=self.account_id, peer_id=peer_id, thread_id=thread_id, peer_type="terminal", user_id=user_id, ) session_id = identity.session_id() connection = TerminalConnection( websocket=websocket, peer_id=peer_id, session_id=session_id, thread_id=thread_id, user_id=user_id, device_name=device_name, capabilities=capabilities, ) if current is not None and current.session_id != session_id: self._remove_connection(current) old = self._connections_by_session.get(session_id) if old is not None and old.websocket is not websocket: with suppress(Exception): await old.websocket.close(code=1000) self._connections_by_session[session_id] = connection self._session_by_peer[peer_id] = session_id self._record( kind="terminal_connected", session_id=session_id, metadata={ "peer_id": peer_id, "raw_peer_id": raw_peer_id, "device_name": device_name, "capabilities": capabilities, }, ) await websocket.send_json( { "type": "connected", "channel_id": self.channel_id, "session_id": session_id, } ) return connection async def _handle_message( self, websocket: WebSocket, connection: TerminalConnection, payload: dict[str, Any], ) -> None: message_id = _clean(payload.get("message_id")) text = _clean(payload.get("text")) if not message_id: await websocket.send_json({"type": "error", "error": "message_id is required"}) return if not text: await websocket.send_json({"type": "error", "error": "text is required"}) return if len(text) > self.max_message_chars: await websocket.send_json( { "type": "error", "error": f"text exceeds max_message_chars ({self.max_message_chars})", } ) return thread_id = _clean(payload.get("thread_id")) or connection.thread_id user_id = _clean(payload.get("user_id")) or connection.user_id identity = ChannelIdentity( channel_id=self.channel_id, kind=self.kind, account_id=self.account_id, peer_id=connection.peer_id, thread_id=thread_id, peer_type="terminal", user_id=user_id, message_id=message_id, ) inbound = InboundMessage( channel=self.channel_id, content=text, content_type="text", user_id=user_id, channel_identity=identity, metadata={ "terminal": { "peer_id": connection.peer_id, "device_name": connection.device_name, "capabilities": connection.capabilities, } }, ) accept = await self.inbound_sink.accept_inbound(inbound) ack: dict[str, Any] = { "type": "ack", "message_id": message_id, "session_id": accept.session_id or identity.session_id(), "accepted": accept.accepted, } if accept.duplicate: ack["duplicate"] = True ack["pending"] = accept.pending record = accept.record or {} if record.get("reply"): ack["reply"] = record["reply"] if accept.error or record.get("error"): ack["error"] = accept.error or record.get("error") await websocket.send_json(ack) async def send(self, message: OutboundMessage) -> None: session_id = message.session_id if not session_id and message.channel_identity is not None: session_id = message.channel_identity.session_id() connection = self._connections_by_session.get(session_id or "") if connection is None: message.metadata["delivery_status"] = "unclaimed" return payload = { "type": "message", "role": "assistant", "message_id": message.channel_identity.message_id if message.channel_identity else message.message_id, "run_id": message.run_id, "text": message.content, "finish_reason": message.finish_reason, } try: await connection.websocket.send_json(payload) except Exception: message.metadata["delivery_status"] = "unclaimed" self._remove_connection(connection) def _remove_connection(self, connection: TerminalConnection) -> None: current = self._connections_by_session.get(connection.session_id) if current is connection: self._connections_by_session.pop(connection.session_id, None) if self._session_by_peer.get(connection.peer_id) == connection.session_id: self._session_by_peer.pop(connection.peer_id, None) def _record( self, *, kind: str, session_id: str | None = None, message_id: str | None = None, status: str = "ok", error: str | None = None, metadata: dict[str, Any] | None = None, ) -> None: if self.event_recorder is None: return self.event_recorder( channel_id=self.channel_id, kind=kind, session_id=session_id, message_id=message_id, status=status, error=error, metadata=metadata, ) def _session_peer_id(self, peer_id: str, device_name: str) -> str: if self.session_peer_from_device_name and device_name: return f"device-{_clean_session_part(device_name)}" return peer_id def _clean_session_part(value: str) -> str: cleaned = "-".join(str(value or "").strip().split()) return cleaned.replace(":", "_") or "unknown"