211 lines
7.7 KiB
Python
211 lines
7.7 KiB
Python
"""Sidecar-backed channel connectors."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Any
|
|
|
|
from .models import ChannelRuntimeSpec, ValidationResult
|
|
from .sidecar_client import ConnectorSidecarClient
|
|
from .store import ChannelConnectionStore, CredentialStore
|
|
|
|
POLICY_CONFIG_KEYS = {
|
|
"allowFrom",
|
|
"groupAllowFrom",
|
|
"requireMentionInGroups",
|
|
"respondToMentionAll",
|
|
"dmMode",
|
|
"maxMessageChars",
|
|
"textBatchDelayMs",
|
|
"textBatchMaxMessages",
|
|
"textBatchMaxChars",
|
|
}
|
|
|
|
|
|
class ExternalConnectorBase:
|
|
kind = ""
|
|
capabilities: list[str] = []
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
connection_store: ChannelConnectionStore,
|
|
credential_store: CredentialStore,
|
|
sidecar_client: ConnectorSidecarClient | Any,
|
|
sidecar_base_url: str,
|
|
) -> None:
|
|
self.connection_store = connection_store
|
|
self.credential_store = credential_store
|
|
self.sidecar_client = sidecar_client
|
|
self.sidecar_base_url = sidecar_base_url
|
|
self.callback_base_url = _callback_base_url()
|
|
|
|
async def start_session(
|
|
self,
|
|
*,
|
|
display_name: str,
|
|
owner_user_id: str | None,
|
|
options: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
runtime_config = {"sidecarBaseUrl": self.sidecar_base_url}
|
|
runtime_config.update(_policy_runtime_config(options))
|
|
connection = self.connection_store.create(
|
|
kind=self.kind,
|
|
mode="sidecar",
|
|
display_name=display_name or self.kind,
|
|
account_id="",
|
|
owner_user_id=owner_user_id,
|
|
auth_type="connector_session",
|
|
runtime_config=runtime_config,
|
|
capabilities=list(self.capabilities),
|
|
)
|
|
connection = self.connection_store.update_status(connection.connection_id, status="pairing", last_error=None)
|
|
payload = {
|
|
"kind": self.kind,
|
|
"connectionId": connection.connection_id,
|
|
"channelId": connection.channel_id,
|
|
"displayName": connection.display_name,
|
|
"callbackBaseUrl": self.callback_base_url,
|
|
"options": dict(options),
|
|
}
|
|
view = dict(await self.sidecar_client.start_session(payload))
|
|
connection.pairing_session_id = str(view.get("sessionId") or "")
|
|
connection = self.connection_store.update(connection)
|
|
connection = self._apply_session_view(connection, view)
|
|
view["connectionId"] = connection.connection_id
|
|
view["channelId"] = connection.channel_id
|
|
return view
|
|
|
|
async def poll_session(self, session_id: str) -> dict[str, Any]:
|
|
view = dict(await self.sidecar_client.get_session(session_id))
|
|
connection = self._connection_for_session(session_id)
|
|
connection = self._apply_session_view(connection, view)
|
|
view["connectionId"] = connection.connection_id
|
|
view["channelId"] = connection.channel_id
|
|
return view
|
|
|
|
def _apply_session_view(self, connection: Any, view: dict[str, Any]) -> Any:
|
|
status = str(view.get("status") or "")
|
|
if status == "connected":
|
|
connection.account_id = str(view.get("accountId") or connection.account_id)
|
|
connection.display_name = str(view.get("displayName") or connection.display_name)
|
|
metadata = view.get("metadata") if isinstance(view.get("metadata"), dict) else {}
|
|
state_ref = metadata.get("stateRef")
|
|
if state_ref:
|
|
connection.credentials_ref = self.credential_store.put(kind=self.kind, values={"stateRef": state_ref})
|
|
self.connection_store.update(connection)
|
|
self.connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
|
elif status in {"expired", "error", "cancelled"}:
|
|
self.connection_store.update_status(
|
|
connection.connection_id,
|
|
status="error",
|
|
last_error=str(view.get("error") or status),
|
|
)
|
|
return self.connection_store.get(connection.connection_id)
|
|
|
|
async def validate(self, connection_id: str) -> ValidationResult:
|
|
connection = self.connection_store.get(connection_id)
|
|
if connection.status in {"connected", "running"}:
|
|
return ValidationResult(
|
|
ok=True,
|
|
status="connected",
|
|
account_id=connection.account_id,
|
|
display_name=connection.display_name,
|
|
)
|
|
return ValidationResult(ok=False, status=connection.status, error=connection.last_error)
|
|
|
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
|
connection = self.connection_store.get(connection_id)
|
|
if connection.status not in {"connected", "running"}:
|
|
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
|
return ChannelRuntimeSpec(
|
|
channel_id=connection.channel_id,
|
|
kind="external_connector",
|
|
mode="http",
|
|
account_id=connection.account_id,
|
|
display_name=connection.display_name,
|
|
config={
|
|
"platformKind": self.kind,
|
|
"connectionId": connection.connection_id,
|
|
**dict(connection.runtime_config),
|
|
"sidecarBaseUrl": connection.runtime_config.get("sidecarBaseUrl") or self.sidecar_base_url,
|
|
},
|
|
secrets_ref=None,
|
|
)
|
|
|
|
async def revoke(self, connection_id: str) -> None:
|
|
await self.sidecar_client.logout(connection_id)
|
|
|
|
def _connection_for_session(self, session_id: str):
|
|
for connection in self.connection_store.list():
|
|
if connection.pairing_session_id == session_id:
|
|
return connection
|
|
raise KeyError(session_id)
|
|
|
|
|
|
class WeixinConnector(ExternalConnectorBase):
|
|
kind = "weixin"
|
|
capabilities = ["receive_text", "send_text", "receive_media", "direct_messages"]
|
|
|
|
|
|
class FeishuConnector(ExternalConnectorBase):
|
|
kind = "feishu"
|
|
capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
|
|
|
|
|
def _policy_runtime_config(options: dict[str, Any]) -> dict[str, Any]:
|
|
result: dict[str, Any] = {}
|
|
for key in POLICY_CONFIG_KEYS:
|
|
if key not in options:
|
|
continue
|
|
value = options[key]
|
|
if key in {"allowFrom", "groupAllowFrom"}:
|
|
items = _string_list(value)
|
|
if items:
|
|
result[key] = items
|
|
continue
|
|
if key in {"maxMessageChars", "textBatchDelayMs", "textBatchMaxMessages", "textBatchMaxChars"}:
|
|
number = _positive_int(value)
|
|
if number is not None:
|
|
result[key] = number
|
|
continue
|
|
if key in {"requireMentionInGroups", "respondToMentionAll"}:
|
|
result[key] = _bool(value)
|
|
continue
|
|
text = str(value or "").strip()
|
|
if text:
|
|
result[key] = text
|
|
return result
|
|
|
|
|
|
def _callback_base_url() -> str:
|
|
for name in ("EXTERNAL_CONNECTOR_CALLBACK_BASE_URL", "BEAVER_CONNECTOR_CALLBACK_BASE_URL"):
|
|
value = os.environ.get(name, "").strip()
|
|
if value:
|
|
return value.rstrip("/")
|
|
return ""
|
|
|
|
|
|
def _string_list(value: Any) -> list[str]:
|
|
if isinstance(value, str):
|
|
raw_items = value.replace("\n", ",").split(",")
|
|
elif isinstance(value, list):
|
|
raw_items = value
|
|
else:
|
|
raw_items = []
|
|
return [str(item).strip() for item in raw_items if str(item).strip()]
|
|
|
|
|
|
def _positive_int(value: Any) -> int | None:
|
|
try:
|
|
number = int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
return number if number > 0 else None
|
|
|
|
|
|
def _bool(value: Any) -> bool:
|
|
if isinstance(value, bool):
|
|
return value
|
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|