feat: implement channel runtime connectors
This commit is contained in:
@ -0,0 +1,29 @@
|
||||
"""Channel connection setup layer."""
|
||||
|
||||
from .connectors import ChannelConnector, ChannelConnectorRegistry
|
||||
from .dedupe import ConnectorMessageDedupeRecord, DedupeBeginResult, MessageDedupeStore
|
||||
from .external import ExternalConnectorBase, FeishuConnector, WeixinConnector
|
||||
from .models import ChannelConnection, ChannelRuntimeSpec, PairingSession, ValidationResult
|
||||
from .sidecar_client import ConnectorSidecarClient
|
||||
from .store import ChannelConnectionStore, CredentialStore, PairingTokenStore
|
||||
from .telegram import TelegramConnector
|
||||
|
||||
__all__ = [
|
||||
"ChannelConnector",
|
||||
"ChannelConnectorRegistry",
|
||||
"ConnectorMessageDedupeRecord",
|
||||
"DedupeBeginResult",
|
||||
"MessageDedupeStore",
|
||||
"ExternalConnectorBase",
|
||||
"FeishuConnector",
|
||||
"WeixinConnector",
|
||||
"ConnectorSidecarClient",
|
||||
"ChannelConnection",
|
||||
"ChannelRuntimeSpec",
|
||||
"PairingSession",
|
||||
"ValidationResult",
|
||||
"ChannelConnectionStore",
|
||||
"CredentialStore",
|
||||
"PairingTokenStore",
|
||||
"TelegramConnector",
|
||||
]
|
||||
@ -0,0 +1,93 @@
|
||||
"""Channel connector registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
|
||||
class ChannelConnector(Protocol):
|
||||
kind: str
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
...
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
...
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ChannelConnectorRegistry:
|
||||
def __init__(self, *, connection_store: ChannelConnectionStore, credential_store: CredentialStore) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self._connectors: dict[str, ChannelConnector] = {}
|
||||
|
||||
def register(self, connector: ChannelConnector) -> None:
|
||||
kind = connector.kind.strip()
|
||||
if not kind:
|
||||
raise ValueError("Connector kind is required")
|
||||
if kind in self._connectors:
|
||||
raise ValueError(f"Connector already registered: {kind}")
|
||||
self._connectors[kind] = connector
|
||||
|
||||
def connectors(self) -> list[dict[str, str]]:
|
||||
return [{"kind": kind} for kind in sorted(self._connectors)]
|
||||
|
||||
def connector_for_kind(self, kind: str) -> ChannelConnector:
|
||||
return self._connector(kind)
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
connector = self._connector(connection.kind)
|
||||
result = await connector.validate(connection_id)
|
||||
self.connection_store.update_status(
|
||||
connection_id,
|
||||
status=result.status,
|
||||
last_error=result.error,
|
||||
)
|
||||
return result
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
return await self._connector(connection.kind).materialize_runtime(connection_id)
|
||||
|
||||
async def materialize_connected_runtime_specs(self) -> list[ChannelRuntimeSpec]:
|
||||
specs: list[ChannelRuntimeSpec] = []
|
||||
for connection in self.connection_store.list():
|
||||
if connection.status not in {"connected", "running"}:
|
||||
continue
|
||||
specs.append(await self._connector(connection.kind).materialize_runtime(connection.connection_id))
|
||||
return specs
|
||||
|
||||
async def materialize_channel_configs(self) -> dict[str, ChannelConfig]:
|
||||
channels: dict[str, ChannelConfig] = {}
|
||||
for spec in await self.materialize_connected_runtime_specs():
|
||||
secrets = self.credential_store.get(spec.secrets_ref) if spec.secrets_ref else {}
|
||||
channels[spec.channel_id] = ChannelConfig(
|
||||
enabled=True,
|
||||
kind=spec.kind,
|
||||
mode=spec.mode,
|
||||
account_id=spec.account_id,
|
||||
display_name=spec.display_name,
|
||||
config=dict(spec.config),
|
||||
secrets=secrets,
|
||||
)
|
||||
return channels
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
await self._connector(connection.kind).revoke(connection_id)
|
||||
self.connection_store.revoke(connection_id)
|
||||
|
||||
def _connector(self, kind: str) -> ChannelConnector:
|
||||
connector = self._connectors.get(kind)
|
||||
if connector is None:
|
||||
raise KeyError(f"Connector not registered: {kind}")
|
||||
return connector
|
||||
@ -0,0 +1,144 @@
|
||||
"""Bridge event dedupe store for external connector retries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _parse_iso(value: str) -> datetime:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectorMessageDedupeRecord:
|
||||
dedupe_key: str
|
||||
connection_id: str
|
||||
event_id: str
|
||||
status: str
|
||||
first_seen_at: str
|
||||
updated_at: str
|
||||
delivery_attempts: int
|
||||
message_id: str | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConnectorMessageDedupeRecord":
|
||||
return cls(
|
||||
dedupe_key=str(data.get("dedupe_key") or ""),
|
||||
connection_id=str(data.get("connection_id") or ""),
|
||||
event_id=str(data.get("event_id") or ""),
|
||||
status=str(data.get("status") or "processing"),
|
||||
first_seen_at=str(data.get("first_seen_at") or _iso_now()),
|
||||
updated_at=str(data.get("updated_at") or _iso_now()),
|
||||
delivery_attempts=int(data.get("delivery_attempts") or 0),
|
||||
message_id=str(data["message_id"]) if data.get("message_id") is not None else None,
|
||||
last_error=str(data["last_error"]) if data.get("last_error") is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DedupeBeginResult:
|
||||
should_process: bool
|
||||
dedupe_key: str
|
||||
status: str
|
||||
http_status: int
|
||||
retry_after_seconds: int | None
|
||||
record: ConnectorMessageDedupeRecord
|
||||
|
||||
|
||||
class MessageDedupeStore:
|
||||
def __init__(self, path: Path, *, processing_ttl_seconds: int = 60) -> None:
|
||||
self.path = Path(path)
|
||||
self.processing_ttl_seconds = int(processing_ttl_seconds)
|
||||
self._lock = Lock()
|
||||
|
||||
def begin(self, *, connection_id: str, event_id: str, delivery_attempt: int) -> DedupeBeginResult:
|
||||
dedupe_key = f"{connection_id}:{event_id}"
|
||||
now = _iso_now()
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
raw = data["records"].get(dedupe_key)
|
||||
if isinstance(raw, dict):
|
||||
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||
if record.status == "completed":
|
||||
return DedupeBeginResult(False, dedupe_key, record.status, 200, None, record)
|
||||
if record.status == "processing" and not self._is_stale(record, now):
|
||||
return DedupeBeginResult(False, dedupe_key, record.status, 409, 5, record)
|
||||
record.status = "processing"
|
||||
record.updated_at = now
|
||||
record.delivery_attempts = max(record.delivery_attempts + 1, int(delivery_attempt))
|
||||
record.last_error = None
|
||||
else:
|
||||
record = ConnectorMessageDedupeRecord(
|
||||
dedupe_key=dedupe_key,
|
||||
connection_id=connection_id,
|
||||
event_id=event_id,
|
||||
status="processing",
|
||||
first_seen_at=now,
|
||||
updated_at=now,
|
||||
delivery_attempts=max(1, int(delivery_attempt)),
|
||||
)
|
||||
data["records"][dedupe_key] = record.to_dict()
|
||||
self._save(data)
|
||||
return DedupeBeginResult(True, dedupe_key, record.status, 200, None, record)
|
||||
|
||||
def complete(self, dedupe_key: str, *, message_id: str | None) -> ConnectorMessageDedupeRecord:
|
||||
return self._mark(dedupe_key, status="completed", message_id=message_id, error=None)
|
||||
|
||||
def fail(self, dedupe_key: str, *, error: str) -> ConnectorMessageDedupeRecord:
|
||||
return self._mark(dedupe_key, status="failed", message_id=None, error=error)
|
||||
|
||||
def _mark(
|
||||
self,
|
||||
dedupe_key: str,
|
||||
*,
|
||||
status: str,
|
||||
message_id: str | None,
|
||||
error: str | None,
|
||||
) -> ConnectorMessageDedupeRecord:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
raw = data["records"].get(dedupe_key)
|
||||
if not isinstance(raw, dict):
|
||||
raise KeyError(dedupe_key)
|
||||
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||
record.status = status
|
||||
record.updated_at = _iso_now()
|
||||
record.message_id = message_id or record.message_id
|
||||
record.last_error = error
|
||||
data["records"][dedupe_key] = record.to_dict()
|
||||
self._save(data)
|
||||
return record
|
||||
|
||||
def _is_stale(self, record: ConnectorMessageDedupeRecord, now: str) -> bool:
|
||||
age = (_parse_iso(now) - _parse_iso(record.updated_at)).total_seconds()
|
||||
return age >= self.processing_ttl_seconds
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"records": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"records": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
|
||||
return {"records": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
@ -0,0 +1,131 @@
|
||||
"""Sidecar-backed channel connectors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .sidecar_client import ConnectorSidecarClient
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
|
||||
class ExternalConnectorBase:
|
||||
kind = ""
|
||||
capabilities: list[str] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_store: ChannelConnectionStore,
|
||||
credential_store: CredentialStore,
|
||||
sidecar_client: ConnectorSidecarClient | Any,
|
||||
sidecar_base_url: str,
|
||||
) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self.sidecar_client = sidecar_client
|
||||
self.sidecar_base_url = sidecar_base_url
|
||||
|
||||
async def start_session(
|
||||
self,
|
||||
*,
|
||||
display_name: str,
|
||||
owner_user_id: str | None,
|
||||
options: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
connection = self.connection_store.create(
|
||||
kind=self.kind,
|
||||
mode="sidecar",
|
||||
display_name=display_name or self.kind,
|
||||
account_id="",
|
||||
owner_user_id=owner_user_id,
|
||||
auth_type="connector_session",
|
||||
runtime_config={"sidecarBaseUrl": self.sidecar_base_url},
|
||||
capabilities=list(self.capabilities),
|
||||
)
|
||||
connection = self.connection_store.update_status(connection.connection_id, status="pairing", last_error=None)
|
||||
payload = {
|
||||
"kind": self.kind,
|
||||
"connectionId": connection.connection_id,
|
||||
"channelId": connection.channel_id,
|
||||
"displayName": connection.display_name,
|
||||
"callbackBaseUrl": "",
|
||||
"options": dict(options),
|
||||
}
|
||||
view = dict(await self.sidecar_client.start_session(payload))
|
||||
connection.pairing_session_id = str(view.get("sessionId") or "")
|
||||
self.connection_store.update(connection)
|
||||
view["connectionId"] = connection.connection_id
|
||||
view["channelId"] = connection.channel_id
|
||||
return view
|
||||
|
||||
async def poll_session(self, session_id: str) -> dict[str, Any]:
|
||||
view = dict(await self.sidecar_client.get_session(session_id))
|
||||
connection = self._connection_for_session(session_id)
|
||||
status = str(view.get("status") or "")
|
||||
if status == "connected":
|
||||
connection.account_id = str(view.get("accountId") or connection.account_id)
|
||||
connection.display_name = str(view.get("displayName") or connection.display_name)
|
||||
metadata = view.get("metadata") if isinstance(view.get("metadata"), dict) else {}
|
||||
state_ref = metadata.get("stateRef")
|
||||
if state_ref:
|
||||
connection.credentials_ref = self.credential_store.put(kind=self.kind, values={"stateRef": state_ref})
|
||||
self.connection_store.update(connection)
|
||||
self.connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||
elif status in {"expired", "error", "cancelled"}:
|
||||
self.connection_store.update_status(
|
||||
connection.connection_id,
|
||||
status="error",
|
||||
last_error=str(view.get("error") or status),
|
||||
)
|
||||
view["connectionId"] = connection.connection_id
|
||||
view["channelId"] = connection.channel_id
|
||||
return view
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status in {"connected", "running"}:
|
||||
return ValidationResult(
|
||||
ok=True,
|
||||
status="connected",
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
)
|
||||
return ValidationResult(ok=False, status=connection.status, error=connection.last_error)
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status not in {"connected", "running"}:
|
||||
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id=connection.channel_id,
|
||||
kind="external_connector",
|
||||
mode="http",
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
config={
|
||||
"platformKind": self.kind,
|
||||
"connectionId": connection.connection_id,
|
||||
"sidecarBaseUrl": connection.runtime_config.get("sidecarBaseUrl") or self.sidecar_base_url,
|
||||
},
|
||||
secrets_ref=None,
|
||||
)
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
await self.sidecar_client.logout(connection_id)
|
||||
|
||||
def _connection_for_session(self, session_id: str):
|
||||
for connection in self.connection_store.list():
|
||||
if connection.pairing_session_id == session_id:
|
||||
return connection
|
||||
raise KeyError(session_id)
|
||||
|
||||
|
||||
class WeixinConnector(ExternalConnectorBase):
|
||||
kind = "weixin"
|
||||
capabilities = ["receive_text", "send_text", "receive_media", "direct_messages"]
|
||||
|
||||
|
||||
class FeishuConnector(ExternalConnectorBase):
|
||||
kind = "feishu"
|
||||
capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||
@ -0,0 +1,117 @@
|
||||
"""Channel connection setup models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
CONNECTION_STATUSES = {"draft", "pairing", "connected", "running", "degraded", "error", "revoked"}
|
||||
|
||||
|
||||
def iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelConnection:
|
||||
connection_id: str
|
||||
owner_user_id: str | None
|
||||
channel_id: str
|
||||
kind: str
|
||||
mode: str
|
||||
display_name: str
|
||||
account_id: str
|
||||
status: str
|
||||
auth_type: str
|
||||
credentials_ref: str | None = None
|
||||
connector_ref: str | None = None
|
||||
pairing_session_id: str | None = None
|
||||
runtime_config: dict[str, Any] = field(default_factory=dict)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
created_at: str = field(default_factory=iso_now)
|
||||
updated_at: str = field(default_factory=iso_now)
|
||||
last_seen_at: str | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ChannelConnection":
|
||||
return cls(
|
||||
connection_id=str(data.get("connection_id") or ""),
|
||||
owner_user_id=_optional_string(data.get("owner_user_id")),
|
||||
channel_id=str(data.get("channel_id") or ""),
|
||||
kind=str(data.get("kind") or ""),
|
||||
mode=str(data.get("mode") or ""),
|
||||
display_name=str(data.get("display_name") or ""),
|
||||
account_id=str(data.get("account_id") or ""),
|
||||
status=str(data.get("status") or "draft"),
|
||||
auth_type=str(data.get("auth_type") or ""),
|
||||
credentials_ref=_optional_string(data.get("credentials_ref")),
|
||||
connector_ref=_optional_string(data.get("connector_ref")),
|
||||
pairing_session_id=_optional_string(data.get("pairing_session_id")),
|
||||
runtime_config=dict(data.get("runtime_config") or {}),
|
||||
capabilities=[str(item) for item in data.get("capabilities") or []],
|
||||
created_at=str(data.get("created_at") or iso_now()),
|
||||
updated_at=str(data.get("updated_at") or iso_now()),
|
||||
last_seen_at=_optional_string(data.get("last_seen_at")),
|
||||
last_error=_optional_string(data.get("last_error")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PairingSession:
|
||||
pairing_session_id: str
|
||||
kind: str
|
||||
scope: str
|
||||
token: str
|
||||
status: str
|
||||
expires_at_ms: int
|
||||
created_at_ms: int
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PairingSession":
|
||||
return cls(
|
||||
pairing_session_id=str(data.get("pairing_session_id") or ""),
|
||||
kind=str(data.get("kind") or ""),
|
||||
scope=str(data.get("scope") or ""),
|
||||
token=str(data.get("token") or ""),
|
||||
status=str(data.get("status") or "pending"),
|
||||
expires_at_ms=int(data.get("expires_at_ms") or 0),
|
||||
created_at_ms=int(data.get("created_at_ms") or 0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelRuntimeSpec:
|
||||
channel_id: str
|
||||
kind: str
|
||||
mode: str
|
||||
account_id: str
|
||||
display_name: str
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
secrets_ref: str | None = None
|
||||
external_endpoint: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ValidationResult:
|
||||
ok: bool
|
||||
status: str
|
||||
account_id: str | None = None
|
||||
display_name: str | None = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _optional_string(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,39 @@
|
||||
"""HTTP client for the generic external connector sidecar."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class ConnectorSidecarClient:
|
||||
def __init__(self, *, base_url: str, token: str, timeout_seconds: float = 20.0) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.timeout_seconds = float(timeout_seconds)
|
||||
|
||||
async def get_connectors(self) -> list[dict[str, Any]]:
|
||||
return await self._request("GET", "/connectors")
|
||||
|
||||
async def start_session(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self._request("POST", "/connector-sessions", json=payload)
|
||||
|
||||
async def get_session(self, session_id: str) -> dict[str, Any]:
|
||||
return await self._request("GET", f"/connector-sessions/{session_id}")
|
||||
|
||||
async def cancel_session(self, session_id: str) -> dict[str, Any]:
|
||||
return await self._request("POST", f"/connector-sessions/{session_id}/cancel", json={})
|
||||
|
||||
async def logout(self, connection_id: str) -> dict[str, Any]:
|
||||
return await self._request("POST", f"/connections/{connection_id}/logout", json={})
|
||||
|
||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self._request("POST", "/send", json=payload)
|
||||
|
||||
async def _request(self, method: str, path: str, *, json: dict[str, Any] | None = None) -> Any:
|
||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||
async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
|
||||
response = await client.request(method, f"{self.base_url}{path}", json=json, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@ -0,0 +1,222 @@
|
||||
"""Persistent channel connection stores."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now
|
||||
|
||||
|
||||
class ChannelConnectionStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
kind: str,
|
||||
mode: str,
|
||||
display_name: str,
|
||||
account_id: str,
|
||||
owner_user_id: str | None,
|
||||
auth_type: str,
|
||||
runtime_config: dict[str, Any] | None = None,
|
||||
capabilities: list[str] | None = None,
|
||||
credentials_ref: str | None = None,
|
||||
) -> ChannelConnection:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
connection_id = f"conn_{uuid4().hex}"
|
||||
channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}"
|
||||
now = iso_now()
|
||||
connection = ChannelConnection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=owner_user_id,
|
||||
channel_id=channel_id,
|
||||
kind=kind,
|
||||
mode=mode,
|
||||
display_name=display_name or channel_id,
|
||||
account_id=account_id,
|
||||
status="draft",
|
||||
auth_type=auth_type,
|
||||
credentials_ref=credentials_ref,
|
||||
runtime_config=runtime_config or {},
|
||||
capabilities=capabilities or [],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
data["connections"][connection_id] = connection.to_dict()
|
||||
self._save(data)
|
||||
return connection
|
||||
|
||||
def get(self, connection_id: str) -> ChannelConnection:
|
||||
data = self._load()
|
||||
raw = data["connections"].get(connection_id)
|
||||
if not isinstance(raw, dict):
|
||||
raise KeyError(connection_id)
|
||||
return ChannelConnection.from_dict(raw)
|
||||
|
||||
def list(self) -> list[ChannelConnection]:
|
||||
data = self._load()
|
||||
return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)]
|
||||
|
||||
def update(self, connection: ChannelConnection) -> ChannelConnection:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
if connection.connection_id not in data["connections"]:
|
||||
raise KeyError(connection.connection_id)
|
||||
connection.updated_at = iso_now()
|
||||
data["connections"][connection.connection_id] = connection.to_dict()
|
||||
self._save(data)
|
||||
return connection
|
||||
|
||||
def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection:
|
||||
if status not in CONNECTION_STATUSES:
|
||||
raise ValueError(f"Unsupported connection status: {status}")
|
||||
connection = self.get(connection_id)
|
||||
connection.status = status
|
||||
connection.last_error = last_error
|
||||
if status in {"connected", "running"}:
|
||||
connection.last_seen_at = iso_now()
|
||||
return self.update(connection)
|
||||
|
||||
def revoke(self, connection_id: str) -> ChannelConnection:
|
||||
return self.update_status(connection_id, status="revoked", last_error=None)
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"connections": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"connections": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("connections"), dict):
|
||||
return {"connections": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def put(self, *, kind: str, values: dict[str, Any]) -> str:
|
||||
cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()}
|
||||
ref = f"cred_{uuid4().hex}"
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()}
|
||||
self._save(data)
|
||||
return ref
|
||||
|
||||
def get(self, ref: str) -> dict[str, str]:
|
||||
data = self._load()
|
||||
item = data["credentials"].get(ref)
|
||||
if not isinstance(item, dict):
|
||||
raise KeyError(ref)
|
||||
values = item.get("values")
|
||||
if not isinstance(values, dict):
|
||||
return {}
|
||||
return {str(key): str(value) for key, value in values.items()}
|
||||
|
||||
def redacted(self, ref: str | None) -> dict[str, str]:
|
||||
if not ref:
|
||||
return {}
|
||||
try:
|
||||
values = self.get(ref)
|
||||
except KeyError:
|
||||
return {}
|
||||
return {key: "***" for key in values}
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"credentials": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"credentials": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict):
|
||||
return {"credentials": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
class PairingTokenStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession:
|
||||
now_ms = _now_ms()
|
||||
session = PairingSession(
|
||||
pairing_session_id=f"pair_{uuid4().hex}",
|
||||
kind=kind,
|
||||
scope=scope,
|
||||
token=f"pair_{uuid4().hex}",
|
||||
status="pending",
|
||||
expires_at_ms=now_ms + int(ttl_seconds * 1000),
|
||||
created_at_ms=now_ms,
|
||||
)
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
data["sessions"][session.pairing_session_id] = session.to_dict()
|
||||
self._save(data)
|
||||
return session
|
||||
|
||||
def consume(self, token: str, *, expected_kind: str) -> PairingSession | None:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
for key, raw in data["sessions"].items():
|
||||
session = PairingSession.from_dict(raw)
|
||||
if session.token != token or session.kind != expected_kind:
|
||||
continue
|
||||
if session.status != "pending" or session.expires_at_ms <= _now_ms():
|
||||
return None
|
||||
session.status = "consumed"
|
||||
data["sessions"][key] = session.to_dict()
|
||||
self._save(data)
|
||||
return session
|
||||
return None
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"sessions": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"sessions": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict):
|
||||
return {"sessions": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _slug(value: str) -> str:
|
||||
text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower())
|
||||
return "-".join(part for part in text.split("-") if part) or "channel"
|
||||
@ -0,0 +1,92 @@
|
||||
"""Telegram channel connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
|
||||
class TelegramConnector:
|
||||
kind = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_store: ChannelConnectionStore,
|
||||
credential_store: CredentialStore,
|
||||
client_factory: Callable[[str], Any] | None = None,
|
||||
) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self.client_factory = client_factory or _default_client_factory
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
token = self._bot_token(connection.credentials_ref)
|
||||
try:
|
||||
client = self.client_factory(token)
|
||||
raw = await client.get_me()
|
||||
bot_id = _value(raw, "id")
|
||||
username = _value(raw, "username")
|
||||
first_name = _value(raw, "first_name") or "Telegram Bot"
|
||||
account_id = f"telegram:{bot_id}" if bot_id else connection.account_id
|
||||
display_name = f"{first_name} (@{username})" if username else first_name
|
||||
connection.account_id = account_id
|
||||
connection.display_name = display_name
|
||||
connection.capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||
self.connection_store.update(connection)
|
||||
return ValidationResult(
|
||||
ok=True,
|
||||
status="connected",
|
||||
account_id=account_id,
|
||||
display_name=display_name,
|
||||
metadata={"username": username} if username else {},
|
||||
)
|
||||
except Exception as exc:
|
||||
return ValidationResult(ok=False, status="error", error=str(exc))
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status not in {"connected", "running"}:
|
||||
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id=connection.channel_id,
|
||||
kind=connection.kind,
|
||||
mode=connection.mode,
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
config=dict(connection.runtime_config),
|
||||
secrets_ref=connection.credentials_ref,
|
||||
)
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
# Telegram bot tokens do not have a Beaver-managed platform revoke action.
|
||||
# The registry owns local connection state transitions.
|
||||
return None
|
||||
|
||||
def _bot_token(self, credentials_ref: str | None) -> str:
|
||||
if not credentials_ref:
|
||||
raise ValueError("Telegram credentials are missing")
|
||||
token = self.credential_store.get(credentials_ref).get("botToken")
|
||||
if not token:
|
||||
raise ValueError("botToken is required")
|
||||
return token
|
||||
|
||||
|
||||
def _value(raw: Any, key: str) -> str:
|
||||
if isinstance(raw, dict):
|
||||
value = raw.get(key)
|
||||
else:
|
||||
value = getattr(raw, key, None)
|
||||
return str(value).strip() if value is not None else ""
|
||||
|
||||
|
||||
def _default_client_factory(token: str) -> Any:
|
||||
try:
|
||||
from telegram import Bot
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[telegram] to validate Telegram connections") from exc
|
||||
return Bot(token=token)
|
||||
Reference in New Issue
Block a user