204 lines
7.9 KiB
Python
204 lines
7.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
|
|
def iso_now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ConnectorSessionState:
|
|
session_id: str
|
|
kind: str
|
|
connection_id: str
|
|
channel_id: str
|
|
display_name: str
|
|
status: str
|
|
options: dict[str, Any] = field(default_factory=dict)
|
|
qr_code: str | None = None
|
|
qr_image: str | None = None
|
|
instructions: list[str] = field(default_factory=list)
|
|
account_id: str | None = None
|
|
error: str | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
created_at: str = field(default_factory=iso_now)
|
|
updated_at: str = field(default_factory=iso_now)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "ConnectorSessionState":
|
|
return cls(
|
|
session_id=str(data.get("session_id") or ""),
|
|
kind=str(data.get("kind") or ""),
|
|
connection_id=str(data.get("connection_id") or ""),
|
|
channel_id=str(data.get("channel_id") or ""),
|
|
display_name=str(data.get("display_name") or ""),
|
|
status=str(data.get("status") or "pending"),
|
|
options=dict(data.get("options") or {}),
|
|
qr_code=str(data["qr_code"]) if data.get("qr_code") is not None else None,
|
|
qr_image=str(data["qr_image"]) if data.get("qr_image") is not None else None,
|
|
instructions=[str(item) for item in data.get("instructions") or []],
|
|
account_id=str(data["account_id"]) if data.get("account_id") is not None else None,
|
|
error=str(data["error"]) if data.get("error") is not None else None,
|
|
metadata=dict(data.get("metadata") or {}),
|
|
created_at=str(data.get("created_at") or iso_now()),
|
|
updated_at=str(data.get("updated_at") or iso_now()),
|
|
)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class SendBeginResult:
|
|
should_send: bool
|
|
dedupe_key: str
|
|
status: str
|
|
http_status: int
|
|
retry_after_seconds: int | None = None
|
|
provider_message_id: str | None = None
|
|
|
|
|
|
class SidecarStateStore:
|
|
def __init__(self, path: Path, *, send_processing_ttl_seconds: int = 60) -> None:
|
|
self.path = Path(path)
|
|
self.send_processing_ttl_seconds = int(send_processing_ttl_seconds)
|
|
self._lock = Lock()
|
|
|
|
def create_session(
|
|
self,
|
|
*,
|
|
kind: str,
|
|
connection_id: str,
|
|
channel_id: str,
|
|
display_name: str,
|
|
options: dict[str, Any],
|
|
) -> ConnectorSessionState:
|
|
session = ConnectorSessionState(
|
|
session_id=f"cs_{uuid4().hex}",
|
|
kind=kind,
|
|
connection_id=connection_id,
|
|
channel_id=channel_id,
|
|
display_name=display_name,
|
|
status="pending",
|
|
options=dict(options),
|
|
)
|
|
with self._lock:
|
|
data = self._load()
|
|
data["sessions"][session.session_id] = session.to_dict()
|
|
self._save(data)
|
|
return session
|
|
|
|
def get_session(self, session_id: str) -> ConnectorSessionState:
|
|
data = self._load()
|
|
raw = data["sessions"].get(session_id)
|
|
if not isinstance(raw, dict):
|
|
raise KeyError(session_id)
|
|
return ConnectorSessionState.from_dict(raw)
|
|
|
|
def list_sessions(self) -> list[ConnectorSessionState]:
|
|
data = self._load()
|
|
return [
|
|
ConnectorSessionState.from_dict(raw)
|
|
for raw in data["sessions"].values()
|
|
if isinstance(raw, dict)
|
|
]
|
|
|
|
def find_session_by_connection_id(self, connection_id: str) -> ConnectorSessionState:
|
|
data = self._load()
|
|
matches: list[ConnectorSessionState] = []
|
|
for raw in data["sessions"].values():
|
|
if not isinstance(raw, dict):
|
|
continue
|
|
session = ConnectorSessionState.from_dict(raw)
|
|
if session.connection_id == connection_id:
|
|
matches.append(session)
|
|
if not matches:
|
|
raise KeyError(connection_id)
|
|
matches.sort(key=lambda item: item.updated_at)
|
|
return matches[-1]
|
|
|
|
def update_session(self, session_id: str, **updates: Any) -> ConnectorSessionState:
|
|
with self._lock:
|
|
data = self._load()
|
|
raw = data["sessions"].get(session_id)
|
|
if not isinstance(raw, dict):
|
|
raise KeyError(session_id)
|
|
session = ConnectorSessionState.from_dict(raw)
|
|
for key, value in updates.items():
|
|
if hasattr(session, key):
|
|
setattr(session, key, value)
|
|
session.updated_at = iso_now()
|
|
data["sessions"][session_id] = session.to_dict()
|
|
self._save(data)
|
|
return session
|
|
|
|
def begin_send(self, *, connection_id: str, request_id: str) -> SendBeginResult:
|
|
dedupe_key = f"{connection_id}:{request_id}"
|
|
with self._lock:
|
|
data = self._load()
|
|
existing = data["sends"].get(dedupe_key)
|
|
if isinstance(existing, dict):
|
|
status = str(existing.get("status") or "processing")
|
|
if status == "completed":
|
|
provider_message_id = str(existing.get("provider_message_id") or "")
|
|
return SendBeginResult(False, dedupe_key, "completed", 200, None, provider_message_id)
|
|
if status == "processing" and not self._send_is_stale(existing):
|
|
return SendBeginResult(False, dedupe_key, "processing", 409, 5)
|
|
data["sends"][dedupe_key] = {
|
|
"connection_id": connection_id,
|
|
"request_id": request_id,
|
|
"status": "processing",
|
|
"updated_at": iso_now(),
|
|
}
|
|
self._save(data)
|
|
return SendBeginResult(True, dedupe_key, "processing", 200)
|
|
|
|
def complete_send(self, dedupe_key: str, *, provider_message_id: str | None) -> None:
|
|
with self._lock:
|
|
data = self._load()
|
|
item = dict(data["sends"].get(dedupe_key) or {})
|
|
item.update({"status": "completed", "provider_message_id": provider_message_id, "updated_at": iso_now()})
|
|
data["sends"][dedupe_key] = item
|
|
self._save(data)
|
|
|
|
def fail_send(self, dedupe_key: str, *, error: str | None) -> None:
|
|
with self._lock:
|
|
data = self._load()
|
|
item = dict(data["sends"].get(dedupe_key) or {})
|
|
item.update({"status": "failed", "last_error": error, "updated_at": iso_now()})
|
|
data["sends"][dedupe_key] = item
|
|
self._save(data)
|
|
|
|
def _send_is_stale(self, item: dict[str, Any]) -> bool:
|
|
updated_at = str(item.get("updated_at") or iso_now())
|
|
updated = datetime.fromisoformat(updated_at.replace("Z", "+00:00"))
|
|
return (datetime.now(timezone.utc) - updated).total_seconds() >= self.send_processing_ttl_seconds
|
|
|
|
def _load(self) -> dict[str, Any]:
|
|
if not self.path.exists():
|
|
return {"sessions": {}, "sends": {}}
|
|
try:
|
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return {"sessions": {}, "sends": {}}
|
|
if not isinstance(data, dict):
|
|
return {"sessions": {}, "sends": {}}
|
|
if not isinstance(data.get("sessions"), dict):
|
|
data["sessions"] = {}
|
|
if not isinstance(data.get("sends"), dict):
|
|
data["sends"] = {}
|
|
return data
|
|
|
|
def _save(self, data: dict[str, Any]) -> None:
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
tmp_path.replace(self.path)
|