112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
"""Cross-process workspace write lock with in-process reentrancy."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Iterator
|
|
|
|
if os.name == "nt": # pragma: no cover - exercised on Windows only
|
|
import msvcrt
|
|
else: # pragma: no cover - import branch is platform-specific
|
|
import fcntl
|
|
|
|
|
|
class WorkspaceWriteLockBusy(RuntimeError):
|
|
"""Raised when the shared workspace write lock cannot be acquired."""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _HeldLock:
|
|
rlock: threading.RLock
|
|
handle: object | None = None
|
|
owner_thread: int | None = None
|
|
depth: int = 0
|
|
|
|
|
|
_REGISTRY_GUARD = threading.Lock()
|
|
_HELD_BY_PATH: dict[Path, _HeldLock] = {}
|
|
|
|
|
|
class WorkspaceWriteLock:
|
|
def __init__(self, workspace: str | Path) -> None:
|
|
self.workspace = Path(workspace)
|
|
self.path = self.workspace / ".beaver" / "locks" / "plugin-skill-write.lock"
|
|
|
|
@contextmanager
|
|
def acquire(
|
|
self,
|
|
*,
|
|
timeout_seconds: float | None = None,
|
|
blocking: bool = True,
|
|
) -> Iterator[None]:
|
|
held = self._held_lock()
|
|
thread_id = threading.get_ident()
|
|
with held.rlock:
|
|
if held.owner_thread == thread_id and held.depth > 0:
|
|
held.depth += 1
|
|
try:
|
|
yield
|
|
finally:
|
|
held.depth -= 1
|
|
return
|
|
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
handle = self.path.open("a+b")
|
|
try:
|
|
self._acquire_os_lock(handle, timeout_seconds=timeout_seconds, blocking=blocking)
|
|
held.handle = handle
|
|
held.owner_thread = thread_id
|
|
held.depth = 1
|
|
try:
|
|
yield
|
|
finally:
|
|
held.depth = 0
|
|
held.owner_thread = None
|
|
held.handle = None
|
|
self._release_os_lock(handle)
|
|
finally:
|
|
handle.close()
|
|
|
|
def _held_lock(self) -> _HeldLock:
|
|
resolved = self.path.resolve()
|
|
with _REGISTRY_GUARD:
|
|
held = _HELD_BY_PATH.get(resolved)
|
|
if held is None:
|
|
held = _HeldLock(rlock=threading.RLock())
|
|
_HELD_BY_PATH[resolved] = held
|
|
return held
|
|
|
|
@staticmethod
|
|
def _acquire_os_lock(handle: object, *, timeout_seconds: float | None, blocking: bool) -> None:
|
|
deadline = None if timeout_seconds is None else time.monotonic() + timeout_seconds
|
|
while True:
|
|
try:
|
|
if os.name == "nt": # pragma: no cover
|
|
mode = msvcrt.LK_LOCK if blocking else msvcrt.LK_NBLCK
|
|
msvcrt.locking(handle.fileno(), mode, 1) # type: ignore[attr-defined]
|
|
else:
|
|
flags = fcntl.LOCK_EX
|
|
if not blocking:
|
|
flags |= fcntl.LOCK_NB
|
|
fcntl.flock(handle.fileno(), flags) # type: ignore[attr-defined]
|
|
return
|
|
except (BlockingIOError, OSError):
|
|
if not blocking:
|
|
raise WorkspaceWriteLockBusy("plugin_write_busy")
|
|
if deadline is not None and time.monotonic() >= deadline:
|
|
raise WorkspaceWriteLockBusy("plugin_write_busy")
|
|
time.sleep(0.05)
|
|
|
|
@staticmethod
|
|
def _release_os_lock(handle: object) -> None:
|
|
if os.name == "nt": # pragma: no cover
|
|
handle.seek(0) # type: ignore[attr-defined]
|
|
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
|
|
else:
|
|
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
|