"""Persistent channel connection stores.""" from __future__ import annotations import json import time from pathlib import Path from threading import Lock from typing import Any from uuid import uuid4 from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now class ChannelConnectionStore: def __init__(self, path: Path) -> None: self.path = Path(path) self._lock = Lock() def create( self, *, kind: str, mode: str, display_name: str, account_id: str, owner_user_id: str | None, auth_type: str, runtime_config: dict[str, Any] | None = None, capabilities: list[str] | None = None, credentials_ref: str | None = None, ) -> ChannelConnection: with self._lock: data = self._load() connection_id = f"conn_{uuid4().hex}" channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}" now = iso_now() connection = ChannelConnection( connection_id=connection_id, owner_user_id=owner_user_id, channel_id=channel_id, kind=kind, mode=mode, display_name=display_name or channel_id, account_id=account_id, status="draft", auth_type=auth_type, credentials_ref=credentials_ref, runtime_config=runtime_config or {}, capabilities=capabilities or [], created_at=now, updated_at=now, ) data["connections"][connection_id] = connection.to_dict() self._save(data) return connection def get(self, connection_id: str) -> ChannelConnection: data = self._load() raw = data["connections"].get(connection_id) if not isinstance(raw, dict): raise KeyError(connection_id) return ChannelConnection.from_dict(raw) def list(self) -> list[ChannelConnection]: data = self._load() return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)] def update(self, connection: ChannelConnection) -> ChannelConnection: with self._lock: data = self._load() if connection.connection_id not in data["connections"]: raise KeyError(connection.connection_id) connection.updated_at = iso_now() data["connections"][connection.connection_id] = connection.to_dict() self._save(data) return connection def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection: if status not in CONNECTION_STATUSES: raise ValueError(f"Unsupported connection status: {status}") connection = self.get(connection_id) connection.status = status connection.last_error = last_error if status in {"connected", "running"}: connection.last_seen_at = iso_now() return self.update(connection) def revoke(self, connection_id: str) -> ChannelConnection: return self.update_status(connection_id, status="revoked", last_error=None) def _load(self) -> dict[str, Any]: if not self.path.exists(): return {"connections": {}} try: data = json.loads(self.path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return {"connections": {}} if not isinstance(data, dict) or not isinstance(data.get("connections"), dict): return {"connections": {}} return data def _save(self, data: dict[str, Any]) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) tmp_path = self.path.with_name(f"{self.path.name}.tmp") tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") tmp_path.replace(self.path) class CredentialStore: def __init__(self, path: Path) -> None: self.path = Path(path) self._lock = Lock() def put(self, *, kind: str, values: dict[str, Any]) -> str: cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()} ref = f"cred_{uuid4().hex}" with self._lock: data = self._load() data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()} self._save(data) return ref def get(self, ref: str) -> dict[str, str]: data = self._load() item = data["credentials"].get(ref) if not isinstance(item, dict): raise KeyError(ref) values = item.get("values") if not isinstance(values, dict): return {} return {str(key): str(value) for key, value in values.items()} def redacted(self, ref: str | None) -> dict[str, str]: if not ref: return {} try: values = self.get(ref) except KeyError: return {} return {key: "***" for key in values} def _load(self) -> dict[str, Any]: if not self.path.exists(): return {"credentials": {}} try: data = json.loads(self.path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return {"credentials": {}} if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict): return {"credentials": {}} return data def _save(self, data: dict[str, Any]) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) tmp_path = self.path.with_name(f"{self.path.name}.tmp") tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") tmp_path.replace(self.path) class PairingTokenStore: def __init__(self, path: Path) -> None: self.path = Path(path) self._lock = Lock() def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession: now_ms = _now_ms() session = PairingSession( pairing_session_id=f"pair_{uuid4().hex}", kind=kind, scope=scope, token=f"pair_{uuid4().hex}", status="pending", expires_at_ms=now_ms + int(ttl_seconds * 1000), created_at_ms=now_ms, ) with self._lock: data = self._load() data["sessions"][session.pairing_session_id] = session.to_dict() self._save(data) return session def consume(self, token: str, *, expected_kind: str) -> PairingSession | None: with self._lock: data = self._load() for key, raw in data["sessions"].items(): session = PairingSession.from_dict(raw) if session.token != token or session.kind != expected_kind: continue if session.status != "pending" or session.expires_at_ms <= _now_ms(): return None session.status = "consumed" data["sessions"][key] = session.to_dict() self._save(data) return session return None def _load(self) -> dict[str, Any]: if not self.path.exists(): return {"sessions": {}} try: data = json.loads(self.path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return {"sessions": {}} if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict): return {"sessions": {}} return data def _save(self, data: dict[str, Any]) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) tmp_path = self.path.with_name(f"{self.path.name}.tmp") tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") tmp_path.replace(self.path) def _now_ms() -> int: return int(time.time() * 1000) def _slug(value: str) -> str: text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower()) return "-".join(part for part in text.split("-") if part) or "channel"