Files
beaver_project/app-instance/backend/beaver/interfaces/channels/connections/store.py

223 lines
8.1 KiB
Python

"""Persistent channel connection stores."""
from __future__ import annotations
import json
import time
from pathlib import Path
from threading import Lock
from typing import Any
from uuid import uuid4
from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now
class ChannelConnectionStore:
def __init__(self, path: Path) -> None:
self.path = Path(path)
self._lock = Lock()
def create(
self,
*,
kind: str,
mode: str,
display_name: str,
account_id: str,
owner_user_id: str | None,
auth_type: str,
runtime_config: dict[str, Any] | None = None,
capabilities: list[str] | None = None,
credentials_ref: str | None = None,
) -> ChannelConnection:
with self._lock:
data = self._load()
connection_id = f"conn_{uuid4().hex}"
channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}"
now = iso_now()
connection = ChannelConnection(
connection_id=connection_id,
owner_user_id=owner_user_id,
channel_id=channel_id,
kind=kind,
mode=mode,
display_name=display_name or channel_id,
account_id=account_id,
status="draft",
auth_type=auth_type,
credentials_ref=credentials_ref,
runtime_config=runtime_config or {},
capabilities=capabilities or [],
created_at=now,
updated_at=now,
)
data["connections"][connection_id] = connection.to_dict()
self._save(data)
return connection
def get(self, connection_id: str) -> ChannelConnection:
data = self._load()
raw = data["connections"].get(connection_id)
if not isinstance(raw, dict):
raise KeyError(connection_id)
return ChannelConnection.from_dict(raw)
def list(self) -> list[ChannelConnection]:
data = self._load()
return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)]
def update(self, connection: ChannelConnection) -> ChannelConnection:
with self._lock:
data = self._load()
if connection.connection_id not in data["connections"]:
raise KeyError(connection.connection_id)
connection.updated_at = iso_now()
data["connections"][connection.connection_id] = connection.to_dict()
self._save(data)
return connection
def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection:
if status not in CONNECTION_STATUSES:
raise ValueError(f"Unsupported connection status: {status}")
connection = self.get(connection_id)
connection.status = status
connection.last_error = last_error
if status in {"connected", "running"}:
connection.last_seen_at = iso_now()
return self.update(connection)
def revoke(self, connection_id: str) -> ChannelConnection:
return self.update_status(connection_id, status="revoked", last_error=None)
def _load(self) -> dict[str, Any]:
if not self.path.exists():
return {"connections": {}}
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"connections": {}}
if not isinstance(data, dict) or not isinstance(data.get("connections"), dict):
return {"connections": {}}
return data
def _save(self, data: dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tmp_path.replace(self.path)
class CredentialStore:
def __init__(self, path: Path) -> None:
self.path = Path(path)
self._lock = Lock()
def put(self, *, kind: str, values: dict[str, Any]) -> str:
cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()}
ref = f"cred_{uuid4().hex}"
with self._lock:
data = self._load()
data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()}
self._save(data)
return ref
def get(self, ref: str) -> dict[str, str]:
data = self._load()
item = data["credentials"].get(ref)
if not isinstance(item, dict):
raise KeyError(ref)
values = item.get("values")
if not isinstance(values, dict):
return {}
return {str(key): str(value) for key, value in values.items()}
def redacted(self, ref: str | None) -> dict[str, str]:
if not ref:
return {}
try:
values = self.get(ref)
except KeyError:
return {}
return {key: "***" for key in values}
def _load(self) -> dict[str, Any]:
if not self.path.exists():
return {"credentials": {}}
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"credentials": {}}
if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict):
return {"credentials": {}}
return data
def _save(self, data: dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tmp_path.replace(self.path)
class PairingTokenStore:
def __init__(self, path: Path) -> None:
self.path = Path(path)
self._lock = Lock()
def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession:
now_ms = _now_ms()
session = PairingSession(
pairing_session_id=f"pair_{uuid4().hex}",
kind=kind,
scope=scope,
token=f"pair_{uuid4().hex}",
status="pending",
expires_at_ms=now_ms + int(ttl_seconds * 1000),
created_at_ms=now_ms,
)
with self._lock:
data = self._load()
data["sessions"][session.pairing_session_id] = session.to_dict()
self._save(data)
return session
def consume(self, token: str, *, expected_kind: str) -> PairingSession | None:
with self._lock:
data = self._load()
for key, raw in data["sessions"].items():
session = PairingSession.from_dict(raw)
if session.token != token or session.kind != expected_kind:
continue
if session.status != "pending" or session.expires_at_ms <= _now_ms():
return None
session.status = "consumed"
data["sessions"][key] = session.to_dict()
self._save(data)
return session
return None
def _load(self) -> dict[str, Any]:
if not self.path.exists():
return {"sessions": {}}
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"sessions": {}}
if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict):
return {"sessions": {}}
return data
def _save(self, data: dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tmp_path.replace(self.path)
def _now_ms() -> int:
return int(time.time() * 1000)
def _slug(value: str) -> str:
text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower())
return "-".join(part for part in text.split("-") if part) or "channel"