Files
beaver_project/app-instance/backend/beaver/interfaces/channels/state.py

199 lines
6.3 KiB
Python

"""Persistent channel runtime state."""
from __future__ import annotations
import json
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from typing import Any
from uuid import uuid4
def _now_ms() -> int:
return int(time.time() * 1000)
def _iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass(slots=True)
class DedupeWriteResult:
created: bool
record: dict[str, Any] | None = None
class ChannelDedupeStore:
def __init__(self, path: Path, *, retention_hours: int = 48) -> None:
self.path = path
self.retention_ms = max(1, int(retention_hours)) * 60 * 60 * 1000
self._lock = Lock()
def get(self, dedupe_key: str) -> dict[str, Any] | None:
with self._lock:
data = self._load()
self._prune_unlocked(data, _now_ms())
record = data["records"].get(dedupe_key)
self._save(data)
return record
def mark_processing(self, *, dedupe_key: str, session_id: str, message_id: str) -> DedupeWriteResult:
with self._lock:
data = self._load()
now_ms = _now_ms()
self._prune_unlocked(data, now_ms)
existing = data["records"].get(dedupe_key)
if existing is not None:
self._save(data)
return DedupeWriteResult(created=False, record=existing)
record = {
"dedupe_key": dedupe_key,
"status": "processing",
"session_id": session_id,
"message_id": message_id,
"run_id": None,
"reply": None,
"error": None,
"created_at_ms": now_ms,
"updated_at_ms": now_ms,
}
data["records"][dedupe_key] = record
self._save(data)
return DedupeWriteResult(created=True, record=record)
def mark_done(
self,
*,
dedupe_key: str,
run_id: str | None,
reply: str,
max_reply_chars: int,
) -> None:
self._mark_result(
dedupe_key=dedupe_key,
status="done",
run_id=run_id,
reply=reply[: max(0, int(max_reply_chars))],
error=None,
)
def mark_error(self, *, dedupe_key: str, error: str, max_error_chars: int) -> None:
self._mark_result(
dedupe_key=dedupe_key,
status="error",
run_id=None,
reply=None,
error=error[: max(0, int(max_error_chars))],
)
def _mark_result(
self,
*,
dedupe_key: str,
status: str,
run_id: str | None,
reply: str | None,
error: str | None,
) -> None:
with self._lock:
data = self._load()
record = data["records"].get(dedupe_key)
if record is None:
record = {"dedupe_key": dedupe_key, "created_at_ms": _now_ms()}
data["records"][dedupe_key] = record
record.update(
{
"status": status,
"run_id": run_id,
"reply": reply,
"error": error,
"updated_at_ms": _now_ms(),
}
)
self._save(data)
def _load(self) -> dict[str, Any]:
if not self.path.exists():
return {"records": {}}
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"records": {}}
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
return {"records": {}}
return data
def _save(self, data: dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tmp_path.replace(self.path)
def _prune_unlocked(self, data: dict[str, Any], now_ms: int) -> None:
records = data.get("records", {})
expired_before = now_ms - self.retention_ms
for key, record in list(records.items()):
updated_at_ms = int(record.get("updated_at_ms") or record.get("created_at_ms") or 0)
if updated_at_ms < expired_before:
records.pop(key, None)
class ChannelEventLog:
def __init__(self, path: Path) -> None:
self.path = path
self._lock = Lock()
def record(
self,
*,
channel_id: str,
kind: str,
session_id: str | None = None,
message_id: str | None = None,
run_id: str | None = None,
status: str = "ok",
error: str | None = None,
text: str | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
entry = {
"event_id": uuid4().hex,
"channel_id": channel_id,
"kind": kind,
"session_id": session_id,
"message_id": message_id,
"run_id": run_id,
"status": status,
"error": error,
"text_preview": (text or "")[:120] if text else None,
"text_length": len(text) if text else 0,
"metadata": metadata or {},
"created_at": _iso_now(),
}
with self._lock:
self.path.parent.mkdir(parents=True, exist_ok=True)
with self.path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(entry, ensure_ascii=False) + "\n")
return entry
def recent(self, *, channel_id: str | None = None, limit: int = 100) -> list[dict[str, Any]]:
if not self.path.exists():
return []
lines = self.path.read_text(encoding="utf-8").splitlines()
items: list[dict[str, Any]] = []
for line in reversed(lines):
try:
item = json.loads(line)
except json.JSONDecodeError:
continue
if channel_id and item.get("channel_id") != channel_id:
continue
items.append(item)
if len(items) >= max(1, int(limit)):
break
return list(reversed(items))