223 lines
8.1 KiB
Python
223 lines
8.1 KiB
Python
"""Persistent channel connection stores."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now
|
|
|
|
|
|
class ChannelConnectionStore:
|
|
def __init__(self, path: Path) -> None:
|
|
self.path = Path(path)
|
|
self._lock = Lock()
|
|
|
|
def create(
|
|
self,
|
|
*,
|
|
kind: str,
|
|
mode: str,
|
|
display_name: str,
|
|
account_id: str,
|
|
owner_user_id: str | None,
|
|
auth_type: str,
|
|
runtime_config: dict[str, Any] | None = None,
|
|
capabilities: list[str] | None = None,
|
|
credentials_ref: str | None = None,
|
|
) -> ChannelConnection:
|
|
with self._lock:
|
|
data = self._load()
|
|
connection_id = f"conn_{uuid4().hex}"
|
|
channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}"
|
|
now = iso_now()
|
|
connection = ChannelConnection(
|
|
connection_id=connection_id,
|
|
owner_user_id=owner_user_id,
|
|
channel_id=channel_id,
|
|
kind=kind,
|
|
mode=mode,
|
|
display_name=display_name or channel_id,
|
|
account_id=account_id,
|
|
status="draft",
|
|
auth_type=auth_type,
|
|
credentials_ref=credentials_ref,
|
|
runtime_config=runtime_config or {},
|
|
capabilities=capabilities or [],
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
data["connections"][connection_id] = connection.to_dict()
|
|
self._save(data)
|
|
return connection
|
|
|
|
def get(self, connection_id: str) -> ChannelConnection:
|
|
data = self._load()
|
|
raw = data["connections"].get(connection_id)
|
|
if not isinstance(raw, dict):
|
|
raise KeyError(connection_id)
|
|
return ChannelConnection.from_dict(raw)
|
|
|
|
def list(self) -> list[ChannelConnection]:
|
|
data = self._load()
|
|
return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)]
|
|
|
|
def update(self, connection: ChannelConnection) -> ChannelConnection:
|
|
with self._lock:
|
|
data = self._load()
|
|
if connection.connection_id not in data["connections"]:
|
|
raise KeyError(connection.connection_id)
|
|
connection.updated_at = iso_now()
|
|
data["connections"][connection.connection_id] = connection.to_dict()
|
|
self._save(data)
|
|
return connection
|
|
|
|
def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection:
|
|
if status not in CONNECTION_STATUSES:
|
|
raise ValueError(f"Unsupported connection status: {status}")
|
|
connection = self.get(connection_id)
|
|
connection.status = status
|
|
connection.last_error = last_error
|
|
if status in {"connected", "running"}:
|
|
connection.last_seen_at = iso_now()
|
|
return self.update(connection)
|
|
|
|
def revoke(self, connection_id: str) -> ChannelConnection:
|
|
return self.update_status(connection_id, status="revoked", last_error=None)
|
|
|
|
def _load(self) -> dict[str, Any]:
|
|
if not self.path.exists():
|
|
return {"connections": {}}
|
|
try:
|
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return {"connections": {}}
|
|
if not isinstance(data, dict) or not isinstance(data.get("connections"), dict):
|
|
return {"connections": {}}
|
|
return data
|
|
|
|
def _save(self, data: dict[str, Any]) -> None:
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
tmp_path.replace(self.path)
|
|
|
|
|
|
class CredentialStore:
|
|
def __init__(self, path: Path) -> None:
|
|
self.path = Path(path)
|
|
self._lock = Lock()
|
|
|
|
def put(self, *, kind: str, values: dict[str, Any]) -> str:
|
|
cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()}
|
|
ref = f"cred_{uuid4().hex}"
|
|
with self._lock:
|
|
data = self._load()
|
|
data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()}
|
|
self._save(data)
|
|
return ref
|
|
|
|
def get(self, ref: str) -> dict[str, str]:
|
|
data = self._load()
|
|
item = data["credentials"].get(ref)
|
|
if not isinstance(item, dict):
|
|
raise KeyError(ref)
|
|
values = item.get("values")
|
|
if not isinstance(values, dict):
|
|
return {}
|
|
return {str(key): str(value) for key, value in values.items()}
|
|
|
|
def redacted(self, ref: str | None) -> dict[str, str]:
|
|
if not ref:
|
|
return {}
|
|
try:
|
|
values = self.get(ref)
|
|
except KeyError:
|
|
return {}
|
|
return {key: "***" for key in values}
|
|
|
|
def _load(self) -> dict[str, Any]:
|
|
if not self.path.exists():
|
|
return {"credentials": {}}
|
|
try:
|
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return {"credentials": {}}
|
|
if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict):
|
|
return {"credentials": {}}
|
|
return data
|
|
|
|
def _save(self, data: dict[str, Any]) -> None:
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
tmp_path.replace(self.path)
|
|
|
|
|
|
class PairingTokenStore:
|
|
def __init__(self, path: Path) -> None:
|
|
self.path = Path(path)
|
|
self._lock = Lock()
|
|
|
|
def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession:
|
|
now_ms = _now_ms()
|
|
session = PairingSession(
|
|
pairing_session_id=f"pair_{uuid4().hex}",
|
|
kind=kind,
|
|
scope=scope,
|
|
token=f"pair_{uuid4().hex}",
|
|
status="pending",
|
|
expires_at_ms=now_ms + int(ttl_seconds * 1000),
|
|
created_at_ms=now_ms,
|
|
)
|
|
with self._lock:
|
|
data = self._load()
|
|
data["sessions"][session.pairing_session_id] = session.to_dict()
|
|
self._save(data)
|
|
return session
|
|
|
|
def consume(self, token: str, *, expected_kind: str) -> PairingSession | None:
|
|
with self._lock:
|
|
data = self._load()
|
|
for key, raw in data["sessions"].items():
|
|
session = PairingSession.from_dict(raw)
|
|
if session.token != token or session.kind != expected_kind:
|
|
continue
|
|
if session.status != "pending" or session.expires_at_ms <= _now_ms():
|
|
return None
|
|
session.status = "consumed"
|
|
data["sessions"][key] = session.to_dict()
|
|
self._save(data)
|
|
return session
|
|
return None
|
|
|
|
def _load(self) -> dict[str, Any]:
|
|
if not self.path.exists():
|
|
return {"sessions": {}}
|
|
try:
|
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return {"sessions": {}}
|
|
if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict):
|
|
return {"sessions": {}}
|
|
return data
|
|
|
|
def _save(self, data: dict[str, Any]) -> None:
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
tmp_path.replace(self.path)
|
|
|
|
|
|
def _now_ms() -> int:
|
|
return int(time.time() * 1000)
|
|
|
|
|
|
def _slug(value: str) -> str:
|
|
text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower())
|
|
return "-".join(part for part in text.split("-") if part) or "channel"
|