feat: implement channel runtime connectors
This commit is contained in:
@ -0,0 +1,301 @@
|
||||
"""Text-only terminal WebSocket channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
except ModuleNotFoundError: # pragma: no cover - import-only fallback
|
||||
class WebSocketDisconnect(Exception):
|
||||
"""Fallback disconnect exception for skeleton import environments."""
|
||||
|
||||
class WebSocket: # type: ignore[override]
|
||||
"""Fallback websocket annotation shim."""
|
||||
|
||||
|
||||
def _clean(value: Any) -> str:
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TerminalConnection:
|
||||
websocket: WebSocket
|
||||
peer_id: str
|
||||
session_id: str
|
||||
thread_id: str | None = None
|
||||
user_id: str | None = None
|
||||
device_name: str = ""
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class TerminalWebSocketAdapter:
|
||||
"""Accept text terminal websocket frames and deliver final assistant replies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str = "",
|
||||
inbound_sink: ChannelInboundSink,
|
||||
event_recorder: Callable[..., None] | None = None,
|
||||
heartbeat_seconds: float = 30,
|
||||
max_message_chars: int = 20000,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name or channel_id
|
||||
self.inbound_sink = inbound_sink
|
||||
self.event_recorder = event_recorder
|
||||
self.heartbeat_seconds = max(1.0, float(heartbeat_seconds))
|
||||
self.max_message_chars = max(1, int(max_message_chars))
|
||||
self.started = False
|
||||
self._connections_by_session: dict[str, TerminalConnection] = {}
|
||||
self._session_by_peer: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.started = False
|
||||
for connection in list(self._connections_by_session.values()):
|
||||
with suppress(Exception):
|
||||
await connection.websocket.close(code=1001)
|
||||
self._connections_by_session.clear()
|
||||
self._session_by_peer.clear()
|
||||
|
||||
def status_extra(self) -> dict[str, Any]:
|
||||
return {"connected_peers": len(self._connections_by_session)}
|
||||
|
||||
async def handle_websocket(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
connection: TerminalConnection | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
payload = await websocket.receive_json()
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except ValueError:
|
||||
await websocket.send_json({"type": "error", "error": "Invalid websocket JSON payload"})
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
await websocket.send_json({"type": "error", "error": "Websocket payload must be a JSON object"})
|
||||
continue
|
||||
|
||||
frame_type = _clean(payload.get("type")).lower()
|
||||
if frame_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
continue
|
||||
if frame_type == "connect":
|
||||
connection = await self._handle_connect(websocket, payload, current=connection)
|
||||
continue
|
||||
if frame_type == "message":
|
||||
if connection is None:
|
||||
await websocket.send_json({"type": "error", "error": "connect is required before message"})
|
||||
continue
|
||||
await self._handle_message(websocket, connection, payload)
|
||||
continue
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"error": f"Unsupported websocket frame type: {frame_type or '<empty>'}",
|
||||
}
|
||||
)
|
||||
finally:
|
||||
if connection is not None:
|
||||
self._remove_connection(connection)
|
||||
self._record(
|
||||
kind="terminal_disconnected",
|
||||
session_id=connection.session_id,
|
||||
metadata={"peer_id": connection.peer_id, "device_name": connection.device_name},
|
||||
)
|
||||
|
||||
async def _handle_connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
current: TerminalConnection | None,
|
||||
) -> TerminalConnection | None:
|
||||
peer_id = _clean(payload.get("peer_id"))
|
||||
if not peer_id:
|
||||
await websocket.send_json({"type": "error", "error": "peer_id is required"})
|
||||
return current
|
||||
|
||||
thread_id = _clean(payload.get("thread_id")) or None
|
||||
user_id = _clean(payload.get("user_id")) or None
|
||||
device_name = _clean(payload.get("device_name"))
|
||||
capabilities = [str(item) for item in payload.get("capabilities") or [] if item is not None]
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type="terminal",
|
||||
user_id=user_id,
|
||||
)
|
||||
session_id = identity.session_id()
|
||||
connection = TerminalConnection(
|
||||
websocket=websocket,
|
||||
peer_id=peer_id,
|
||||
session_id=session_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
device_name=device_name,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
if current is not None and current.session_id != session_id:
|
||||
self._remove_connection(current)
|
||||
old = self._connections_by_session.get(session_id)
|
||||
if old is not None and old.websocket is not websocket:
|
||||
with suppress(Exception):
|
||||
await old.websocket.close(code=1000)
|
||||
self._connections_by_session[session_id] = connection
|
||||
self._session_by_peer[peer_id] = session_id
|
||||
self._record(
|
||||
kind="terminal_connected",
|
||||
session_id=session_id,
|
||||
metadata={"peer_id": peer_id, "device_name": device_name, "capabilities": capabilities},
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": self.channel_id,
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
return connection
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection: TerminalConnection,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = _clean(payload.get("message_id"))
|
||||
text = _clean(payload.get("text"))
|
||||
if not message_id:
|
||||
await websocket.send_json({"type": "error", "error": "message_id is required"})
|
||||
return
|
||||
if not text:
|
||||
await websocket.send_json({"type": "error", "error": "text is required"})
|
||||
return
|
||||
if len(text) > self.max_message_chars:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"error": f"text exceeds max_message_chars ({self.max_message_chars})",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
thread_id = _clean(payload.get("thread_id")) or connection.thread_id
|
||||
user_id = _clean(payload.get("user_id")) or connection.user_id
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=connection.peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type="terminal",
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
inbound = InboundMessage(
|
||||
channel=self.channel_id,
|
||||
content=text,
|
||||
content_type="text",
|
||||
user_id=user_id,
|
||||
channel_identity=identity,
|
||||
metadata={
|
||||
"terminal": {
|
||||
"peer_id": connection.peer_id,
|
||||
"device_name": connection.device_name,
|
||||
"capabilities": connection.capabilities,
|
||||
}
|
||||
},
|
||||
)
|
||||
accept = await self.inbound_sink.accept_inbound(inbound)
|
||||
ack: dict[str, Any] = {
|
||||
"type": "ack",
|
||||
"message_id": message_id,
|
||||
"session_id": accept.session_id or identity.session_id(),
|
||||
"accepted": accept.accepted,
|
||||
}
|
||||
if accept.duplicate:
|
||||
ack["duplicate"] = True
|
||||
ack["pending"] = accept.pending
|
||||
record = accept.record or {}
|
||||
if record.get("reply"):
|
||||
ack["reply"] = record["reply"]
|
||||
if accept.error or record.get("error"):
|
||||
ack["error"] = accept.error or record.get("error")
|
||||
await websocket.send_json(ack)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
session_id = message.session_id
|
||||
if not session_id and message.channel_identity is not None:
|
||||
session_id = message.channel_identity.session_id()
|
||||
connection = self._connections_by_session.get(session_id or "")
|
||||
if connection is None:
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
return
|
||||
|
||||
payload = {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||
"run_id": message.run_id,
|
||||
"text": message.content,
|
||||
"finish_reason": message.finish_reason,
|
||||
}
|
||||
try:
|
||||
await connection.websocket.send_json(payload)
|
||||
except Exception:
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
self._remove_connection(connection)
|
||||
|
||||
def _remove_connection(self, connection: TerminalConnection) -> None:
|
||||
current = self._connections_by_session.get(connection.session_id)
|
||||
if current is connection:
|
||||
self._connections_by_session.pop(connection.session_id, None)
|
||||
if self._session_by_peer.get(connection.peer_id) == connection.session_id:
|
||||
self._session_by_peer.pop(connection.peer_id, None)
|
||||
|
||||
def _record(
|
||||
self,
|
||||
*,
|
||||
kind: str,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
status: str = "ok",
|
||||
error: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self.event_recorder is None:
|
||||
return
|
||||
self.event_recorder(
|
||||
channel_id=self.channel_id,
|
||||
kind=kind,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
status=status,
|
||||
error=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
Reference in New Issue
Block a user