Files

204 lines
7.9 KiB
Python

from __future__ import annotations
import json
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from typing import Any
from uuid import uuid4
def iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass(slots=True)
class ConnectorSessionState:
session_id: str
kind: str
connection_id: str
channel_id: str
display_name: str
status: str
options: dict[str, Any] = field(default_factory=dict)
qr_code: str | None = None
qr_image: str | None = None
instructions: list[str] = field(default_factory=list)
account_id: str | None = None
error: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=iso_now)
updated_at: str = field(default_factory=iso_now)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ConnectorSessionState":
return cls(
session_id=str(data.get("session_id") or ""),
kind=str(data.get("kind") or ""),
connection_id=str(data.get("connection_id") or ""),
channel_id=str(data.get("channel_id") or ""),
display_name=str(data.get("display_name") or ""),
status=str(data.get("status") or "pending"),
options=dict(data.get("options") or {}),
qr_code=str(data["qr_code"]) if data.get("qr_code") is not None else None,
qr_image=str(data["qr_image"]) if data.get("qr_image") is not None else None,
instructions=[str(item) for item in data.get("instructions") or []],
account_id=str(data["account_id"]) if data.get("account_id") is not None else None,
error=str(data["error"]) if data.get("error") is not None else None,
metadata=dict(data.get("metadata") or {}),
created_at=str(data.get("created_at") or iso_now()),
updated_at=str(data.get("updated_at") or iso_now()),
)
@dataclass(slots=True)
class SendBeginResult:
should_send: bool
dedupe_key: str
status: str
http_status: int
retry_after_seconds: int | None = None
provider_message_id: str | None = None
class SidecarStateStore:
def __init__(self, path: Path, *, send_processing_ttl_seconds: int = 60) -> None:
self.path = Path(path)
self.send_processing_ttl_seconds = int(send_processing_ttl_seconds)
self._lock = Lock()
def create_session(
self,
*,
kind: str,
connection_id: str,
channel_id: str,
display_name: str,
options: dict[str, Any],
) -> ConnectorSessionState:
session = ConnectorSessionState(
session_id=f"cs_{uuid4().hex}",
kind=kind,
connection_id=connection_id,
channel_id=channel_id,
display_name=display_name,
status="pending",
options=dict(options),
)
with self._lock:
data = self._load()
data["sessions"][session.session_id] = session.to_dict()
self._save(data)
return session
def get_session(self, session_id: str) -> ConnectorSessionState:
data = self._load()
raw = data["sessions"].get(session_id)
if not isinstance(raw, dict):
raise KeyError(session_id)
return ConnectorSessionState.from_dict(raw)
def list_sessions(self) -> list[ConnectorSessionState]:
data = self._load()
return [
ConnectorSessionState.from_dict(raw)
for raw in data["sessions"].values()
if isinstance(raw, dict)
]
def find_session_by_connection_id(self, connection_id: str) -> ConnectorSessionState:
data = self._load()
matches: list[ConnectorSessionState] = []
for raw in data["sessions"].values():
if not isinstance(raw, dict):
continue
session = ConnectorSessionState.from_dict(raw)
if session.connection_id == connection_id:
matches.append(session)
if not matches:
raise KeyError(connection_id)
matches.sort(key=lambda item: item.updated_at)
return matches[-1]
def update_session(self, session_id: str, **updates: Any) -> ConnectorSessionState:
with self._lock:
data = self._load()
raw = data["sessions"].get(session_id)
if not isinstance(raw, dict):
raise KeyError(session_id)
session = ConnectorSessionState.from_dict(raw)
for key, value in updates.items():
if hasattr(session, key):
setattr(session, key, value)
session.updated_at = iso_now()
data["sessions"][session_id] = session.to_dict()
self._save(data)
return session
def begin_send(self, *, connection_id: str, request_id: str) -> SendBeginResult:
dedupe_key = f"{connection_id}:{request_id}"
with self._lock:
data = self._load()
existing = data["sends"].get(dedupe_key)
if isinstance(existing, dict):
status = str(existing.get("status") or "processing")
if status == "completed":
provider_message_id = str(existing.get("provider_message_id") or "")
return SendBeginResult(False, dedupe_key, "completed", 200, None, provider_message_id)
if status == "processing" and not self._send_is_stale(existing):
return SendBeginResult(False, dedupe_key, "processing", 409, 5)
data["sends"][dedupe_key] = {
"connection_id": connection_id,
"request_id": request_id,
"status": "processing",
"updated_at": iso_now(),
}
self._save(data)
return SendBeginResult(True, dedupe_key, "processing", 200)
def complete_send(self, dedupe_key: str, *, provider_message_id: str | None) -> None:
with self._lock:
data = self._load()
item = dict(data["sends"].get(dedupe_key) or {})
item.update({"status": "completed", "provider_message_id": provider_message_id, "updated_at": iso_now()})
data["sends"][dedupe_key] = item
self._save(data)
def fail_send(self, dedupe_key: str, *, error: str | None) -> None:
with self._lock:
data = self._load()
item = dict(data["sends"].get(dedupe_key) or {})
item.update({"status": "failed", "last_error": error, "updated_at": iso_now()})
data["sends"][dedupe_key] = item
self._save(data)
def _send_is_stale(self, item: dict[str, Any]) -> bool:
updated_at = str(item.get("updated_at") or iso_now())
updated = datetime.fromisoformat(updated_at.replace("Z", "+00:00"))
return (datetime.now(timezone.utc) - updated).total_seconds() >= self.send_processing_ttl_seconds
def _load(self) -> dict[str, Any]:
if not self.path.exists():
return {"sessions": {}, "sends": {}}
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"sessions": {}, "sends": {}}
if not isinstance(data, dict):
return {"sessions": {}, "sends": {}}
if not isinstance(data.get("sessions"), dict):
data["sessions"] = {}
if not isinstance(data.get("sends"), dict):
data["sends"] = {}
return data
def _save(self, data: dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tmp_path.replace(self.path)