chore: initialize EverOS 1.0.0
md-first memory extraction framework for AI agents. Markdown is the single source of truth; SQLite holds state and LanceDB provides the rebuildable vector + BM25 + scalar index. The codebase follows a single-direction DDD layering (entrypoints -> service -> memory -> infra, with component / core / config cross-cutting) enforced by import-linter. Engineering surface: - Coding conventions in .claude/rules/ (path-scoped) and workflows in .claude/skills/ (/commit, /new-branch, /pr). - GitHub Actions CI runs make lint + test + integration; pre-commit mirrors the gates locally (ruff, hygiene hooks, gitlint commit-msg). - Commit messages follow Conventional Commits, enforced by gitlint. - make lint also enforces datetime two-zone discipline and OpenAPI drift.
This commit is contained in:
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
0
tests/unit/test_component/__init__.py
Normal file
0
tests/unit/test_component/__init__.py
Normal file
0
tests/unit/test_component/test_config/__init__.py
Normal file
0
tests/unit/test_component/test_config/__init__.py
Normal file
167
tests/unit/test_component/test_config/test_loader.py
Normal file
167
tests/unit/test_component/test_config/test_loader.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""Unit tests for YamlConfigLoader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.config import YamlConfigLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_root(tmp_path: Path) -> Path:
|
||||
"""Build a fixture config tree::
|
||||
|
||||
tmp_path/
|
||||
prompt_slots/
|
||||
episode.yaml
|
||||
atomic_fact.yaml
|
||||
custom_dir/
|
||||
alpha.yaml
|
||||
"""
|
||||
(tmp_path / "prompt_slots").mkdir()
|
||||
(tmp_path / "prompt_slots" / "episode.yaml").write_text(
|
||||
"template: extract episode\nvariables:\n memcell: input memcell\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "prompt_slots" / "atomic_fact.yaml").write_text(
|
||||
"template: extract atomic fact\n", encoding="utf-8"
|
||||
)
|
||||
(tmp_path / "custom_dir").mkdir()
|
||||
(tmp_path / "custom_dir" / "alpha.yaml").write_text(
|
||||
"value: alpha\n", encoding="utf-8"
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_register_default_subdir(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
meta = loader.find("prompt_slots", "episode")
|
||||
assert meta == {
|
||||
"template": "extract episode",
|
||||
"variables": {"memcell": "input memcell"},
|
||||
}
|
||||
|
||||
|
||||
def test_register_custom_subdir(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("alphas", subdir="custom_dir")
|
||||
meta = loader.find("alphas", "alpha")
|
||||
assert meta == {"value": "alpha"}
|
||||
|
||||
|
||||
def test_constructor_categories_dict(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(
|
||||
root=config_root,
|
||||
categories={"prompt_slots": None, "alphas": "custom_dir"},
|
||||
)
|
||||
assert sorted(loader.categories()) == ["alphas", "prompt_slots"]
|
||||
assert loader.find("alphas", "alpha") == {"value": "alpha"}
|
||||
|
||||
|
||||
def test_find_unregistered_category_raises(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
with pytest.raises(KeyError, match="not registered"):
|
||||
loader.find("ghost", "x")
|
||||
|
||||
|
||||
def test_find_missing_file_raises(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
loader.find("prompt_slots", "no_such")
|
||||
|
||||
|
||||
def test_find_non_mapping_top_level_raises(tmp_path: Path) -> None:
|
||||
(tmp_path / "prompt_slots").mkdir()
|
||||
# Top-level is a list, not a mapping — must be rejected.
|
||||
(tmp_path / "prompt_slots" / "bad.yaml").write_text(
|
||||
"- one\n- two\n", encoding="utf-8"
|
||||
)
|
||||
loader = YamlConfigLoader(root=tmp_path)
|
||||
loader.register_category("prompt_slots")
|
||||
with pytest.raises(TypeError, match="must be a mapping"):
|
||||
loader.find("prompt_slots", "bad")
|
||||
|
||||
|
||||
def test_find_empty_file_yields_empty_dict(tmp_path: Path) -> None:
|
||||
(tmp_path / "prompt_slots").mkdir()
|
||||
(tmp_path / "prompt_slots" / "blank.yaml").write_text("", encoding="utf-8")
|
||||
loader = YamlConfigLoader(root=tmp_path)
|
||||
loader.register_category("prompt_slots")
|
||||
assert loader.find("prompt_slots", "blank") == {}
|
||||
|
||||
|
||||
def test_list_returns_sorted_stems(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
assert loader.list("prompt_slots") == ["atomic_fact", "episode"]
|
||||
|
||||
|
||||
def test_list_unregistered_category_raises(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
with pytest.raises(KeyError):
|
||||
loader.list("ghost")
|
||||
|
||||
|
||||
def test_list_empty_directory(tmp_path: Path) -> None:
|
||||
loader = YamlConfigLoader(root=tmp_path)
|
||||
loader.register_category("nope")
|
||||
assert loader.list("nope") == [] # missing directory → empty
|
||||
|
||||
|
||||
def test_cache_returns_same_object(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
a = loader.find("prompt_slots", "episode")
|
||||
b = loader.find("prompt_slots", "episode")
|
||||
assert a is b # cached, same dict reference
|
||||
|
||||
|
||||
def test_refresh_invalidates_cache_and_reloads(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
a = loader.find("prompt_slots", "episode")
|
||||
|
||||
# Modify the file on disk; without refresh the loader still returns
|
||||
# the cached value.
|
||||
(config_root / "prompt_slots" / "episode.yaml").write_text(
|
||||
"template: MODIFIED\n", encoding="utf-8"
|
||||
)
|
||||
cached = loader.find("prompt_slots", "episode")
|
||||
assert cached is a # still the cached object
|
||||
|
||||
loader.refresh()
|
||||
fresh = loader.find("prompt_slots", "episode")
|
||||
assert fresh is not a
|
||||
assert fresh == {"template": "MODIFIED"}
|
||||
|
||||
|
||||
def test_refresh_specific_entry(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(root=config_root)
|
||||
loader.register_category("prompt_slots")
|
||||
e = loader.find("prompt_slots", "episode")
|
||||
a = loader.find("prompt_slots", "atomic_fact")
|
||||
|
||||
(config_root / "prompt_slots" / "episode.yaml").write_text(
|
||||
"template: NEW\n", encoding="utf-8"
|
||||
)
|
||||
loader.refresh("prompt_slots", "episode")
|
||||
|
||||
assert loader.find("prompt_slots", "episode") != e # reloaded
|
||||
assert loader.find("prompt_slots", "atomic_fact") is a # untouched
|
||||
|
||||
|
||||
def test_refresh_full_category(config_root: Path) -> None:
|
||||
loader = YamlConfigLoader(
|
||||
root=config_root,
|
||||
categories={"prompt_slots": None, "alphas": "custom_dir"},
|
||||
)
|
||||
loader.find("prompt_slots", "episode")
|
||||
a = loader.find("alphas", "alpha")
|
||||
|
||||
loader.refresh("prompt_slots")
|
||||
# alphas cache survives the prompt_slots refresh
|
||||
assert loader.find("alphas", "alpha") is a
|
||||
46
tests/unit/test_component/test_embedding/test_factory.py
Normal file
46
tests/unit/test_component/test_embedding/test_factory.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""``build_embedding_provider`` — settings validation + provider build."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from everos.component.embedding import (
|
||||
OpenAIEmbeddingProvider,
|
||||
build_embedding_provider,
|
||||
)
|
||||
from everos.config.settings import EmbeddingSettings
|
||||
|
||||
|
||||
def test_raises_when_model_missing() -> None:
|
||||
s = EmbeddingSettings(model=None, api_key=SecretStr("k"), base_url="https://x")
|
||||
with pytest.raises(ValueError, match="EVEROS_EMBEDDING__MODEL"):
|
||||
build_embedding_provider(s)
|
||||
|
||||
|
||||
def test_raises_when_api_key_missing() -> None:
|
||||
s = EmbeddingSettings(model="m", api_key=None, base_url="https://x")
|
||||
with pytest.raises(ValueError, match="EVEROS_EMBEDDING__API_KEY"):
|
||||
build_embedding_provider(s)
|
||||
|
||||
|
||||
def test_raises_when_base_url_missing() -> None:
|
||||
s = EmbeddingSettings(model="m", api_key=SecretStr("k"), base_url=None)
|
||||
with pytest.raises(ValueError, match="EVEROS_EMBEDDING__BASE_URL"):
|
||||
build_embedding_provider(s)
|
||||
|
||||
|
||||
def test_builds_openai_embedding_provider_with_default_dim() -> None:
|
||||
s = EmbeddingSettings(model="m", api_key=SecretStr("k"), base_url="https://x")
|
||||
p = build_embedding_provider(s)
|
||||
assert isinstance(p, OpenAIEmbeddingProvider)
|
||||
|
||||
|
||||
def test_custom_dim_passes_through() -> None:
|
||||
s = EmbeddingSettings(model="m", api_key=SecretStr("k"), base_url="https://x")
|
||||
p = build_embedding_provider(s, dim=512)
|
||||
assert isinstance(p, OpenAIEmbeddingProvider)
|
||||
# Provider stores dim on a private attr; assert via the public output shape
|
||||
# only if straightforward. Skip introspection if attr name differs.
|
||||
if hasattr(p, "_dim"):
|
||||
assert p._dim == 512
|
||||
0
tests/unit/test_component/test_llm/__init__.py
Normal file
0
tests/unit/test_component/test_llm/__init__.py
Normal file
64
tests/unit/test_component/test_llm/test_client.py
Normal file
64
tests/unit/test_component/test_llm/test_client.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""get_llm_client — raises on missing credentials, caches on success."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from everos.component.llm import LLMNotConfiguredError
|
||||
from everos.config import Settings
|
||||
from everos.config.settings import LLMSettings
|
||||
|
||||
_client_mod = importlib.import_module("everos.component.llm.client")
|
||||
|
||||
|
||||
def _reset_singleton(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(_client_mod, "_llm_client", None, raising=False)
|
||||
|
||||
|
||||
def _patch_settings(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
api_key: str | None,
|
||||
base_url: str | None,
|
||||
) -> None:
|
||||
"""Stub the ``load_settings`` reference bound inside the client module."""
|
||||
cfg = Settings(
|
||||
llm=LLMSettings(
|
||||
model="gpt-4o-mini",
|
||||
api_key=SecretStr(api_key) if api_key is not None else None,
|
||||
base_url=base_url,
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(_client_mod, "load_settings", lambda: cfg)
|
||||
|
||||
|
||||
def test_raises_when_api_key_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_reset_singleton(monkeypatch)
|
||||
_patch_settings(monkeypatch, api_key=None, base_url="https://example.test")
|
||||
|
||||
with pytest.raises(LLMNotConfiguredError, match="EVEROS_LLM__API_KEY"):
|
||||
_client_mod.get_llm_client()
|
||||
|
||||
|
||||
def test_raises_when_base_url_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_reset_singleton(monkeypatch)
|
||||
_patch_settings(monkeypatch, api_key="sk-test", base_url=None)
|
||||
|
||||
with pytest.raises(LLMNotConfiguredError, match="EVEROS_LLM__BASE_URL"):
|
||||
_client_mod.get_llm_client()
|
||||
|
||||
|
||||
def test_returns_singleton_when_configured(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_reset_singleton(monkeypatch)
|
||||
_patch_settings(monkeypatch, api_key="sk-test", base_url="https://example.test")
|
||||
sentinel = object()
|
||||
monkeypatch.setattr(_client_mod, "build_client", lambda cfg: sentinel)
|
||||
|
||||
first = _client_mod.get_llm_client()
|
||||
second = _client_mod.get_llm_client()
|
||||
|
||||
assert first is sentinel
|
||||
assert first is second
|
||||
28
tests/unit/test_component/test_llm/test_factory.py
Normal file
28
tests/unit/test_component/test_llm/test_factory.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""``build_llm_provider`` — settings validation + provider build."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from everos.component.llm import build_llm_provider
|
||||
from everos.component.llm.openai_provider import OpenAIProvider
|
||||
from everos.config.settings import LLMSettings
|
||||
|
||||
|
||||
def test_raises_when_api_key_missing() -> None:
|
||||
s = LLMSettings(model="m", api_key=None, base_url="https://x")
|
||||
with pytest.raises(ValueError, match="EVEROS_LLM__API_KEY"):
|
||||
build_llm_provider(s)
|
||||
|
||||
|
||||
def test_raises_when_base_url_missing() -> None:
|
||||
s = LLMSettings(model="m", api_key=SecretStr("k"), base_url=None)
|
||||
with pytest.raises(ValueError, match="EVEROS_LLM__BASE_URL"):
|
||||
build_llm_provider(s)
|
||||
|
||||
|
||||
def test_builds_openai_provider() -> None:
|
||||
s = LLMSettings(model="m", api_key=SecretStr("k"), base_url="https://x")
|
||||
p = build_llm_provider(s)
|
||||
assert isinstance(p, OpenAIProvider)
|
||||
0
tests/unit/test_component/test_rerank/__init__.py
Normal file
0
tests/unit/test_component/test_rerank/__init__.py
Normal file
254
tests/unit/test_component/test_rerank/test_deepinfra_provider.py
Normal file
254
tests/unit/test_component/test_rerank/test_deepinfra_provider.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""DeepInfra rerank provider — happy path, batching, retries, errors.
|
||||
|
||||
httpx is faked via :class:`httpx.MockTransport`; the provider's
|
||||
``httpx.AsyncClient(timeout=...)`` ctx manager is monkeypatched to
|
||||
return a client wired to the transport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from everos.component.rerank import DeepInfraRerankProvider, RerankError
|
||||
|
||||
|
||||
def _patch_httpx(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
handler: Callable[[httpx.Request], httpx.Response],
|
||||
) -> None:
|
||||
"""Make ``httpx.AsyncClient(timeout=...)`` use a MockTransport."""
|
||||
transport = httpx.MockTransport(handler)
|
||||
import everos.component.rerank.deepinfra_provider as mod
|
||||
|
||||
real_cls = httpx.AsyncClient
|
||||
|
||||
def factory(*args: object, **kwargs: object) -> httpx.AsyncClient:
|
||||
kwargs["transport"] = transport
|
||||
return real_cls(*args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
monkeypatch.setattr(mod.httpx, "AsyncClient", factory)
|
||||
|
||||
|
||||
def _ok_response(scores: list[float]) -> httpx.Response:
|
||||
return httpx.Response(200, json={"scores": [scores]})
|
||||
|
||||
|
||||
async def test_empty_documents_short_circuits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return _ok_response([])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(model="m", api_key="k", base_url="https://api/v1")
|
||||
assert await p.rerank("q", []) == []
|
||||
assert calls == 0
|
||||
|
||||
|
||||
async def test_scores_sorted_descending(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return _ok_response([0.1, 0.9, 0.5])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", batch_size=10
|
||||
)
|
||||
results = await p.rerank("q", ["a", "b", "c"])
|
||||
assert [r.index for r in results] == [1, 2, 0]
|
||||
assert results[0].score == pytest.approx(0.9)
|
||||
|
||||
|
||||
async def test_batching_merges_chunk_indices(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""batch_size=2 with 3 documents → 2 chunks; merged indices respect offset."""
|
||||
seen_bodies: list[list[str]] = []
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
body = json.loads(req.content)
|
||||
seen_bodies.append(body["documents"])
|
||||
# Score by length so we can verify ordering.
|
||||
return _ok_response([float(len(d)) for d in body["documents"]])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", batch_size=2
|
||||
)
|
||||
docs = ["x", "yy", "zzz"]
|
||||
results = await p.rerank("q", docs)
|
||||
assert {len(b) for b in seen_bodies} == {1, 2}
|
||||
# Sorted desc by score = len: "zzz"=3 → idx 2, "yy"=2 → idx 1, "x"=1 → idx 0
|
||||
assert [r.index for r in results] == [2, 1, 0]
|
||||
|
||||
|
||||
async def test_url_appends_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen_urls: list[str] = []
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
seen_urls.append(str(req.url))
|
||||
return _ok_response([0.5])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="Qwen/Q",
|
||||
api_key="k",
|
||||
# Trailing slash should be stripped before appending model path.
|
||||
base_url="https://api.deepinfra.com/v1/inference/",
|
||||
)
|
||||
await p.rerank("q", ["a"])
|
||||
assert seen_urls == ["https://api.deepinfra.com/v1/inference/Qwen/Q"]
|
||||
|
||||
|
||||
async def test_4xx_raises_immediately(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return httpx.Response(400, text="bad input")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", max_retries=3
|
||||
)
|
||||
with pytest.raises(RerankError, match="HTTP 400"):
|
||||
await p.rerank("q", ["a"])
|
||||
assert calls == 1 # no retry on 4xx
|
||||
|
||||
|
||||
async def test_5xx_retries_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = {"calls": 0}
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
state["calls"] += 1
|
||||
if state["calls"] < 3:
|
||||
return httpx.Response(503, text="busy")
|
||||
return _ok_response([0.7])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", max_retries=3
|
||||
)
|
||||
results = await p.rerank("q", ["a"])
|
||||
assert state["calls"] == 3
|
||||
assert results[0].score == pytest.approx(0.7)
|
||||
|
||||
|
||||
async def test_5xx_exhausts_retries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(500, text="boom")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", max_retries=1
|
||||
)
|
||||
with pytest.raises(RerankError, match="HTTP 500"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_429_retries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = {"calls": 0}
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
state["calls"] += 1
|
||||
if state["calls"] == 1:
|
||||
return httpx.Response(429, text="slow down")
|
||||
return _ok_response([0.4])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", max_retries=3
|
||||
)
|
||||
results = await p.rerank("q", ["a"])
|
||||
assert state["calls"] == 2
|
||||
assert results[0].score == pytest.approx(0.4)
|
||||
|
||||
|
||||
async def test_transport_error_retries_then_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectError("network down")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", max_retries=1
|
||||
)
|
||||
with pytest.raises(RerankError, match="transport failure"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_malformed_scores_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"something_else": []})
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(model="m", api_key="k", base_url="https://api/v1")
|
||||
with pytest.raises(RerankError, match="missing scores"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_score_length_mismatch_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"scores": [[0.1, 0.2]]})
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(
|
||||
model="m", api_key="k", base_url="https://api/v1", batch_size=10
|
||||
)
|
||||
with pytest.raises(RerankError, match="returned 2 scores, expected 3"):
|
||||
await p.rerank("q", ["a", "b", "c"])
|
||||
|
||||
|
||||
async def test_payload_wraps_qwen3_template(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Query + documents are wrapped in the Qwen3-Reranker chat template.
|
||||
|
||||
DeepInfra's inference API scores raw text, so the prompt scaffolding
|
||||
(system frame + ``<Instruct>``/``<Query>``/``<Document>`` markers) must be
|
||||
supplied client-side or the reranker returns uncalibrated scores.
|
||||
"""
|
||||
captured: dict[str, list[str]] = {}
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
captured.update(json.loads(req.content))
|
||||
return _ok_response([0.5])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(model="m", api_key="k", base_url="https://api/v1")
|
||||
await p.rerank("what did Alice eat?", ["pasta"], instruction="find facts")
|
||||
|
||||
query_sent = captured["queries"][0]
|
||||
assert query_sent.startswith("<|im_start|>system")
|
||||
assert "<Instruct>: find facts" in query_sent
|
||||
assert "<Query>: what did Alice eat?" in query_sent
|
||||
assert captured["documents"][0].startswith("<Document>: pasta")
|
||||
|
||||
|
||||
async def test_default_instruction_when_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A ``None`` instruction falls back to the provider's default, not blank."""
|
||||
captured: dict[str, list[str]] = {}
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
captured.update(json.loads(req.content))
|
||||
return _ok_response([0.5])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(model="m", api_key="k", base_url="https://api/v1")
|
||||
await p.rerank("q", ["d"])
|
||||
assert "<Instruct>: Given a question and a passage" in captured["queries"][0]
|
||||
|
||||
|
||||
async def test_flat_scores_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If response is ``{"scores": [s1, s2]}`` (flat), the unwrap still works."""
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"scores": [0.3, 0.6]})
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = DeepInfraRerankProvider(model="m", api_key="k", base_url="https://api/v1")
|
||||
results = await p.rerank("q", ["a", "b"])
|
||||
assert [r.score for r in results] == [0.6, 0.3]
|
||||
67
tests/unit/test_component/test_rerank/test_factory.py
Normal file
67
tests/unit/test_component/test_rerank/test_factory.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""``build_rerank_provider`` — settings validation + provider routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from everos.component.rerank import (
|
||||
DeepInfraRerankProvider,
|
||||
VllmRerankProvider,
|
||||
build_rerank_provider,
|
||||
)
|
||||
from everos.config.settings import RerankSettings
|
||||
|
||||
|
||||
def test_raises_when_model_missing() -> None:
|
||||
s = RerankSettings(model=None, api_key=SecretStr("k"), base_url="https://x")
|
||||
with pytest.raises(ValueError, match="EVEROS_RERANK__MODEL"):
|
||||
build_rerank_provider(s)
|
||||
|
||||
|
||||
def test_raises_when_base_url_missing() -> None:
|
||||
s = RerankSettings(model="m", api_key=SecretStr("k"), base_url=None)
|
||||
with pytest.raises(ValueError, match="EVEROS_RERANK__BASE_URL"):
|
||||
build_rerank_provider(s)
|
||||
|
||||
|
||||
def test_deepinfra_requires_api_key() -> None:
|
||||
s = RerankSettings(
|
||||
provider="deepinfra", model="m", api_key=None, base_url="https://x"
|
||||
)
|
||||
with pytest.raises(ValueError, match="EVEROS_RERANK__API_KEY"):
|
||||
build_rerank_provider(s)
|
||||
|
||||
|
||||
def test_deepinfra_builds_provider() -> None:
|
||||
s = RerankSettings(
|
||||
provider="deepinfra",
|
||||
model="m",
|
||||
api_key=SecretStr("k"),
|
||||
base_url="https://api/v1/inference",
|
||||
)
|
||||
p = build_rerank_provider(s)
|
||||
assert isinstance(p, DeepInfraRerankProvider)
|
||||
|
||||
|
||||
def test_vllm_accepts_empty_api_key() -> None:
|
||||
"""vLLM self-hosted: empty api_key is allowed (no auth header)."""
|
||||
s = RerankSettings(
|
||||
provider="vllm",
|
||||
model="m",
|
||||
api_key=None,
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
p = build_rerank_provider(s)
|
||||
assert isinstance(p, VllmRerankProvider)
|
||||
|
||||
|
||||
def test_vllm_with_api_key() -> None:
|
||||
s = RerankSettings(
|
||||
provider="vllm",
|
||||
model="m",
|
||||
api_key=SecretStr("k"),
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
p = build_rerank_provider(s)
|
||||
assert isinstance(p, VllmRerankProvider)
|
||||
187
tests/unit/test_component/test_rerank/test_vllm_provider.py
Normal file
187
tests/unit/test_component/test_rerank/test_vllm_provider.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""vLLM rerank provider — auth header conditional, results parsing, retries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from everos.component.rerank import RerankError, VllmRerankProvider
|
||||
|
||||
|
||||
def _patch_httpx(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
handler: Callable[[httpx.Request], httpx.Response],
|
||||
) -> None:
|
||||
transport = httpx.MockTransport(handler)
|
||||
import everos.component.rerank.vllm_provider as mod
|
||||
|
||||
real_cls = httpx.AsyncClient
|
||||
|
||||
def factory(*args: object, **kwargs: object) -> httpx.AsyncClient:
|
||||
kwargs["transport"] = transport
|
||||
return real_cls(*args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
monkeypatch.setattr(mod.httpx, "AsyncClient", factory)
|
||||
|
||||
|
||||
def _ok_response(items: list[dict[str, float | int]]) -> httpx.Response:
|
||||
return httpx.Response(200, json={"results": items})
|
||||
|
||||
|
||||
async def test_empty_documents_short_circuits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return _ok_response([])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1")
|
||||
assert await p.rerank("q", []) == []
|
||||
assert calls == 0
|
||||
|
||||
|
||||
async def test_url_and_sort_desc(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen_urls: list[str] = []
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
seen_urls.append(str(req.url))
|
||||
return _ok_response(
|
||||
[
|
||||
{"index": 0, "relevance_score": 0.1},
|
||||
{"index": 1, "relevance_score": 0.9},
|
||||
{"index": 2, "relevance_score": 0.5},
|
||||
]
|
||||
)
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="k", base_url="http://localhost:8000/v1/")
|
||||
results = await p.rerank("q", ["a", "b", "c"])
|
||||
# Trailing slash stripped, ``/rerank`` appended.
|
||||
assert seen_urls == ["http://localhost:8000/v1/rerank"]
|
||||
assert [r.index for r in results] == [1, 2, 0]
|
||||
|
||||
|
||||
async def test_auth_header_added_when_api_key_set(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
seen_headers: list[dict[str, str]] = []
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
seen_headers.append(dict(req.headers))
|
||||
return _ok_response([{"index": 0, "relevance_score": 0.5}])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="sk-abc", base_url="http://x/v1")
|
||||
await p.rerank("q", ["a"])
|
||||
assert seen_headers[0].get("authorization") == "Bearer sk-abc"
|
||||
|
||||
|
||||
async def test_auth_header_omitted_when_api_key_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
seen_headers: list[dict[str, str]] = []
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
seen_headers.append(dict(req.headers))
|
||||
return _ok_response([{"index": 0, "relevance_score": 0.5}])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1")
|
||||
await p.rerank("q", ["a"])
|
||||
assert "authorization" not in seen_headers[0]
|
||||
|
||||
|
||||
async def test_batching_offsets_indices(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""With batch_size=2 and 3 docs, the second batch's result index 0 becomes 2."""
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
import json
|
||||
|
||||
body = json.loads(req.content)
|
||||
docs = body["documents"]
|
||||
# Each chunk: return per-chunk indices 0..len-1
|
||||
return _ok_response(
|
||||
[{"index": i, "relevance_score": float(i)} for i in range(len(docs))]
|
||||
)
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1", batch_size=2)
|
||||
results = await p.rerank("q", ["a", "b", "c"])
|
||||
# Returned indices should be 0, 1 from chunk 1; 2 from chunk 2.
|
||||
assert sorted(r.index for r in results) == [0, 1, 2]
|
||||
|
||||
|
||||
async def test_4xx_raises_immediately(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = {"calls": 0}
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
state["calls"] += 1
|
||||
return httpx.Response(401, text="unauthorized")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(
|
||||
model="m", api_key="bad", base_url="http://x/v1", max_retries=3
|
||||
)
|
||||
with pytest.raises(RerankError, match="HTTP 401"):
|
||||
await p.rerank("q", ["a"])
|
||||
assert state["calls"] == 1
|
||||
|
||||
|
||||
async def test_5xx_retries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = {"calls": 0}
|
||||
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
state["calls"] += 1
|
||||
if state["calls"] < 2:
|
||||
return httpx.Response(502, text="bad gw")
|
||||
return _ok_response([{"index": 0, "relevance_score": 0.42}])
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1", max_retries=3)
|
||||
results = await p.rerank("q", ["a"])
|
||||
assert state["calls"] == 2
|
||||
assert results[0].score == pytest.approx(0.42)
|
||||
|
||||
|
||||
async def test_5xx_exhausts_retries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(500, text="boom")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1", max_retries=1)
|
||||
with pytest.raises(RerankError, match="HTTP 500"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_transport_error_exhausts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ReadTimeout("timeout")
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1", max_retries=1)
|
||||
with pytest.raises(RerankError, match="transport failure"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_malformed_results_missing_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"data": []})
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1")
|
||||
with pytest.raises(RerankError, match="missing results"):
|
||||
await p.rerank("q", ["a"])
|
||||
|
||||
|
||||
async def test_malformed_result_entry(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def handler(_req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"results": [{"index": 0}]})
|
||||
|
||||
_patch_httpx(monkeypatch, handler)
|
||||
p = VllmRerankProvider(model="m", api_key="", base_url="http://x/v1")
|
||||
with pytest.raises(RerankError, match="malformed rerank result"):
|
||||
await p.rerank("q", ["a"])
|
||||
98
tests/unit/test_component/test_tokenizer/test_jieba.py
Normal file
98
tests/unit/test_component/test_tokenizer/test_jieba.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Unit tests for :class:`JiebaTokenizer`.
|
||||
|
||||
Verify the contract that callers downstream depend on:
|
||||
|
||||
* clean token list (no whitespace, no empty strings),
|
||||
* CJK + ASCII pass-through under ``cut_for_search`` segmentation,
|
||||
* default stopword + ``min_length=2`` filter applied,
|
||||
* batch preserves order.
|
||||
|
||||
The tokenizer is symmetric — cascade write side and search query side
|
||||
both go through this code path, so changes here change BM25 recall on
|
||||
both ends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from everos.component.tokenizer import JiebaTokenizer, build_tokenizer
|
||||
|
||||
|
||||
def test_tokenize_returns_list_for_english() -> None:
|
||||
tokens = JiebaTokenizer().tokenize("hello world")
|
||||
assert tokens == ["hello", "world"]
|
||||
|
||||
|
||||
def test_tokenize_drops_pure_whitespace() -> None:
|
||||
"""Whitespace-only tokens never reach the BM25 column."""
|
||||
tokens = JiebaTokenizer().tokenize("foo bar")
|
||||
assert all(t.strip() for t in tokens)
|
||||
|
||||
|
||||
def test_tokenize_empty_input() -> None:
|
||||
assert JiebaTokenizer().tokenize("") == []
|
||||
|
||||
|
||||
def test_tokenize_cjk_keeps_multichar_words() -> None:
|
||||
"""``cut_for_search`` keeps multi-character compounds usable by BM25."""
|
||||
tokens = JiebaTokenizer().tokenize("我爱北京天安门")
|
||||
# Single-char tokens (我 / 爱) are filtered by min_length=2 (and 我
|
||||
# is also in the default stopword set). Multi-char compounds survive.
|
||||
assert "我" not in tokens
|
||||
assert "爱" not in tokens
|
||||
assert "北京" in tokens
|
||||
assert any(t in {"天安门", "天安"} for t in tokens)
|
||||
|
||||
|
||||
def test_tokenize_drops_default_english_stopwords() -> None:
|
||||
tokens = JiebaTokenizer().tokenize("the quick brown fox")
|
||||
assert "the" not in tokens
|
||||
assert "quick" in tokens
|
||||
assert "brown" in tokens
|
||||
assert "fox" in tokens
|
||||
|
||||
|
||||
def test_tokenize_drops_short_tokens_below_min_length() -> None:
|
||||
"""Single-char ASCII tokens are dropped by the default ``min_length=2``."""
|
||||
tokens = JiebaTokenizer().tokenize("a quick b run")
|
||||
assert "a" not in tokens
|
||||
assert "b" not in tokens
|
||||
assert "quick" in tokens
|
||||
assert "run" in tokens
|
||||
|
||||
|
||||
def test_tokenize_is_case_insensitive() -> None:
|
||||
"""Lowercasing is part of the symmetric contract."""
|
||||
tokens = JiebaTokenizer().tokenize("HELLO World")
|
||||
assert tokens == ["hello", "world"]
|
||||
|
||||
|
||||
def test_extra_stopwords_extend_defaults() -> None:
|
||||
tk = JiebaTokenizer(extra_stopwords=frozenset({"hello"}))
|
||||
tokens = tk.tokenize("hello world")
|
||||
assert "hello" not in tokens
|
||||
assert "world" in tokens
|
||||
|
||||
|
||||
def test_custom_min_token_length_relaxes_filter() -> None:
|
||||
"""Lower ``min_length`` lets shorter tokens through.
|
||||
|
||||
Stopword filter still applies — even at ``min_length=1`` the English
|
||||
article ``"a"`` stays filtered because it's in the default stopwords.
|
||||
"""
|
||||
tokens = JiebaTokenizer(min_token_length=1).tokenize("a quick b")
|
||||
# 'a' is in the default English stopword set even at min_length=1.
|
||||
assert "a" not in tokens
|
||||
assert "b" in tokens
|
||||
assert "quick" in tokens
|
||||
|
||||
|
||||
def test_tokenize_batch_preserves_order() -> None:
|
||||
tk = JiebaTokenizer()
|
||||
out = tk.tokenize_batch(["foo bar", "baz", ""])
|
||||
assert len(out) == 3
|
||||
assert out[2] == []
|
||||
|
||||
|
||||
def test_build_tokenizer_returns_jieba_default() -> None:
|
||||
"""Factory exposes the same JiebaTokenizer the cascade handler uses."""
|
||||
assert isinstance(build_tokenizer(), JiebaTokenizer)
|
||||
0
tests/unit/test_component/test_utils/__init__.py
Normal file
0
tests/unit/test_component/test_utils/__init__.py
Normal file
1471
tests/unit/test_component/test_utils/test_datetime.py
Normal file
1471
tests/unit/test_component/test_utils/test_datetime.py
Normal file
File diff suppressed because it is too large
Load Diff
0
tests/unit/test_config/__init__.py
Normal file
0
tests/unit/test_config/__init__.py
Normal file
173
tests/unit/test_config/test_settings.py
Normal file
173
tests/unit/test_config/test_settings.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Unit tests for Settings loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.config import Settings, load_settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Strip any EVEROS_* env vars from the host so tests are deterministic."""
|
||||
for key in list(__import__("os").environ):
|
||||
if key.startswith("EVEROS_"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
load_settings.cache_clear()
|
||||
|
||||
|
||||
def test_load_settings_defaults_from_toml() -> None:
|
||||
s = load_settings()
|
||||
# Values straight out of config/default.toml
|
||||
assert s.memory.root == Path("~/.everos")
|
||||
assert s.memory.timezone == "UTC"
|
||||
assert s.sqlite.journal_mode == "WAL"
|
||||
assert s.sqlite.synchronous == "NORMAL"
|
||||
assert s.sqlite.foreign_keys is True
|
||||
assert s.sqlite.temp_store == "MEMORY"
|
||||
assert s.sqlite.busy_timeout_ms == 5000
|
||||
assert s.sqlite.journal_size_limit_bytes == 64 * 1024 * 1024
|
||||
assert s.sqlite.cache_size_kb == 2048
|
||||
assert s.lancedb.read_consistency_seconds is None
|
||||
|
||||
|
||||
def test_env_overrides_toml(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_SQLITE__BUSY_TIMEOUT_MS", "10000")
|
||||
monkeypatch.setenv("EVEROS_SQLITE__JOURNAL_MODE", "DELETE")
|
||||
s = Settings()
|
||||
assert s.sqlite.busy_timeout_ms == 10000
|
||||
assert s.sqlite.journal_mode == "DELETE"
|
||||
# Untouched values stay at TOML defaults.
|
||||
assert s.sqlite.synchronous == "NORMAL"
|
||||
|
||||
|
||||
def test_init_args_override_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_SQLITE__BUSY_TIMEOUT_MS", "10000")
|
||||
from everos.config.settings import SqliteSettings
|
||||
|
||||
s = Settings(sqlite=SqliteSettings(busy_timeout_ms=99999))
|
||||
assert s.sqlite.busy_timeout_ms == 99999 # init beats env
|
||||
|
||||
|
||||
def test_invalid_journal_mode_rejected() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Settings.model_validate({"sqlite": {"journal_mode": "BOGUS"}})
|
||||
|
||||
|
||||
def test_negative_busy_timeout_rejected() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Settings.model_validate({"sqlite": {"busy_timeout_ms": -1}})
|
||||
|
||||
|
||||
def test_lancedb_read_consistency_optional_float() -> None:
|
||||
s = Settings.model_validate({"lancedb": {"read_consistency_seconds": 5.0}})
|
||||
assert s.lancedb.read_consistency_seconds == 5.0
|
||||
s2 = Settings.model_validate({"lancedb": {"read_consistency_seconds": None}})
|
||||
assert s2.lancedb.read_consistency_seconds is None
|
||||
|
||||
|
||||
def test_memory_timezone_overridable_via_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_MEMORY__TIMEZONE", "Asia/Shanghai")
|
||||
s = Settings()
|
||||
assert s.memory.timezone == "Asia/Shanghai"
|
||||
|
||||
|
||||
def test_memory_timezone_invalid_rejected() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="invalid timezone"):
|
||||
Settings.model_validate({"memory": {"timezone": "Not/A/Real_Zone"}})
|
||||
|
||||
|
||||
def test_load_settings_is_cached() -> None:
|
||||
"""Repeated calls return the same Settings object until cache_clear."""
|
||||
a = load_settings()
|
||||
b = load_settings()
|
||||
assert a is b
|
||||
load_settings.cache_clear()
|
||||
c = load_settings()
|
||||
assert c is not a
|
||||
|
||||
|
||||
def test_embedding_rerank_defaults() -> None:
|
||||
"""Embedding / rerank ship with runtime knobs but no model credentials."""
|
||||
# ``_isolate_env`` already strips shell env; ``_env_file=None`` further
|
||||
# prevents a developer's ``.env`` (which typically sets MODEL / API_KEY /
|
||||
# BASE_URL for live runs) from leaking into this default-state check.
|
||||
s = Settings(_env_file=None) # type: ignore[call-arg]
|
||||
# Credentials must be set explicitly (no default).
|
||||
assert s.embedding.model is None
|
||||
assert s.embedding.api_key is None
|
||||
assert s.embedding.base_url is None
|
||||
# Runtime knobs come from default.toml.
|
||||
assert s.embedding.timeout_seconds == 30.0
|
||||
assert s.embedding.max_retries == 3
|
||||
assert s.embedding.batch_size == 10
|
||||
assert s.embedding.max_concurrent == 5
|
||||
# Rerank mirrors the shape.
|
||||
assert s.rerank.model is None
|
||||
assert s.rerank.timeout_seconds == 30.0
|
||||
assert s.rerank.batch_size == 10
|
||||
|
||||
|
||||
def test_embedding_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__MODEL", "intfloat/e5-large-v2")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__BASE_URL", "http://localhost:8000/v1")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__BATCH_SIZE", "32")
|
||||
s = Settings()
|
||||
assert s.embedding.model == "intfloat/e5-large-v2"
|
||||
assert s.embedding.base_url == "http://localhost:8000/v1"
|
||||
assert s.embedding.batch_size == 32
|
||||
|
||||
|
||||
def test_rerank_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_RERANK__MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
monkeypatch.setenv("EVEROS_RERANK__MAX_CONCURRENT", "8")
|
||||
s = Settings()
|
||||
assert s.rerank.model == "BAAI/bge-reranker-v2-m3"
|
||||
assert s.rerank.max_concurrent == 8
|
||||
|
||||
|
||||
def test_user_toml_override_via_env_path(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""``EVEROS_CONFIG_FILE`` points pydantic-settings at a user toml."""
|
||||
user_toml = tmp_path / "config.toml"
|
||||
user_toml.write_text(
|
||||
'[sqlite]\nbusy_timeout_ms = 7777\n[memory]\ntimezone = "Asia/Tokyo"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("EVEROS_CONFIG_FILE", str(user_toml))
|
||||
s = Settings()
|
||||
assert s.sqlite.busy_timeout_ms == 7777
|
||||
assert s.memory.timezone == "Asia/Tokyo"
|
||||
# Values not touched by the user toml still come from the shipped default.
|
||||
assert s.sqlite.journal_mode == "WAL"
|
||||
|
||||
|
||||
def test_user_toml_loses_to_env(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""env vars beat the user-level toml."""
|
||||
user_toml = tmp_path / "config.toml"
|
||||
user_toml.write_text("[sqlite]\nbusy_timeout_ms = 7777\n", encoding="utf-8")
|
||||
monkeypatch.setenv("EVEROS_CONFIG_FILE", str(user_toml))
|
||||
monkeypatch.setenv("EVEROS_SQLITE__BUSY_TIMEOUT_MS", "9999")
|
||||
s = Settings()
|
||||
assert s.sqlite.busy_timeout_ms == 9999
|
||||
|
||||
|
||||
def test_user_toml_missing_file_is_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""A non-existent user toml path is silently skipped, not an error."""
|
||||
monkeypatch.setenv("EVEROS_CONFIG_FILE", str(tmp_path / "nope.toml"))
|
||||
s = Settings()
|
||||
# Falls back to shipped defaults.
|
||||
assert s.sqlite.busy_timeout_ms == 5000
|
||||
0
tests/unit/test_core/__init__.py
Normal file
0
tests/unit/test_core/__init__.py
Normal file
0
tests/unit/test_core/test_lifespan/__init__.py
Normal file
0
tests/unit/test_core/test_lifespan/__init__.py
Normal file
88
tests/unit/test_core/test_lifespan/test_factory.py
Normal file
88
tests/unit/test_core/test_lifespan/test_factory.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""``build_lifespan`` — provider ordering, state storage, shutdown errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from everos.core.lifespan import LifespanProvider
|
||||
from everos.core.lifespan.factory import build_lifespan
|
||||
|
||||
|
||||
class _RecordingProvider(LifespanProvider):
|
||||
"""Provider that records the order in which startup/shutdown ran."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
order: int,
|
||||
log: list[str],
|
||||
*,
|
||||
returns: object | None = None,
|
||||
shutdown_raises: bool = False,
|
||||
) -> None:
|
||||
super().__init__(name=name, order=order)
|
||||
self._log = log
|
||||
self._returns = returns
|
||||
self._shutdown_raises = shutdown_raises
|
||||
|
||||
async def startup(self, app: FastAPI) -> object | None:
|
||||
self._log.append(f"start:{self.name}")
|
||||
return self._returns
|
||||
|
||||
async def shutdown(self, app: FastAPI) -> None:
|
||||
self._log.append(f"stop:{self.name}")
|
||||
if self._shutdown_raises:
|
||||
raise RuntimeError(f"{self.name} shutdown boom")
|
||||
|
||||
|
||||
async def test_startup_runs_in_order_ascending() -> None:
|
||||
log: list[str] = []
|
||||
p1 = _RecordingProvider("a", order=2, log=log)
|
||||
p2 = _RecordingProvider("b", order=1, log=log)
|
||||
p3 = _RecordingProvider("c", order=3, log=log)
|
||||
|
||||
app = FastAPI()
|
||||
async with build_lifespan([p1, p2, p3])(app):
|
||||
pass
|
||||
assert log[:3] == ["start:b", "start:a", "start:c"]
|
||||
|
||||
|
||||
async def test_shutdown_runs_in_reverse_order() -> None:
|
||||
log: list[str] = []
|
||||
p1 = _RecordingProvider("a", order=1, log=log)
|
||||
p2 = _RecordingProvider("b", order=2, log=log)
|
||||
|
||||
app = FastAPI()
|
||||
async with build_lifespan([p1, p2])(app):
|
||||
pass
|
||||
# shutdown phase: reverse of startup
|
||||
assert log[2:] == ["stop:b", "stop:a"]
|
||||
|
||||
|
||||
async def test_non_none_startup_result_stored_in_state() -> None:
|
||||
sentinel = object()
|
||||
p = _RecordingProvider("x", order=1, log=[], returns=sentinel)
|
||||
app = FastAPI()
|
||||
async with build_lifespan([p])(app):
|
||||
assert app.state.lifespan_data["x"] is sentinel
|
||||
|
||||
|
||||
async def test_none_startup_result_not_stored() -> None:
|
||||
p = _RecordingProvider("nullone", order=1, log=[], returns=None)
|
||||
app = FastAPI()
|
||||
async with build_lifespan([p])(app):
|
||||
assert "nullone" not in app.state.lifespan_data
|
||||
|
||||
|
||||
async def test_shutdown_exception_swallowed_and_logged() -> None:
|
||||
"""Failed shutdown logs but doesn't break sibling shutdown."""
|
||||
log: list[str] = []
|
||||
p1 = _RecordingProvider("a", order=1, log=log)
|
||||
p2 = _RecordingProvider("b", order=2, log=log, shutdown_raises=True)
|
||||
|
||||
app = FastAPI()
|
||||
async with build_lifespan([p1, p2])(app):
|
||||
pass
|
||||
# Even though "b" raised, "a" still shut down.
|
||||
assert log[-1] == "stop:a"
|
||||
assert "stop:b" in log # b's shutdown ran (and raised, but swallowed)
|
||||
35
tests/unit/test_core/test_lifespan/test_metrics_lifespan.py
Normal file
35
tests/unit/test_core/test_lifespan/test_metrics_lifespan.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""``MetricsLifespanProvider`` — startup returns registry, shutdown logs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
from prometheus_client import CollectorRegistry
|
||||
|
||||
from everos.core.lifespan.metrics_lifespan import MetricsLifespanProvider
|
||||
from everos.core.observability.metrics import (
|
||||
reset_metrics_registry,
|
||||
set_metrics_registry,
|
||||
)
|
||||
|
||||
|
||||
async def test_startup_returns_registry() -> None:
|
||||
fresh = CollectorRegistry()
|
||||
set_metrics_registry(fresh)
|
||||
try:
|
||||
p = MetricsLifespanProvider()
|
||||
result = await p.startup(FastAPI())
|
||||
assert result is fresh
|
||||
finally:
|
||||
reset_metrics_registry()
|
||||
|
||||
|
||||
async def test_shutdown_is_noop() -> None:
|
||||
# Smoke test — must not raise.
|
||||
p = MetricsLifespanProvider()
|
||||
await p.shutdown(FastAPI())
|
||||
|
||||
|
||||
def test_provider_metadata() -> None:
|
||||
p = MetricsLifespanProvider(order=42)
|
||||
assert p.name == "metrics"
|
||||
assert p.order == 42
|
||||
0
tests/unit/test_core/test_middleware/__init__.py
Normal file
0
tests/unit/test_core/test_middleware/__init__.py
Normal file
106
tests/unit/test_core/test_middleware/test_global_exception.py
Normal file
106
tests/unit/test_core/test_middleware/test_global_exception.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""``global_exception_handler`` — uniform error envelope per v1 API §1.
|
||||
|
||||
We mount the handler on a minimal FastAPI app with three error-emitting
|
||||
routes (HTTPException 4xx / 5xx, RequestValidationError, raw exception)
|
||||
and assert the envelope shape + status code each route produces.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from everos.core.middleware.global_exception import global_exception_handler
|
||||
|
||||
|
||||
class _Body(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_exception_handler(HTTPException, global_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, global_exception_handler)
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
|
||||
@app.get("/raise-400")
|
||||
async def raise_400() -> None:
|
||||
raise HTTPException(status_code=400, detail="bad input")
|
||||
|
||||
@app.get("/raise-500-http")
|
||||
async def raise_500_http() -> None:
|
||||
raise HTTPException(status_code=503, detail="upstream dead")
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> None:
|
||||
raise RuntimeError("hidden internals")
|
||||
|
||||
@app.post("/validate")
|
||||
async def validate(_body: _Body) -> dict[str, str]:
|
||||
return {"ok": "yes"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncIterator[AsyncClient]:
|
||||
app = _build_app()
|
||||
# raise_app_exceptions=False — let the registered handler convert the
|
||||
# RuntimeError into a 500 response instead of re-raising into the test.
|
||||
transport = ASGITransport(app=app, raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _assert_envelope(body: dict[str, object], *, code: str, path: str) -> None:
|
||||
"""Wiki §1 envelope: ``{request_id, error: {code, message, timestamp, path}}``."""
|
||||
assert isinstance(body["request_id"], str) and body["request_id"]
|
||||
error = body["error"]
|
||||
assert isinstance(error, dict)
|
||||
assert error["code"] == code
|
||||
assert isinstance(error["message"], str) and error["message"]
|
||||
assert isinstance(error["timestamp"], str) and "T" in error["timestamp"]
|
||||
assert error["path"] == path
|
||||
|
||||
|
||||
async def test_http_exception_4xx(client: AsyncClient) -> None:
|
||||
resp = await client.get("/raise-400")
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
_assert_envelope(body, code="HTTP_ERROR", path="/raise-400")
|
||||
assert body["error"]["message"] == "bad input"
|
||||
|
||||
|
||||
async def test_http_exception_5xx_uses_system_error(client: AsyncClient) -> None:
|
||||
"""5xx routed through HTTPException still produces SYSTEM_ERROR + generic msg."""
|
||||
resp = await client.get("/raise-500-http")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
_assert_envelope(body, code="SYSTEM_ERROR", path="/raise-500-http")
|
||||
# Internal detail "upstream dead" is suppressed in 5xx envelopes.
|
||||
assert body["error"]["message"] == "Internal server error"
|
||||
|
||||
|
||||
async def test_unhandled_exception_5xx(client: AsyncClient) -> None:
|
||||
"""RuntimeError → 500 with generic ``SYSTEM_ERROR`` envelope; details hidden."""
|
||||
resp = await client.get("/boom")
|
||||
assert resp.status_code == 500
|
||||
body = resp.json()
|
||||
_assert_envelope(body, code="SYSTEM_ERROR", path="/boom")
|
||||
assert body["error"]["message"] == "Internal server error"
|
||||
# Must not leak the internal exception message.
|
||||
assert "hidden internals" not in resp.text
|
||||
|
||||
|
||||
async def test_validation_error_returns_422(client: AsyncClient) -> None:
|
||||
resp = await client.post("/validate", json={}) # missing ``name``
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
_assert_envelope(body, code="HTTP_ERROR", path="/validate")
|
||||
# First-error message includes the offending field somewhere.
|
||||
assert "name" in body["error"]["message"].lower()
|
||||
148
tests/unit/test_core/test_middleware/test_profile.py
Normal file
148
tests/unit/test_core/test_middleware/test_profile.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""``ProfileMiddleware`` — env gating, query-param gating, pyinstrument output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from everos.core.middleware.profile import ProfileMiddleware, _profiling_enabled
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(ProfileMiddleware)
|
||||
|
||||
@app.get("/hello")
|
||||
async def hello() -> dict[str, str]:
|
||||
return {"ok": "yes"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("PROFILING_ENABLED", raising=False)
|
||||
monkeypatch.delenv("PROFILING", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("PROFILING_ENABLED", "true")
|
||||
|
||||
|
||||
def test_profiling_enabled_truthy_variants(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for v in ("1", "true", "TRUE", "yes"):
|
||||
monkeypatch.setenv("PROFILING_ENABLED", v)
|
||||
assert _profiling_enabled() is True
|
||||
|
||||
|
||||
def test_profiling_enabled_falsy_variants(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for v in ("0", "false", "no", "", "anything-else"):
|
||||
monkeypatch.setenv("PROFILING_ENABLED", v)
|
||||
assert _profiling_enabled() is False
|
||||
|
||||
|
||||
def test_profiling_falls_back_to_legacy_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("PROFILING_ENABLED", raising=False)
|
||||
monkeypatch.setenv("PROFILING", "yes")
|
||||
assert _profiling_enabled() is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def disabled_client(disable_env: None) -> AsyncIterator[AsyncClient]:
|
||||
app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def enabled_client(enable_env: None) -> AsyncIterator[AsyncClient]:
|
||||
app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
async def test_disabled_passthrough(disabled_client: AsyncClient) -> None:
|
||||
"""When profiling is disabled, ``?profile=true`` is ignored — JSON returned."""
|
||||
resp = await disabled_client.get("/hello?profile=true")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": "yes"}
|
||||
|
||||
|
||||
async def test_enabled_without_query_passthrough(enabled_client: AsyncClient) -> None:
|
||||
"""Enabled middleware but request without ``?profile=true`` → normal response."""
|
||||
resp = await enabled_client.get("/hello")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": "yes"}
|
||||
|
||||
|
||||
async def test_enabled_with_query_returns_html(enabled_client: AsyncClient) -> None:
|
||||
"""With ``?profile=true`` and pyinstrument available, response is HTML."""
|
||||
try:
|
||||
import pyinstrument # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("pyinstrument not installed in this env")
|
||||
|
||||
resp = await enabled_client.get("/hello?profile=true")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers.get("content-type", "")
|
||||
# Pyinstrument output contains the word "pyinstrument" in its template.
|
||||
assert "pyinstrument" in resp.text.lower() or "<html" in resp.text.lower()
|
||||
|
||||
|
||||
async def test_enabled_with_query_returns_html_when_inner_raises(
|
||||
enabled_client: AsyncClient,
|
||||
) -> None:
|
||||
"""An exception inside the wrapped handler is logged but still produces HTML."""
|
||||
try:
|
||||
import pyinstrument # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("pyinstrument not installed in this env")
|
||||
|
||||
# Rebuild a tiny app whose route raises so the middleware's except branch
|
||||
# fires; the profile HTML is still emitted regardless.
|
||||
app = FastAPI()
|
||||
app.add_middleware(ProfileMiddleware)
|
||||
|
||||
@app.get("/bang")
|
||||
async def bang() -> None:
|
||||
raise RuntimeError("inner exception")
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app, raise_app_exceptions=False),
|
||||
base_url="http://test",
|
||||
) as c:
|
||||
resp = await c.get("/bang?profile=true")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers.get("content-type", "")
|
||||
|
||||
|
||||
async def test_enabled_without_pyinstrument(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If pyinstrument import fails, middleware degrades to passthrough."""
|
||||
monkeypatch.setenv("PROFILING_ENABLED", "true")
|
||||
# Force the import inside ProfileMiddleware.__init__ to fail.
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fail_pyinstrument(name: str, *args: object, **kwargs: object) -> object:
|
||||
if name == "pyinstrument":
|
||||
raise ImportError("simulated")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fail_pyinstrument)
|
||||
app = _build_app() # ProfileMiddleware ctor runs here
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
resp = await c.get("/hello?profile=true")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": "yes"}
|
||||
162
tests/unit/test_core/test_middleware/test_prometheus.py
Normal file
162
tests/unit/test_core/test_middleware/test_prometheus.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""``PrometheusMiddleware`` — increments counters / histograms, skips /metrics.
|
||||
|
||||
We isolate the test from the production registry by overriding it with a
|
||||
fresh :class:`prometheus_client.CollectorRegistry` for the duration of
|
||||
the test. The middleware was already imported with module-level Counter /
|
||||
Histogram bound to whatever the registry was at import time — those
|
||||
metric objects continue to record to the real registry. The test
|
||||
therefore reads via ``_http_requests_total`` directly rather than via
|
||||
``generate_metrics_response()``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from everos.core.middleware import prometheus as prom_mod
|
||||
|
||||
|
||||
def _sample_value(metric: object, **labels: str) -> float:
|
||||
"""Read the current value of a labeled prometheus metric (test helper)."""
|
||||
labeled = metric.labels(**labels)._labeled # type: ignore[attr-defined]
|
||||
for sample in labeled.collect()[0].samples:
|
||||
if sample.name.endswith("_total"):
|
||||
return float(sample.value)
|
||||
return float("nan")
|
||||
|
||||
|
||||
def _histogram_count(metric: object, **labels: str) -> float:
|
||||
labeled = metric.labels(**labels)._labeled # type: ignore[attr-defined]
|
||||
for sample in labeled.collect()[0].samples:
|
||||
if sample.name.endswith("_count"):
|
||||
return float(sample.value)
|
||||
return float("nan")
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(prom_mod.PrometheusMiddleware)
|
||||
|
||||
@app.get("/hello")
|
||||
async def hello() -> dict[str, str]:
|
||||
return {"ok": "yes"}
|
||||
|
||||
@app.get("/users/{user_id}")
|
||||
async def get_user(user_id: str) -> dict[str, str]:
|
||||
return {"user": user_id}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncIterator[AsyncClient]:
|
||||
app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
async def test_increments_counter_on_200(client: AsyncClient) -> None:
|
||||
before = _sample_value(
|
||||
prom_mod._http_requests_total, method="GET", path="/hello", status="200"
|
||||
)
|
||||
resp = await client.get("/hello")
|
||||
assert resp.status_code == 200
|
||||
after = _sample_value(
|
||||
prom_mod._http_requests_total, method="GET", path="/hello", status="200"
|
||||
)
|
||||
assert after == before + 1
|
||||
|
||||
|
||||
async def test_observes_duration_histogram(client: AsyncClient) -> None:
|
||||
before = _histogram_count(
|
||||
prom_mod._http_request_duration_seconds, method="GET", path="/hello"
|
||||
)
|
||||
await client.get("/hello")
|
||||
after = _histogram_count(
|
||||
prom_mod._http_request_duration_seconds, method="GET", path="/hello"
|
||||
)
|
||||
assert after == before + 1
|
||||
|
||||
|
||||
def test_skip_paths_constant_contains_known_endpoints() -> None:
|
||||
"""Skip set is the contract — assert membership directly to avoid
|
||||
|
||||
polluting the global registry by ``.labels(path='/metrics')``-ing it
|
||||
(that creates a zero-valued sample which then leaks into the
|
||||
exposition format that test_metrics_route inspects).
|
||||
"""
|
||||
assert "/metrics" in prom_mod._SKIP_PATHS
|
||||
assert "/health" in prom_mod._SKIP_PATHS
|
||||
assert "/healthz" in prom_mod._SKIP_PATHS
|
||||
assert "/favicon.ico" in prom_mod._SKIP_PATHS
|
||||
|
||||
|
||||
async def test_path_params_normalized(client: AsyncClient) -> None:
|
||||
"""``/users/abc`` should record against the route template ``/users/{user_id}``."""
|
||||
before = _sample_value(
|
||||
prom_mod._http_requests_total,
|
||||
method="GET",
|
||||
path="/users/{user_id}",
|
||||
status="200",
|
||||
)
|
||||
resp = await client.get("/users/abc")
|
||||
assert resp.status_code == 200
|
||||
after = _sample_value(
|
||||
prom_mod._http_requests_total,
|
||||
method="GET",
|
||||
path="/users/{user_id}",
|
||||
status="200",
|
||||
)
|
||||
assert after == before + 1
|
||||
|
||||
|
||||
# ── _normalize_path direct tests (defensive fallback branches) ─────────
|
||||
|
||||
|
||||
def test_normalize_path_uses_path_params_fallback() -> None:
|
||||
"""When scope has no ``route`` but ``path_params`` is set, substitute names."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from everos.core.middleware.prometheus import _normalize_path
|
||||
|
||||
fake_req = SimpleNamespace(
|
||||
scope={},
|
||||
url=SimpleNamespace(path="/x/abc/y"),
|
||||
path_params={"id": "abc"},
|
||||
)
|
||||
# type: ignore[arg-type] — helper accepts anything duck-typed.
|
||||
assert _normalize_path(fake_req) == "/x/{id}/y" # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_normalize_path_unmatched_fallback() -> None:
|
||||
"""No route, no path_params → ``{unmatched}`` sentinel."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from everos.core.middleware.prometheus import _normalize_path
|
||||
|
||||
fake_req = SimpleNamespace(
|
||||
scope={},
|
||||
url=SimpleNamespace(path="/x"),
|
||||
path_params={},
|
||||
)
|
||||
assert _normalize_path(fake_req) == "{unmatched}" # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_normalize_path_non_dict_scope_falls_through() -> None:
|
||||
"""Defensive: a non-dict ``scope`` skips the route lookup entirely."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from everos.core.middleware.prometheus import _normalize_path
|
||||
|
||||
fake_req = SimpleNamespace(
|
||||
scope="not-a-dict",
|
||||
url=SimpleNamespace(path="/x"),
|
||||
path_params={},
|
||||
)
|
||||
assert _normalize_path(fake_req) == "{unmatched}" # type: ignore[arg-type]
|
||||
0
tests/unit/test_core/test_observability/__init__.py
Normal file
0
tests/unit/test_core/test_observability/__init__.py
Normal file
74
tests/unit/test_core/test_observability/test_gauge.py
Normal file
74
tests/unit/test_core/test_observability/test_gauge.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""``Gauge`` / ``LabeledGauge`` — set / inc / dec; with & without labels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from prometheus_client import CollectorRegistry
|
||||
|
||||
from everos.core.observability.metrics import (
|
||||
Gauge,
|
||||
reset_metrics_registry,
|
||||
set_metrics_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_registry() -> Iterator[None]:
|
||||
"""Swap in a fresh registry so test names don't clash with prod metrics."""
|
||||
set_metrics_registry(CollectorRegistry())
|
||||
yield
|
||||
reset_metrics_registry()
|
||||
|
||||
|
||||
def _value(gauge: Gauge, **labels: str) -> float:
|
||||
"""Read the gauge's current scalar value (helper for assertions)."""
|
||||
labeled = (
|
||||
gauge.labels(**labels)._labeled # type: ignore[attr-defined]
|
||||
if labels
|
||||
else gauge._gauge # type: ignore[attr-defined]
|
||||
)
|
||||
for sample in labeled.collect()[0].samples:
|
||||
if sample.name.endswith("_gauge") or "_" in sample.name:
|
||||
return float(sample.value)
|
||||
return float("nan")
|
||||
|
||||
|
||||
def test_unlabeled_set_inc_dec() -> None:
|
||||
g = Gauge(name="queue_depth", description="rows pending")
|
||||
g.set(10)
|
||||
assert _value(g) == 10
|
||||
g.inc(2)
|
||||
assert _value(g) == 12
|
||||
g.dec()
|
||||
assert _value(g) == 11
|
||||
g.dec(5)
|
||||
assert _value(g) == 6
|
||||
|
||||
|
||||
def test_labeled_isolates_streams() -> None:
|
||||
g = Gauge(name="cache_size", description="entries", labelnames=("region",))
|
||||
g.labels(region="us").set(100)
|
||||
g.labels(region="eu").set(50)
|
||||
g.labels(region="us").inc(5)
|
||||
g.labels(region="eu").dec(10)
|
||||
assert _value(g, region="us") == 105
|
||||
assert _value(g, region="eu") == 40
|
||||
|
||||
|
||||
def test_namespace_subsystem_unit_render_in_metric_name() -> None:
|
||||
g = Gauge(
|
||||
name="depth",
|
||||
description="d",
|
||||
namespace="everos",
|
||||
subsystem="cascade",
|
||||
unit="rows",
|
||||
)
|
||||
g.set(7)
|
||||
# Underlying name should include all parts.
|
||||
full_name = g._gauge._name # type: ignore[attr-defined]
|
||||
assert "everos" in full_name
|
||||
assert "cascade" in full_name
|
||||
assert "depth" in full_name
|
||||
assert "rows" in full_name
|
||||
111
tests/unit/test_core/test_observability/test_logging_factory.py
Normal file
111
tests/unit/test_core/test_observability/test_logging_factory.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""``configure_logging`` + ``get_logger`` smoke tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import structlog
|
||||
|
||||
from everos.core.observability.logging.factory import configure_logging, get_logger
|
||||
|
||||
|
||||
def test_get_logger_returns_structlog_instance() -> None:
|
||||
logger = get_logger("test.module")
|
||||
# structlog's BoundLogger interface — must expose .info / .warning / .error.
|
||||
assert hasattr(logger, "info")
|
||||
assert hasattr(logger, "warning")
|
||||
assert hasattr(logger, "error")
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences so assertions are stable."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\x1b\[[0-9;]*m", "", text)
|
||||
|
||||
|
||||
def test_configure_logging_accepts_known_levels() -> None:
|
||||
"""Smoke-test the level-name → log-level mapping path; no raise."""
|
||||
for level in ("DEBUG", "INFO", "WARNING", "ERROR", "info", "warn"):
|
||||
configure_logging(level=level)
|
||||
|
||||
|
||||
def test_configure_logging_handles_unknown_level_silently() -> None:
|
||||
"""Unknown level name silently falls back via ``getattr(logging, ..., INFO)``."""
|
||||
# Just must not raise; behavior verified by absence of exception.
|
||||
configure_logging(level="NOPE")
|
||||
|
||||
|
||||
def test_configure_logging_emits_through_structlog(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
configure_logging(level="INFO")
|
||||
logger = get_logger("everos.test")
|
||||
logger.info("hello", k="v")
|
||||
plain = _strip_ansi(capsys.readouterr().out)
|
||||
assert "hello" in plain
|
||||
# ConsoleRenderer renders key=value pairs (sans color codes).
|
||||
assert "k=v" in plain
|
||||
|
||||
|
||||
def test_configure_logging_demotes_noisy_http_loggers_to_warning(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
"""Third-party HTTP client loggers (httpx / httpcore / urllib3) must be
|
||||
pinned at WARNING so each successful HTTP request doesn't produce an
|
||||
INFO line. everos's own ``get_logger(...)`` calls remain unaffected.
|
||||
"""
|
||||
import logging
|
||||
|
||||
configure_logging(level="INFO")
|
||||
|
||||
for name in ("httpx", "httpcore", "urllib3"):
|
||||
assert logging.getLogger(name).level == logging.WARNING, (
|
||||
f"{name} logger must be pinned to WARNING, got "
|
||||
f"{logging.getLevelName(logging.getLogger(name).level)}"
|
||||
)
|
||||
|
||||
# Behavioral check: an INFO from httpx must NOT reach stdout.
|
||||
logging.getLogger("httpx").info("HTTP Request: GET https://example.com 200 OK")
|
||||
plain = _strip_ansi(capsys.readouterr().out)
|
||||
assert "HTTP Request" not in plain
|
||||
|
||||
|
||||
def test_configure_logging_routes_stdlib_loggers_through_same_formatter(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
"""stdlib ``logging.getLogger(...)`` output must share the structlog
|
||||
ProcessorFormatter so uvicorn / fastapi / third-party libs render with
|
||||
the same ``[level] event`` shape as everos's own structlog calls.
|
||||
|
||||
This is the user-visible half of the foreign-log-integration setup —
|
||||
without it, uvicorn's default ``LOGGING_CONFIG`` would (a) reinstall
|
||||
its own handlers and (b) print ``INFO:logger.name:message`` lines
|
||||
that look nothing like the structlog ConsoleRenderer output.
|
||||
"""
|
||||
import logging
|
||||
|
||||
configure_logging(level="INFO")
|
||||
third_party = logging.getLogger("uvicorn.access")
|
||||
third_party.info("foreign event")
|
||||
|
||||
plain = _strip_ansi(capsys.readouterr().out)
|
||||
assert "foreign event" in plain
|
||||
# Default stdlib LogRecord prefix must NOT survive.
|
||||
assert "INFO:uvicorn.access" not in plain
|
||||
# ConsoleRenderer marks level in brackets; both structlog and stdlib
|
||||
# paths must produce the same shape.
|
||||
assert "[info" in plain
|
||||
|
||||
|
||||
def test_get_logger_with_same_name_returns_equivalent(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""structlog caches bound loggers per name when cache_logger_on_first_use=True."""
|
||||
configure_logging()
|
||||
a = get_logger("everos.cache.test")
|
||||
b = get_logger("everos.cache.test")
|
||||
# Both should behave equivalently; identity is not guaranteed by structlog
|
||||
# API, but both must satisfy the same protocol surface.
|
||||
assert isinstance(a, structlog.stdlib.BoundLogger | structlog.BoundLoggerBase) or (
|
||||
hasattr(a, "info") and hasattr(b, "info")
|
||||
)
|
||||
0
tests/unit/test_core/test_persistence/__init__.py
Normal file
0
tests/unit/test_core/test_persistence/__init__.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Unit tests for the LanceDB async connection factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.config import LanceDBSettings
|
||||
from everos.core.persistence import MemoryRoot, open_lancedb_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
async def test_connect_creates_lancedb_dir(memory_root: MemoryRoot) -> None:
|
||||
settings = LanceDBSettings()
|
||||
# Remove the auto-created dir to verify the factory recreates it.
|
||||
if memory_root.lancedb_dir.exists():
|
||||
memory_root.lancedb_dir.rmdir()
|
||||
assert not memory_root.lancedb_dir.exists()
|
||||
|
||||
conn = await open_lancedb_connection(memory_root.lancedb_dir, settings)
|
||||
try:
|
||||
assert memory_root.lancedb_dir.is_dir()
|
||||
assert conn.is_open()
|
||||
finally:
|
||||
conn.close() # AsyncConnection.close() is sync
|
||||
|
||||
|
||||
async def test_empty_connection_lists_no_tables(memory_root: MemoryRoot) -> None:
|
||||
settings = LanceDBSettings()
|
||||
conn = await open_lancedb_connection(memory_root.lancedb_dir, settings)
|
||||
try:
|
||||
# list_tables() returns ListTablesResponse(tables, page_token).
|
||||
result = await conn.list_tables()
|
||||
assert list(result.tables) == []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
async def test_read_consistency_seconds_translated_to_timedelta(
|
||||
memory_root: MemoryRoot,
|
||||
) -> None:
|
||||
"""Non-None read_consistency_seconds must be passed as a timedelta."""
|
||||
settings = LanceDBSettings(read_consistency_seconds=5.0)
|
||||
conn = await open_lancedb_connection(memory_root.lancedb_dir, settings)
|
||||
try:
|
||||
# The interval echoed back from the connection should equal what we set.
|
||||
# AsyncConnection.get_read_consistency_interval is async.
|
||||
import datetime as dt
|
||||
|
||||
interval = await conn.get_read_consistency_interval()
|
||||
assert interval == dt.timedelta(seconds=5.0)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
async def test_default_consistency_is_none(memory_root: MemoryRoot) -> None:
|
||||
settings = LanceDBSettings()
|
||||
conn = await open_lancedb_connection(memory_root.lancedb_dir, settings)
|
||||
try:
|
||||
interval = await conn.get_read_consistency_interval()
|
||||
assert interval is None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
async def test_index_cache_cap_is_plumbed_into_session(
|
||||
memory_root: MemoryRoot, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A capped ``Session`` must reach ``lancedb.connect_async``.
|
||||
|
||||
The connection factory's whole purpose for installing a Session is
|
||||
to bound the index reader cache so FDs do not leak. We spy on the
|
||||
underlying ``connect_async`` and assert a Session is passed —
|
||||
Session objects don't expose the configured cap back as a property,
|
||||
so verifying that a Session is wired through is the closest unit-
|
||||
level check we can make. The behavioural side (LRU eviction →
|
||||
FD release under load) is covered by the fd-probe scripts kept
|
||||
outside the test suite.
|
||||
"""
|
||||
import lancedb
|
||||
|
||||
settings = LanceDBSettings(index_cache_size_bytes=1024)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
real_connect = lancedb.connect_async
|
||||
|
||||
async def spy(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
captured["session"] = kwargs.get("session")
|
||||
return await real_connect(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(lancedb, "connect_async", spy)
|
||||
|
||||
conn = await open_lancedb_connection(memory_root.lancedb_dir, settings)
|
||||
try:
|
||||
assert isinstance(captured.get("session"), lancedb.Session)
|
||||
finally:
|
||||
conn.close()
|
||||
@ -0,0 +1,175 @@
|
||||
"""FTS-layer normalisation contract tests.
|
||||
|
||||
``BaseLanceTable.ensure_fts_indexes`` builds the LanceDB FTS index with
|
||||
the following configuration::
|
||||
|
||||
base_tokenizer="whitespace"
|
||||
lower_case=True
|
||||
stem=True
|
||||
remove_stop_words=True
|
||||
ascii_folding=True
|
||||
language="English" (tantivy default)
|
||||
|
||||
The app-layer ``JiebaTokenizer`` already handles segmentation +
|
||||
stopword filtering, so these FTS-layer settings act as a *belt-and-
|
||||
braces* layer of normalisation. These tests probe the FTS layer
|
||||
*directly* (bypassing jieba) to verify each setting actually behaves
|
||||
as the docstring claims:
|
||||
|
||||
- lower_case=True → query case-insensitive against the raw-cased text
|
||||
- stem=True → query for the word root hits inflected forms
|
||||
- remove_stop_words=False → FTS layer does NOT drop stop-words; the
|
||||
app-layer JiebaTokenizer is the single source of truth for
|
||||
stop-word filtering (English + Chinese)
|
||||
- ascii_folding=True → diacritics on Latin chars normalised (café → cafe)
|
||||
- CJK pass-through → no stemming applied to CJK
|
||||
|
||||
Tests build a fresh in-memory-ish LanceDB store under ``tmp_path``,
|
||||
declare a minimal schema with one ``body`` column, and inspect query
|
||||
hits against handcrafted rows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
from lancedb import AsyncTable
|
||||
|
||||
from everos.core.persistence.lancedb import BaseLanceTable
|
||||
|
||||
|
||||
class _FtsSpec(BaseLanceTable):
|
||||
"""Minimal schema with one BM25-indexed column for FTS-layer probes."""
|
||||
|
||||
TABLE_NAME: ClassVar[str] = "fts_probe"
|
||||
BM25_FIELDS: ClassVar[list[str]] = ["body"]
|
||||
|
||||
id: str
|
||||
body: str
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fts_table(tmp_path: Path) -> AsyncIterator[AsyncTable]:
|
||||
"""Build a fresh tmp LanceDB store + ``_FtsSpec`` table; index gets
|
||||
built on first ``ensure_fts_indexes`` call by each test (FTS index
|
||||
requires data first to materialise sensibly).
|
||||
"""
|
||||
conn = await lancedb.connect_async(str(tmp_path / "lancedb"))
|
||||
table = await conn.create_table(_FtsSpec.TABLE_NAME, schema=_FtsSpec)
|
||||
yield table
|
||||
|
||||
|
||||
async def _seed_and_index(table: AsyncTable, rows: list[dict]) -> None:
|
||||
"""Insert rows, then (re)build the FTS index over the full table."""
|
||||
await table.add([_FtsSpec(**r) for r in rows])
|
||||
await _FtsSpec.ensure_fts_indexes(table)
|
||||
|
||||
|
||||
async def _query_ids(table: AsyncTable, text: str) -> set[str]:
|
||||
"""Run a BM25 keyword query over the ``body`` column, return matched ids."""
|
||||
rows = await table.query().nearest_to_text(text, columns="body").limit(10).to_list()
|
||||
return {r["id"] for r in rows}
|
||||
|
||||
|
||||
# ── lower_case=True ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_lower_case_query_matches_uppercase_index(
|
||||
fts_table: AsyncTable,
|
||||
) -> None:
|
||||
"""Document indexed as ``HELLO`` is found by query ``hello``."""
|
||||
await _seed_and_index(
|
||||
fts_table,
|
||||
[
|
||||
{"id": "1", "body": "HELLO world"},
|
||||
{"id": "2", "body": "GOODBYE world"},
|
||||
],
|
||||
)
|
||||
hits = await _query_ids(fts_table, "hello")
|
||||
assert hits == {"1"}
|
||||
|
||||
|
||||
# ── stem=True ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_stem_query_root_matches_inflected_forms(
|
||||
fts_table: AsyncTable,
|
||||
) -> None:
|
||||
"""Query ``counsel`` hits documents containing ``counseling`` / ``counseled``."""
|
||||
await _seed_and_index(
|
||||
fts_table,
|
||||
[
|
||||
{"id": "1", "body": "counseling session happened"},
|
||||
{"id": "2", "body": "counseled patient yesterday"},
|
||||
{"id": "3", "body": "unrelated content"},
|
||||
],
|
||||
)
|
||||
hits = await _query_ids(fts_table, "counsel")
|
||||
assert hits == {"1", "2"}
|
||||
|
||||
|
||||
# ── remove_stop_words=False (app layer owns stop-words) ────────────────
|
||||
|
||||
|
||||
async def test_fts_layer_does_not_filter_stopwords(
|
||||
fts_table: AsyncTable,
|
||||
) -> None:
|
||||
"""FTS layer is configured ``remove_stop_words=False`` — app layer owns it.
|
||||
|
||||
The FTS index does NOT strip English stop-words. A query ``the``
|
||||
reaches BM25 unfiltered and hits a document that contains it.
|
||||
In production, :class:`JiebaTokenizer` removes ``the`` before
|
||||
tokens reach this layer; this test bypasses jieba to probe the
|
||||
FTS layer's behaviour in isolation.
|
||||
"""
|
||||
await _seed_and_index(
|
||||
fts_table,
|
||||
[
|
||||
{"id": "1", "body": "the cat sat on the mat"},
|
||||
{"id": "2", "body": "unrelated body text"},
|
||||
],
|
||||
)
|
||||
hits = await _query_ids(fts_table, "the")
|
||||
assert hits == {"1"}
|
||||
|
||||
|
||||
# ── ascii_folding=True ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_ascii_folding_strips_diacritics(fts_table: AsyncTable) -> None:
|
||||
"""``café`` is indexed/queried as ``cafe`` once diacritics are folded."""
|
||||
await _seed_and_index(
|
||||
fts_table,
|
||||
[
|
||||
{"id": "1", "body": "café latte"},
|
||||
{"id": "2", "body": "tea house"},
|
||||
],
|
||||
)
|
||||
hits = await _query_ids(fts_table, "cafe")
|
||||
assert hits == {"1"}
|
||||
|
||||
|
||||
# ── CJK pass-through ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_cjk_terms_pass_through_untouched(fts_table: AsyncTable) -> None:
|
||||
"""CJK tokens are not stemmed or stop-word-filtered (English-only rules).
|
||||
|
||||
Note: ``base_tokenizer="whitespace"`` means CJK substrings are split
|
||||
only on whitespace. The app-layer tokenizer (``JiebaTokenizer``)
|
||||
normally inserts spaces between CJK words before they reach this
|
||||
layer; here we simulate that by pre-spacing the body text.
|
||||
"""
|
||||
await _seed_and_index(
|
||||
fts_table,
|
||||
[
|
||||
{"id": "1", "body": "北京 天安门"},
|
||||
{"id": "2", "body": "上海 外滩"},
|
||||
],
|
||||
)
|
||||
hits = await _query_ids(fts_table, "北京")
|
||||
assert hits == {"1"}
|
||||
@ -0,0 +1,649 @@
|
||||
"""Tests for :class:`LanceRepoBase` + :class:`LanceDailyLogRepoBase`.
|
||||
|
||||
Exercises the chassis-level query helpers shared by every business
|
||||
LanceDB repo: ``find_where`` / ``find_one_where`` / ``find_by_owner`` /
|
||||
``find_by_md_path`` (on :class:`LanceRepoBase`), and the daily-log
|
||||
slice ``find_by_owner_entry`` / ``find_by_session`` /
|
||||
``find_by_parent`` (on :class:`LanceDailyLogRepoBase`). Also covers
|
||||
``get_by_id`` + ``upsert`` so the chassis CRUD surface is end-to-end
|
||||
verified.
|
||||
|
||||
Uses a tmp LanceDB connection + a locally-defined daily-log-shaped
|
||||
table so the chassis can be exercised without depending on any
|
||||
specific business schema (episode / atomic_fact / …).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.config import LanceDBSettings
|
||||
from everos.core.persistence import (
|
||||
BaseLanceTable,
|
||||
MemoryRoot,
|
||||
Vector,
|
||||
open_lancedb_connection,
|
||||
)
|
||||
from everos.core.persistence.lancedb import (
|
||||
LanceDailyLogRepoBase,
|
||||
LanceRepoBase,
|
||||
)
|
||||
|
||||
|
||||
class _Note(BaseLanceTable):
|
||||
"""Minimal daily-log-shaped table for chassis tests."""
|
||||
|
||||
TABLE_NAME: ClassVar[str] = "_note"
|
||||
|
||||
id: str
|
||||
owner_id: str
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
entry_id: str
|
||||
session_id: str
|
||||
parent_type: str
|
||||
parent_id: str
|
||||
md_path: str
|
||||
text: str
|
||||
vector: Vector(4) # type: ignore[valid-type]
|
||||
|
||||
|
||||
class _SearchNote(BaseLanceTable):
|
||||
"""Schema with BM25_FIELDS declared — exercises FTS index setup."""
|
||||
|
||||
TABLE_NAME: ClassVar[str] = "_search_note"
|
||||
BM25_FIELDS: ClassVar[list[str]] = ["tokens"]
|
||||
|
||||
id: str
|
||||
text: str
|
||||
"""Original surface form (display)."""
|
||||
|
||||
tokens: str
|
||||
"""Space-joined pre-tokenised text (BM25 index target)."""
|
||||
|
||||
vector: Vector(4) # type: ignore[valid-type]
|
||||
|
||||
|
||||
class _NoteRepo(LanceDailyLogRepoBase[_Note]):
|
||||
schema = _Note
|
||||
|
||||
|
||||
def _row(
|
||||
*,
|
||||
owner: str,
|
||||
entry: str,
|
||||
session: str = "sess_a",
|
||||
parent_type: str = "memcell",
|
||||
parent_id: str = "mc_1",
|
||||
md_path: str | None = None,
|
||||
text: str = "x",
|
||||
) -> _Note:
|
||||
return _Note(
|
||||
id=f"{owner}_{entry}",
|
||||
owner_id=owner,
|
||||
entry_id=entry,
|
||||
session_id=session,
|
||||
parent_type=parent_type,
|
||||
parent_id=parent_id,
|
||||
md_path=md_path or f"users/{owner}/notes/{entry}.md",
|
||||
text=text,
|
||||
vector=[1.0, 0.0, 0.0, 0.0],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_write_locks() -> None:
|
||||
"""Drop the per-table write-lock pool between tests.
|
||||
|
||||
``LanceRepoBase`` lazily creates an ``asyncio.Lock`` per table name
|
||||
and stashes it in a class-level dict; without a reset the lock
|
||||
object outlives the pytest-asyncio function-scoped event loop and
|
||||
the next test fails with "bound to a different event loop".
|
||||
"""
|
||||
LanceRepoBase._reset_locks_for_tests()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path: Path) -> _NoteRepo:
|
||||
"""Open a tmp connection, create the ``_note`` table, return a repo."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
conn = await open_lancedb_connection(mr.lancedb_dir, LanceDBSettings())
|
||||
table = await conn.create_table("_note", schema=_Note)
|
||||
return _NoteRepo(table=table)
|
||||
|
||||
|
||||
# ── add + get_by_id + count ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_add_and_count(repo: _NoteRepo) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_1"), _row(owner="u1", entry="ep_2")])
|
||||
assert await repo.count() == 2
|
||||
|
||||
|
||||
async def test_get_by_id_returns_typed_instance(repo: _NoteRepo) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_1", text="hello")])
|
||||
got = await repo.get_by_id("u1_ep_1")
|
||||
assert got is not None
|
||||
assert isinstance(got, _Note)
|
||||
assert got.text == "hello"
|
||||
|
||||
|
||||
async def test_get_by_id_returns_none_when_missing(repo: _NoteRepo) -> None:
|
||||
assert await repo.get_by_id("ghost") is None
|
||||
|
||||
|
||||
# ── upsert ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_upsert_inserts_on_new(repo: _NoteRepo) -> None:
|
||||
await repo.upsert([_row(owner="u1", entry="ep_1", text="v1")])
|
||||
got = await repo.get_by_id("u1_ep_1")
|
||||
assert got is not None
|
||||
assert got.text == "v1"
|
||||
|
||||
|
||||
async def test_upsert_updates_on_existing(repo: _NoteRepo) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_1", text="v1")])
|
||||
await repo.upsert([_row(owner="u1", entry="ep_1", text="v2")])
|
||||
got = await repo.get_by_id("u1_ep_1")
|
||||
assert got is not None
|
||||
assert got.text == "v2"
|
||||
assert await repo.count() == 1 # update, not append
|
||||
|
||||
|
||||
# ── find_where / find_one_where ─────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_find_where_returns_typed_list(repo: _NoteRepo) -> None:
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1"),
|
||||
_row(owner="u1", entry="ep_2"),
|
||||
_row(owner="u2", entry="ep_3"),
|
||||
]
|
||||
)
|
||||
rows = await repo.find_where("owner_id = 'u1'")
|
||||
assert len(rows) == 2
|
||||
assert all(isinstance(r, _Note) for r in rows)
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_2"}
|
||||
|
||||
|
||||
async def test_find_one_where_returns_first_match(repo: _NoteRepo) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_1")])
|
||||
got = await repo.find_one_where("entry_id = 'ep_1'")
|
||||
assert got is not None
|
||||
assert got.entry_id == "ep_1"
|
||||
|
||||
|
||||
async def test_find_one_where_returns_none(repo: _NoteRepo) -> None:
|
||||
assert await repo.find_one_where("entry_id = 'ghost'") is None
|
||||
|
||||
|
||||
# ── find_where_paginated ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_find_where_paginated_first_page(repo: _NoteRepo) -> None:
|
||||
"""5 rows, page=1 size=2 → 2 rows, total=5, sorted DESC by entry_id."""
|
||||
await repo.add(
|
||||
[_row(owner="u1", entry=f"ep_{i}") for i in range(1, 6)],
|
||||
)
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'u1'",
|
||||
sort_by="entry_id",
|
||||
descending=True,
|
||||
page=1,
|
||||
page_size=2,
|
||||
)
|
||||
assert total == 5
|
||||
assert [r.entry_id for r in rows] == ["ep_5", "ep_4"]
|
||||
|
||||
|
||||
async def test_find_where_paginated_last_page_partial(repo: _NoteRepo) -> None:
|
||||
"""5 rows, page=3 size=2 → 1 row (the tail)."""
|
||||
await repo.add(
|
||||
[_row(owner="u1", entry=f"ep_{i}") for i in range(1, 6)],
|
||||
)
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'u1'",
|
||||
sort_by="entry_id",
|
||||
descending=True,
|
||||
page=3,
|
||||
page_size=2,
|
||||
)
|
||||
assert total == 5
|
||||
assert [r.entry_id for r in rows] == ["ep_1"]
|
||||
|
||||
|
||||
async def test_find_where_paginated_ascending_sort(repo: _NoteRepo) -> None:
|
||||
"""``descending=False`` flips order."""
|
||||
await repo.add(
|
||||
[_row(owner="u1", entry=f"ep_{i}") for i in range(1, 4)],
|
||||
)
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'u1'",
|
||||
sort_by="entry_id",
|
||||
descending=False,
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
assert total == 3
|
||||
assert [r.entry_id for r in rows] == ["ep_1", "ep_2", "ep_3"]
|
||||
|
||||
|
||||
async def test_find_where_paginated_empty_predicate(repo: _NoteRepo) -> None:
|
||||
"""Predicate that matches nothing → empty list + total=0."""
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'ghost'",
|
||||
sort_by="entry_id",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert rows == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
async def test_find_where_paginated_filters_by_owner(repo: _NoteRepo) -> None:
|
||||
"""Total is the predicate's true count, not the table's row count."""
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1"),
|
||||
_row(owner="u1", entry="ep_2"),
|
||||
_row(owner="u2", entry="ep_3"),
|
||||
]
|
||||
)
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'u1'",
|
||||
sort_by="entry_id",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
assert total == 2
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_2"}
|
||||
|
||||
|
||||
async def test_find_where_paginated_truncates_above_max_fetch(
|
||||
repo: _NoteRepo,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""When total > max_fetch the chassis warns and returns a prefix sort.
|
||||
|
||||
Correctness contract: ``total`` is still the *true* row count from
|
||||
``count_rows(filter=...)``, but the page contents are taken from
|
||||
only the first ``max_fetch`` rows the engine scanned. structlog now
|
||||
routes through stdlib's root logger (see
|
||||
``core/observability/logging/factory.py``), so the standard
|
||||
``caplog`` fixture is the right way to assert on the warning.
|
||||
"""
|
||||
# Unit tests don't go through the CLI entry, so the structlog →
|
||||
# stdlib bridge is uninitialised — wire it up here so ``caplog``
|
||||
# can observe the warning.
|
||||
from everos.core.observability.logging import configure_logging
|
||||
|
||||
configure_logging(level="WARNING")
|
||||
|
||||
await repo.add(
|
||||
[_row(owner="u1", entry=f"ep_{i:03d}") for i in range(1, 11)],
|
||||
)
|
||||
with caplog.at_level("WARNING"):
|
||||
rows, total = await repo.find_where_paginated(
|
||||
"owner_id = 'u1'",
|
||||
sort_by="entry_id",
|
||||
page=1,
|
||||
page_size=3,
|
||||
max_fetch=5,
|
||||
)
|
||||
assert total == 10 # true match count
|
||||
assert len(rows) == 3
|
||||
assert "find_where_paginated truncated" in caplog.text
|
||||
|
||||
|
||||
# ── 5-table shared: find_by_owner / find_by_md_path ─────────────────────
|
||||
|
||||
|
||||
async def test_find_by_owner(repo: _NoteRepo) -> None:
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1"),
|
||||
_row(owner="u1", entry="ep_2"),
|
||||
_row(owner="u2", entry="ep_3"),
|
||||
]
|
||||
)
|
||||
rows = await repo.find_by_owner("u1")
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_2"}
|
||||
|
||||
|
||||
async def test_find_by_md_path_round_trip(repo: _NoteRepo) -> None:
|
||||
path = "users/u1/notes/ep_1.md"
|
||||
await repo.add([_row(owner="u1", entry="ep_1", md_path=path)])
|
||||
got = await repo.find_by_md_path(path)
|
||||
assert got is not None
|
||||
assert got.entry_id == "ep_1"
|
||||
|
||||
|
||||
async def test_find_by_md_path_returns_none_when_missing(repo: _NoteRepo) -> None:
|
||||
assert await repo.find_by_md_path("users/u1/notes/ghost.md") is None
|
||||
|
||||
|
||||
# ── daily-log: find_by_owner_entry / find_by_session / find_by_parent ───
|
||||
|
||||
|
||||
async def test_find_by_owner_entry(repo: _NoteRepo) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_7")])
|
||||
got = await repo.find_by_owner_entry("u1", "ep_7")
|
||||
assert got is not None
|
||||
assert got.entry_id == "ep_7"
|
||||
|
||||
|
||||
async def test_find_by_owner_entry_returns_none_when_missing(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
assert await repo.find_by_owner_entry("u1", "ghost") is None
|
||||
|
||||
|
||||
async def test_find_by_owner_entries_returns_only_matching_rows(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""Bulk lookup keeps only rows whose ``entry_id`` is in the set."""
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1"),
|
||||
_row(owner="u1", entry="ep_2"),
|
||||
_row(owner="u1", entry="ep_3"),
|
||||
_row(owner="u2", entry="ep_1"), # different owner — must not leak
|
||||
]
|
||||
)
|
||||
rows = await repo.find_by_owner_entries("u1", ["ep_1", "ep_3"])
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_3"}
|
||||
assert all(r.owner_id == "u1" for r in rows)
|
||||
|
||||
|
||||
async def test_find_by_owner_entries_empty_input_short_circuits(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""No ids → ``[]`` without emitting a ``WHERE entry_id IN ()`` predicate."""
|
||||
await repo.add([_row(owner="u1", entry="ep_1")])
|
||||
assert await repo.find_by_owner_entries("u1", []) == []
|
||||
|
||||
|
||||
async def test_find_by_session(repo: _NoteRepo) -> None:
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1", session="sess_a"),
|
||||
_row(owner="u1", entry="ep_2", session="sess_a"),
|
||||
_row(owner="u1", entry="ep_3", session="sess_b"),
|
||||
]
|
||||
)
|
||||
rows = await repo.find_by_session("u1", "sess_a")
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_2"}
|
||||
|
||||
|
||||
async def test_find_by_parent(repo: _NoteRepo) -> None:
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1", parent_type="memcell", parent_id="mc_x"),
|
||||
_row(owner="u1", entry="ep_2", parent_type="memcell", parent_id="mc_x"),
|
||||
_row(owner="u1", entry="ep_3", parent_type="other", parent_id="mc_y"),
|
||||
]
|
||||
)
|
||||
rows = await repo.find_by_parent("memcell", "mc_x")
|
||||
assert {r.entry_id for r in rows} == {"ep_1", "ep_2"}
|
||||
|
||||
|
||||
# ── chassis fallback behaviour ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_table_lookup_not_implemented_when_no_override() -> None:
|
||||
"""Repo with neither ``table=`` injection nor ``_table_lookup`` raises."""
|
||||
|
||||
class _BareRepo(LanceRepoBase[_Note]):
|
||||
schema = _Note
|
||||
|
||||
bare = _BareRepo()
|
||||
with pytest.raises(NotImplementedError, match="_table_lookup"):
|
||||
await bare.count()
|
||||
|
||||
|
||||
async def test_table_name_derived_from_schema() -> None:
|
||||
"""``repo.table_name`` reads off ``schema.TABLE_NAME`` (single source of truth)."""
|
||||
|
||||
class _R(LanceRepoBase[_Note]):
|
||||
schema = _Note
|
||||
|
||||
assert _R().table_name == "_note" # equals _Note.TABLE_NAME
|
||||
|
||||
|
||||
# ── SQL-quote escape defence ────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── BaseLanceTable.ensure_fts_indexes ───────────────────────────────────
|
||||
|
||||
|
||||
async def test_ensure_fts_indexes_creates_index(tmp_path: Path) -> None:
|
||||
"""Declared ``BM25_FIELDS`` becomes an FTS index after ensure."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
conn = await open_lancedb_connection(mr.lancedb_dir, LanceDBSettings())
|
||||
table = await conn.create_table("_search_note", schema=_SearchNote)
|
||||
await table.add(
|
||||
[
|
||||
_SearchNote(
|
||||
id="1",
|
||||
text="hello world",
|
||||
tokens="hello world",
|
||||
vector=[1, 0, 0, 0],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
await _SearchNote.ensure_fts_indexes(table)
|
||||
|
||||
indices = await table.list_indices()
|
||||
indexed_cols = {col for idx in indices for col in (idx.columns or [])}
|
||||
assert "tokens" in indexed_cols
|
||||
conn.close()
|
||||
|
||||
|
||||
async def test_ensure_fts_indexes_is_idempotent(tmp_path: Path) -> None:
|
||||
"""Calling twice is safe — no error, no duplicate index."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
conn = await open_lancedb_connection(mr.lancedb_dir, LanceDBSettings())
|
||||
table = await conn.create_table("_search_note", schema=_SearchNote)
|
||||
await table.add([_SearchNote(id="1", text="hi", tokens="hi", vector=[1, 0, 0, 0])])
|
||||
|
||||
await _SearchNote.ensure_fts_indexes(table)
|
||||
first = await table.list_indices()
|
||||
await _SearchNote.ensure_fts_indexes(table)
|
||||
second = await table.list_indices()
|
||||
|
||||
assert len(first) == len(second)
|
||||
conn.close()
|
||||
|
||||
|
||||
async def test_ensure_fts_indexes_noop_when_no_fields_declared(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""Schema without ``BM25_FIELDS`` is a no-op (no error)."""
|
||||
table = await repo._table()
|
||||
# _Note declares no BM25_FIELDS — calling the classmethod is a no-op.
|
||||
await _Note.ensure_fts_indexes(table)
|
||||
indices = await table.list_indices()
|
||||
# No FTS index was created; vector/scalar may exist by default but we
|
||||
# only assert no error path triggered.
|
||||
assert isinstance(indices, list) or hasattr(indices, "__iter__")
|
||||
|
||||
|
||||
# ── SQL-quote escape defence ────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── delete_by_md_path ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_delete_by_md_path_removes_matching_row(repo: _NoteRepo) -> None:
|
||||
"""Cascade md-deleted flow: rows for a path are wiped, count returned."""
|
||||
target = "users/u1/notes/ep_1.md"
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1", md_path=target),
|
||||
_row(owner="u1", entry="ep_2"),
|
||||
]
|
||||
)
|
||||
deleted = await repo.delete_by_md_path(target)
|
||||
assert deleted == 1
|
||||
assert await repo.find_by_md_path(target) is None
|
||||
assert await repo.count() == 1 # the other row survived
|
||||
|
||||
|
||||
async def test_delete_by_md_path_returns_zero_when_no_match(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
await repo.add([_row(owner="u1", entry="ep_1")])
|
||||
assert await repo.delete_by_md_path("users/u1/notes/ghost.md") == 0
|
||||
assert await repo.count() == 1
|
||||
|
||||
|
||||
async def test_delete_by_md_path_removes_multiple_entries_one_file(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""A daily-log md holds many entries → all rows for the path go."""
|
||||
shared = "users/u1/notes/episode-2026-05-12.md"
|
||||
await repo.add(
|
||||
[
|
||||
_row(owner="u1", entry="ep_1", md_path=shared),
|
||||
_row(owner="u1", entry="ep_2", md_path=shared),
|
||||
_row(owner="u1", entry="ep_3", md_path=shared),
|
||||
_row(owner="u2", entry="ep_4"), # different path, untouched
|
||||
]
|
||||
)
|
||||
deleted = await repo.delete_by_md_path(shared)
|
||||
assert deleted == 3
|
||||
assert await repo.count() == 1
|
||||
|
||||
|
||||
async def test_delete_by_md_path_escapes_single_quotes(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""A path containing a single quote does not break the predicate."""
|
||||
tricky = "users/u1/notes/it's.md"
|
||||
await repo.add([_row(owner="u1", entry="ep_1", md_path=tricky)])
|
||||
assert await repo.delete_by_md_path(tricky) == 1
|
||||
|
||||
|
||||
# ── SQL-quote escape defence (kept) ─────────────────────────────────────
|
||||
|
||||
|
||||
async def test_get_by_id_escapes_single_quotes(repo: _NoteRepo) -> None:
|
||||
"""An id containing a single quote does not break the predicate."""
|
||||
quoted_id = "u1_it's_fine"
|
||||
await repo.add(
|
||||
[
|
||||
_Note(
|
||||
id=quoted_id,
|
||||
owner_id="u1",
|
||||
entry_id="it's_fine",
|
||||
session_id="s",
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
md_path="x",
|
||||
text="t",
|
||||
vector=[1.0, 0.0, 0.0, 0.0],
|
||||
)
|
||||
]
|
||||
)
|
||||
got = await repo.get_by_id(quoted_id)
|
||||
assert got is not None
|
||||
assert got.entry_id == "it's_fine"
|
||||
|
||||
|
||||
# ── Concurrency: per-table write lock ───────────────────────────────────
|
||||
|
||||
|
||||
async def test_concurrent_upsert_disjoint_ids_no_lost_update(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""Regression for Bug B: cascade ``asyncio.gather`` over rows of the
|
||||
same kind would race on ``merge_insert`` and drop a write (observed
|
||||
on ``user_profile`` — pk = owner_id, two disjoint INSERTs ending up
|
||||
with only one row in LanceDB). The per-table ``asyncio.Lock`` in
|
||||
:meth:`LanceRepoBase.upsert` must serialise those writes so every
|
||||
submitted row lands.
|
||||
"""
|
||||
n = 16
|
||||
rows = [_row(owner=f"u_{i}", entry=f"ep_{i}") for i in range(n)]
|
||||
await asyncio.gather(*(repo.upsert([r]) for r in rows))
|
||||
assert await repo.count() == n
|
||||
for i in range(n):
|
||||
got = await repo.get_by_id(f"u_{i}_ep_{i}")
|
||||
assert got is not None, f"u_{i}_ep_{i} disappeared after concurrent upsert"
|
||||
|
||||
|
||||
async def test_concurrent_upsert_same_id_last_writer_wins(
|
||||
repo: _NoteRepo,
|
||||
) -> None:
|
||||
"""Concurrent upserts on the *same* pk must converge: exactly one row,
|
||||
one of the texts wins. The lock makes the outcome deterministic per
|
||||
schedule (no torn state, no duplicate row)."""
|
||||
row_a = _row(owner="u1", entry="ep_1", text="A")
|
||||
row_b = _row(owner="u1", entry="ep_1", text="B")
|
||||
await asyncio.gather(repo.upsert([row_a]), repo.upsert([row_b]))
|
||||
assert await repo.count() == 1
|
||||
got = await repo.get_by_id("u1_ep_1")
|
||||
assert got is not None
|
||||
assert got.text in {"A", "B"}
|
||||
|
||||
|
||||
async def test_read_not_blocked_by_write_lock(repo: _NoteRepo) -> None:
|
||||
"""Search / count must remain available while a write lock is held —
|
||||
only write paths take the lock. Acquires the table lock manually,
|
||||
then verifies a read still resolves."""
|
||||
await repo.add([_row(owner="u1", entry="ep_1", text="seed")])
|
||||
lock = repo._write_lock(repo.table_name)
|
||||
async with lock:
|
||||
# Whilst the lock is held, reads should not block.
|
||||
got = await asyncio.wait_for(repo.get_by_id("u1_ep_1"), timeout=2.0)
|
||||
assert got is not None
|
||||
assert got.text == "seed"
|
||||
|
||||
|
||||
async def test_write_lock_is_per_table(tmp_path: Path) -> None:
|
||||
"""Distinct tables share no lock — writes on table A do not stall
|
||||
writes on table B."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
conn = await open_lancedb_connection(mr.lancedb_dir, LanceDBSettings())
|
||||
|
||||
class _OtherNote(BaseLanceTable):
|
||||
TABLE_NAME: ClassVar[str] = "_other_note"
|
||||
id: str
|
||||
owner_id: str
|
||||
entry_id: str
|
||||
session_id: str
|
||||
parent_type: str
|
||||
parent_id: str
|
||||
md_path: str
|
||||
text: str
|
||||
vector: Vector(4) # type: ignore[valid-type]
|
||||
|
||||
class _OtherRepo(LanceDailyLogRepoBase[_OtherNote]):
|
||||
schema = _OtherNote
|
||||
|
||||
table_a = await conn.create_table("_note_a", schema=_Note)
|
||||
table_b = await conn.create_table(_OtherNote.TABLE_NAME, schema=_OtherNote)
|
||||
|
||||
class _NoteARepo(LanceDailyLogRepoBase[_Note]):
|
||||
schema = _Note
|
||||
|
||||
@property
|
||||
def table_name(self) -> str:
|
||||
return "_note_a"
|
||||
|
||||
repo_a = _NoteARepo(table=table_a)
|
||||
repo_b = _OtherRepo(table=table_b)
|
||||
assert repo_a._write_lock(repo_a.table_name) is not repo_b._write_lock(
|
||||
repo_b.table_name
|
||||
)
|
||||
@ -0,0 +1,82 @@
|
||||
"""LanceDB IO toolkit — typical workflow demo.
|
||||
|
||||
End-to-end story for how to author + use a LanceDB-backed table in everos:
|
||||
|
||||
1. Define a table schema by subclassing :class:`BaseLanceTable` and
|
||||
declaring a ``Vector(N)`` column for the embedding.
|
||||
2. ``open_lancedb_connection`` to get an :class:`AsyncConnection`.
|
||||
3. ``conn.create_table(name, schema=Cls)`` to create the table from
|
||||
the Pydantic schema.
|
||||
4. ``table.add(rows)`` to insert.
|
||||
5. ``table.query().nearest_to(vec).limit(k).to_list()`` for vector
|
||||
search (BM25 + scalar filter can chain in the same query).
|
||||
6. ``table.count_rows()`` for size.
|
||||
7. Mutate via :func:`touch` + :meth:`AsyncTable.update` (LanceDB has
|
||||
no SQL ``onupdate`` equivalent — the app must bump ``updated_at``).
|
||||
8. ``table.delete(predicate)`` to remove rows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from everos.config import LanceDBSettings
|
||||
from everos.core.persistence import (
|
||||
BaseLanceTable,
|
||||
MemoryRoot,
|
||||
Vector,
|
||||
open_lancedb_connection,
|
||||
)
|
||||
|
||||
|
||||
class _DemoNote(BaseLanceTable):
|
||||
"""Demo table — used only by this test module."""
|
||||
|
||||
text: str
|
||||
vector: Vector(4) # 4-dim for the test fixture
|
||||
|
||||
|
||||
async def test_lancedb_typical_workflow(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
settings = LanceDBSettings()
|
||||
|
||||
# 1. Open async connection rooted at <memory_root>/.index/lancedb/
|
||||
conn = await open_lancedb_connection(mr.lancedb_dir, settings)
|
||||
|
||||
# 2. Create the table from the BaseLanceTable schema
|
||||
table = await conn.create_table("_demo_notes", schema=_DemoNote)
|
||||
|
||||
# 3. Insert rows (Pydantic instances; created_at / updated_at filled in
|
||||
# by BaseLanceTable's default_factory).
|
||||
rows = [
|
||||
_DemoNote(text="hello world", vector=[1.0, 0.0, 0.0, 0.0]),
|
||||
_DemoNote(text="goodbye cruel world", vector=[0.0, 1.0, 0.0, 0.0]),
|
||||
_DemoNote(text="welcome aboard", vector=[1.0, 0.5, 0.0, 0.0]),
|
||||
]
|
||||
await table.add(rows)
|
||||
|
||||
# 4. Count
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
# 5. Vector search — nearest_to picks rows by ANN distance.
|
||||
results = await table.query().nearest_to([0.95, 0.05, 0.0, 0.0]).limit(2).to_list()
|
||||
assert len(results) == 2
|
||||
# The closest row to [0.95, 0.05, 0, 0] is "hello world" [1, 0, 0, 0]
|
||||
# ahead of "welcome aboard" [1, 0.5, 0, 0].
|
||||
assert results[0]["text"] == "hello world"
|
||||
|
||||
# 6. Filter (scalar predicate). LanceDB SQL-like predicate string.
|
||||
only_hello = await table.query().where("text = 'hello world'").to_list()
|
||||
assert len(only_hello) == 1
|
||||
assert only_hello[0]["text"] == "hello world"
|
||||
|
||||
# 7. Delete by predicate
|
||||
await table.delete("text = 'goodbye cruel world'")
|
||||
assert await table.count_rows() == 2
|
||||
|
||||
# 8. List tables on the connection
|
||||
tables_response = await conn.list_tables()
|
||||
assert "_demo_notes" in list(tables_response.tables)
|
||||
|
||||
conn.close()
|
||||
96
tests/unit/test_core/test_persistence/test_locking.py
Normal file
96
tests/unit/test_core/test_persistence/test_locking.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Unit tests for memory_root_lock async context manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import LockError, MemoryRoot, memory_root_lock
|
||||
|
||||
|
||||
async def test_lock_creates_anchor_file(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
async with memory_root_lock(mr):
|
||||
assert mr.lock_file.exists()
|
||||
|
||||
|
||||
async def test_lock_acquire_release_acquire(tmp_path: Path) -> None:
|
||||
"""Same process can re-acquire after release (no leftover state)."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
async with memory_root_lock(mr):
|
||||
pass
|
||||
async with memory_root_lock(mr):
|
||||
pass
|
||||
|
||||
|
||||
def _hold_lock(memory_root_path: str, ready: object, release: object) -> None:
|
||||
"""Subprocess helper: acquire blocking lock, signal, wait, release.
|
||||
|
||||
The subprocess runs its own event loop via :func:`anyio.run` since
|
||||
:func:`memory_root_lock` is now async.
|
||||
"""
|
||||
|
||||
async def _run() -> None:
|
||||
mr = MemoryRoot(memory_root_path)
|
||||
async with memory_root_lock(mr, blocking=True):
|
||||
ready.set()
|
||||
# Use a thread-offloaded wait so we don't block the event loop.
|
||||
await anyio.to_thread.run_sync(release.wait, 5)
|
||||
|
||||
anyio.run(_run)
|
||||
|
||||
|
||||
async def test_nonblocking_raises_when_held_by_other_process(tmp_path: Path) -> None:
|
||||
"""Different process holding the lock → blocking=False raises LockError."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
ready = ctx.Event()
|
||||
release = ctx.Event()
|
||||
proc = ctx.Process(target=_hold_lock, args=(str(mr.root), ready, release))
|
||||
proc.start()
|
||||
try:
|
||||
assert ready.wait(timeout=5), "subprocess failed to acquire lock"
|
||||
with pytest.raises(LockError):
|
||||
async with memory_root_lock(mr, blocking=False):
|
||||
pass
|
||||
finally:
|
||||
release.set()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
|
||||
async def test_blocking_waits_for_release(tmp_path: Path) -> None:
|
||||
"""Different process holding lock + main process blocking=True waits."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
ready = ctx.Event()
|
||||
release = ctx.Event()
|
||||
proc = ctx.Process(target=_hold_lock, args=(str(mr.root), ready, release))
|
||||
proc.start()
|
||||
try:
|
||||
assert ready.wait(timeout=5)
|
||||
# Schedule the subprocess to release shortly; main process should
|
||||
# acquire the lock after that.
|
||||
release_started = time.monotonic()
|
||||
|
||||
def release_after_short_delay() -> None:
|
||||
time.sleep(0.2)
|
||||
release.set()
|
||||
|
||||
import threading
|
||||
|
||||
threading.Thread(target=release_after_short_delay, daemon=True).start()
|
||||
async with memory_root_lock(mr, blocking=True):
|
||||
elapsed = time.monotonic() - release_started
|
||||
# Should have waited at least roughly the delay.
|
||||
assert elapsed >= 0.1
|
||||
finally:
|
||||
release.set()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
@ -0,0 +1,68 @@
|
||||
"""Tests for Frontmatter base classes (chassis layer)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.core.persistence.markdown import (
|
||||
AgentScopedFrontmatter,
|
||||
BaseFrontmatter,
|
||||
UserScopedFrontmatter,
|
||||
)
|
||||
|
||||
|
||||
def test_base_requires_id_and_type() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
BaseFrontmatter() # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_base_default_schema_version_is_one() -> None:
|
||||
fm = BaseFrontmatter(id="x", type="t")
|
||||
assert fm.schema_version == 1
|
||||
|
||||
|
||||
def test_base_extra_fields_allowed() -> None:
|
||||
"""L2 / L3 / L4 fields ride along without subclass declaration."""
|
||||
fm = BaseFrontmatter(
|
||||
id="x",
|
||||
type="t",
|
||||
md_sha256="abc", # L2
|
||||
last_indexed_at="2026-04-22T10:00:00Z",
|
||||
custom_user_field="anything", # L4
|
||||
)
|
||||
dumped = fm.model_dump()
|
||||
assert dumped["md_sha256"] == "abc"
|
||||
assert dumped["custom_user_field"] == "anything"
|
||||
|
||||
|
||||
def test_user_scoped_track_default() -> None:
|
||||
fm = UserScopedFrontmatter(id="x", type="t", user_id="u_jason")
|
||||
assert fm.track == "user"
|
||||
|
||||
|
||||
def test_user_scoped_requires_user_id() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UserScopedFrontmatter(id="x", type="t") # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_agent_scoped_track_default() -> None:
|
||||
fm = AgentScopedFrontmatter(id="x", type="t", agent_id="agent_zhangsan")
|
||||
assert fm.track == "agent"
|
||||
|
||||
|
||||
def test_agent_scoped_requires_agent_id() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentScopedFrontmatter(id="x", type="t") # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_track_literal_rejects_invalid_value() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UserScopedFrontmatter(id="x", type="t", user_id="u", track="agent")
|
||||
|
||||
|
||||
def test_scope_dir_classvars() -> None:
|
||||
"""Scope mixins declare the top-level memory-root subdirectory."""
|
||||
assert BaseFrontmatter.SCOPE_DIR == "" # scope-agnostic by default
|
||||
assert UserScopedFrontmatter.SCOPE_DIR == "users"
|
||||
assert AgentScopedFrontmatter.SCOPE_DIR == "agents"
|
||||
@ -0,0 +1,94 @@
|
||||
"""Unit tests for entry marker parsing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from everos.core.persistence import find_entry, split_entries
|
||||
|
||||
|
||||
def test_split_no_entries() -> None:
|
||||
assert split_entries("# heading\n\nbody.") == []
|
||||
|
||||
|
||||
def test_split_single_entry() -> None:
|
||||
body = (
|
||||
"preamble\n"
|
||||
"<!-- entry:abc123 -->\n"
|
||||
"content here\n"
|
||||
"<!-- /entry:abc123 -->\n"
|
||||
"trailing\n"
|
||||
)
|
||||
entries = split_entries(body)
|
||||
assert len(entries) == 1
|
||||
e = entries[0]
|
||||
assert e.id == "abc123"
|
||||
assert e.body == "content here"
|
||||
# offsets should bracket the markers
|
||||
assert body[e.start : e.end].startswith("<!-- entry:abc123 -->")
|
||||
assert body[e.start : e.end].endswith("<!-- /entry:abc123 -->")
|
||||
|
||||
|
||||
def test_split_multiple_entries() -> None:
|
||||
body = (
|
||||
"<!-- entry:e1 -->\nfirst\n<!-- /entry:e1 -->\n"
|
||||
"<!-- entry:e2 -->\nsecond\n<!-- /entry:e2 -->\n"
|
||||
)
|
||||
entries = split_entries(body)
|
||||
assert [e.id for e in entries] == ["e1", "e2"]
|
||||
assert entries[0].body == "first"
|
||||
assert entries[1].body == "second"
|
||||
|
||||
|
||||
def test_split_unmatched_open() -> None:
|
||||
"""Open without close → scan stops; preceding entries are still returned."""
|
||||
body = "<!-- entry:e1 -->\nok\n<!-- /entry:e1 -->\n<!-- entry:e2 -->\nno close\n"
|
||||
entries = split_entries(body)
|
||||
assert [e.id for e in entries] == ["e1"]
|
||||
|
||||
|
||||
def test_split_mismatched_id() -> None:
|
||||
"""Open id != close id → no match → scan stops at unterminated open."""
|
||||
body = "<!-- entry:e1 -->\ncontent\n<!-- /entry:other -->\n"
|
||||
entries = split_entries(body)
|
||||
assert entries == []
|
||||
|
||||
|
||||
def test_split_id_with_underscore_and_hyphen() -> None:
|
||||
body = "<!-- entry:abc_def-123 -->\nx\n<!-- /entry:abc_def-123 -->\n"
|
||||
entries = split_entries(body)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].id == "abc_def-123"
|
||||
|
||||
|
||||
def test_split_offsets_consistent() -> None:
|
||||
body = "before\n<!-- entry:e1 -->\nx\n<!-- /entry:e1 -->\nafter\n"
|
||||
e = split_entries(body)[0]
|
||||
assert body[e.start : e.end] == "<!-- entry:e1 -->\nx\n<!-- /entry:e1 -->"
|
||||
|
||||
|
||||
def test_find_entry_found() -> None:
|
||||
body = (
|
||||
"<!-- entry:a -->\nfirst\n<!-- /entry:a -->\n"
|
||||
"<!-- entry:b -->\nsecond\n<!-- /entry:b -->\n"
|
||||
)
|
||||
e = find_entry(body, "b")
|
||||
assert e is not None
|
||||
assert e.id == "b"
|
||||
assert e.body == "second"
|
||||
|
||||
|
||||
def test_find_entry_not_found() -> None:
|
||||
body = "<!-- entry:a -->\nx\n<!-- /entry:a -->\n"
|
||||
assert find_entry(body, "missing") is None
|
||||
|
||||
|
||||
def test_find_entry_open_without_close() -> None:
|
||||
body = "<!-- entry:a -->\nx\n" # no close
|
||||
assert find_entry(body, "a") is None
|
||||
|
||||
|
||||
def test_split_entry_body_no_internal_newline_stripping() -> None:
|
||||
"""Internal blank lines preserved; only the *single* leading/trailing
|
||||
newline introduced by formatter is stripped."""
|
||||
body = "<!-- entry:e1 -->\nline1\n\nline3\n<!-- /entry:e1 -->\n"
|
||||
e = split_entries(body)[0]
|
||||
assert e.body == "line1\n\nline3"
|
||||
@ -0,0 +1,99 @@
|
||||
"""Tests for ``EntryId`` parse / format / next_for."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import EntryId
|
||||
|
||||
# ── format ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_pads_seq_to_eight_digits() -> None:
|
||||
eid = EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=1)
|
||||
assert eid.format() == "umc_20260422_00000001"
|
||||
|
||||
|
||||
def test_format_pads_seq_at_99999999() -> None:
|
||||
eid = EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=99_999_999)
|
||||
assert eid.format() == "umc_20260422_99999999"
|
||||
|
||||
|
||||
def test_str_uses_format() -> None:
|
||||
eid = EntryId(prefix="ep", date=dt.date(2026, 1, 1), seq=42)
|
||||
assert str(eid) == "ep_20260101_00000042"
|
||||
|
||||
|
||||
# ── parse ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_parse_round_trip() -> None:
|
||||
raw = "umc_20260422_00000001"
|
||||
eid = EntryId.parse(raw)
|
||||
assert eid.prefix == "umc"
|
||||
assert eid.date == dt.date(2026, 4, 22)
|
||||
assert eid.seq == 1
|
||||
assert eid.format() == raw
|
||||
|
||||
|
||||
def test_parse_handles_seq_above_pad_width() -> None:
|
||||
"""Seq above 10**8 still parses; format emits more than 8 digits."""
|
||||
eid = EntryId.parse("umc_20260422_150000000")
|
||||
assert eid.seq == 150_000_000
|
||||
assert eid.format() == "umc_20260422_150000000"
|
||||
|
||||
|
||||
def test_parse_accepts_legacy_four_digit_seq() -> None:
|
||||
"""Pre-bump 4-digit seq strings still parse — format upgrades on round-trip."""
|
||||
eid = EntryId.parse("umc_20260422_0001")
|
||||
assert eid.seq == 1
|
||||
# format() returns the new 8-digit padding.
|
||||
assert eid.format() == "umc_20260422_00000001"
|
||||
|
||||
|
||||
def test_parse_accepts_legacy_three_digit_seq() -> None:
|
||||
"""Older 3-digit seq strings still parse cleanly."""
|
||||
eid = EntryId.parse("umc_20260422_001")
|
||||
assert eid.seq == 1
|
||||
assert eid.format() == "umc_20260422_00000001"
|
||||
|
||||
|
||||
def test_parse_rejects_too_few_segments() -> None:
|
||||
with pytest.raises(ValueError, match="invalid entry id format"):
|
||||
EntryId.parse("umc_20260422")
|
||||
|
||||
|
||||
def test_parse_rejects_invalid_date() -> None:
|
||||
with pytest.raises(ValueError, match="invalid date"):
|
||||
EntryId.parse("umc_2026XX22_00000001")
|
||||
|
||||
|
||||
def test_parse_rejects_non_numeric_seq() -> None:
|
||||
with pytest.raises(ValueError, match="invalid seq"):
|
||||
EntryId.parse("umc_20260422_xxxx")
|
||||
|
||||
|
||||
def test_parse_rejects_empty_prefix() -> None:
|
||||
with pytest.raises(ValueError, match="empty prefix"):
|
||||
EntryId.parse("_20260422_00000001")
|
||||
|
||||
|
||||
# ── next_for ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_next_for_seq_is_count_plus_one() -> None:
|
||||
eid = EntryId.next_for("umc", dt.date(2026, 4, 22), current_count=2)
|
||||
assert eid.seq == 3
|
||||
assert eid.format() == "umc_20260422_00000003"
|
||||
|
||||
|
||||
def test_next_for_starts_at_one_when_empty() -> None:
|
||||
eid = EntryId.next_for("umc", dt.date(2026, 4, 22), current_count=0)
|
||||
assert eid.seq == 1
|
||||
|
||||
|
||||
def test_next_for_rejects_negative_count() -> None:
|
||||
with pytest.raises(ValueError, match="must be >= 0"):
|
||||
EntryId.next_for("umc", dt.date(2026, 4, 22), current_count=-1)
|
||||
@ -0,0 +1,168 @@
|
||||
"""Unit tests for frontmatter parse / dump + path_glob chassis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import (
|
||||
AgentScopedFrontmatter,
|
||||
BaseFrontmatter,
|
||||
DailyLogPathMixin,
|
||||
SkillPathMixin,
|
||||
UserScopedFrontmatter,
|
||||
dump_frontmatter,
|
||||
parse_frontmatter,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_no_frontmatter() -> None:
|
||||
text = "# Just a heading\n\nbody."
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {}
|
||||
assert body == text
|
||||
|
||||
|
||||
def test_parse_empty_frontmatter() -> None:
|
||||
text = "---\n---\n# body\n"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {}
|
||||
assert body == "# body\n"
|
||||
|
||||
|
||||
def test_parse_simple_frontmatter() -> None:
|
||||
text = "---\ntitle: Hello\ntags: [a, b]\n---\n# body\n"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {"title": "Hello", "tags": ["a", "b"]}
|
||||
assert body == "# body\n"
|
||||
|
||||
|
||||
def test_parse_nested_frontmatter() -> None:
|
||||
text = "---\nuser:\n id: u_1\n name: Alice\n---\nbody"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {"user": {"id": "u_1", "name": "Alice"}}
|
||||
assert body == "body"
|
||||
|
||||
|
||||
def test_parse_no_closing_delim() -> None:
|
||||
"""Missing closing --- → treat as no frontmatter (return original text)."""
|
||||
text = "---\ntitle: Hello\n# body without closing\n"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {}
|
||||
assert body == text
|
||||
|
||||
|
||||
def test_parse_non_mapping_yaml() -> None:
|
||||
"""YAML that parses to a non-mapping (e.g. list) → empty dict + original text."""
|
||||
text = "---\n- item1\n- item2\n---\nbody\n"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {}
|
||||
assert body == text
|
||||
|
||||
|
||||
def test_parse_opening_delim_no_newline() -> None:
|
||||
"""``---`` followed by non-newline char → not a frontmatter block."""
|
||||
text = "---this is not frontmatter"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {}
|
||||
assert body == text
|
||||
|
||||
|
||||
def test_parse_unicode_values() -> None:
|
||||
text = "---\ntitle: 你好\n---\n世界"
|
||||
meta, body = parse_frontmatter(text)
|
||||
assert meta == {"title": "你好"}
|
||||
assert body == "世界"
|
||||
|
||||
|
||||
def test_dump_empty_mapping_returns_empty_string() -> None:
|
||||
assert dump_frontmatter({}) == ""
|
||||
|
||||
|
||||
def test_dump_simple_mapping() -> None:
|
||||
out = dump_frontmatter({"title": "Hello"})
|
||||
assert out.startswith("---\n")
|
||||
assert out.endswith("---\n")
|
||||
assert "title: Hello" in out
|
||||
|
||||
|
||||
def test_dump_preserves_key_order() -> None:
|
||||
out = dump_frontmatter({"z": 1, "a": 2, "m": 3})
|
||||
body = out.strip("-\n")
|
||||
keys = [line.split(":", 1)[0] for line in body.strip().splitlines() if ":" in line]
|
||||
assert keys == ["z", "a", "m"]
|
||||
|
||||
|
||||
def test_dump_unicode() -> None:
|
||||
out = dump_frontmatter({"title": "你好"})
|
||||
assert "你好" in out # allow_unicode keeps non-ASCII verbatim
|
||||
|
||||
|
||||
def test_round_trip() -> None:
|
||||
meta = {"title": "Hello", "tags": ["a", "b"], "nested": {"k": "v"}}
|
||||
body_text = "# Body\n\nLine.\n"
|
||||
composed = dump_frontmatter(meta) + body_text
|
||||
parsed_meta, parsed_body = parse_frontmatter(composed)
|
||||
assert parsed_meta == meta
|
||||
assert parsed_body == body_text
|
||||
|
||||
|
||||
# ── path_glob chassis ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_base_path_glob_raises_not_implemented() -> None:
|
||||
"""A schema with no strategy mixin must surface a clear error."""
|
||||
|
||||
class _PlainFm(BaseFrontmatter):
|
||||
type: Literal["_plain"] = "_plain"
|
||||
|
||||
with pytest.raises(NotImplementedError, match="path_glob"):
|
||||
_PlainFm.path_glob()
|
||||
|
||||
|
||||
def test_daily_log_path_glob_user_scope() -> None:
|
||||
"""Mixin builds ``users/*/<dir>/<prefix>-*.md`` from ClassVars."""
|
||||
|
||||
class _UserDaily(DailyLogPathMixin, UserScopedFrontmatter):
|
||||
DIR_NAME: ClassVar[str] = "demo"
|
||||
FILE_PREFIX: ClassVar[str] = "entry"
|
||||
type: Literal["_user_daily"] = "_user_daily"
|
||||
|
||||
assert _UserDaily.path_glob() == "*/*/users/*/demo/entry-*.md"
|
||||
|
||||
|
||||
def test_daily_log_path_glob_agent_scope() -> None:
|
||||
"""Same mixin, agent scope swaps the leading directory."""
|
||||
|
||||
class _AgentDaily(DailyLogPathMixin, AgentScopedFrontmatter):
|
||||
DIR_NAME: ClassVar[str] = "cases"
|
||||
FILE_PREFIX: ClassVar[str] = "case"
|
||||
type: Literal["_agent_daily"] = "_agent_daily"
|
||||
|
||||
assert _AgentDaily.path_glob() == "*/*/agents/*/cases/case-*.md"
|
||||
|
||||
|
||||
def test_skill_path_glob() -> None:
|
||||
"""SkillPathMixin builds ``<scope>/*/<container>/<prefix>*/<main>``."""
|
||||
|
||||
class _AgentSkill(SkillPathMixin, AgentScopedFrontmatter):
|
||||
SKILLS_CONTAINER_NAME: ClassVar[str] = "skills"
|
||||
SKILL_DIR_PREFIX: ClassVar[str] = "skill_"
|
||||
SKILL_MAIN_FILENAME: ClassVar[str] = "SKILL.md"
|
||||
type: Literal["_agent_skill"] = "_agent_skill"
|
||||
|
||||
assert _AgentSkill.path_glob() == "*/*/agents/*/skills/skill_*/SKILL.md"
|
||||
|
||||
|
||||
def test_strategy_mixin_overrides_base_via_mro() -> None:
|
||||
"""Strategy mixin placed first in the parent list wins over abstract base."""
|
||||
|
||||
class _Daily(DailyLogPathMixin, UserScopedFrontmatter):
|
||||
DIR_NAME: ClassVar[str] = "x"
|
||||
FILE_PREFIX: ClassVar[str] = "y"
|
||||
type: Literal["_daily_mro"] = "_daily_mro"
|
||||
|
||||
# Concrete is reachable; abstract NotImplementedError is shadowed.
|
||||
assert isinstance(_Daily.path_glob(), str)
|
||||
assert "NotImplementedError" not in _Daily.path_glob()
|
||||
@ -0,0 +1,66 @@
|
||||
"""Unit tests for MarkdownReader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from everos.core.persistence import MarkdownReader
|
||||
|
||||
|
||||
def test_parse_text_with_frontmatter_and_entries() -> None:
|
||||
text = (
|
||||
"---\n"
|
||||
"title: Day Log\n"
|
||||
"date: 2026-04-22\n"
|
||||
"---\n"
|
||||
"# Header\n"
|
||||
"<!-- entry:e1 -->\n"
|
||||
"first entry\n"
|
||||
"<!-- /entry:e1 -->\n"
|
||||
)
|
||||
parsed = MarkdownReader.parse(text)
|
||||
# PyYAML auto-converts unquoted ISO dates to datetime.date.
|
||||
assert parsed.frontmatter == {
|
||||
"title": "Day Log",
|
||||
"date": datetime.date(2026, 4, 22),
|
||||
}
|
||||
assert "# Header" in parsed.body
|
||||
assert len(parsed.entries) == 1
|
||||
assert parsed.entries[0].id == "e1"
|
||||
assert parsed.entries[0].body == "first entry"
|
||||
|
||||
|
||||
def test_parse_no_frontmatter_no_entries() -> None:
|
||||
text = "# Just a header\n\nbody.\n"
|
||||
parsed = MarkdownReader.parse(text)
|
||||
assert parsed.frontmatter == {}
|
||||
assert parsed.body == text
|
||||
assert parsed.entries == []
|
||||
|
||||
|
||||
def test_parse_only_frontmatter() -> None:
|
||||
text = "---\nkey: value\n---\n"
|
||||
parsed = MarkdownReader.parse(text)
|
||||
assert parsed.frontmatter == {"key": "value"}
|
||||
assert parsed.body == ""
|
||||
assert parsed.entries == []
|
||||
|
||||
|
||||
async def test_read_file(tmp_path: Path) -> None:
|
||||
f = tmp_path / "doc.md"
|
||||
f.write_text(
|
||||
"---\nk: v\n---\n<!-- entry:x -->\nbody\n<!-- /entry:x -->\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
parsed = await MarkdownReader.read(f)
|
||||
assert parsed.frontmatter == {"k": "v"}
|
||||
assert parsed.entries[0].id == "x"
|
||||
|
||||
|
||||
async def test_read_unicode_file(tmp_path: Path) -> None:
|
||||
f = tmp_path / "zh.md"
|
||||
f.write_text("---\ntitle: 你好\n---\n世界\n", encoding="utf-8")
|
||||
parsed = await MarkdownReader.read(f)
|
||||
assert parsed.frontmatter == {"title": "你好"}
|
||||
assert parsed.body == "世界\n"
|
||||
@ -0,0 +1,214 @@
|
||||
"""Tests for the audit-form structured entry chassis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence.markdown import (
|
||||
StructuredEntry,
|
||||
parse_structured_entry,
|
||||
render_structured_entry,
|
||||
)
|
||||
|
||||
# ── render ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_render_with_header_inline_and_sections() -> None:
|
||||
out = render_structured_entry(
|
||||
header="ep_20260422_001",
|
||||
inline={
|
||||
"type": "episode",
|
||||
"user_id": "u_jason",
|
||||
"group_id": "sp_1",
|
||||
},
|
||||
sections={"Summary": "first line\nsecond line"},
|
||||
)
|
||||
assert out.startswith("## ep_20260422_001\n\n")
|
||||
assert "**type**: episode" in out
|
||||
assert "**user_id**: u_jason" in out
|
||||
assert "**group_id**: sp_1" in out
|
||||
assert "### Summary\nfirst line\nsecond line" in out
|
||||
|
||||
|
||||
def test_render_inline_only_no_header_no_sections() -> None:
|
||||
out = render_structured_entry(inline={"k": "v"})
|
||||
assert out == "**k**: v"
|
||||
|
||||
|
||||
def test_render_lists_use_bracket_notation() -> None:
|
||||
out = render_structured_entry(
|
||||
inline={"participants": ["u_jason", "u_sarah"], "tags": ("a", "b")}
|
||||
)
|
||||
assert "**participants**: [u_jason, u_sarah]" in out
|
||||
assert "**tags**: [a, b]" in out
|
||||
|
||||
|
||||
def test_render_none_value_renders_empty() -> None:
|
||||
out = render_structured_entry(inline={"optional": None})
|
||||
assert out == "**optional**: "
|
||||
|
||||
|
||||
def test_render_scalar_uses_str() -> None:
|
||||
out = render_structured_entry(inline={"count": 3, "ratio": 0.5, "active": True})
|
||||
assert "**count**: 3" in out
|
||||
assert "**ratio**: 0.5" in out
|
||||
assert "**active**: True" in out
|
||||
|
||||
|
||||
# ── parse ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_parse_full_round_trip() -> None:
|
||||
src = render_structured_entry(
|
||||
header="ep_001",
|
||||
inline={"type": "episode", "user_id": "u_jason"},
|
||||
sections={"Summary": "the summary", "Body": "the body"},
|
||||
)
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.header == "ep_001"
|
||||
assert entry.inline == {"type": "episode", "user_id": "u_jason"}
|
||||
assert entry.sections == {"Summary": "the summary", "Body": "the body"}
|
||||
|
||||
|
||||
def test_parse_no_header_yields_none() -> None:
|
||||
src = "**k**: v\n\n### Section\nbody"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.header is None
|
||||
assert entry.inline == {"k": "v"}
|
||||
assert entry.sections == {"Section": "body"}
|
||||
|
||||
|
||||
def test_parse_no_inline() -> None:
|
||||
src = "## ep_001\n\n### Body\nonly section"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.header == "ep_001"
|
||||
assert entry.inline == {}
|
||||
assert entry.sections == {"Body": "only section"}
|
||||
|
||||
|
||||
def test_parse_no_sections() -> None:
|
||||
src = "## ep_001\n\n**k**: v"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.header == "ep_001"
|
||||
assert entry.inline == {"k": "v"}
|
||||
assert entry.sections == {}
|
||||
|
||||
|
||||
def test_parse_inline_value_with_colon_kept_verbatim() -> None:
|
||||
src = "**timestamp**: 2026-04-22T10:03:11Z"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.inline == {"timestamp": "2026-04-22T10:03:11Z"}
|
||||
|
||||
|
||||
def test_parse_list_value_kept_as_string() -> None:
|
||||
"""Type-agnostic by design — bracket notation is preserved as text."""
|
||||
src = "**participants**: [u_jason, u_sarah]"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.inline == {"participants": "[u_jason, u_sarah]"}
|
||||
|
||||
|
||||
def test_parse_section_with_multiline_body() -> None:
|
||||
src = "### Episode\nline 1\nline 2\nline 3"
|
||||
entry = parse_structured_entry(src)
|
||||
assert entry.sections == {"Episode": "line 1\nline 2\nline 3"}
|
||||
|
||||
|
||||
def test_parse_section_titles_kept_verbatim() -> None:
|
||||
"""No Title-casing — titles stay exactly as written."""
|
||||
src = "### task_intent\ndoc text"
|
||||
entry = parse_structured_entry(src)
|
||||
assert "task_intent" in entry.sections
|
||||
|
||||
|
||||
def test_parse_tolerates_stray_text_outside_blocks() -> None:
|
||||
"""Stray paragraphs in the head become part of nothing — silently dropped."""
|
||||
src = (
|
||||
"## ep_001\n\nrandom prose paragraph\n"
|
||||
"**k**: v\nmore stray text\n\n### Section\nbody"
|
||||
)
|
||||
entry = parse_structured_entry(src)
|
||||
# H2 + inline match anchors; stray prose lines that don't match
|
||||
# **key**: ... are simply not captured.
|
||||
assert entry.header == "ep_001"
|
||||
assert entry.inline == {"k": "v"}
|
||||
assert entry.sections == {"Section": "body"}
|
||||
|
||||
|
||||
def test_dataclass_immutable() -> None:
|
||||
"""``StructuredEntry`` is frozen — accidental mutation raises."""
|
||||
entry = StructuredEntry(id="", body="", start=0, end=0, header="x")
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
entry.header = "y" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_structured_entry_inherits_entry() -> None:
|
||||
"""``StructuredEntry`` is an :class:`Entry` subclass and carries
|
||||
the marker context plus the parsed audit-form fields together."""
|
||||
from everos.core.persistence.markdown import Entry
|
||||
|
||||
entry = StructuredEntry(
|
||||
id="ep_001",
|
||||
body="b",
|
||||
start=0,
|
||||
end=10,
|
||||
header="ep_001",
|
||||
inline={"k": "v"},
|
||||
sections={"S": "x"},
|
||||
)
|
||||
assert isinstance(entry, Entry)
|
||||
assert entry.id == "ep_001"
|
||||
assert entry.header == "ep_001"
|
||||
|
||||
|
||||
def test_entry_as_structured_preserves_marker_context() -> None:
|
||||
"""``Entry.as_structured`` copies id/start/end and adds parsed fields."""
|
||||
from everos.core.persistence.markdown import Entry
|
||||
|
||||
entry = Entry(
|
||||
id="ep_001",
|
||||
body="## ep_001\n\n**k**: v\n\n### Body\nthe body",
|
||||
start=42,
|
||||
end=128,
|
||||
)
|
||||
s = entry.as_structured()
|
||||
assert isinstance(s, StructuredEntry)
|
||||
assert s.id == "ep_001"
|
||||
assert s.start == 42
|
||||
assert s.end == 128
|
||||
assert s.header == "ep_001"
|
||||
assert s.inline == {"k": "v"}
|
||||
assert s.sections == {"Body": "the body"}
|
||||
|
||||
|
||||
# ── round-trip with realistic Episode entry ─────────────────────────────
|
||||
|
||||
|
||||
def test_round_trip_episode_shape() -> None:
|
||||
"""Mirrors the shape from the wiki Memory Types doc."""
|
||||
inline = {
|
||||
"type": "episode",
|
||||
"user_id": "u_jason",
|
||||
"group_id": "sp_1",
|
||||
"session_id": "sess_abc123",
|
||||
"timestamp": "2026-04-22T10:03:11Z",
|
||||
"parent_type": "memcell",
|
||||
"parent_id": "mc_20260422_001",
|
||||
"participants": ["u_jason", "u_sarah"],
|
||||
"subject": "weekend planning",
|
||||
}
|
||||
sections = {
|
||||
"Summary": "Jason and Sarah discussed weekend coffee plans.",
|
||||
"Episode": "At ten in the morning, while making coffee, Jason told Sarah...",
|
||||
}
|
||||
rendered = render_structured_entry(
|
||||
header="ep_20260422_001",
|
||||
inline=inline,
|
||||
sections=sections,
|
||||
)
|
||||
entry = parse_structured_entry(rendered)
|
||||
assert entry.header == "ep_20260422_001"
|
||||
# Lists become string in audit form.
|
||||
assert entry.inline["participants"] == "[u_jason, u_sarah]"
|
||||
# Scalars round-trip exactly.
|
||||
assert entry.inline["session_id"] == "sess_abc123"
|
||||
assert entry.sections == sections
|
||||
@ -0,0 +1,87 @@
|
||||
"""Markdown IO toolkit — typical workflow demo.
|
||||
|
||||
Doubles as living documentation for how a caller assembles + reads a
|
||||
day-level markdown file with multiple ``<!-- entry:id -->`` records.
|
||||
|
||||
End-to-end story:
|
||||
1. Build a body that contains entry markers.
|
||||
2. Use ``MarkdownWriter.write_markdown`` to persist frontmatter + body
|
||||
atomically (tmp file + fsync + rename, all inside the target dir).
|
||||
3. Use ``MarkdownReader.read`` to parse the resulting file back into
|
||||
a ``ParsedMarkdown`` (frontmatter dict + raw body + list[Entry]).
|
||||
4. Verify each entry's id / body matches what was written.
|
||||
5. Look up a single entry by id with ``find_entry``.
|
||||
6. Round-trip: dump_frontmatter + body → parse_frontmatter recovers
|
||||
the original mapping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from everos.core.persistence import (
|
||||
MarkdownReader,
|
||||
MarkdownWriter,
|
||||
MemoryRoot,
|
||||
dump_frontmatter,
|
||||
find_entry,
|
||||
parse_frontmatter,
|
||||
)
|
||||
|
||||
|
||||
async def test_typical_day_log_write_then_read(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
writer = MarkdownWriter(mr)
|
||||
|
||||
# 1. Build a body with two entries (typical day-level append log).
|
||||
body = (
|
||||
"# Day log\n"
|
||||
"\n"
|
||||
"<!-- entry:ep_001 -->\n"
|
||||
"**Title**: Met Alice\n"
|
||||
"We discussed the new project layout.\n"
|
||||
"<!-- /entry:ep_001 -->\n"
|
||||
"\n"
|
||||
"<!-- entry:ep_002 -->\n"
|
||||
"**Title**: Read paper X\n"
|
||||
"Key idea: end-to-end async pipelines.\n"
|
||||
"<!-- /entry:ep_002 -->\n"
|
||||
)
|
||||
frontmatter = {
|
||||
"type": "episodic_day_log",
|
||||
"date": "2026-04-22",
|
||||
"user_id": "u_jason",
|
||||
"tags": ["meeting", "research"],
|
||||
}
|
||||
|
||||
# 2. Atomic write via the writer.
|
||||
target = mr.users_dir() / "u_jason" / "episodic" / "2026-04-22.md"
|
||||
written_path = await writer.write_markdown(
|
||||
target, frontmatter=frontmatter, body=body
|
||||
)
|
||||
assert written_path == target
|
||||
assert target.is_file()
|
||||
# No leftover temp file.
|
||||
leftover = list(target.parent.glob(f".{target.name}.tmp.*"))
|
||||
assert leftover == []
|
||||
|
||||
# 3. Read back into ParsedMarkdown.
|
||||
parsed = await MarkdownReader.read(target)
|
||||
|
||||
# 4. Validate frontmatter + entries.
|
||||
assert parsed.frontmatter == frontmatter
|
||||
assert [e.id for e in parsed.entries] == ["ep_001", "ep_002"]
|
||||
assert "Met Alice" in parsed.entries[0].body
|
||||
assert "Read paper X" in parsed.entries[1].body
|
||||
|
||||
# 5. Single-entry lookup.
|
||||
e2 = find_entry(parsed.body, "ep_002")
|
||||
assert e2 is not None
|
||||
assert "async pipelines" in e2.body
|
||||
|
||||
# 6. Round-trip frontmatter parse / dump.
|
||||
composed = dump_frontmatter(frontmatter) + body
|
||||
re_meta, re_body = parse_frontmatter(composed)
|
||||
assert re_meta == frontmatter
|
||||
assert re_body == body
|
||||
@ -0,0 +1,229 @@
|
||||
"""Unit tests for MarkdownWriter (atomic write)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import (
|
||||
EntryId,
|
||||
MarkdownReader,
|
||||
MarkdownWriter,
|
||||
MemoryRoot,
|
||||
)
|
||||
|
||||
|
||||
def _make_writer(tmp_path: Path) -> MarkdownWriter:
|
||||
return MarkdownWriter(MemoryRoot(tmp_path))
|
||||
|
||||
|
||||
async def test_write_creates_file_with_content(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "users" / "u1" / "out.md"
|
||||
result = await writer.write(target, "hello\n")
|
||||
assert result == target
|
||||
assert target.read_text(encoding="utf-8") == "hello\n"
|
||||
|
||||
|
||||
async def test_write_creates_parent_directories(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "a" / "b" / "c" / "f.md"
|
||||
await writer.write(target, "x")
|
||||
assert target.is_file()
|
||||
|
||||
|
||||
async def test_write_overwrites_existing(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "f.md"
|
||||
target.write_text("old", encoding="utf-8")
|
||||
await writer.write(target, "new")
|
||||
assert target.read_text(encoding="utf-8") == "new"
|
||||
|
||||
|
||||
async def test_write_no_temp_file_left_after_success(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "f.md"
|
||||
await writer.write(target, "ok")
|
||||
leftovers = [
|
||||
p.name
|
||||
for p in tmp_path.iterdir() # noqa: ASYNC240 — sync iterdir over a pytest tmp_path is fine in tests
|
||||
if p.name.startswith(".f.md.tmp.")
|
||||
]
|
||||
assert leftovers == []
|
||||
|
||||
|
||||
async def test_write_cleans_up_temp_on_failure(tmp_path: Path) -> None:
|
||||
"""If os.replace fails, the temp file should be cleaned up."""
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "f.md"
|
||||
|
||||
boom = OSError("simulated rename failure")
|
||||
with (
|
||||
patch("everos.core.persistence.markdown.writer.os.replace", side_effect=boom),
|
||||
pytest.raises(OSError, match="simulated"),
|
||||
):
|
||||
await writer.write(target, "hello")
|
||||
|
||||
# No tmp file leftover, and the target was not created.
|
||||
leftovers = [
|
||||
p.name
|
||||
for p in tmp_path.iterdir() # noqa: ASYNC240 — sync iterdir over a pytest tmp_path is fine in tests
|
||||
if p.name.startswith(".f.md.tmp.")
|
||||
]
|
||||
assert leftovers == []
|
||||
assert not target.exists()
|
||||
|
||||
|
||||
async def test_write_markdown_assembles_frontmatter_and_body(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "doc.md"
|
||||
await writer.write_markdown(
|
||||
target,
|
||||
frontmatter={"title": "Hello"},
|
||||
body="# Body\n",
|
||||
)
|
||||
text = target.read_text(encoding="utf-8")
|
||||
assert text.startswith("---\n")
|
||||
assert "title: Hello" in text
|
||||
assert text.rstrip("\n").endswith("# Body")
|
||||
|
||||
|
||||
async def test_write_markdown_round_trip(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "rt.md"
|
||||
await writer.write_markdown(
|
||||
target,
|
||||
frontmatter={"k": "v", "n": 1},
|
||||
body="<!-- entry:x -->\ncontent\n<!-- /entry:x -->\n",
|
||||
)
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert parsed.frontmatter == {"k": "v", "n": 1}
|
||||
assert len(parsed.entries) == 1
|
||||
assert parsed.entries[0].body == "content"
|
||||
|
||||
|
||||
async def test_write_markdown_no_frontmatter(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "plain.md"
|
||||
await writer.write_markdown(target, body="just body\n")
|
||||
assert target.read_text(encoding="utf-8") == "just body\n"
|
||||
|
||||
|
||||
def test_memory_root_property_accessible(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
assert writer.memory_root.root == tmp_path.resolve()
|
||||
|
||||
|
||||
# ── append_entry ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_append_entry_creates_file_when_missing(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "memcells" / "memcell-2026-04-22.md"
|
||||
eid = EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=1)
|
||||
written = await writer.append_entry(
|
||||
target,
|
||||
entry_body="hello world",
|
||||
entry_id=eid,
|
||||
frontmatter_updates={
|
||||
"file_type": "memcell_daily",
|
||||
"entry_count": 1,
|
||||
},
|
||||
)
|
||||
assert written == target
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert parsed.frontmatter == {"file_type": "memcell_daily", "entry_count": 1}
|
||||
assert len(parsed.entries) == 1
|
||||
assert parsed.entries[0].id == "umc_20260422_00000001"
|
||||
assert parsed.entries[0].body == "hello world"
|
||||
|
||||
|
||||
async def test_append_entry_appends_to_existing(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "log.md"
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body="first",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=1),
|
||||
frontmatter_updates={"entry_count": 1},
|
||||
)
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body="second",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=2),
|
||||
frontmatter_updates={"entry_count": 2},
|
||||
)
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert [e.id for e in parsed.entries] == [
|
||||
"umc_20260422_00000001",
|
||||
"umc_20260422_00000002",
|
||||
]
|
||||
assert [e.body for e in parsed.entries] == ["first", "second"]
|
||||
|
||||
|
||||
async def test_append_entry_merges_frontmatter_shallow(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "log.md"
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body="b",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=1),
|
||||
frontmatter_updates={
|
||||
"file_type": "memcell_daily",
|
||||
"entry_count": 1,
|
||||
"last_appended_at": "2026-04-22T10:00:00Z",
|
||||
},
|
||||
)
|
||||
# Second append — overwrite entry_count + last_appended_at, keep file_type.
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body="b",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=2),
|
||||
frontmatter_updates={
|
||||
"entry_count": 2,
|
||||
"last_appended_at": "2026-04-22T10:05:00Z",
|
||||
},
|
||||
)
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert parsed.frontmatter == {
|
||||
"file_type": "memcell_daily",
|
||||
"entry_count": 2,
|
||||
"last_appended_at": "2026-04-22T10:05:00Z",
|
||||
}
|
||||
|
||||
|
||||
async def test_append_entry_without_frontmatter_updates_keeps_existing(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "log.md"
|
||||
await writer.write_markdown(target, frontmatter={"file_type": "x", "n": 1}, body="")
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body="body",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=1),
|
||||
)
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert parsed.frontmatter == {"file_type": "x", "n": 1}
|
||||
assert len(parsed.entries) == 1
|
||||
|
||||
|
||||
async def test_append_entry_round_trip_with_reader(tmp_path: Path) -> None:
|
||||
writer = _make_writer(tmp_path)
|
||||
target = tmp_path / "log.md"
|
||||
for i in range(5):
|
||||
await writer.append_entry(
|
||||
target,
|
||||
entry_body=f"content {i}",
|
||||
entry_id=EntryId(prefix="umc", date=dt.date(2026, 4, 22), seq=i + 1),
|
||||
frontmatter_updates={"entry_count": i + 1},
|
||||
)
|
||||
parsed = await MarkdownReader.read(target)
|
||||
assert len(parsed.entries) == 5
|
||||
assert parsed.frontmatter["entry_count"] == 5
|
||||
for i, e in enumerate(parsed.entries):
|
||||
assert e.id == f"umc_20260422_{i + 1:08d}"
|
||||
assert e.body == f"content {i}"
|
||||
@ -0,0 +1,200 @@
|
||||
"""Regression tests for the MarkdownWriter read-modify-write race.
|
||||
|
||||
Before the per-path :class:`asyncio.Lock` was added, two concurrent tasks
|
||||
calling :meth:`MarkdownWriter.append_entry` against the same path would
|
||||
each load the file, append one entry block in memory, and write the
|
||||
merged file back — the second writer's read pre-dated the first
|
||||
writer's write, so it overwrote the first writer's append. Both
|
||||
``entry_count`` (frontmatter) and the entry block markers were lost in
|
||||
proportion to concurrency level.
|
||||
|
||||
These tests drive ``N`` concurrent appends against one ``(owner, date)``
|
||||
and assert that no entry is lost at any concurrency level. They cover
|
||||
both the single-entry ``append_entry`` path (taken by tests / external
|
||||
callers) and the batched ``append_entries`` path (taken by strategies
|
||||
after the per-owner batching migration).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import EntryId, MarkdownWriter, MemoryRoot
|
||||
from everos.infra.persistence.markdown.writers.atomic_fact_writer import (
|
||||
AtomicFactWriter,
|
||||
)
|
||||
|
||||
|
||||
def _scan_md(md_path: Path) -> tuple[int, int]:
|
||||
"""Return ``(entry_tag_count, frontmatter_entry_count)``."""
|
||||
text = md_path.read_text(encoding="utf-8")
|
||||
tag_count = len(re.findall(r"<!-- entry:af_", text))
|
||||
fm_match = re.search(r"^entry_count: (\d+)", text, re.MULTILINE)
|
||||
fm_count = int(fm_match.group(1)) if fm_match else -1
|
||||
return tag_count, fm_count
|
||||
|
||||
|
||||
async def _drive_concurrent_appends(
|
||||
writer: AtomicFactWriter,
|
||||
owner: str,
|
||||
n: int,
|
||||
concurrency: int,
|
||||
) -> None:
|
||||
"""Issue ``n`` single-entry ``append_entry`` calls with bounded concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _guarded(idx: int) -> None:
|
||||
async with sem:
|
||||
await writer.append_entry(
|
||||
owner,
|
||||
inline={
|
||||
"owner_id": owner,
|
||||
"session_id": "race_test",
|
||||
"timestamp": "2026-05-18T00:00:00+00:00",
|
||||
"parent_type": "memcell",
|
||||
"parent_id": f"mc_{idx:04d}",
|
||||
},
|
||||
sections={"Fact": f"fact-{idx:04d}"},
|
||||
)
|
||||
|
||||
await asyncio.gather(*(_guarded(i) for i in range(n)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("concurrency", [1, 2, 4, 8, 16])
|
||||
async def test_append_entry_no_lost_updates_under_concurrency(
|
||||
tmp_path: Path, concurrency: int
|
||||
) -> None:
|
||||
"""``append_entry`` from N concurrent tasks must not drop any entry."""
|
||||
root = MemoryRoot(root=tmp_path)
|
||||
writer = AtomicFactWriter(root=root)
|
||||
owner = "race_user"
|
||||
n = 30
|
||||
|
||||
await _drive_concurrent_appends(writer, owner, n, concurrency)
|
||||
|
||||
md_files = list((root.users_dir() / owner).rglob("*.md"))
|
||||
assert len(md_files) == 1, f"expected 1 md file, got {md_files}"
|
||||
tag_count, fm_count = _scan_md(md_files[0])
|
||||
|
||||
assert tag_count == n, (
|
||||
f"lost {n - tag_count} entries at concurrency={concurrency} "
|
||||
f"(tag_count={tag_count}, expected={n})"
|
||||
)
|
||||
assert fm_count == n, (
|
||||
f"frontmatter entry_count drift at concurrency={concurrency} "
|
||||
f"(fm_count={fm_count}, expected={n})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("concurrency", [1, 2, 4, 8, 16])
|
||||
async def test_append_entries_batch_no_lost_updates_under_concurrency(
|
||||
tmp_path: Path, concurrency: int
|
||||
) -> None:
|
||||
"""``append_entries`` (batched) from N concurrent tasks must not drop any
|
||||
entry."""
|
||||
root = MemoryRoot(root=tmp_path)
|
||||
writer = AtomicFactWriter(root=root)
|
||||
owner = "race_user_batched"
|
||||
batches = 6
|
||||
items_per_batch = 5
|
||||
total = batches * items_per_batch
|
||||
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _one_batch(batch_idx: int) -> None:
|
||||
async with sem:
|
||||
items = [
|
||||
(
|
||||
{
|
||||
"owner_id": owner,
|
||||
"session_id": "race_test",
|
||||
"timestamp": "2026-05-18T00:00:00+00:00",
|
||||
"parent_type": "memcell",
|
||||
"parent_id": f"mc_b{batch_idx:02d}_i{i:02d}",
|
||||
},
|
||||
{"Fact": f"batched-fact-b{batch_idx:02d}-{i:02d}"},
|
||||
)
|
||||
for i in range(items_per_batch)
|
||||
]
|
||||
await writer.append_entries(owner, items)
|
||||
|
||||
await asyncio.gather(*(_one_batch(b) for b in range(batches)))
|
||||
|
||||
md_files = list((root.users_dir() / owner).rglob("*.md"))
|
||||
assert len(md_files) == 1
|
||||
tag_count, fm_count = _scan_md(md_files[0])
|
||||
|
||||
assert tag_count == total, (
|
||||
f"lost {total - tag_count} entries at concurrency={concurrency} "
|
||||
f"(tag_count={tag_count}, expected={total})"
|
||||
)
|
||||
assert fm_count == total, (
|
||||
f"frontmatter entry_count drift at concurrency={concurrency} "
|
||||
f"(fm_count={fm_count}, expected={total})"
|
||||
)
|
||||
|
||||
|
||||
async def test_lock_for_returns_same_lock_per_path(tmp_path: Path) -> None:
|
||||
"""``lock_for`` is the keying primitive that BaseDailyWriter relies on
|
||||
to serialise its multi-step read-compute-write sequence; aliasing paths
|
||||
must collapse to one lock object."""
|
||||
writer = MarkdownWriter(MemoryRoot(root=tmp_path))
|
||||
p1 = tmp_path / "foo" / "bar.md"
|
||||
p2 = tmp_path / "foo" / "bar.md"
|
||||
p3 = tmp_path / "foo" / ".." / "foo" / "bar.md"
|
||||
|
||||
lock1 = writer.lock_for(p1)
|
||||
lock2 = writer.lock_for(p2)
|
||||
lock3 = writer.lock_for(p3)
|
||||
|
||||
# Same canonical path → identical Lock object.
|
||||
assert lock1 is lock2
|
||||
assert lock1 is lock3
|
||||
|
||||
# Different path → different Lock.
|
||||
other = writer.lock_for(tmp_path / "foo" / "baz.md")
|
||||
assert other is not lock1
|
||||
|
||||
|
||||
async def test_append_entries_empty_is_noop(tmp_path: Path) -> None:
|
||||
"""Empty batch must not touch the file or allocate any EntryId."""
|
||||
writer = MarkdownWriter(MemoryRoot(root=tmp_path))
|
||||
target = tmp_path / "scratch.md"
|
||||
result = await writer.append_entries(target, [])
|
||||
assert result == target
|
||||
# No file should have been created (empty body + no frontmatter updates
|
||||
# still calls write_markdown — verify the file is empty or absent).
|
||||
if target.exists():
|
||||
assert target.read_text(encoding="utf-8") in ("", "---\n---\n\n")
|
||||
|
||||
|
||||
async def test_markdown_writer_append_entry_delegates_to_batch(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""``append_entry`` is documented as a wrapper for ``append_entries`` —
|
||||
asserting they produce identical file contents protects callers from
|
||||
drift between the two paths."""
|
||||
writer = MarkdownWriter(MemoryRoot(root=tmp_path))
|
||||
eid = EntryId.next_for("af", __import__("datetime").date(2026, 5, 18), 0)
|
||||
body = "**fact**: hello"
|
||||
|
||||
path_a = tmp_path / "a.md"
|
||||
path_b = tmp_path / "b.md"
|
||||
|
||||
await writer.append_entry(
|
||||
path_a,
|
||||
entry_body=body,
|
||||
entry_id=eid,
|
||||
frontmatter_updates={"id": "shared", "entry_count": 1},
|
||||
)
|
||||
await writer.append_entries(
|
||||
path_b,
|
||||
[(body, eid)],
|
||||
frontmatter_updates={"id": "shared", "entry_count": 1},
|
||||
)
|
||||
|
||||
assert path_a.read_text(encoding="utf-8") == path_b.read_text(encoding="utf-8")
|
||||
126
tests/unit/test_core/test_persistence/test_memory_root.py
Normal file
126
tests/unit/test_core/test_persistence/test_memory_root.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""Unit tests for MemoryRoot path manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import MemoryRoot
|
||||
|
||||
|
||||
def test_default_returns_home_everos(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Isolate from any ambient EVEROS_MEMORY__ROOT (e.g. the session-scoped
|
||||
# search-corpus fixture sets it for the whole run); the autouse
|
||||
# _reset_settings_cache fixture clears the load_settings cache, so the
|
||||
# delenv takes effect for this assertion of the hard-coded default.
|
||||
monkeypatch.delenv("EVEROS_MEMORY__ROOT", raising=False)
|
||||
mr = MemoryRoot.default()
|
||||
assert mr.root == (Path.home() / ".everos").resolve()
|
||||
|
||||
|
||||
def test_accepts_str_path(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(str(tmp_path))
|
||||
assert mr.root == tmp_path.resolve()
|
||||
|
||||
|
||||
def test_accepts_pathlib_path(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
assert mr.root == tmp_path.resolve()
|
||||
|
||||
|
||||
def test_user_visible_dirs_default_scope(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
# Omitting app/project resolves to the default space; "default" lands as
|
||||
# the reserved ``default_app`` / ``default_project`` directory names.
|
||||
base = mr.root / "default_app" / "default_project"
|
||||
assert mr.agents_dir() == base / "agents"
|
||||
assert mr.users_dir() == base / "users"
|
||||
assert mr.knowledge_dir() == base / "knowledge"
|
||||
|
||||
|
||||
def test_user_visible_dirs_named_scope(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
# A non-default app/project maps to itself (no ``default_*`` rewrite).
|
||||
base = mr.root / "claude_code" / "oss"
|
||||
assert mr.agents_dir("claude_code", "oss") == base / "agents"
|
||||
assert mr.users_dir("claude_code", "oss") == base / "users"
|
||||
assert mr.knowledge_dir("claude_code", "oss") == base / "knowledge"
|
||||
|
||||
|
||||
def test_dotfile_paths(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
assert mr.index_dir == tmp_path / ".index"
|
||||
assert mr.lancedb_dir == tmp_path / ".index" / "lancedb"
|
||||
assert mr.sqlite_dir == tmp_path / ".index" / "sqlite"
|
||||
assert mr.system_db == tmp_path / ".index" / "sqlite" / "system.db"
|
||||
assert mr.lock_file == tmp_path / ".lock"
|
||||
assert mr.tmp_dir == tmp_path / ".tmp"
|
||||
|
||||
|
||||
def test_ensure_creates_required_dirs(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path / "fresh")
|
||||
mr.ensure()
|
||||
assert mr.root.is_dir()
|
||||
assert mr.index_dir.is_dir()
|
||||
assert mr.sqlite_dir.is_dir()
|
||||
assert mr.lancedb_dir.is_dir()
|
||||
assert mr.tmp_dir.is_dir()
|
||||
# User-visible dirs are NOT pre-created.
|
||||
assert not mr.agents_dir().exists()
|
||||
assert not mr.users_dir().exists()
|
||||
assert not mr.knowledge_dir().exists()
|
||||
|
||||
|
||||
def test_ensure_is_idempotent(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
mr.ensure() # second call must not fail
|
||||
assert mr.tmp_dir.is_dir()
|
||||
|
||||
|
||||
def test_ensure_materializes_ome_config_template(tmp_path: Path) -> None:
|
||||
"""First ensure() drops a real ``ome.toml`` users can edit.
|
||||
|
||||
Without this, ``pip install everos && everos server start`` produced
|
||||
a warning (``config_reload_failed: No such file``) because the OME
|
||||
config reloader had no file to point at. The template ships under
|
||||
``src/everos/config/default_ome.toml`` and is byte-copied on first run.
|
||||
"""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
assert mr.ome_config.is_file()
|
||||
# Content is the shipped template verbatim — protects against a future
|
||||
# diff that silently changes what users see on first run.
|
||||
template = Path(__file__).resolve().parents[4] / (
|
||||
"src/everos/config/default_ome.toml"
|
||||
)
|
||||
assert mr.ome_config.read_bytes() == template.read_bytes()
|
||||
|
||||
|
||||
def test_ensure_preserves_user_edited_ome_config(tmp_path: Path) -> None:
|
||||
"""Second ensure() must not overwrite user edits.
|
||||
|
||||
The template materialisation is an existence check, not a content
|
||||
sync — once the user has tweaked their overrides the file is theirs.
|
||||
"""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
custom = b"# user-edited\n[strategies.extract_foresight]\nenabled = false\n"
|
||||
mr.ome_config.write_bytes(custom)
|
||||
mr.ensure()
|
||||
assert mr.ome_config.read_bytes() == custom
|
||||
|
||||
|
||||
def test_frozen_dataclass_hashable(tmp_path: Path) -> None:
|
||||
a = MemoryRoot(tmp_path)
|
||||
b = MemoryRoot(tmp_path)
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert {a, b} == {a} # set deduplication works
|
||||
|
||||
|
||||
def test_user_expansion(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
mr = MemoryRoot("~/custom")
|
||||
assert mr.root == (tmp_path / "custom").resolve()
|
||||
113
tests/unit/test_core/test_persistence/test_sqlite/test_engine.py
Normal file
113
tests/unit/test_core/test_persistence/test_sqlite/test_engine.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Unit tests for the SQLite async engine + PRAGMA listener.
|
||||
|
||||
Critical: verifies PRAGMAs are actually applied at the SQLite layer
|
||||
(not just declared in code). The whole reason for the listener is that
|
||||
PRAGMAs are per-connection and the SA pool reuses connections.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from everos.config import SqliteSettings
|
||||
from everos.core.persistence import (
|
||||
MemoryRoot,
|
||||
create_session_factory,
|
||||
create_system_engine,
|
||||
session_scope,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
async def test_engine_creates_db_file(memory_root: MemoryRoot) -> None:
|
||||
engine = create_system_engine(memory_root.system_db, SqliteSettings())
|
||||
factory = create_session_factory(engine)
|
||||
async with session_scope(factory) as s:
|
||||
await s.execute(text("SELECT 1"))
|
||||
await engine.dispose()
|
||||
assert memory_root.system_db.exists()
|
||||
|
||||
|
||||
async def test_pragmas_actually_applied_default_settings(
|
||||
memory_root: MemoryRoot,
|
||||
) -> None:
|
||||
"""Default PRAGMAs match what's in default.toml."""
|
||||
settings = SqliteSettings()
|
||||
engine = create_system_engine(memory_root.system_db, settings)
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
async with session_scope(factory) as s:
|
||||
assert _scalar(await _pragma(s, "journal_mode")) == "wal"
|
||||
# synchronous: 0=OFF 1=NORMAL 2=FULL 3=EXTRA
|
||||
assert _scalar(await _pragma(s, "synchronous")) == 1
|
||||
# foreign_keys: 1=ON 0=OFF
|
||||
assert _scalar(await _pragma(s, "foreign_keys")) == 1
|
||||
# temp_store: 0=DEFAULT 1=FILE 2=MEMORY
|
||||
assert _scalar(await _pragma(s, "temp_store")) == 2
|
||||
assert _scalar(await _pragma(s, "busy_timeout")) == 5000
|
||||
assert _scalar(await _pragma(s, "journal_size_limit")) == 64 * 1024 * 1024
|
||||
# cache_size: negative value = KB; positive = pages
|
||||
assert _scalar(await _pragma(s, "cache_size")) == -2048
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_pragmas_respect_custom_settings(memory_root: MemoryRoot) -> None:
|
||||
"""Engine reflects non-default tunables."""
|
||||
settings = SqliteSettings(
|
||||
journal_mode="DELETE",
|
||||
synchronous="FULL",
|
||||
foreign_keys=False,
|
||||
temp_store="FILE",
|
||||
busy_timeout_ms=10000,
|
||||
journal_size_limit_bytes=1024 * 1024,
|
||||
cache_size_kb=4096,
|
||||
)
|
||||
engine = create_system_engine(memory_root.system_db, settings)
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
async with session_scope(factory) as s:
|
||||
assert _scalar(await _pragma(s, "journal_mode")) == "delete"
|
||||
assert _scalar(await _pragma(s, "synchronous")) == 2 # FULL
|
||||
assert _scalar(await _pragma(s, "foreign_keys")) == 0
|
||||
assert _scalar(await _pragma(s, "temp_store")) == 1 # FILE
|
||||
assert _scalar(await _pragma(s, "busy_timeout")) == 10000
|
||||
assert _scalar(await _pragma(s, "cache_size")) == -4096
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_pragmas_applied_on_each_new_connection(
|
||||
memory_root: MemoryRoot,
|
||||
) -> None:
|
||||
"""The listener fires on every new connection from the pool, not just once."""
|
||||
settings = SqliteSettings()
|
||||
engine = create_system_engine(memory_root.system_db, settings)
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
# Two independent sessions → at least two connection acquisitions
|
||||
# → both must see WAL mode.
|
||||
async with session_scope(factory) as s1:
|
||||
assert _scalar(await _pragma(s1, "journal_mode")) == "wal"
|
||||
async with session_scope(factory) as s2:
|
||||
assert _scalar(await _pragma(s2, "journal_mode")) == "wal"
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _pragma(session, name: str): # type: ignore[no-untyped-def]
|
||||
return await session.execute(text(f"PRAGMA {name}"))
|
||||
|
||||
|
||||
def _scalar(result): # type: ignore[no-untyped-def]
|
||||
row = result.fetchone()
|
||||
return row[0] if row is not None else None
|
||||
@ -0,0 +1,126 @@
|
||||
"""ORM CRUD demo: full INSERT / SELECT / UPDATE / DELETE on a BaseTable.
|
||||
|
||||
Doubles as living documentation for how to author a SQLModel-backed
|
||||
business table inside the everos persistence stack:
|
||||
|
||||
1. Subclass ``BaseTable`` (gets ``created_at`` / ``updated_at`` for free).
|
||||
2. Build a session factory from a real engine.
|
||||
3. Use ``session_scope`` for the transaction lifecycle.
|
||||
4. Verify ``updated_at`` auto-bumps on UPDATE.
|
||||
|
||||
The local table name is prefixed with ``_`` so it cannot be confused with
|
||||
a real business table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlmodel import SQLModel, select
|
||||
|
||||
from everos.config import SqliteSettings
|
||||
from everos.core.persistence import (
|
||||
BaseTable,
|
||||
Field,
|
||||
MemoryRoot,
|
||||
create_session_factory,
|
||||
create_system_engine,
|
||||
session_scope,
|
||||
)
|
||||
|
||||
|
||||
class _DemoNote(BaseTable, table=True):
|
||||
"""Tiny demo table — used only by this test module."""
|
||||
|
||||
__tablename__ = "_demo_notes" # type: ignore[assignment]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
body: str
|
||||
tags: str | None = Field(default=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
async def test_orm_full_crud_lifecycle(memory_root: MemoryRoot) -> None:
|
||||
engine = create_system_engine(memory_root.system_db, SqliteSettings())
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
# ── Create schema ───────────────────────────────────────────────
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# ── INSERT ──────────────────────────────────────────────────────
|
||||
async with session_scope(factory) as s:
|
||||
note = _DemoNote(body="hello")
|
||||
s.add(note)
|
||||
await s.commit()
|
||||
await s.refresh(note)
|
||||
assert note.id is not None
|
||||
assert note.created_at is not None
|
||||
assert note.updated_at is not None
|
||||
# default_factory runs once per field, so the two timestamps
|
||||
# may differ by a few microseconds on INSERT. Order must hold.
|
||||
assert note.created_at <= note.updated_at
|
||||
note_id = note.id
|
||||
initial_created = note.created_at
|
||||
initial_updated = note.updated_at
|
||||
|
||||
# ── SELECT (single by id) ───────────────────────────────────────
|
||||
async with session_scope(factory) as s:
|
||||
stmt = select(_DemoNote).where(_DemoNote.id == note_id)
|
||||
result = (await s.execute(stmt)).scalars().first()
|
||||
assert result is not None
|
||||
assert result.body == "hello"
|
||||
|
||||
# ── SELECT (filter + order) ─────────────────────────────────────
|
||||
async with session_scope(factory) as s:
|
||||
s.add(_DemoNote(body="second"))
|
||||
s.add(_DemoNote(body="third"))
|
||||
await s.commit()
|
||||
|
||||
async with session_scope(factory) as s:
|
||||
stmt = select(_DemoNote).order_by(_DemoNote.id)
|
||||
rows = (await s.execute(stmt)).scalars().all()
|
||||
assert [r.body for r in rows] == ["hello", "second", "third"]
|
||||
|
||||
# ── UPDATE (verify updated_at auto-bumps) ───────────────────────
|
||||
# Sleep slightly so onupdate has a measurably newer timestamp
|
||||
# than the initial insert (timestamp resolution is fine but the
|
||||
# comparison should be ``>=`` to be robust on fast machines).
|
||||
await asyncio.sleep(0.01)
|
||||
async with session_scope(factory) as s:
|
||||
stmt = select(_DemoNote).where(_DemoNote.id == note_id)
|
||||
n = (await s.execute(stmt)).scalars().first()
|
||||
assert n is not None
|
||||
n.body = "hello world"
|
||||
n.tags = "demo"
|
||||
await s.commit()
|
||||
await s.refresh(n)
|
||||
assert n.body == "hello world"
|
||||
assert n.tags == "demo"
|
||||
assert n.updated_at >= initial_updated # bumped via onupdate
|
||||
assert n.created_at == initial_created # unchanged on update
|
||||
|
||||
# ── DELETE ──────────────────────────────────────────────────────
|
||||
async with session_scope(factory) as s:
|
||||
stmt = select(_DemoNote).where(_DemoNote.id == note_id)
|
||||
n = (await s.execute(stmt)).scalars().first()
|
||||
assert n is not None
|
||||
await s.delete(n)
|
||||
await s.commit()
|
||||
|
||||
async with session_scope(factory) as s:
|
||||
count_stmt = select(_DemoNote).where(_DemoNote.id == note_id)
|
||||
assert (await s.execute(count_stmt)).scalars().first() is None
|
||||
# Other rows survive
|
||||
remaining = (await s.execute(select(_DemoNote))).scalars().all()
|
||||
assert {r.body for r in remaining} == {"second", "third"}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
@ -0,0 +1,160 @@
|
||||
"""RepoBase CRUD demo + assertions.
|
||||
|
||||
Doubles as living documentation for how a service / memory layer caller
|
||||
uses the generic repository — no manual session handling. Exercises the
|
||||
explicit-factory constructor path; the lazy ``_factory_lookup`` hook is
|
||||
exercised indirectly via the lifespan + manager tests once business
|
||||
repos land under ``infra/.../repos/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from everos.config import SqliteSettings
|
||||
from everos.core.persistence import (
|
||||
BaseTable,
|
||||
Field,
|
||||
MemoryRoot,
|
||||
RepoBase,
|
||||
create_session_factory,
|
||||
create_system_engine,
|
||||
)
|
||||
|
||||
|
||||
class _DemoUser(BaseTable, table=True):
|
||||
"""Demo table — only used by this test module."""
|
||||
|
||||
__tablename__ = "_demo_users" # type: ignore[assignment]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
active: bool = Field(default=True)
|
||||
|
||||
|
||||
class _DemoUserRepo(RepoBase[_DemoUser]):
|
||||
model = _DemoUser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
async def _setup_repo(memory_root: MemoryRoot) -> tuple[_DemoUserRepo, object]:
|
||||
"""Build engine, factory, and ensure schema. Returns (repo, engine)."""
|
||||
engine = create_system_engine(memory_root.system_db, SqliteSettings())
|
||||
factory = create_session_factory(engine)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
return _DemoUserRepo(factory), engine
|
||||
|
||||
|
||||
async def test_repo_add_and_get(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
added = await repo.add(_DemoUser(name="alice"))
|
||||
assert added.id is not None
|
||||
assert added.created_at is not None
|
||||
|
||||
fetched = await repo.get_by_id(added.id)
|
||||
assert fetched is not None
|
||||
assert fetched.name == "alice"
|
||||
assert fetched.active is True
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_repo_add_many_and_list_all(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
users = await repo.add_many(
|
||||
[
|
||||
_DemoUser(name="alice"),
|
||||
_DemoUser(name="bob"),
|
||||
_DemoUser(name="carol", active=False),
|
||||
]
|
||||
)
|
||||
assert all(u.id is not None for u in users)
|
||||
|
||||
all_users = await repo.list_all()
|
||||
assert {u.name for u in all_users} == {"alice", "bob", "carol"}
|
||||
|
||||
assert await repo.count() == 3
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_repo_find_where_and_find_one(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
await repo.add_many(
|
||||
[
|
||||
_DemoUser(name="alice", active=True),
|
||||
_DemoUser(name="bob", active=False),
|
||||
_DemoUser(name="carol", active=True),
|
||||
]
|
||||
)
|
||||
|
||||
actives = await repo.find_where(active=True)
|
||||
assert {u.name for u in actives} == {"alice", "carol"}
|
||||
|
||||
bob = await repo.find_one(name="bob")
|
||||
assert bob is not None
|
||||
assert bob.active is False
|
||||
|
||||
ghost = await repo.find_one(name="no_such")
|
||||
assert ghost is None
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_repo_update_bumps_updated_at(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
u = await repo.add(_DemoUser(name="alice"))
|
||||
original_updated = u.updated_at
|
||||
original_created = u.created_at
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
u.name = "alice2"
|
||||
u.active = False
|
||||
updated = await repo.update(u)
|
||||
|
||||
assert updated.name == "alice2"
|
||||
assert updated.active is False
|
||||
assert updated.updated_at >= original_updated # bumped
|
||||
assert updated.created_at == original_created
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_repo_delete_by_instance(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
u = await repo.add(_DemoUser(name="alice"))
|
||||
assert await repo.count() == 1
|
||||
|
||||
await repo.delete(u)
|
||||
assert await repo.count() == 0
|
||||
assert await repo.get_by_id(u.id) is None
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_repo_delete_by_id_returns_bool(memory_root: MemoryRoot) -> None:
|
||||
repo, engine = await _setup_repo(memory_root)
|
||||
try:
|
||||
u = await repo.add(_DemoUser(name="alice"))
|
||||
|
||||
assert await repo.delete_by_id(u.id) is True
|
||||
assert await repo.delete_by_id(u.id) is False # already gone
|
||||
assert await repo.delete_by_id(99999) is False # never existed
|
||||
finally:
|
||||
await engine.dispose()
|
||||
@ -0,0 +1,78 @@
|
||||
"""Unit tests for session_scope semantics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from everos.config import SqliteSettings
|
||||
from everos.core.persistence import (
|
||||
MemoryRoot,
|
||||
create_session_factory,
|
||||
create_system_engine,
|
||||
session_scope,
|
||||
)
|
||||
|
||||
|
||||
class _Sample(SQLModel, table=True):
|
||||
"""Tiny model used only by these tests."""
|
||||
|
||||
__tablename__ = "_sample_session_scope" # type: ignore[assignment]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
note: str
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
async def test_session_scope_commits_on_success(memory_root: MemoryRoot) -> None:
|
||||
engine = create_system_engine(memory_root.system_db, SqliteSettings())
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async with session_scope(factory) as s:
|
||||
s.add(_Sample(note="hello"))
|
||||
await s.commit()
|
||||
|
||||
async with session_scope(factory) as s:
|
||||
row = (
|
||||
await s.execute(text("SELECT note FROM _sample_session_scope"))
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "hello"
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_session_scope_rolls_back_on_exception(
|
||||
memory_root: MemoryRoot,
|
||||
) -> None:
|
||||
engine = create_system_engine(memory_root.system_db, SqliteSettings())
|
||||
factory = create_session_factory(engine)
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async with session_scope(factory) as s:
|
||||
s.add(_Sample(note="should rollback"))
|
||||
# No commit yet → scope must rollback on exception.
|
||||
raise RuntimeError("boom")
|
||||
|
||||
async with session_scope(factory) as s:
|
||||
count = (
|
||||
await s.execute(text("SELECT COUNT(*) FROM _sample_session_scope"))
|
||||
).fetchone()
|
||||
assert count is not None
|
||||
assert count[0] == 0
|
||||
finally:
|
||||
await engine.dispose()
|
||||
0
tests/unit/test_entrypoints/__init__.py
Normal file
0
tests/unit/test_entrypoints/__init__.py
Normal file
0
tests/unit/test_entrypoints/test_api/__init__.py
Normal file
0
tests/unit/test_entrypoints/test_api/__init__.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""``CascadeLifespanProvider`` — startup builds orchestrator, shutdown stops it."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from everos.entrypoints.api.lifespans import cascade as cascade_lifespan_mod
|
||||
from everos.entrypoints.api.lifespans.cascade import CascadeLifespanProvider
|
||||
|
||||
|
||||
class _StubOrchestrator:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self.start_calls = 0
|
||||
self.stop_calls = 0
|
||||
|
||||
async def start(self) -> None:
|
||||
self.start_calls += 1
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.stop_calls += 1
|
||||
|
||||
|
||||
def test_provider_metadata() -> None:
|
||||
p = CascadeLifespanProvider(order=42)
|
||||
assert p.name == "cascade"
|
||||
assert p.order == 42
|
||||
|
||||
|
||||
async def test_startup_constructs_and_starts_orchestrator(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__MODEL", "stub-model")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__BASE_URL", "http://stub.invalid/v1")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__API_KEY", "stub-key")
|
||||
|
||||
captured: list[_StubOrchestrator] = []
|
||||
|
||||
def fake_orch(**kwargs: object) -> _StubOrchestrator:
|
||||
o = _StubOrchestrator()
|
||||
captured.append(o)
|
||||
return o
|
||||
|
||||
monkeypatch.setattr(cascade_lifespan_mod, "CascadeOrchestrator", fake_orch)
|
||||
|
||||
p = CascadeLifespanProvider()
|
||||
result = await p.startup(FastAPI())
|
||||
assert len(captured) == 1
|
||||
assert result is captured[0]
|
||||
assert captured[0].start_calls == 1
|
||||
|
||||
|
||||
async def test_shutdown_without_startup_is_noop() -> None:
|
||||
p = CascadeLifespanProvider()
|
||||
await p.shutdown(FastAPI()) # must not raise
|
||||
|
||||
|
||||
async def test_shutdown_stops_orchestrator_and_clears_reference(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__MODEL", "stub-model")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__BASE_URL", "http://stub.invalid/v1")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__API_KEY", "stub-key")
|
||||
|
||||
captured: list[_StubOrchestrator] = []
|
||||
|
||||
def fake_orch(**kwargs: object) -> _StubOrchestrator:
|
||||
o = _StubOrchestrator()
|
||||
captured.append(o)
|
||||
return o
|
||||
|
||||
monkeypatch.setattr(cascade_lifespan_mod, "CascadeOrchestrator", fake_orch)
|
||||
|
||||
p = CascadeLifespanProvider()
|
||||
app = FastAPI()
|
||||
await p.startup(app)
|
||||
await p.shutdown(app)
|
||||
assert captured[0].stop_calls == 1
|
||||
# Second shutdown is a no-op (reference cleared).
|
||||
await p.shutdown(app)
|
||||
assert captured[0].stop_calls == 1
|
||||
@ -0,0 +1,45 @@
|
||||
"""LLMLifespanProvider — startup raises on missing credentials, otherwise resolves."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from everos.component.llm import LLMNotConfiguredError
|
||||
from everos.entrypoints.api.lifespans import LLMLifespanProvider
|
||||
|
||||
|
||||
async def test_startup_raises_when_credentials_missing() -> None:
|
||||
provider = LLMLifespanProvider()
|
||||
app = FastAPI()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.entrypoints.api.lifespans.llm.get_llm_client",
|
||||
side_effect=LLMNotConfiguredError("missing api_key"),
|
||||
),
|
||||
pytest.raises(LLMNotConfiguredError),
|
||||
):
|
||||
await provider.startup(app)
|
||||
|
||||
|
||||
async def test_startup_returns_client_when_configured() -> None:
|
||||
provider = LLMLifespanProvider()
|
||||
app = FastAPI()
|
||||
sentinel = object()
|
||||
|
||||
with patch(
|
||||
"everos.entrypoints.api.lifespans.llm.get_llm_client",
|
||||
return_value=sentinel,
|
||||
):
|
||||
result = await provider.startup(app)
|
||||
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
async def test_shutdown_is_noop() -> None:
|
||||
provider = LLMLifespanProvider()
|
||||
# Should not raise; the algo client is stateless.
|
||||
await provider.shutdown(FastAPI())
|
||||
@ -0,0 +1,34 @@
|
||||
"""OmeLifespanProvider — startup wires engine, shutdown stops it."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from everos.entrypoints.api.lifespans import OmeLifespanProvider
|
||||
|
||||
|
||||
async def test_lifespan_starts_and_stops_engine(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from everos.core.persistence import MemoryRoot
|
||||
|
||||
svc = importlib.import_module("everos.service.memorize")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(svc, "_ome_engine", None, raising=False)
|
||||
|
||||
provider = OmeLifespanProvider()
|
||||
app = FastAPI()
|
||||
|
||||
engine = await provider.startup(app)
|
||||
assert engine is not None
|
||||
assert engine._started is True # noqa: SLF001 — test introspection
|
||||
|
||||
await provider.shutdown(app)
|
||||
assert engine._started is False # noqa: SLF001
|
||||
@ -0,0 +1,72 @@
|
||||
"""SQLite + LanceDB lifespan providers — startup wires singletons, shutdown disposes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from everos.entrypoints.api.lifespans import (
|
||||
LanceDBLifespanProvider,
|
||||
SqliteLifespanProvider,
|
||||
)
|
||||
from everos.infra.persistence.lancedb import lancedb_manager
|
||||
from everos.infra.persistence.sqlite import sqlite_manager
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Redirect both managers at an isolated memory-root."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
sqlite_manager._engine = None
|
||||
sqlite_manager._session_factory = None
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await sqlite_manager.dispose_engine()
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
async def test_sqlite_provider_startup_builds_engine_and_creates_schema(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
provider = SqliteLifespanProvider()
|
||||
app = FastAPI()
|
||||
|
||||
engine = await provider.startup(app)
|
||||
|
||||
assert engine is sqlite_manager.get_engine() # singleton wired
|
||||
assert (
|
||||
tmp_path / ".index" / "sqlite" / "system.db"
|
||||
).exists() # schema create_all opened the file
|
||||
|
||||
|
||||
async def test_sqlite_provider_shutdown_disposes_singleton() -> None:
|
||||
provider = SqliteLifespanProvider()
|
||||
app = FastAPI()
|
||||
await provider.startup(app)
|
||||
assert sqlite_manager._engine is not None
|
||||
|
||||
await provider.shutdown(app)
|
||||
assert sqlite_manager._engine is None
|
||||
|
||||
|
||||
async def test_lancedb_provider_startup_opens_connection(tmp_path: Path) -> None:
|
||||
provider = LanceDBLifespanProvider()
|
||||
app = FastAPI()
|
||||
|
||||
conn = await provider.startup(app)
|
||||
|
||||
assert conn is await lancedb_manager.get_connection() # singleton wired
|
||||
assert (tmp_path / ".index" / "lancedb").is_dir()
|
||||
|
||||
|
||||
async def test_lancedb_provider_shutdown_disposes_singleton() -> None:
|
||||
provider = LanceDBLifespanProvider()
|
||||
app = FastAPI()
|
||||
await provider.startup(app)
|
||||
assert lancedb_manager._conn is not None
|
||||
|
||||
await provider.shutdown(app)
|
||||
assert lancedb_manager._conn is None
|
||||
@ -0,0 +1,157 @@
|
||||
"""422 validation paths for ``POST /api/v1/memory/get``.
|
||||
|
||||
These are route-layer error tests — they exercise:
|
||||
|
||||
- DTO-layer rejections (page_size cap, empty owner_id, missing /
|
||||
invalid memory_type, invalid sort_order, owner+memory_type mismatch)
|
||||
- service-layer ``compile_filters_for_get`` rejections (unknown filter
|
||||
field, malformed op shape)
|
||||
|
||||
No data is seeded; nothing reaches LanceDB. The full happy-path / data
|
||||
e2e suite (with seeded rows and 200 assertions) lives in
|
||||
``tests/integration/test_get_endpoint_e2e.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from everos.config import load_settings
|
||||
from everos.entrypoints.api.app import create_app
|
||||
from everos.infra.persistence.lancedb import lancedb_manager
|
||||
|
||||
# ``everos.service.__init__`` re-exports ``get`` shadowing the
|
||||
# submodule. Reach the real module via importlib so we can reset its
|
||||
# ``_manager`` lazy singleton.
|
||||
get_service_mod = import_module("everos.service.get")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> AsyncIterator[AsyncClient]:
|
||||
"""FastAPI app with no lifespan; resets get-path singletons per test."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
load_settings.cache_clear()
|
||||
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
get_service_mod._manager = None
|
||||
|
||||
app = create_app(lifespan_providers=[])
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
await lancedb_manager.dispose_connection()
|
||||
load_settings.cache_clear()
|
||||
|
||||
|
||||
# ── DTO-layer 422 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_page_size_above_cap_returns_422(client: AsyncClient) -> None:
|
||||
"""``page_size > 100`` violates the wiki cap → 422 at the DTO layer."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "episode",
|
||||
"page_size": 200,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_empty_user_id_returns_422(client: AsyncClient) -> None:
|
||||
"""``user_id`` carries ``min_length=1`` end-to-end."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "",
|
||||
"memory_type": "episode",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_missing_memory_type_returns_422(client: AsyncClient) -> None:
|
||||
"""Omitting the required ``memory_type`` field is rejected at the DTO layer."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={"user_id": "u1"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_invalid_memory_type_value_returns_422(client: AsyncClient) -> None:
|
||||
"""``memory_type`` outside the four-kind enum → 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "atomic_fact", # not a top-level kind
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_invalid_sort_order_returns_422(client: AsyncClient) -> None:
|
||||
"""``sort_order`` is a tight Literal — uppercase variant rejected."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "episode",
|
||||
"sort_order": "DESC",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_owner_memory_type_mismatch_returns_422(client: AsyncClient) -> None:
|
||||
"""``user`` + ``agent_case`` is a hard pydantic error."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "agent_case",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── service.compile_filters_for_get 422 ───────────────────────────────
|
||||
|
||||
|
||||
async def test_unknown_filter_field_returns_422(client: AsyncClient) -> None:
|
||||
"""A field outside ``ALLOWED_FIELDS`` surfaces as 422 from the adapter."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "episode",
|
||||
"filters": {"random_attr": "boom"},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
assert "unsupported" in resp.text
|
||||
|
||||
|
||||
async def test_malformed_filter_in_op_returns_422(client: AsyncClient) -> None:
|
||||
"""``in`` op with a scalar (not list) surfaces as 422 from the adapter."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/get",
|
||||
json={
|
||||
"user_id": "u1",
|
||||
"memory_type": "episode",
|
||||
"filters": {"session_id": {"in": "not_a_list"}},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
@ -0,0 +1,125 @@
|
||||
"""``GET /metrics`` — Prometheus exposition + middleware integration.
|
||||
|
||||
Verifies three contracts of the metrics path:
|
||||
|
||||
1. The route renders ``prometheus_client``-parseable exposition format.
|
||||
2. The ``PrometheusMiddleware`` actually bumps the per-route counter
|
||||
on a real round-trip (verified via before/after delta to avoid
|
||||
coupling to the global registry's cross-test accumulation).
|
||||
3. The ``_SKIP_PATHS`` set (``/metrics``, ``/health``) is honoured —
|
||||
those endpoints never appear in ``everos_http_requests_total``.
|
||||
|
||||
No lifespan / no LanceDB / no LLM needed — middleware lives at the ASGI
|
||||
layer above any of that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
|
||||
from everos.config import load_settings
|
||||
from everos.entrypoints.api.app import create_app
|
||||
|
||||
# ``prometheus_client.parser`` strips the ``_total`` counter suffix from
|
||||
# the *family* name but leaves *sample* names intact.
|
||||
_REQUESTS_FAMILY = "everos_http_requests"
|
||||
_REQUESTS_TOTAL = "everos_http_requests_total"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> AsyncIterator[AsyncClient]:
|
||||
"""FastAPI app with no lifespan; middleware stack is wired by ``create_app``."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
load_settings.cache_clear()
|
||||
|
||||
app = create_app(lifespan_providers=[])
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
load_settings.cache_clear()
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _counter_value(text: str, path: str, status: str) -> float:
|
||||
"""Sum ``everos_http_requests_total`` samples matching path + status."""
|
||||
total = 0.0
|
||||
for fam in text_string_to_metric_families(text):
|
||||
if fam.name != _REQUESTS_FAMILY:
|
||||
continue
|
||||
for s in fam.samples:
|
||||
if s.name != _REQUESTS_TOTAL:
|
||||
continue
|
||||
if s.labels.get("path") == path and s.labels.get("status") == status:
|
||||
total += s.value
|
||||
return total
|
||||
|
||||
|
||||
def _all_recorded_paths(text: str) -> set[str]:
|
||||
"""Set of ``path`` label values present in ``everos_http_requests_total``."""
|
||||
paths: set[str] = set()
|
||||
for fam in text_string_to_metric_families(text):
|
||||
if fam.name != _REQUESTS_FAMILY:
|
||||
continue
|
||||
for s in fam.samples:
|
||||
if s.name == _REQUESTS_TOTAL:
|
||||
paths.add(s.labels.get("path", ""))
|
||||
return paths
|
||||
|
||||
|
||||
# ── Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_metrics_endpoint_renders_prometheus_format(
|
||||
client: AsyncClient,
|
||||
) -> None:
|
||||
"""``GET /metrics`` returns parsable Prometheus exposition format."""
|
||||
resp = await client.get("/metrics")
|
||||
assert resp.status_code == 200
|
||||
assert "text/plain" in resp.headers.get("content-type", "")
|
||||
|
||||
# Must parse cleanly + expose the request counter family.
|
||||
families = {f.name for f in text_string_to_metric_families(resp.text)}
|
||||
assert _REQUESTS_FAMILY in families
|
||||
|
||||
|
||||
async def test_metrics_counter_increments_on_request(client: AsyncClient) -> None:
|
||||
"""A real route hit bumps ``everos_http_requests_total`` for that label triple.
|
||||
|
||||
Uses a 422 to avoid needing LanceDB — Pydantic rejects the empty
|
||||
body before the route handler runs, but the middleware still sees
|
||||
a completed request/response with ``status=422``.
|
||||
"""
|
||||
before_resp = await client.get("/metrics")
|
||||
before = _counter_value(before_resp.text, "/api/v1/memory/get", "422")
|
||||
|
||||
bad = await client.post("/api/v1/memory/get", json={})
|
||||
assert bad.status_code == 422
|
||||
|
||||
after_resp = await client.get("/metrics")
|
||||
after = _counter_value(after_resp.text, "/api/v1/memory/get", "422")
|
||||
|
||||
assert after - before == 1.0, f"counter not bumped: {before} → {after}"
|
||||
|
||||
|
||||
async def test_metrics_skip_paths_not_recorded(client: AsyncClient) -> None:
|
||||
"""``_SKIP_PATHS`` (``/metrics``, ``/health``) never appear in the counter."""
|
||||
# Hit both endpoints. If they were *not* skipped, they'd show up in
|
||||
# the next /metrics dump.
|
||||
await client.get("/health")
|
||||
await client.get("/metrics")
|
||||
|
||||
resp = await client.get("/metrics")
|
||||
recorded = _all_recorded_paths(resp.text)
|
||||
assert "/metrics" not in recorded, recorded
|
||||
assert "/health" not in recorded, recorded
|
||||
@ -0,0 +1,133 @@
|
||||
"""422 validation paths for ``POST /api/v1/memory/search``.
|
||||
|
||||
These exercise the request → DTO / route → service.compile_filters
|
||||
error paths *without* needing any seeded data or external services
|
||||
(no embedder / no LLM / no LanceDB rows). The full data-driven e2e
|
||||
suite lives in ``tests/integration/test_search_endpoint_e2e.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from everos.config import load_settings
|
||||
from everos.entrypoints.api.app import create_app
|
||||
from everos.infra.persistence.lancedb import lancedb_manager
|
||||
|
||||
search_service_mod = import_module("everos.service.search")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> AsyncIterator[AsyncClient]:
|
||||
"""FastAPI app with no lifespan; resets search singletons per test."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
load_settings.cache_clear()
|
||||
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
for attr in ("_manager", "_embedding", "_reranker", "_llm_client"):
|
||||
setattr(search_service_mod, attr, None)
|
||||
for attr in ("_embedding_resolved", "_rerank_resolved", "_llm_resolved"):
|
||||
setattr(search_service_mod, attr, False)
|
||||
|
||||
app = create_app(lifespan_providers=[])
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
await lancedb_manager.dispose_connection()
|
||||
load_settings.cache_clear()
|
||||
|
||||
|
||||
def _body(**overrides) -> dict:
|
||||
"""Minimal valid SearchRequest body; tests override one field to break it.
|
||||
|
||||
``method="keyword"`` is pinned because the SearchRequest DTO defaults
|
||||
to HYBRID, which ``SearchManager._validate_components`` rejects when
|
||||
no ``[embedding]`` provider is configured (the case in CI). Keyword
|
||||
needs no embedder, so DTO / compile_filters validation paths fire
|
||||
cleanly without external services — which is exactly what this file
|
||||
is supposed to exercise.
|
||||
"""
|
||||
base = {
|
||||
"user_id": "u1",
|
||||
"query": "hello",
|
||||
"method": "keyword",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
# ── DTO-layer 422 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_empty_query_returns_422(client: AsyncClient) -> None:
|
||||
"""``query`` carries ``min_length=1``."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(query=""))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_empty_user_id_returns_422(client: AsyncClient) -> None:
|
||||
"""``user_id`` carries ``min_length=1``."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(user_id=""))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_both_user_and_agent_id_returns_422(client: AsyncClient) -> None:
|
||||
"""Both ``user_id`` and ``agent_id`` set → xor validator rejects."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(agent_id="agent_x"))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_invalid_method_returns_422(client: AsyncClient) -> None:
|
||||
"""``method`` outside the SearchMethod enum → 422."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(method="bm42"))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_top_k_zero_returns_422(client: AsyncClient) -> None:
|
||||
"""``top_k=0`` violates the validator (must be -1 or 1..100)."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(top_k=0))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_top_k_above_cap_returns_422(client: AsyncClient) -> None:
|
||||
"""``top_k=101`` exceeds the 100 cap."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(top_k=101))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
async def test_radius_above_one_returns_422(client: AsyncClient) -> None:
|
||||
"""``radius`` is constrained to [0.0, 1.0]."""
|
||||
resp = await client.post("/api/v1/memory/search", json=_body(radius=1.5))
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── service.compile_filters 422 ───────────────────────────────────────
|
||||
|
||||
|
||||
async def test_unknown_filter_field_returns_422(client: AsyncClient) -> None:
|
||||
"""A field outside ``ALLOWED_FIELDS`` surfaces as 422 from the adapter."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/search",
|
||||
json=_body(filters={"random_attr": "boom"}),
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
assert "unsupported" in resp.text
|
||||
|
||||
|
||||
async def test_reserved_owner_id_in_filters_returns_422(client: AsyncClient) -> None:
|
||||
"""``owner_id`` is reserved at the top level — must not appear inside filters."""
|
||||
resp = await client.post(
|
||||
"/api/v1/memory/search",
|
||||
json=_body(filters={"owner_id": "spoof"}),
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
0
tests/unit/test_entrypoints/test_cli/__init__.py
Normal file
0
tests/unit/test_entrypoints/test_cli/__init__.py
Normal file
98
tests/unit/test_entrypoints/test_cli/test_cascade_command.py
Normal file
98
tests/unit/test_entrypoints/test_cli/test_cascade_command.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""``everos cascade`` — structural smoke + pure helper tests.
|
||||
|
||||
The orchestrator paths require live sqlite + lancedb singletons; those
|
||||
are exercised by integration tests. Here we cover:
|
||||
|
||||
- subcommand registration (sync / status / fix)
|
||||
- ``--help`` exit codes
|
||||
- ``_resolve_relative`` (path arithmetic vs. memory root)
|
||||
- ``_print_failed_table`` (formatting of failed rows)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from everos.entrypoints.cli.commands import cascade as cascade_mod
|
||||
|
||||
|
||||
def test_app_registers_three_commands() -> None:
|
||||
names = {cmd.name for cmd in cascade_mod.app.registered_commands}
|
||||
assert names == {"sync", "status", "fix"}
|
||||
|
||||
|
||||
def test_help_exits_zero() -> None:
|
||||
result = CliRunner().invoke(cascade_mod.app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "sync" in result.stdout
|
||||
assert "status" in result.stdout
|
||||
assert "fix" in result.stdout
|
||||
|
||||
|
||||
def test_resolve_relative_under_root(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
from everos.config import load_settings
|
||||
|
||||
load_settings.cache_clear()
|
||||
|
||||
rel = cascade_mod._resolve_relative(tmp_path / "users" / "u1" / "x.md")
|
||||
assert rel == "users/u1/x.md"
|
||||
|
||||
|
||||
def test_resolve_relative_outside_root_raises(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path / "memory"))
|
||||
from everos.config import load_settings
|
||||
|
||||
load_settings.cache_clear()
|
||||
|
||||
other = tmp_path / "somewhere-else.md"
|
||||
with pytest.raises(typer.BadParameter, match="not under memory root"):
|
||||
cascade_mod._resolve_relative(other)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FailedRow:
|
||||
md_path: str
|
||||
retryable: bool
|
||||
retry_count: int
|
||||
last_attempt_at: object
|
||||
error: str | None
|
||||
|
||||
|
||||
def test_print_failed_table_formats_rows(capsys: pytest.CaptureFixture[str]) -> None:
|
||||
from datetime import UTC, datetime
|
||||
|
||||
rows = [
|
||||
_FailedRow(
|
||||
md_path="users/u1/a.md",
|
||||
retryable=True,
|
||||
retry_count=2,
|
||||
last_attempt_at=datetime(2026, 1, 1, tzinfo=UTC),
|
||||
error="boom",
|
||||
),
|
||||
_FailedRow(
|
||||
md_path="users/u2/b.md",
|
||||
retryable=False,
|
||||
retry_count=5,
|
||||
last_attempt_at=None,
|
||||
error=None,
|
||||
),
|
||||
]
|
||||
cascade_mod._print_failed_table(rows) # type: ignore[arg-type]
|
||||
out = capsys.readouterr().out
|
||||
assert "2 failed row(s):" in out
|
||||
assert "users/u1/a.md" in out
|
||||
assert "TRUE" in out
|
||||
assert "users/u2/b.md" in out
|
||||
assert "FALSE" in out
|
||||
# Header row present
|
||||
assert "md_path" in out and "retries" in out
|
||||
213
tests/unit/test_entrypoints/test_cli/test_init_command.py
Normal file
213
tests/unit/test_entrypoints/test_cli/test_init_command.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""``everos init`` — CLI behavior + edge cases.
|
||||
|
||||
Covers:
|
||||
|
||||
- default ``./.env`` path, written with 0600 permissions
|
||||
- ``--to <path>`` creates parent dirs
|
||||
- ``--force`` overwrites; without it the command refuses with exit 1
|
||||
- ``--print`` writes to stdout, NOT to disk
|
||||
- ``--xdg`` and ``--to`` are mutually exclusive (exit 2)
|
||||
- ``--xdg`` honors ``XDG_CONFIG_HOME``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from everos.entrypoints.cli.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner() -> CliRunner:
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_tmp(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
"""Run from a fresh tmp cwd so default ``./.env`` lands in tmp_path."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_default_writes_dotenv_in_cwd(runner: CliRunner, in_tmp: Path) -> None:
|
||||
result = runner.invoke(app, ["init"])
|
||||
assert result.exit_code == 0, result.output
|
||||
written = in_tmp / ".env"
|
||||
assert written.exists()
|
||||
assert written.stat().st_size > 0
|
||||
assert "EVEROS_LLM__API_KEY" in written.read_text()
|
||||
|
||||
|
||||
def test_default_file_permissions_are_0600(runner: CliRunner, in_tmp: Path) -> None:
|
||||
"""The generated .env holds API keys — must not be world-readable."""
|
||||
result = runner.invoke(app, ["init"])
|
||||
assert result.exit_code == 0
|
||||
mode = stat.S_IMODE((in_tmp / ".env").stat().st_mode)
|
||||
assert mode == 0o600, f"expected 0o600, got {oct(mode)}"
|
||||
|
||||
|
||||
def test_refuses_overwrite_without_force(runner: CliRunner, in_tmp: Path) -> None:
|
||||
(in_tmp / ".env").write_text("PREEXISTING=1\n")
|
||||
result = runner.invoke(app, ["init"])
|
||||
assert result.exit_code == 1
|
||||
assert "already exists" in (result.output + (result.stderr or ""))
|
||||
# Original content must be preserved.
|
||||
assert (in_tmp / ".env").read_text() == "PREEXISTING=1\n"
|
||||
|
||||
|
||||
def test_force_overwrites(runner: CliRunner, in_tmp: Path) -> None:
|
||||
(in_tmp / ".env").write_text("PREEXISTING=1\n")
|
||||
result = runner.invoke(app, ["init", "--force"])
|
||||
assert result.exit_code == 0
|
||||
body = (in_tmp / ".env").read_text()
|
||||
assert "PREEXISTING=1" not in body
|
||||
assert "EVEROS_LLM__API_KEY" in body
|
||||
|
||||
|
||||
def test_to_creates_parent_dirs(runner: CliRunner, in_tmp: Path) -> None:
|
||||
target = in_tmp / "nested" / "subdir" / ".env"
|
||||
result = runner.invoke(app, ["init", "--to", str(target)])
|
||||
assert result.exit_code == 0
|
||||
assert target.exists()
|
||||
assert "EVEROS_LLM__API_KEY" in target.read_text()
|
||||
|
||||
|
||||
def test_print_writes_stdout_not_disk(runner: CliRunner, in_tmp: Path) -> None:
|
||||
result = runner.invoke(app, ["init", "--print"])
|
||||
assert result.exit_code == 0
|
||||
assert "EVEROS_LLM__API_KEY" in result.output
|
||||
# No disk side-effect.
|
||||
assert not (in_tmp / ".env").exists()
|
||||
|
||||
|
||||
def test_xdg_writes_to_xdg_config_home(
|
||||
runner: CliRunner, in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
xdg_root = in_tmp / "xdg"
|
||||
monkeypatch.setenv("XDG_CONFIG_HOME", str(xdg_root))
|
||||
result = runner.invoke(app, ["init", "--xdg"])
|
||||
assert result.exit_code == 0
|
||||
target = xdg_root / "everos" / ".env"
|
||||
assert target.exists()
|
||||
|
||||
|
||||
def test_xdg_falls_back_to_dot_config(
|
||||
runner: CliRunner, in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""No ``XDG_CONFIG_HOME`` → default ``~/.config``.
|
||||
|
||||
We sandbox ``$HOME`` to ``in_tmp`` so the test does not touch a real
|
||||
user's ``~/.config``.
|
||||
"""
|
||||
monkeypatch.delenv("XDG_CONFIG_HOME", raising=False)
|
||||
monkeypatch.setenv("HOME", str(in_tmp))
|
||||
result = runner.invoke(app, ["init", "--xdg"])
|
||||
assert result.exit_code == 0
|
||||
target = in_tmp / ".config" / "everos" / ".env"
|
||||
assert target.exists()
|
||||
|
||||
|
||||
def test_xdg_and_to_are_mutually_exclusive(runner: CliRunner, in_tmp: Path) -> None:
|
||||
result = runner.invoke(app, ["init", "--xdg", "--to", str(in_tmp / "other.env")])
|
||||
assert result.exit_code == 2
|
||||
assert "mutually exclusive" in (result.output + (result.stderr or ""))
|
||||
|
||||
|
||||
def test_template_resource_is_packaged_under_everos_templates() -> None:
|
||||
"""The packaged resource must remain at the canonical location.
|
||||
|
||||
Guards the wheel/sdist layout: ``init_cmd`` reads
|
||||
``everos.templates.env.template`` via ``importlib.resources``; if
|
||||
someone moves the file without updating ``_TEMPLATE_PACKAGE``, this
|
||||
test fails immediately.
|
||||
"""
|
||||
from importlib import resources
|
||||
|
||||
res = resources.files("everos.templates").joinpath("env.template")
|
||||
assert res.is_file()
|
||||
body = res.read_text(encoding="utf-8")
|
||||
assert "EVEROS_LLM__API_KEY" in body
|
||||
|
||||
|
||||
# ── 4-layer .env resolution for ``server start`` ────────────────────────
|
||||
|
||||
|
||||
def test_resolve_env_file_explicit_wins(in_tmp: Path) -> None:
|
||||
"""``--env-file <path>`` beats cwd / XDG / ~/.everos fallbacks."""
|
||||
from everos.entrypoints.cli.commands.server import _resolve_env_file
|
||||
|
||||
explicit = in_tmp / "explicit.env"
|
||||
explicit.write_text("X=1\n")
|
||||
# Also seed cwd .env so we can prove the explicit wins.
|
||||
(in_tmp / ".env").write_text("CWD=1\n")
|
||||
resolved = _resolve_env_file(str(explicit))
|
||||
assert resolved == explicit
|
||||
|
||||
|
||||
def test_resolve_env_file_cwd_wins_over_xdg(
|
||||
in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from everos.entrypoints.cli.commands.server import _resolve_env_file
|
||||
|
||||
xdg_root = in_tmp / "xdg"
|
||||
(xdg_root / "everos").mkdir(parents=True)
|
||||
(xdg_root / "everos" / ".env").write_text("XDG=1\n")
|
||||
monkeypatch.setenv("XDG_CONFIG_HOME", str(xdg_root))
|
||||
cwd_env = in_tmp / ".env"
|
||||
cwd_env.write_text("CWD=1\n")
|
||||
resolved = _resolve_env_file(None)
|
||||
assert resolved == cwd_env
|
||||
|
||||
|
||||
def test_resolve_env_file_xdg_when_no_cwd(
|
||||
in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from everos.entrypoints.cli.commands.server import _resolve_env_file
|
||||
|
||||
xdg_root = in_tmp / "xdg"
|
||||
(xdg_root / "everos").mkdir(parents=True)
|
||||
target = xdg_root / "everos" / ".env"
|
||||
target.write_text("XDG=1\n")
|
||||
monkeypatch.setenv("XDG_CONFIG_HOME", str(xdg_root))
|
||||
# No cwd/.env.
|
||||
resolved = _resolve_env_file(None)
|
||||
assert resolved == target
|
||||
|
||||
|
||||
def test_resolve_env_file_everos_home_fallback(
|
||||
in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""``~/.everos/.env`` is the last fallback when nothing else exists."""
|
||||
from everos.entrypoints.cli.commands.server import _resolve_env_file
|
||||
|
||||
monkeypatch.delenv("XDG_CONFIG_HOME", raising=False)
|
||||
monkeypatch.setenv("HOME", str(in_tmp))
|
||||
target = in_tmp / ".everos" / ".env"
|
||||
target.parent.mkdir(parents=True)
|
||||
target.write_text("EVEROS_ROOT=1\n")
|
||||
resolved = _resolve_env_file(None)
|
||||
assert resolved == target
|
||||
|
||||
|
||||
def test_resolve_env_file_none_when_no_layer_matches(
|
||||
in_tmp: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""All four layers absent → ``None`` (the server then falls back to
|
||||
inherited process env, which is the documented CI/container path)."""
|
||||
from everos.entrypoints.cli.commands.server import _resolve_env_file
|
||||
|
||||
monkeypatch.delenv("XDG_CONFIG_HOME", raising=False)
|
||||
monkeypatch.setenv("HOME", str(in_tmp))
|
||||
# Nothing in cwd, no XDG path, no ~/.everos/.
|
||||
assert not (in_tmp / ".env").exists()
|
||||
assert _resolve_env_file(None) is None
|
||||
|
||||
|
||||
# ``os`` imported above just to keep ruff from complaining; remove if Ruff
|
||||
# F401 hits.
|
||||
_ = os
|
||||
22
tests/unit/test_entrypoints/test_cli/test_main.py
Normal file
22
tests/unit/test_entrypoints/test_cli/test_main.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""CLI root app — verifies sub-typer wiring + ``--help`` exit code."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from everos.entrypoints.cli.main import app
|
||||
|
||||
|
||||
def test_help_exits_zero() -> None:
|
||||
result = CliRunner().invoke(app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "everos" in result.stdout
|
||||
assert "server" in result.stdout
|
||||
assert "cascade" in result.stdout
|
||||
|
||||
|
||||
def test_no_args_shows_help_and_exits_nonzero() -> None:
|
||||
# ``no_args_is_help=True`` triggers a help exit with code 2 (typer default).
|
||||
result = CliRunner().invoke(app, [])
|
||||
assert result.exit_code != 0
|
||||
assert "Usage" in result.stdout or "Usage" in result.stderr
|
||||
134
tests/unit/test_entrypoints/test_cli/test_server_command.py
Normal file
134
tests/unit/test_entrypoints/test_cli/test_server_command.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""``everos server start`` — argument resolution + uvicorn handoff.
|
||||
|
||||
Uvicorn ``run`` is the external boundary and is mocked. We assert the
|
||||
host/port/log_level resolution chain (CLI flag > env > default) and the
|
||||
KeyboardInterrupt / OSError exit paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from everos.entrypoints.cli.commands import server as server_mod
|
||||
from everos.entrypoints.cli.main import app as root_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def captured(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
"""Mock ``uvicorn.run`` and return the kwargs it was called with."""
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_run(*args: object, **kwargs: object) -> None:
|
||||
captured["args"] = args
|
||||
captured["kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(server_mod.uvicorn, "run", fake_run)
|
||||
# Strip env so default resolution path is deterministic.
|
||||
for k in ("EVEROS_HOST", "EVEROS_PORT", "EVEROS_LOG_LEVEL"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
return captured
|
||||
|
||||
|
||||
# Typer lifts single-command sub-apps to root; we invoke via the real
|
||||
# ``everos server start`` path through the assembled root app.
|
||||
|
||||
|
||||
def test_start_uses_default_host_port_log_level(captured: dict[str, object]) -> None:
|
||||
result = CliRunner().invoke(
|
||||
root_app, ["server", "start", "--env-file", "/nonexistent"]
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
kwargs = captured["kwargs"]
|
||||
assert isinstance(kwargs, dict)
|
||||
assert kwargs["host"] == "127.0.0.1"
|
||||
assert kwargs["port"] == 8000
|
||||
assert kwargs["log_level"] == "info"
|
||||
assert kwargs["factory"] is True
|
||||
args = captured["args"]
|
||||
assert args == ("everos.entrypoints.api.app:create_app",)
|
||||
|
||||
|
||||
def test_start_cli_flags_override_env(
|
||||
captured: dict[str, object], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_API__HOST", "1.2.3.4")
|
||||
monkeypatch.setenv("EVEROS_API__PORT", "9000")
|
||||
monkeypatch.setenv("EVEROS_API__LOG_LEVEL", "debug")
|
||||
result = CliRunner().invoke(
|
||||
root_app,
|
||||
[
|
||||
"server",
|
||||
"start",
|
||||
"--env-file",
|
||||
"/nonexistent",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8765",
|
||||
"--log-level",
|
||||
"warning",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
kwargs = captured["kwargs"]
|
||||
assert isinstance(kwargs, dict)
|
||||
assert kwargs["host"] == "127.0.0.1"
|
||||
assert kwargs["port"] == 8765
|
||||
assert kwargs["log_level"] == "warning"
|
||||
|
||||
|
||||
def test_start_falls_back_to_env_when_flags_omitted(
|
||||
captured: dict[str, object], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("EVEROS_API__HOST", "10.0.0.1")
|
||||
monkeypatch.setenv("EVEROS_API__PORT", "8765")
|
||||
result = CliRunner().invoke(
|
||||
root_app, ["server", "start", "--env-file", "/nonexistent"]
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
kwargs = captured["kwargs"]
|
||||
assert isinstance(kwargs, dict)
|
||||
assert kwargs["host"] == "10.0.0.1"
|
||||
assert kwargs["port"] == 8765
|
||||
|
||||
|
||||
def test_start_swallows_keyboard_interrupt(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def boom(*args: object, **kwargs: object) -> None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
monkeypatch.setattr(server_mod.uvicorn, "run", boom)
|
||||
result = CliRunner().invoke(
|
||||
root_app, ["server", "start", "--env-file", "/nonexistent"]
|
||||
)
|
||||
# KeyboardInterrupt path returns normally — exit 0.
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_start_exits_one_on_os_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def boom(*args: object, **kwargs: object) -> None:
|
||||
raise OSError("port in use")
|
||||
|
||||
monkeypatch.setattr(server_mod.uvicorn, "run", boom)
|
||||
result = CliRunner().invoke(
|
||||
root_app, ["server", "start", "--env-file", "/nonexistent"]
|
||||
)
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
def test_load_env_file_missing_path_is_noop(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
# Function should not raise when the file does not exist.
|
||||
server_mod._load_env_file(str(tmp_path / "does-not-exist.env"))
|
||||
|
||||
|
||||
def test_load_env_file_reads_present_file(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
monkeypatch.delenv("EVEROS_TEST_DOTENV_VAR", raising=False)
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("EVEROS_TEST_DOTENV_VAR=loaded\n")
|
||||
server_mod._load_env_file(str(env_file))
|
||||
import os
|
||||
|
||||
assert os.environ.get("EVEROS_TEST_DOTENV_VAR") == "loaded"
|
||||
monkeypatch.delenv("EVEROS_TEST_DOTENV_VAR", raising=False)
|
||||
0
tests/unit/test_infra/__init__.py
Normal file
0
tests/unit/test_infra/__init__.py
Normal file
0
tests/unit/test_infra/test_lancedb/__init__.py
Normal file
0
tests/unit/test_infra/test_lancedb/__init__.py
Normal file
72
tests/unit/test_infra/test_lancedb/test_lancedb_manager.py
Normal file
72
tests/unit/test_infra/test_lancedb/test_lancedb_manager.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""LanceDB manager singletons.
|
||||
|
||||
Verifies ``get_connection`` / ``get_table`` / ``dispose_connection``
|
||||
are idempotent and rebuild after dispose.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from lancedb.pydantic import Vector
|
||||
|
||||
from everos.core.persistence import BaseLanceTable
|
||||
from everos.infra.persistence.lancedb import lancedb_manager
|
||||
|
||||
|
||||
class _DemoVec(BaseLanceTable):
|
||||
"""Demo schema — only used by this test module."""
|
||||
|
||||
text: str
|
||||
vector: Vector(3) # type: ignore[valid-type]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Point the singleton at an isolated memory-root and reset module state."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
async def test_get_connection_is_singleton() -> None:
|
||||
c1 = await lancedb_manager.get_connection()
|
||||
c2 = await lancedb_manager.get_connection()
|
||||
assert c1 is c2
|
||||
|
||||
|
||||
async def test_get_table_creates_then_caches() -> None:
|
||||
t1 = await lancedb_manager.get_table("demo", _DemoVec)
|
||||
t2 = await lancedb_manager.get_table("demo", _DemoVec)
|
||||
assert t1 is t2
|
||||
assert "demo" in lancedb_manager._tables
|
||||
|
||||
|
||||
async def test_get_table_reopens_existing() -> None:
|
||||
"""A second connection cycle must reopen (not recreate) the table."""
|
||||
await lancedb_manager.get_table("demo", _DemoVec)
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
t = await lancedb_manager.get_table("demo", _DemoVec)
|
||||
assert t is not None
|
||||
# Round-trip a row to prove the schema survived.
|
||||
await t.add([_DemoVec(text="hello", vector=[0.1, 0.2, 0.3])])
|
||||
assert await t.count_rows() == 1
|
||||
|
||||
|
||||
async def test_dispose_resets_state() -> None:
|
||||
await lancedb_manager.get_connection()
|
||||
await lancedb_manager.get_table("demo", _DemoVec)
|
||||
await lancedb_manager.dispose_connection()
|
||||
assert lancedb_manager._conn is None
|
||||
assert lancedb_manager._tables == {}
|
||||
|
||||
|
||||
async def test_dispose_is_idempotent() -> None:
|
||||
await lancedb_manager.dispose_connection() # nothing built yet
|
||||
await lancedb_manager.get_connection()
|
||||
await lancedb_manager.dispose_connection()
|
||||
await lancedb_manager.dispose_connection() # second call must not raise
|
||||
@ -0,0 +1,153 @@
|
||||
"""Tests for :class:`everos.infra.persistence.lancedb._AgentSkillRepo`.
|
||||
|
||||
Real LanceDB under ``tmp_path`` (no mocks) — these tests exercise the
|
||||
SQL ``where`` predicate, cosine ``distance_type`` ranking, and
|
||||
``_distance`` stripping that the repo owns. Strategy-level routing
|
||||
across these methods is covered separately in
|
||||
``tests/unit/test_memory/test_strategies/test_extract_agent_skill.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.persistence.lancedb import (
|
||||
AgentSkill as LanceAgentSkill,
|
||||
)
|
||||
from everos.infra.persistence.lancedb import (
|
||||
agent_skill_repo,
|
||||
lancedb_manager,
|
||||
)
|
||||
|
||||
|
||||
def _skill_row(
|
||||
*,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
cluster_id: str,
|
||||
vector: list[float],
|
||||
) -> LanceAgentSkill:
|
||||
"""Minimal AgentSkill row sufficient to land in LanceDB for repo tests."""
|
||||
return LanceAgentSkill(
|
||||
id=f"{owner_id}_{name}",
|
||||
owner_id=owner_id,
|
||||
owner_type="agent",
|
||||
name=name,
|
||||
description=f"desc {name}",
|
||||
description_tokens=f"desc {name}",
|
||||
content=f"body of {name}",
|
||||
content_tokens=f"body of {name}",
|
||||
confidence=0.7,
|
||||
maturity_score=0.6,
|
||||
source_case_ids=[],
|
||||
cluster_id=cluster_id,
|
||||
md_path=f"agents/{owner_id}/skills/{name}/SKILL.md",
|
||||
content_sha256="x" * 64,
|
||||
vector=vector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def _real_lancedb(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Spin up a clean LanceDB rooted under ``tmp_path`` for one test."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
async def test_count_in_cluster_isolates_owner_and_cluster(
|
||||
_real_lancedb: None,
|
||||
) -> None:
|
||||
"""``count_in_cluster`` returns only rows matching both filters."""
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="s1", owner_id="a", cluster_id="cl_x", vector=[0.1] * 1024),
|
||||
_skill_row(name="s2", owner_id="a", cluster_id="cl_x", vector=[0.2] * 1024),
|
||||
_skill_row(
|
||||
name="other_cluster",
|
||||
owner_id="a",
|
||||
cluster_id="cl_y",
|
||||
vector=[0.3] * 1024,
|
||||
),
|
||||
_skill_row(
|
||||
name="other_owner",
|
||||
owner_id="b",
|
||||
cluster_id="cl_x",
|
||||
vector=[0.4] * 1024,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
await agent_skill_repo.count_in_cluster(owner_id="a", cluster_id="cl_x")
|
||||
) == 2
|
||||
|
||||
|
||||
async def test_find_in_cluster_returns_typed_rows_no_ranking(
|
||||
_real_lancedb: None,
|
||||
) -> None:
|
||||
"""Scalar fetch within one cluster; capped at ``limit`` regardless of order."""
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="s1", owner_id="a", cluster_id="cl_x", vector=[0.1] * 1024),
|
||||
_skill_row(name="s2", owner_id="a", cluster_id="cl_x", vector=[0.2] * 1024),
|
||||
_skill_row(name="s3", owner_id="a", cluster_id="cl_x", vector=[0.3] * 1024),
|
||||
_skill_row(
|
||||
name="other_cluster",
|
||||
owner_id="a",
|
||||
cluster_id="cl_y",
|
||||
vector=[0.4] * 1024,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
got = await agent_skill_repo.find_in_cluster(
|
||||
owner_id="a", cluster_id="cl_x", limit=2
|
||||
)
|
||||
assert len(got) == 2
|
||||
assert {s.name for s in got}.issubset({"s1", "s2", "s3"})
|
||||
assert all(s.owner_id == "a" and s.cluster_id == "cl_x" for s in got)
|
||||
|
||||
|
||||
async def test_find_topk_relevant_in_cluster_ranks_by_cosine(
|
||||
_real_lancedb: None,
|
||||
) -> None:
|
||||
"""LanceDB native ``nearest_to + distance_type('cosine')`` ordering."""
|
||||
near = [1.0] + [0.0] * 1023
|
||||
far = [0.0] * 1023 + [1.0]
|
||||
medium = [0.7, 0.7] + [0.0] * 1022
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="near", owner_id="a", cluster_id="cl_x", vector=near),
|
||||
_skill_row(name="far", owner_id="a", cluster_id="cl_x", vector=far),
|
||||
_skill_row(name="medium", owner_id="a", cluster_id="cl_x", vector=medium),
|
||||
# Different cluster — must not leak.
|
||||
_skill_row(name="other", owner_id="a", cluster_id="cl_y", vector=near),
|
||||
# Different owner — must not leak either.
|
||||
_skill_row(name="near", owner_id="b", cluster_id="cl_x", vector=near),
|
||||
]
|
||||
)
|
||||
|
||||
got = await agent_skill_repo.find_topk_relevant_in_cluster(
|
||||
owner_id="a", cluster_id="cl_x", query_vector=near, top_k=2
|
||||
)
|
||||
assert [s.name for s in got] == ["near", "medium"]
|
||||
|
||||
|
||||
async def test_find_topk_relevant_in_cluster_raises_on_empty_vector(
|
||||
_real_lancedb: None,
|
||||
) -> None:
|
||||
"""Empty ``query_vector`` is a caller-side error — the repo refuses."""
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="s1", owner_id="a", cluster_id="cl_x", vector=[0.1] * 1024),
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError, match="query_vector must be non-empty"):
|
||||
await agent_skill_repo.find_topk_relevant_in_cluster(
|
||||
owner_id="a", cluster_id="cl_x", query_vector=[], top_k=2
|
||||
)
|
||||
@ -0,0 +1,150 @@
|
||||
"""``content_sha256`` is a required field on every business lancedb table.
|
||||
|
||||
Cascade handler (16 doc §3.3) diffs by this digest to skip no-op
|
||||
re-embeds. Every business schema — including ``agent_skill`` — declares
|
||||
the field; daily-log kinds hash a per-handler subset of inline +
|
||||
section keys, agent_skill hashes the file-level content-bearing parts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.persistence.lancedb import (
|
||||
AgentCase,
|
||||
AgentSkill,
|
||||
AtomicFact,
|
||||
Episode,
|
||||
Foresight,
|
||||
)
|
||||
|
||||
_VEC = [0.0] * 1024
|
||||
_NOW = dt.datetime(2026, 5, 14, 10, 0, 0, tzinfo=dt.UTC)
|
||||
_SHA = "f" * 64
|
||||
|
||||
|
||||
def _episode() -> Episode:
|
||||
return Episode(
|
||||
id="u1_ep_1",
|
||||
entry_id="ep_20260514_0001",
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
session_id="s1",
|
||||
timestamp=_NOW,
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
sender_ids=["u1"],
|
||||
episode="hello world",
|
||||
episode_tokens="hello world",
|
||||
md_path="users/u1/episodes/episode-2026-05-14.md",
|
||||
content_sha256=_SHA,
|
||||
vector=_VEC,
|
||||
)
|
||||
|
||||
|
||||
def _atomic_fact() -> AtomicFact:
|
||||
return AtomicFact(
|
||||
id="u1_af_1",
|
||||
entry_id="af_20260514_0001",
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
session_id="s1",
|
||||
timestamp=_NOW,
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
sender_ids=["u1"],
|
||||
fact="x is y",
|
||||
fact_tokens="x is y",
|
||||
md_path="users/u1/.atomic_facts/atomic_fact-2026-05-14.md",
|
||||
content_sha256=_SHA,
|
||||
vector=_VEC,
|
||||
)
|
||||
|
||||
|
||||
def _foresight() -> Foresight:
|
||||
return Foresight(
|
||||
id="u1_fs_1",
|
||||
entry_id="fs_20260514_0001",
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
session_id="s1",
|
||||
timestamp=_NOW,
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
sender_ids=["u1"],
|
||||
foresight="user plans X",
|
||||
foresight_tokens="user plans X",
|
||||
md_path="users/u1/.foresights/foresight-2026-05-14.md",
|
||||
content_sha256=_SHA,
|
||||
vector=_VEC,
|
||||
)
|
||||
|
||||
|
||||
def _agent_case() -> AgentCase:
|
||||
return AgentCase(
|
||||
id="a1_ac_1",
|
||||
entry_id="ac_20260514_0001",
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
session_id="s1",
|
||||
timestamp=_NOW,
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
quality_score=0.9,
|
||||
task_intent="scan contract",
|
||||
task_intent_tokens="scan contract",
|
||||
approach="step 1; step 2",
|
||||
approach_tokens="step 1 step 2",
|
||||
md_path="agents/a1/.cases/agent_case-2026-05-14.md",
|
||||
content_sha256=_SHA,
|
||||
vector=_VEC,
|
||||
)
|
||||
|
||||
|
||||
def _agent_skill() -> AgentSkill:
|
||||
return AgentSkill(
|
||||
id="a1_demo_skill",
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
name="demo_skill",
|
||||
description="just a demo",
|
||||
description_tokens="just a demo",
|
||||
content="body content",
|
||||
content_tokens="body content",
|
||||
confidence=0.7,
|
||||
maturity_score=0.6,
|
||||
source_case_ids=[],
|
||||
md_path="agents/a1/agent_skills/demo_skill/SKILL.md",
|
||||
content_sha256=_SHA,
|
||||
vector=_VEC,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"factory",
|
||||
[_episode, _atomic_fact, _foresight, _agent_case, _agent_skill],
|
||||
ids=["episode", "atomic_fact", "foresight", "agent_case", "agent_skill"],
|
||||
)
|
||||
def test_content_sha256_round_trip(factory) -> None: # type: ignore[no-untyped-def]
|
||||
row = factory()
|
||||
assert row.content_sha256 == _SHA
|
||||
dumped = row.model_dump()
|
||||
assert dumped["content_sha256"] == _SHA
|
||||
restored = type(row).model_validate(dumped)
|
||||
assert restored.content_sha256 == _SHA
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"factory",
|
||||
[_episode, _atomic_fact, _foresight, _agent_case, _agent_skill],
|
||||
ids=["episode", "atomic_fact", "foresight", "agent_case", "agent_skill"],
|
||||
)
|
||||
def test_content_sha256_required(factory) -> None: # type: ignore[no-untyped-def]
|
||||
"""Dropping content_sha256 from the kwargs surfaces a ValidationError."""
|
||||
row = factory()
|
||||
kwargs = row.model_dump()
|
||||
del kwargs["content_sha256"]
|
||||
with pytest.raises(Exception): # noqa: B017,PT011
|
||||
type(row).model_validate(kwargs)
|
||||
0
tests/unit/test_infra/test_markdown/__init__.py
Normal file
0
tests/unit/test_infra/test_markdown/__init__.py
Normal file
104
tests/unit/test_infra/test_markdown/test_mds/test_agent_skill.py
Normal file
104
tests/unit/test_infra/test_markdown/test_mds/test_agent_skill.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Tests for :class:`AgentSkillFrontmatter` — the AgentSkill schema.
|
||||
|
||||
Lives under ``test_infra`` because :class:`AgentSkillFrontmatter` itself
|
||||
lives under ``infra/.../mds`` (it carries business fields + the
|
||||
directory-shape ClassVars). The schema-agnostic chassis tests live
|
||||
under ``test_core/test_persistence/test_markdown/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.infra.persistence.markdown import AgentSkillFrontmatter
|
||||
|
||||
|
||||
def _kwargs(**overrides: object) -> dict[str, object]:
|
||||
"""Minimal valid kwargs for AgentSkillFrontmatter."""
|
||||
base: dict[str, object] = {
|
||||
"id": "skill_contract_risk_scan",
|
||||
"agent_id": "agent_zhang_legal",
|
||||
"name": "contract_risk_scan",
|
||||
"description": "Scan a contract draft for risk clauses.",
|
||||
"confidence": 0.5,
|
||||
"maturity_score": 0.5,
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def test_skill_inherits_agent_scope() -> None:
|
||||
"""Skills always live under ``agents/`` — track + SCOPE_DIR confirm."""
|
||||
assert AgentSkillFrontmatter.SCOPE_DIR == "agents"
|
||||
fm = AgentSkillFrontmatter(**_kwargs()) # type: ignore[arg-type]
|
||||
assert fm.track == "agent"
|
||||
assert fm.type == "agent_skill"
|
||||
|
||||
|
||||
def test_skill_requires_name_and_description() -> None:
|
||||
"""Tier-1 prompt injection demands both fields — schema enforces."""
|
||||
bad = _kwargs()
|
||||
del bad["name"]
|
||||
with pytest.raises(ValidationError):
|
||||
AgentSkillFrontmatter(**bad) # type: ignore[arg-type]
|
||||
|
||||
bad = _kwargs()
|
||||
del bad["description"]
|
||||
with pytest.raises(ValidationError):
|
||||
AgentSkillFrontmatter(**bad) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_skill_requires_confidence_and_maturity_score() -> None:
|
||||
"""LLM-emitted score fields are required (no default)."""
|
||||
bad = _kwargs()
|
||||
del bad["confidence"]
|
||||
with pytest.raises(ValidationError):
|
||||
AgentSkillFrontmatter(**bad) # type: ignore[arg-type]
|
||||
|
||||
bad = _kwargs()
|
||||
del bad["maturity_score"]
|
||||
with pytest.raises(ValidationError):
|
||||
AgentSkillFrontmatter(**bad) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_skill_optional_fields_default() -> None:
|
||||
"""``source_case_ids`` defaults to empty list; ``cluster_id`` to None."""
|
||||
fm = AgentSkillFrontmatter(**_kwargs()) # type: ignore[arg-type]
|
||||
assert fm.source_case_ids == []
|
||||
assert fm.cluster_id is None
|
||||
|
||||
|
||||
def test_skill_lineage_fields_round_trip() -> None:
|
||||
"""``source_case_ids`` + ``cluster_id`` round-trip through model_dump."""
|
||||
fm = AgentSkillFrontmatter(
|
||||
**_kwargs(
|
||||
source_case_ids=["case_a", "case_b"],
|
||||
cluster_id="cl_x",
|
||||
), # type: ignore[arg-type]
|
||||
)
|
||||
dumped = fm.model_dump()
|
||||
assert dumped["source_case_ids"] == ["case_a", "case_b"]
|
||||
assert dumped["cluster_id"] == "cl_x"
|
||||
|
||||
|
||||
def test_skill_extra_fields_still_allowed() -> None:
|
||||
"""L2 system metadata (md_sha256 / last_indexed_at) rides along."""
|
||||
fm = AgentSkillFrontmatter(
|
||||
**_kwargs(
|
||||
md_sha256="deadbeef",
|
||||
last_indexed_at="2026-05-07T08:00:00Z",
|
||||
), # type: ignore[arg-type]
|
||||
)
|
||||
dumped = fm.model_dump()
|
||||
assert dumped["md_sha256"] == "deadbeef"
|
||||
assert dumped["last_indexed_at"] == "2026-05-07T08:00:00Z"
|
||||
|
||||
|
||||
def test_skill_directory_shape_classvars() -> None:
|
||||
"""Path-shape ClassVars pin the wiki layout for the writer/reader pair."""
|
||||
assert AgentSkillFrontmatter.SKILLS_CONTAINER_NAME == "skills"
|
||||
assert AgentSkillFrontmatter.SKILL_DIR_PREFIX == "skill_"
|
||||
assert AgentSkillFrontmatter.SKILL_MAIN_FILENAME == "SKILL.md"
|
||||
assert AgentSkillFrontmatter.SKILL_REFERENCES_DIR_NAME == "references"
|
||||
assert AgentSkillFrontmatter.SKILL_SCRIPTS_DIR_NAME == "scripts"
|
||||
@ -0,0 +1,30 @@
|
||||
"""Tests that every business frontmatter class reports the expected
|
||||
``path_glob()`` — the cascade scanner reads these to enumerate eligible
|
||||
files, so a wrong glob silently drops a whole kind from cascade.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.persistence.markdown import (
|
||||
AgentCaseDailyFrontmatter,
|
||||
AgentSkillFrontmatter,
|
||||
AtomicFactDailyFrontmatter,
|
||||
EpisodeDailyFrontmatter,
|
||||
ForesightDailyFrontmatter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("schema", "expected"),
|
||||
[
|
||||
(EpisodeDailyFrontmatter, "*/*/users/*/episodes/episode-*.md"),
|
||||
(AtomicFactDailyFrontmatter, "*/*/users/*/.atomic_facts/atomic_fact-*.md"),
|
||||
(ForesightDailyFrontmatter, "*/*/users/*/.foresights/foresight-*.md"),
|
||||
(AgentCaseDailyFrontmatter, "*/*/agents/*/.cases/agent_case-*.md"),
|
||||
(AgentSkillFrontmatter, "*/*/agents/*/skills/skill_*/SKILL.md"),
|
||||
],
|
||||
)
|
||||
def test_path_glob(schema: type, expected: str) -> None:
|
||||
assert schema.path_glob() == expected
|
||||
71
tests/unit/test_infra/test_markdown/test_mds/test_profile.py
Normal file
71
tests/unit/test_infra/test_markdown/test_mds/test_profile.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Tests for the profile frontmatter duck-typed shape.
|
||||
|
||||
Profile schemas have no shared base class — they only need a
|
||||
``PROFILE_FILENAME`` ClassVar plus inheritance from a scope mixin. This
|
||||
test exercises that contract via a local fixture class.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.core.persistence.markdown import UserScopedFrontmatter
|
||||
|
||||
|
||||
class _SampleUserProfileFM(UserScopedFrontmatter):
|
||||
"""Local fixture: a user-track profile schema."""
|
||||
|
||||
PROFILE_FILENAME: ClassVar[str] = "user.md"
|
||||
|
||||
type: Literal["sample_user_profile"] = "sample_user_profile"
|
||||
display_name: str
|
||||
bio: str
|
||||
interests: list[str] = []
|
||||
|
||||
|
||||
def test_schema_inherits_user_scope() -> None:
|
||||
fm = _SampleUserProfileFM(
|
||||
id="sample_user_profile_u_jason",
|
||||
type="sample_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason",
|
||||
bio="hiker.",
|
||||
)
|
||||
assert fm.track == "user"
|
||||
assert fm.SCOPE_DIR == "users"
|
||||
|
||||
|
||||
def test_profile_filename_classvar() -> None:
|
||||
"""Path-shape ClassVar is duck-typed onto the schema directly."""
|
||||
assert _SampleUserProfileFM.PROFILE_FILENAME == "user.md"
|
||||
|
||||
|
||||
def test_requires_display_name_and_bio() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
_SampleUserProfileFM( # type: ignore[call-arg]
|
||||
id="x",
|
||||
type="sample_user_profile",
|
||||
user_id="u_jason",
|
||||
bio="missing display_name",
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_SampleUserProfileFM( # type: ignore[call-arg]
|
||||
id="x",
|
||||
type="sample_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="missing bio",
|
||||
)
|
||||
|
||||
|
||||
def test_interests_default_empty() -> None:
|
||||
fm = _SampleUserProfileFM(
|
||||
id="x",
|
||||
type="sample_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason",
|
||||
bio="hiker.",
|
||||
)
|
||||
assert fm.interests == []
|
||||
@ -0,0 +1,129 @@
|
||||
"""Tests for :class:`AgentSkillReader` — typed read for the skill directory layout."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.persistence.markdown import (
|
||||
AgentSkillFrontmatter,
|
||||
AgentSkillReader,
|
||||
AgentSkillWriter,
|
||||
)
|
||||
|
||||
|
||||
def _make_fm(**overrides: object) -> AgentSkillFrontmatter:
|
||||
base: dict[str, object] = {
|
||||
"id": "agent_x_skill_alpha",
|
||||
"agent_id": "agent_x",
|
||||
"name": "alpha",
|
||||
"description": "A test skill.",
|
||||
"confidence": 0.5,
|
||||
"maturity_score": 0.5,
|
||||
}
|
||||
base.update(overrides)
|
||||
return AgentSkillFrontmatter(**base) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer(root: MemoryRoot) -> AgentSkillWriter:
|
||||
return AgentSkillWriter(root)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reader(root: MemoryRoot) -> AgentSkillReader:
|
||||
return AgentSkillReader(root)
|
||||
|
||||
|
||||
async def test_read_main_returns_typed_frontmatter_and_body(
|
||||
writer: AgentSkillWriter, reader: AgentSkillReader
|
||||
) -> None:
|
||||
fm_in = _make_fm(
|
||||
description="Contract risk scan.",
|
||||
confidence=0.88,
|
||||
maturity_score=0.82,
|
||||
source_case_ids=["case_a", "case_b"],
|
||||
)
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm_in, body="The body.")
|
||||
|
||||
out = await reader.read_main("agent_x", "alpha", schema=AgentSkillFrontmatter)
|
||||
assert out is not None
|
||||
fm_out, body = out
|
||||
assert isinstance(fm_out, AgentSkillFrontmatter)
|
||||
assert fm_out.name == "alpha"
|
||||
assert fm_out.source_case_ids == ["case_a", "case_b"]
|
||||
assert fm_out.confidence == 0.88
|
||||
assert fm_out.maturity_score == 0.82
|
||||
assert body == "The body."
|
||||
|
||||
|
||||
async def test_read_main_returns_none_when_missing(reader: AgentSkillReader) -> None:
|
||||
assert (
|
||||
await reader.read_main("agent_x", "ghost", schema=AgentSkillFrontmatter) is None
|
||||
)
|
||||
|
||||
|
||||
async def test_read_main_round_trip_through_extra_fields(
|
||||
writer: AgentSkillWriter, reader: AgentSkillReader
|
||||
) -> None:
|
||||
"""L2 / L4 ride-along fields survive a write+read cycle (extra="allow")."""
|
||||
fm_in = _make_fm(md_sha256="abc", custom_label="ride-along")
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm_in, body="b")
|
||||
out = await reader.read_main("agent_x", "alpha", schema=AgentSkillFrontmatter)
|
||||
assert out is not None
|
||||
fm_out, _ = out
|
||||
dumped = fm_out.model_dump()
|
||||
assert dumped["md_sha256"] == "abc"
|
||||
assert dumped["custom_label"] == "ride-along"
|
||||
|
||||
|
||||
async def test_read_main_validates_against_supplied_schema(
|
||||
writer: AgentSkillWriter, reader: AgentSkillReader
|
||||
) -> None:
|
||||
"""A stricter schema rejects loose existing data — proves typed parsing."""
|
||||
|
||||
class _StricterSkillFM(AgentSkillFrontmatter):
|
||||
# Required field with no default — written file lacks it.
|
||||
priority: int
|
||||
|
||||
fm_in = _make_fm()
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm_in, body="b")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
await reader.read_main("agent_x", "alpha", schema=_StricterSkillFM)
|
||||
|
||||
|
||||
async def test_read_reference_round_trip(
|
||||
writer: AgentSkillWriter, reader: AgentSkillReader
|
||||
) -> None:
|
||||
await writer.write_reference(
|
||||
"agent_x", "alpha", "termination", "## term clauses\n..."
|
||||
)
|
||||
content = await reader.read_reference("agent_x", "alpha", "termination")
|
||||
assert content == "## term clauses\n..."
|
||||
|
||||
|
||||
async def test_read_reference_returns_none_when_missing(
|
||||
reader: AgentSkillReader,
|
||||
) -> None:
|
||||
assert await reader.read_reference("agent_x", "alpha", "ghost") is None
|
||||
|
||||
|
||||
async def test_read_script_round_trip(
|
||||
writer: AgentSkillWriter, reader: AgentSkillReader
|
||||
) -> None:
|
||||
await writer.write_script("agent_x", "alpha", "redline.py", "print('hi')\n")
|
||||
content = await reader.read_script("agent_x", "alpha", "redline.py")
|
||||
assert content == "print('hi')"
|
||||
|
||||
|
||||
async def test_read_script_returns_none_when_missing(reader: AgentSkillReader) -> None:
|
||||
assert await reader.read_script("agent_x", "alpha", "ghost.py") is None
|
||||
182
tests/unit/test_infra/test_markdown/test_readers/test_base.py
Normal file
182
tests/unit/test_infra/test_markdown/test_readers/test_base.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Tests for ``BaseDailyReader`` chassis.
|
||||
|
||||
Symmetric to ``test_writers/test_base.py`` — exercises path resolution
|
||||
+ entry locating + structured-entry upgrading on a dummy schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import (
|
||||
EntryId,
|
||||
MemoryRoot,
|
||||
StructuredEntry,
|
||||
UserScopedFrontmatter,
|
||||
render_structured_entry,
|
||||
)
|
||||
from everos.infra.persistence.markdown.readers import BaseDailyReader
|
||||
from everos.infra.persistence.markdown.writers import BaseDailyWriter
|
||||
|
||||
|
||||
class _DemoFrontmatter(UserScopedFrontmatter):
|
||||
ENTRY_ID_PREFIX: ClassVar[str] = "demo"
|
||||
DIR_NAME: ClassVar[str] = "demos"
|
||||
FILE_PREFIX: ClassVar[str] = "demo"
|
||||
type: Literal["user_demo"] = "user_demo"
|
||||
|
||||
|
||||
class _DemoWriter(BaseDailyWriter):
|
||||
schema = _DemoFrontmatter
|
||||
|
||||
|
||||
class _DemoReader(BaseDailyReader):
|
||||
schema = _DemoFrontmatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
# ── construction ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_reader_rejects_missing_schema(root: MemoryRoot) -> None:
|
||||
class _NoSchemaReader(BaseDailyReader):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="schema"):
|
||||
_NoSchemaReader(root)
|
||||
|
||||
|
||||
def test_reader_rejects_schema_missing_classvars(root: MemoryRoot) -> None:
|
||||
class _IncompleteFrontmatter(UserScopedFrontmatter):
|
||||
# Missing DIR_NAME / FILE_PREFIX.
|
||||
type: Literal["incomplete"] = "incomplete"
|
||||
|
||||
class _IncompleteReader(BaseDailyReader):
|
||||
schema = _IncompleteFrontmatter
|
||||
|
||||
with pytest.raises(TypeError, match="missing ClassVar"):
|
||||
_IncompleteReader(root)
|
||||
|
||||
|
||||
# ── read_for ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_read_for_returns_none_when_file_missing(root: MemoryRoot) -> None:
|
||||
reader = _DemoReader(root)
|
||||
assert await reader.read_for("u_jason", dt.date(2026, 4, 22)) is None
|
||||
|
||||
|
||||
async def test_read_for_returns_parsed_when_file_exists(
|
||||
tmp_path: Path, root: MemoryRoot
|
||||
) -> None:
|
||||
writer = _DemoWriter(root)
|
||||
await writer.append("u_jason", "first body", date=dt.date(2026, 4, 22))
|
||||
|
||||
reader = _DemoReader(root)
|
||||
parsed = await reader.read_for("u_jason", dt.date(2026, 4, 22))
|
||||
assert parsed is not None
|
||||
assert len(parsed.entries) == 1
|
||||
assert parsed.entries[0].body == "first body"
|
||||
|
||||
|
||||
async def test_read_for_today_default(root: MemoryRoot) -> None:
|
||||
"""Omitting ``date`` falls back to today_with_timezone()."""
|
||||
writer = _DemoWriter(root)
|
||||
await writer.append("u_jason", "today body")
|
||||
|
||||
reader = _DemoReader(root)
|
||||
parsed = await reader.read_for("u_jason")
|
||||
assert parsed is not None
|
||||
assert parsed.entries[0].body == "today body"
|
||||
|
||||
|
||||
# ── find_entry ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_find_entry_resolves_file_from_entry_id(root: MemoryRoot) -> None:
|
||||
writer = _DemoWriter(root)
|
||||
await writer.append("u_jason", "alpha", date=dt.date(2026, 4, 22))
|
||||
await writer.append("u_jason", "beta", date=dt.date(2026, 4, 22))
|
||||
|
||||
reader = _DemoReader(root)
|
||||
e = await reader.find_entry("u_jason", "demo_20260422_00000002")
|
||||
assert e is not None
|
||||
assert e.id == "demo_20260422_00000002"
|
||||
assert e.body == "beta"
|
||||
|
||||
|
||||
async def test_find_entry_returns_none_when_file_missing(root: MemoryRoot) -> None:
|
||||
reader = _DemoReader(root)
|
||||
assert await reader.find_entry("u_jason", "demo_20260422_00000001") is None
|
||||
|
||||
|
||||
async def test_find_entry_returns_none_when_entry_missing(root: MemoryRoot) -> None:
|
||||
writer = _DemoWriter(root)
|
||||
await writer.append("u_jason", "only", date=dt.date(2026, 4, 22))
|
||||
|
||||
reader = _DemoReader(root)
|
||||
assert await reader.find_entry("u_jason", "demo_20260422_00000099") is None
|
||||
|
||||
|
||||
async def test_find_entry_accepts_entryid_object(root: MemoryRoot) -> None:
|
||||
writer = _DemoWriter(root)
|
||||
await writer.append("u_jason", "alpha", date=dt.date(2026, 4, 22))
|
||||
|
||||
reader = _DemoReader(root)
|
||||
eid = EntryId(prefix="demo", date=dt.date(2026, 4, 22), seq=1)
|
||||
e = await reader.find_entry("u_jason", eid)
|
||||
assert e is not None
|
||||
assert e.body == "alpha"
|
||||
|
||||
|
||||
# ── find_structured ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_find_structured_parses_audit_form(root: MemoryRoot) -> None:
|
||||
writer = _DemoWriter(root)
|
||||
body = render_structured_entry(
|
||||
header="demo_20260422_00000001",
|
||||
inline={"type": "demo", "user_id": "u_jason"},
|
||||
sections={"Body": "the body"},
|
||||
)
|
||||
await writer.append("u_jason", body, date=dt.date(2026, 4, 22))
|
||||
|
||||
reader = _DemoReader(root)
|
||||
structured = await reader.find_structured("u_jason", "demo_20260422_00000001")
|
||||
assert structured is not None
|
||||
assert isinstance(structured, StructuredEntry)
|
||||
assert structured.id == "demo_20260422_00000001"
|
||||
assert structured.header == "demo_20260422_00000001"
|
||||
assert structured.inline == {"type": "demo", "user_id": "u_jason"}
|
||||
assert structured.sections == {"Body": "the body"}
|
||||
|
||||
|
||||
async def test_find_structured_returns_none_when_missing(root: MemoryRoot) -> None:
|
||||
reader = _DemoReader(root)
|
||||
assert await reader.find_structured("u_jason", "demo_20260422_00000001") is None
|
||||
|
||||
|
||||
# ── path_for ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_path_for_matches_writer(tmp_path: Path, root: MemoryRoot) -> None:
|
||||
"""Reader and writer resolve to the same path for the same schema."""
|
||||
reader = _DemoReader(root)
|
||||
writer = _DemoWriter(root)
|
||||
d = dt.date(2026, 4, 22)
|
||||
assert reader.path_for("u_jason", d) == writer.path_for("u_jason", d)
|
||||
|
||||
|
||||
def test_path_for_does_not_create_files(tmp_path: Path, root: MemoryRoot) -> None:
|
||||
reader = _DemoReader(root)
|
||||
p = reader.path_for("u_jason", dt.date(2026, 4, 22))
|
||||
assert not p.exists()
|
||||
assert not (tmp_path / "users").exists()
|
||||
@ -0,0 +1,121 @@
|
||||
"""Tests for :class:`ProfileReader` — typed read for profile files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.core.persistence import MemoryRoot, UserScopedFrontmatter
|
||||
from everos.infra.persistence.markdown.readers import ProfileReader
|
||||
from everos.infra.persistence.markdown.writers import ProfileWriter
|
||||
|
||||
|
||||
class _UserProfileFM(UserScopedFrontmatter):
|
||||
PROFILE_FILENAME: ClassVar[str] = "user.md"
|
||||
type: Literal["demo_user_profile"] = "demo_user_profile"
|
||||
display_name: str = ""
|
||||
bio: str = ""
|
||||
interests: list[str] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer(root: MemoryRoot) -> ProfileWriter:
|
||||
return ProfileWriter(root)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reader(root: MemoryRoot) -> ProfileReader:
|
||||
return ProfileReader(root)
|
||||
|
||||
|
||||
async def test_read_returns_typed_frontmatter_and_body(
|
||||
writer: ProfileWriter, reader: ProfileReader
|
||||
) -> None:
|
||||
fm_in = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason",
|
||||
bio="weekend hiker.",
|
||||
interests=["hiking", "coffee"],
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm_in, body="The body.")
|
||||
|
||||
out = await reader.read("u_jason", schema=_UserProfileFM)
|
||||
assert out is not None
|
||||
fm_out, body = out
|
||||
assert isinstance(fm_out, _UserProfileFM)
|
||||
assert fm_out.display_name == "Jason"
|
||||
assert fm_out.interests == ["hiking", "coffee"]
|
||||
assert body == "The body."
|
||||
|
||||
|
||||
async def test_read_returns_none_when_missing(reader: ProfileReader) -> None:
|
||||
assert await reader.read("u_ghost", schema=_UserProfileFM) is None
|
||||
|
||||
|
||||
async def test_read_round_trip_through_extra_fields(
|
||||
writer: ProfileWriter, reader: ProfileReader
|
||||
) -> None:
|
||||
"""L2 / L4 ride-along fields survive a write+read cycle."""
|
||||
fm_in = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
md_sha256="abc", # extra
|
||||
custom_label="ride-along", # extra
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm_in, body="b")
|
||||
out = await reader.read("u_jason", schema=_UserProfileFM)
|
||||
assert out is not None
|
||||
fm_out, _ = out
|
||||
dumped = fm_out.model_dump()
|
||||
assert dumped["md_sha256"] == "abc"
|
||||
assert dumped["custom_label"] == "ride-along"
|
||||
|
||||
|
||||
async def test_read_validates_against_supplied_schema(
|
||||
writer: ProfileWriter, reader: ProfileReader
|
||||
) -> None:
|
||||
"""A stricter schema rejects loose existing data — proves typed parsing."""
|
||||
|
||||
class _StricterFM(UserScopedFrontmatter):
|
||||
PROFILE_FILENAME: ClassVar[str] = "user.md"
|
||||
type: Literal["demo_user_profile"] = "demo_user_profile"
|
||||
# Required field with no default — written file lacks it.
|
||||
priority: int
|
||||
|
||||
fm_in = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm_in, body="b")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
await reader.read("u_jason", schema=_StricterFM)
|
||||
|
||||
|
||||
def test_path_for_matches_writer(
|
||||
tmp_path: Path,
|
||||
writer: ProfileWriter,
|
||||
reader: ProfileReader,
|
||||
) -> None:
|
||||
"""Reader and writer resolve to the same path for the same schema."""
|
||||
assert reader.path_for("u_jason", schema=_UserProfileFM) == writer.path_for(
|
||||
"u_jason", schema=_UserProfileFM
|
||||
)
|
||||
|
||||
|
||||
def test_path_for_does_not_create_files(tmp_path: Path, reader: ProfileReader) -> None:
|
||||
p = reader.path_for("u_jason", schema=_UserProfileFM)
|
||||
assert not p.exists()
|
||||
assert not (tmp_path / "users").exists()
|
||||
@ -0,0 +1,147 @@
|
||||
"""Tests for :class:`AgentSkillWriter` — directory + progressive disclosure."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import MarkdownReader, MemoryRoot
|
||||
from everos.infra.persistence.markdown import (
|
||||
AgentSkillFrontmatter,
|
||||
AgentSkillWriter,
|
||||
)
|
||||
|
||||
|
||||
def _make_fm(**overrides: object) -> AgentSkillFrontmatter:
|
||||
"""Build an AgentSkillFrontmatter with sensible defaults for tests."""
|
||||
base: dict[str, object] = {
|
||||
"id": "agent_x_skill_alpha",
|
||||
"agent_id": "agent_x",
|
||||
"name": "alpha",
|
||||
"description": "A test skill.",
|
||||
"confidence": 0.5,
|
||||
"maturity_score": 0.5,
|
||||
}
|
||||
base.update(overrides)
|
||||
return AgentSkillFrontmatter(**base) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer(root: MemoryRoot) -> AgentSkillWriter:
|
||||
return AgentSkillWriter(root)
|
||||
|
||||
|
||||
async def test_write_main_creates_directory_layout(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
fm = _make_fm()
|
||||
path = await writer.write_main(
|
||||
"agent_x", "alpha", frontmatter=fm, body="Step 1: do thing."
|
||||
)
|
||||
expected = root.agents_dir() / "agent_x" / "skills" / "skill_alpha" / "SKILL.md"
|
||||
assert path == expected
|
||||
assert expected.is_file()
|
||||
|
||||
|
||||
async def test_write_main_writes_frontmatter_and_body(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
fm = _make_fm(
|
||||
description="Contract risk scan.",
|
||||
confidence=0.88,
|
||||
maturity_score=0.82,
|
||||
source_case_ids=["case_a", "case_b"],
|
||||
cluster_id="cl_x",
|
||||
)
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm, body="The body.")
|
||||
parsed = await MarkdownReader.read(
|
||||
root.agents_dir() / "agent_x" / "skills" / "skill_alpha" / "SKILL.md"
|
||||
)
|
||||
assert parsed.frontmatter["name"] == "alpha"
|
||||
assert parsed.frontmatter["description"] == "Contract risk scan."
|
||||
assert parsed.frontmatter["confidence"] == 0.88
|
||||
assert parsed.frontmatter["maturity_score"] == 0.82
|
||||
assert parsed.frontmatter["source_case_ids"] == ["case_a", "case_b"]
|
||||
assert parsed.frontmatter["cluster_id"] == "cl_x"
|
||||
assert parsed.body.rstrip("\n") == "The body."
|
||||
|
||||
|
||||
async def test_write_main_is_upsert_full_replace(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
"""Second call overwrites both frontmatter and body — no append."""
|
||||
fm1 = _make_fm(description="v1", maturity_score=0.4)
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm1, body="body v1")
|
||||
|
||||
fm2 = _make_fm(description="v2", maturity_score=0.7)
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm2, body="body v2")
|
||||
|
||||
parsed = await MarkdownReader.read(
|
||||
root.agents_dir() / "agent_x" / "skills" / "skill_alpha" / "SKILL.md"
|
||||
)
|
||||
assert parsed.frontmatter["description"] == "v2"
|
||||
assert parsed.frontmatter["maturity_score"] == 0.7
|
||||
assert parsed.body.rstrip("\n") == "body v2"
|
||||
# No "body v1" residue from the previous version.
|
||||
assert "body v1" not in parsed.body
|
||||
|
||||
|
||||
async def test_write_reference_uses_md_extension(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
path = await writer.write_reference(
|
||||
"agent_x", "alpha", "termination_clauses", "## Termination\n..."
|
||||
)
|
||||
expected = (
|
||||
root.agents_dir()
|
||||
/ "agent_x"
|
||||
/ "skills"
|
||||
/ "skill_alpha"
|
||||
/ "references"
|
||||
/ "termination_clauses.md"
|
||||
)
|
||||
assert path == expected
|
||||
assert path.read_text(encoding="utf-8").startswith("## Termination")
|
||||
|
||||
|
||||
async def test_write_script_keeps_full_filename(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
path = await writer.write_script("agent_x", "alpha", "redline.py", "print('hi')\n")
|
||||
expected = (
|
||||
root.agents_dir()
|
||||
/ "agent_x"
|
||||
/ "skills"
|
||||
/ "skill_alpha"
|
||||
/ "scripts"
|
||||
/ "redline.py"
|
||||
)
|
||||
assert path == expected
|
||||
assert path.read_text(encoding="utf-8") == "print('hi')\n"
|
||||
|
||||
|
||||
def test_main_path_does_not_create_anything(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
"""``main_path`` is a pure path resolver — no IO."""
|
||||
p = writer.main_path("agent_x", "alpha")
|
||||
assert p.name == "SKILL.md"
|
||||
assert not root.agents_dir().exists()
|
||||
|
||||
|
||||
async def test_write_main_normalises_trailing_newline(
|
||||
root: MemoryRoot, writer: AgentSkillWriter
|
||||
) -> None:
|
||||
"""Body without a trailing newline still ends in exactly one newline."""
|
||||
fm = _make_fm()
|
||||
await writer.write_main("agent_x", "alpha", frontmatter=fm, body="no-newline-end")
|
||||
text = (
|
||||
root.agents_dir() / "agent_x" / "skills" / "skill_alpha" / "SKILL.md"
|
||||
).read_text(encoding="utf-8")
|
||||
assert text.endswith("no-newline-end\n")
|
||||
182
tests/unit/test_infra/test_markdown/test_writers/test_base.py
Normal file
182
tests/unit/test_infra/test_markdown/test_writers/test_base.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Tests for ``BaseDailyWriter`` skeleton.
|
||||
|
||||
Uses a dummy ``UserScopedFrontmatter`` subclass to exercise the path
|
||||
resolution + entry-id construction + today-by-default logic without
|
||||
pulling in any concrete business schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.utils.datetime import today_with_timezone
|
||||
from everos.core.persistence import (
|
||||
AgentScopedFrontmatter,
|
||||
MarkdownReader,
|
||||
MemoryRoot,
|
||||
UserScopedFrontmatter,
|
||||
)
|
||||
from everos.infra.persistence.markdown.writers import BaseDailyWriter
|
||||
|
||||
|
||||
class _UserDemoFrontmatter(UserScopedFrontmatter):
|
||||
ENTRY_ID_PREFIX: ClassVar[str] = "demo"
|
||||
DIR_NAME: ClassVar[str] = "demos"
|
||||
FILE_PREFIX: ClassVar[str] = "demo"
|
||||
type: Literal["user_demo"] = "user_demo"
|
||||
|
||||
|
||||
class _AgentDemoFrontmatter(AgentScopedFrontmatter):
|
||||
ENTRY_ID_PREFIX: ClassVar[str] = "ademo"
|
||||
DIR_NAME: ClassVar[str] = "demos"
|
||||
FILE_PREFIX: ClassVar[str] = "demo"
|
||||
type: Literal["agent_demo"] = "agent_demo"
|
||||
|
||||
|
||||
class _UserDemoWriter(BaseDailyWriter):
|
||||
schema = _UserDemoFrontmatter
|
||||
|
||||
|
||||
class _AgentDemoWriter(BaseDailyWriter):
|
||||
schema = _AgentDemoFrontmatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
def test_constructor_rejects_missing_schema(root: MemoryRoot) -> None:
|
||||
class _NoSchema(BaseDailyWriter):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="schema"):
|
||||
_NoSchema(root)
|
||||
|
||||
|
||||
def test_constructor_rejects_schema_missing_classvars(root: MemoryRoot) -> None:
|
||||
class _IncompleteFrontmatter(UserScopedFrontmatter):
|
||||
# Missing ENTRY_ID_PREFIX / DIR_NAME / FILE_PREFIX.
|
||||
type: Literal["incomplete"] = "incomplete"
|
||||
|
||||
class _IncompleteWriter(BaseDailyWriter):
|
||||
schema = _IncompleteFrontmatter
|
||||
|
||||
with pytest.raises(TypeError, match="ENTRY_ID_PREFIX"):
|
||||
_IncompleteWriter(root)
|
||||
|
||||
|
||||
async def test_append_writes_to_user_track(root: MemoryRoot) -> None:
|
||||
writer = _UserDemoWriter(root)
|
||||
eid = await writer.append("u_jason", "first", date=dt.date(2026, 4, 22))
|
||||
assert eid.prefix == "demo"
|
||||
assert eid.date == dt.date(2026, 4, 22)
|
||||
assert eid.seq == 1
|
||||
expected = root.users_dir() / "u_jason" / "demos" / "demo-2026-04-22.md"
|
||||
assert expected.exists()
|
||||
parsed = await MarkdownReader.read(expected)
|
||||
assert parsed.entries[0].id == "demo_20260422_00000001"
|
||||
assert parsed.entries[0].body == "first"
|
||||
|
||||
|
||||
async def test_append_writes_to_agent_track(root: MemoryRoot) -> None:
|
||||
writer = _AgentDemoWriter(root)
|
||||
eid = await writer.append("agent_zhangsan", "trace", date=dt.date(2026, 4, 22))
|
||||
assert eid.prefix == "ademo"
|
||||
expected = root.agents_dir() / "agent_zhangsan" / "demos" / "demo-2026-04-22.md"
|
||||
assert expected.exists()
|
||||
|
||||
|
||||
async def test_append_increments_seq_across_calls(root: MemoryRoot) -> None:
|
||||
writer = _UserDemoWriter(root)
|
||||
eids = [
|
||||
await writer.append("u_jason", f"body {i}", date=dt.date(2026, 4, 22))
|
||||
for i in range(3)
|
||||
]
|
||||
assert [e.seq for e in eids] == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_append_date_defaults_to_today(root: MemoryRoot) -> None:
|
||||
"""Omitting ``date`` falls back to today_with_timezone()."""
|
||||
writer = _UserDemoWriter(root)
|
||||
eid = await writer.append("u_jason", "body")
|
||||
today = today_with_timezone()
|
||||
assert eid.date == today
|
||||
expected = root.users_dir() / "u_jason" / "demos" / f"demo-{today.isoformat()}.md"
|
||||
assert expected.exists()
|
||||
|
||||
|
||||
async def test_append_passes_frontmatter_updates(root: MemoryRoot) -> None:
|
||||
writer = _UserDemoWriter(root)
|
||||
await writer.append(
|
||||
"u_jason",
|
||||
"body",
|
||||
date=dt.date(2026, 4, 22),
|
||||
frontmatter_updates={"file_type": "user_demo_daily", "entry_count": 1},
|
||||
)
|
||||
path = root.users_dir() / "u_jason" / "demos" / "demo-2026-04-22.md"
|
||||
parsed = await MarkdownReader.read(path)
|
||||
assert parsed.frontmatter["file_type"] == "user_demo_daily"
|
||||
assert parsed.frontmatter["entry_count"] == 1
|
||||
|
||||
|
||||
async def test_current_count_hook_can_be_overridden(root: MemoryRoot) -> None:
|
||||
"""Subclass override of ``_current_count`` controls seq."""
|
||||
|
||||
class _ConstantCount(BaseDailyWriter):
|
||||
schema = _UserDemoFrontmatter
|
||||
|
||||
async def _current_count(self, path): # noqa: ANN001
|
||||
return 41 # always claim 41 existing entries
|
||||
|
||||
writer = _ConstantCount(root)
|
||||
eid = await writer.append("u_jason", "body", date=dt.date(2026, 4, 22))
|
||||
assert eid.seq == 42 # 41 + 1
|
||||
|
||||
|
||||
async def test_frontmatter_updates_hook_supplies_defaults(root: MemoryRoot) -> None:
|
||||
"""Subclass override of ``_frontmatter_updates`` populates frontmatter."""
|
||||
|
||||
class _WithDefaults(BaseDailyWriter):
|
||||
schema = _UserDemoFrontmatter
|
||||
|
||||
def _frontmatter_updates(self, scope_id, date, *, next_count): # noqa: ANN001
|
||||
return {
|
||||
"user_id": scope_id,
|
||||
"entry_count": next_count,
|
||||
"marker": "from-hook",
|
||||
}
|
||||
|
||||
writer = _WithDefaults(root)
|
||||
await writer.append("u_jason", "body", date=dt.date(2026, 4, 22))
|
||||
|
||||
path = root.users_dir() / "u_jason" / "demos" / "demo-2026-04-22.md"
|
||||
parsed = await MarkdownReader.read(path)
|
||||
assert parsed.frontmatter["marker"] == "from-hook"
|
||||
assert parsed.frontmatter["entry_count"] == 1
|
||||
assert parsed.frontmatter["user_id"] == "u_jason"
|
||||
|
||||
|
||||
async def test_explicit_frontmatter_updates_skip_hook(root: MemoryRoot) -> None:
|
||||
"""Caller-supplied ``frontmatter_updates`` overrides the hook entirely."""
|
||||
|
||||
class _WithDefaults(BaseDailyWriter):
|
||||
schema = _UserDemoFrontmatter
|
||||
|
||||
def _frontmatter_updates(self, scope_id, date, *, next_count): # noqa: ANN001
|
||||
return {"marker": "from-hook"}
|
||||
|
||||
writer = _WithDefaults(root)
|
||||
await writer.append(
|
||||
"u_jason",
|
||||
"body",
|
||||
date=dt.date(2026, 4, 22),
|
||||
frontmatter_updates={"marker": "explicit"},
|
||||
)
|
||||
path = root.users_dir() / "u_jason" / "demos" / "demo-2026-04-22.md"
|
||||
parsed = await MarkdownReader.read(path)
|
||||
assert parsed.frontmatter["marker"] == "explicit"
|
||||
@ -0,0 +1,344 @@
|
||||
"""Tests for AtomicFact / Foresight / AgentCase daily-log writers.
|
||||
|
||||
The 4 daily-log kinds (episode + these 3) all share ``BaseDailyWriter``
|
||||
plumbing — exhaustive chassis tests live in ``test_base.py`` and
|
||||
``test_episode_writer.py`` indirectly via the e2e flows. Here we focus
|
||||
on the per-kind path resolution + frontmatter shape that each
|
||||
subclass owns: ``schema``, ``_frontmatter_updates``, and the
|
||||
writer ↔ reader round-trip on a fresh tmp memory_root.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import MarkdownReader, MemoryRoot
|
||||
from everos.infra.persistence.markdown import (
|
||||
AgentCaseReader,
|
||||
AgentCaseWriter,
|
||||
AtomicFactReader,
|
||||
AtomicFactWriter,
|
||||
ForesightReader,
|
||||
ForesightWriter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
# ── AtomicFact ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_atomic_fact_writer_round_trip(memory_root: MemoryRoot) -> None:
|
||||
writer = AtomicFactWriter(memory_root)
|
||||
today = _dt.date(2026, 5, 15)
|
||||
eid = await writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"sender_ids": ["u1"],
|
||||
},
|
||||
sections={"Fact": "Alice prefers Italian."},
|
||||
date=today,
|
||||
)
|
||||
path = (
|
||||
memory_root.users_dir() / "u1" / ".atomic_facts" / "atomic_fact-2026-05-15.md"
|
||||
)
|
||||
parsed = await MarkdownReader.read(path)
|
||||
|
||||
# frontmatter
|
||||
fm = parsed.frontmatter
|
||||
assert fm["id"] == "atomic_fact_log_u1_2026-05-15"
|
||||
assert fm["type"] == "atomic_fact_daily"
|
||||
assert fm["file_type"] == "atomic_fact_daily"
|
||||
assert fm["user_id"] == "u1"
|
||||
assert fm["track"] == "user"
|
||||
assert fm["date"] == "2026-05-15"
|
||||
assert fm["entry_count"] == 1
|
||||
|
||||
# entry body
|
||||
assert len(parsed.entries) == 1
|
||||
entry = parsed.entries[0]
|
||||
assert entry.id == eid.format()
|
||||
structured = entry.as_structured()
|
||||
assert structured.inline["owner_id"] == "u1"
|
||||
assert structured.inline["parent_id"] == "mc_1"
|
||||
assert structured.sections["Fact"] == "Alice prefers Italian."
|
||||
|
||||
# reader is symmetric
|
||||
reader = AtomicFactReader(memory_root)
|
||||
assert reader.path_for("u1", today) == path
|
||||
found = await reader.find_structured("u1", eid)
|
||||
assert found is not None
|
||||
assert found.sections["Fact"] == "Alice prefers Italian."
|
||||
|
||||
|
||||
async def test_atomic_fact_writer_appends_multiple(memory_root: MemoryRoot) -> None:
|
||||
writer = AtomicFactWriter(memory_root)
|
||||
today = _dt.date(2026, 5, 15)
|
||||
eid1 = await writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
},
|
||||
sections={"Fact": "fact 1"},
|
||||
date=today,
|
||||
)
|
||||
eid2 = await writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T11:00:00+00:00",
|
||||
"parent_id": "mc_2",
|
||||
},
|
||||
sections={"Fact": "fact 2"},
|
||||
date=today,
|
||||
)
|
||||
assert eid1.format() != eid2.format()
|
||||
assert eid2.format().endswith("0002")
|
||||
|
||||
|
||||
# ── Foresight ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_foresight_writer_round_trip(memory_root: MemoryRoot) -> None:
|
||||
writer = ForesightWriter(memory_root)
|
||||
today = _dt.date(2026, 5, 15)
|
||||
eid = await writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"start_time": "2026-05-15T12:00:00+00:00",
|
||||
"end_time": "2026-05-15T13:00:00+00:00",
|
||||
"duration_days": 1,
|
||||
},
|
||||
sections={
|
||||
"Foresight": "User will book lunch at noon.",
|
||||
"Evidence": "Past calendar pattern.",
|
||||
},
|
||||
date=today,
|
||||
)
|
||||
path = memory_root.users_dir() / "u1" / ".foresights" / "foresight-2026-05-15.md"
|
||||
parsed = await MarkdownReader.read(path)
|
||||
fm = parsed.frontmatter
|
||||
assert fm["id"] == "foresight_log_u1_2026-05-15"
|
||||
assert fm["type"] == "foresight_daily"
|
||||
|
||||
structured = parsed.entries[0].as_structured()
|
||||
assert structured.sections["Foresight"] == "User will book lunch at noon."
|
||||
assert structured.sections["Evidence"] == "Past calendar pattern."
|
||||
assert structured.inline["duration_days"] == "1"
|
||||
assert structured.inline["start_time"].startswith("2026-05-15T12:00:00")
|
||||
|
||||
reader = ForesightReader(memory_root)
|
||||
found = await reader.find_structured("u1", eid)
|
||||
assert found is not None
|
||||
assert found.sections["Evidence"] == "Past calendar pattern."
|
||||
|
||||
|
||||
# ── AgentCase ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agent_case_writer_round_trip(memory_root: MemoryRoot) -> None:
|
||||
writer = AgentCaseWriter(memory_root)
|
||||
today = _dt.date(2026, 5, 15)
|
||||
eid = await writer.append_entry(
|
||||
"a1",
|
||||
inline={
|
||||
"owner_id": "a1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_agent",
|
||||
"quality_score": 0.87,
|
||||
},
|
||||
sections={
|
||||
"TaskIntent": "Scan contract for indemnity gaps.",
|
||||
"Approach": "1. read sections;\n2. flag clauses;\n3. cross-check cap.",
|
||||
"KeyInsight": "Indemnity cap missing in section 4.",
|
||||
},
|
||||
date=today,
|
||||
)
|
||||
path = memory_root.agents_dir() / "a1" / ".cases" / "agent_case-2026-05-15.md"
|
||||
parsed = await MarkdownReader.read(path)
|
||||
fm = parsed.frontmatter
|
||||
assert fm["id"] == "agent_case_log_a1_2026-05-15"
|
||||
assert fm["type"] == "agent_case_daily"
|
||||
assert fm["agent_id"] == "a1"
|
||||
assert fm["track"] == "agent"
|
||||
|
||||
structured = parsed.entries[0].as_structured()
|
||||
assert structured.inline["quality_score"] == "0.87"
|
||||
assert structured.sections["TaskIntent"].startswith("Scan contract")
|
||||
assert structured.sections["Approach"].startswith("1. read sections")
|
||||
assert structured.sections["KeyInsight"].startswith("Indemnity cap missing")
|
||||
|
||||
reader = AgentCaseReader(memory_root)
|
||||
assert reader.path_for("a1", today) == path
|
||||
found = await reader.find_structured("a1", eid)
|
||||
assert found is not None
|
||||
assert found.sections["TaskIntent"].startswith("Scan contract")
|
||||
|
||||
|
||||
# ── round-trip with cascade handler (md → LanceDB row mapping) ─────────────
|
||||
|
||||
|
||||
async def test_atomic_fact_writer_output_feeds_handler(
|
||||
memory_root: MemoryRoot,
|
||||
) -> None:
|
||||
"""The writer's md is exactly what AtomicFactHandler expects to read."""
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.memory.cascade.handlers import AtomicFactHandler, HandlerDeps
|
||||
from everos.memory.cascade.handlers._daily_log_base import ParsedEntry
|
||||
|
||||
class _T(Tokenizer):
|
||||
def tokenize(self, t): # type: ignore[no-untyped-def]
|
||||
return [x for x in t.split() if x]
|
||||
|
||||
def tokenize_batch(self, ts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(x) for x in ts]
|
||||
|
||||
class _E(EmbeddingProvider):
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, t): # type: ignore[no-untyped-def]
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, ts): # type: ignore[no-untyped-def]
|
||||
return [await self.embed(x) for x in ts]
|
||||
|
||||
today = _dt.date(2026, 5, 15)
|
||||
eid = await AtomicFactWriter(memory_root).append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"sender_ids": ["u1"],
|
||||
},
|
||||
sections={"Fact": "Alice prefers Italian."},
|
||||
date=today,
|
||||
)
|
||||
path = (
|
||||
memory_root.users_dir() / "u1" / ".atomic_facts" / "atomic_fact-2026-05-15.md"
|
||||
)
|
||||
rel = path.relative_to(memory_root.root).as_posix()
|
||||
parsed = await MarkdownReader.read(path)
|
||||
entry = parsed.entries[0]
|
||||
handler = AtomicFactHandler(
|
||||
HandlerDeps(memory_root=memory_root, embedder=_E(), tokenizer=_T())
|
||||
)
|
||||
structured = entry.as_structured()
|
||||
pe = ParsedEntry(entry.id, structured, handler._content_sha256(structured))
|
||||
row = await handler._build_row(
|
||||
owner_id="u1", owner_type="user", md_path=rel, entry=pe
|
||||
)
|
||||
assert row.id == f"u1_{eid.format()}"
|
||||
assert row.fact == "Alice prefers Italian."
|
||||
assert row.parent_id == "mc_1"
|
||||
assert row.sender_ids == ["u1"]
|
||||
assert len(row.vector) == 1024
|
||||
|
||||
|
||||
# ── Display-tz contract for frontmatter timestamps (Gap #5) ────────────
|
||||
|
||||
|
||||
async def test_atomic_fact_frontmatter_last_appended_at_carries_display_tz_offset(
|
||||
memory_root: MemoryRoot,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``last_appended_at`` in markdown frontmatter renders in the display tz.
|
||||
|
||||
Markdown frontmatter is a display-side artefact (users read the file
|
||||
directly), so ``last_appended_at`` must use
|
||||
:func:`get_now_with_timezone` not :func:`get_utc_now`. Pins that
|
||||
contract end-to-end: configure ``EVEROS_MEMORY__TIMEZONE=Asia/Shanghai``,
|
||||
write an entry, read the .md file, assert the literal string ends
|
||||
with ``+08:00``.
|
||||
|
||||
Repeats the same check for ``ForesightWriter`` and
|
||||
``AgentCaseWriter`` — they share ``BaseDailyWriter`` plumbing so a
|
||||
regression on one would likely affect all three, but pinning each
|
||||
rules out per-subclass shadowing of ``_frontmatter_updates``.
|
||||
"""
|
||||
from everos.component.utils import datetime as _dt_module
|
||||
from everos.config import load_settings
|
||||
|
||||
monkeypatch.setenv("EVEROS_MEMORY__TIMEZONE", "Asia/Shanghai")
|
||||
load_settings.cache_clear()
|
||||
_dt_module._display_tz.cache_clear()
|
||||
|
||||
today = _dt.date(2026, 5, 15)
|
||||
|
||||
# AtomicFact
|
||||
af_writer = AtomicFactWriter(memory_root)
|
||||
await af_writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"sender_ids": ["u1"],
|
||||
},
|
||||
sections={"Fact": "x"},
|
||||
date=today,
|
||||
)
|
||||
af_path = (
|
||||
memory_root.users_dir() / "u1" / ".atomic_facts" / "atomic_fact-2026-05-15.md"
|
||||
)
|
||||
af_fm = (await MarkdownReader.read(af_path)).frontmatter
|
||||
assert af_fm["last_appended_at"].endswith("+08:00"), af_fm["last_appended_at"]
|
||||
|
||||
# Foresight
|
||||
fs_writer = ForesightWriter(memory_root)
|
||||
await fs_writer.append_entry(
|
||||
"u1",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"scope": "today",
|
||||
"horizon_days": 1,
|
||||
},
|
||||
sections={"Foresight": "x"},
|
||||
date=today,
|
||||
)
|
||||
fs_path = memory_root.users_dir() / "u1" / ".foresights" / "foresight-2026-05-15.md"
|
||||
fs_fm = (await MarkdownReader.read(fs_path)).frontmatter
|
||||
assert fs_fm["last_appended_at"].endswith("+08:00"), fs_fm["last_appended_at"]
|
||||
|
||||
# AgentCase
|
||||
ac_writer = AgentCaseWriter(memory_root)
|
||||
await ac_writer.append_entry(
|
||||
"a1",
|
||||
inline={
|
||||
"owner_id": "a1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-15T10:00:00+00:00",
|
||||
"quality_score": 0.9,
|
||||
},
|
||||
sections={"Task intent": "x", "Approach": "y"},
|
||||
date=today,
|
||||
)
|
||||
ac_path = memory_root.agents_dir() / "a1" / ".cases" / "agent_case-2026-05-15.md"
|
||||
ac_fm = (await MarkdownReader.read(ac_path)).frontmatter
|
||||
assert ac_fm["last_appended_at"].endswith("+08:00"), ac_fm["last_appended_at"]
|
||||
@ -0,0 +1,166 @@
|
||||
"""Tests for :class:`ProfileWriter` — single-file rewrite layout."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import (
|
||||
AgentScopedFrontmatter,
|
||||
BaseFrontmatter,
|
||||
MarkdownReader,
|
||||
MemoryRoot,
|
||||
UserScopedFrontmatter,
|
||||
)
|
||||
from everos.infra.persistence.markdown.writers import ProfileWriter
|
||||
|
||||
|
||||
class _UserProfileFM(UserScopedFrontmatter):
|
||||
PROFILE_FILENAME: ClassVar[str] = "user.md"
|
||||
type: Literal["demo_user_profile"] = "demo_user_profile"
|
||||
display_name: str = ""
|
||||
bio: str = ""
|
||||
|
||||
|
||||
class _AgentProfileFM(AgentScopedFrontmatter):
|
||||
PROFILE_FILENAME: ClassVar[str] = "agent.md"
|
||||
type: Literal["demo_agent_profile"] = "demo_agent_profile"
|
||||
name: str = ""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def root(tmp_path: Path) -> MemoryRoot:
|
||||
return MemoryRoot(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer(root: MemoryRoot) -> ProfileWriter:
|
||||
return ProfileWriter(root)
|
||||
|
||||
|
||||
async def test_write_creates_user_profile(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
fm = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason",
|
||||
bio="hiker.",
|
||||
)
|
||||
path = await writer.write("u_jason", frontmatter=fm, body="Long-form profile.")
|
||||
expected = root.users_dir() / "u_jason" / "user.md"
|
||||
assert path == expected
|
||||
assert expected.is_file()
|
||||
|
||||
|
||||
async def test_write_creates_agent_profile(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
fm = _AgentProfileFM(
|
||||
id="demo_agent_profile_agent_x",
|
||||
type="demo_agent_profile",
|
||||
agent_id="agent_x",
|
||||
name="zhang_legal",
|
||||
)
|
||||
path = await writer.write("agent_x", frontmatter=fm, body="Agent playbook.")
|
||||
expected = root.agents_dir() / "agent_x" / "agent.md"
|
||||
assert path == expected
|
||||
assert expected.is_file()
|
||||
|
||||
|
||||
async def test_write_writes_frontmatter_and_body(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
fm = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason",
|
||||
bio="weekend hiker.",
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm, body="The body.")
|
||||
|
||||
parsed = await MarkdownReader.read(root.users_dir() / "u_jason" / "user.md")
|
||||
assert parsed.frontmatter["display_name"] == "Jason"
|
||||
assert parsed.frontmatter["bio"] == "weekend hiker."
|
||||
assert parsed.body.rstrip("\n") == "The body."
|
||||
|
||||
|
||||
async def test_write_is_upsert_full_replace(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
"""Second call overwrites both frontmatter and body — no append."""
|
||||
fm1 = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason v1",
|
||||
bio="v1",
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm1, body="body v1")
|
||||
|
||||
fm2 = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
display_name="Jason v2",
|
||||
bio="v2",
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm2, body="body v2")
|
||||
|
||||
parsed = await MarkdownReader.read(root.users_dir() / "u_jason" / "user.md")
|
||||
assert parsed.frontmatter["display_name"] == "Jason v2"
|
||||
assert parsed.frontmatter["bio"] == "v2"
|
||||
assert parsed.body.rstrip("\n") == "body v2"
|
||||
assert "v1" not in parsed.body
|
||||
|
||||
|
||||
def test_path_for_does_not_create_files(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
"""``path_for`` is a pure path resolver — no IO."""
|
||||
p = writer.path_for("u_jason", schema=_UserProfileFM)
|
||||
assert p == root.users_dir() / "u_jason" / "user.md"
|
||||
assert not p.exists()
|
||||
assert not root.users_dir().exists()
|
||||
|
||||
|
||||
async def test_write_normalises_trailing_newline(
|
||||
root: MemoryRoot, writer: ProfileWriter
|
||||
) -> None:
|
||||
fm = _UserProfileFM(
|
||||
id="demo_user_profile_u_jason",
|
||||
type="demo_user_profile",
|
||||
user_id="u_jason",
|
||||
)
|
||||
await writer.write("u_jason", frontmatter=fm, body="no-newline-end")
|
||||
text = (root.users_dir() / "u_jason" / "user.md").read_text(encoding="utf-8")
|
||||
assert text.endswith("no-newline-end\n")
|
||||
|
||||
|
||||
async def test_write_rejects_schema_missing_profile_filename(
|
||||
writer: ProfileWriter,
|
||||
) -> None:
|
||||
"""Schema without ``PROFILE_FILENAME`` ClassVar raises a clear error."""
|
||||
|
||||
class _BadSchema(UserScopedFrontmatter):
|
||||
type: Literal["bad"] = "bad"
|
||||
|
||||
fm = _BadSchema(id="x", type="bad", user_id="u_jason")
|
||||
with pytest.raises(TypeError, match="PROFILE_FILENAME"):
|
||||
await writer.write("u_jason", frontmatter=fm, body="body")
|
||||
|
||||
|
||||
async def test_write_rejects_schema_missing_scope_dir(writer: ProfileWriter) -> None:
|
||||
"""Schema without scope mixin (empty ``SCOPE_DIR``) raises a clear error."""
|
||||
|
||||
class _ScopelessSchema(BaseFrontmatter):
|
||||
PROFILE_FILENAME: ClassVar[str] = "profile.md"
|
||||
type: Literal["scopeless"] = "scopeless"
|
||||
|
||||
fm = _ScopelessSchema(id="x", type="scopeless")
|
||||
with pytest.raises(TypeError, match="SCOPE_DIR"):
|
||||
await writer.write("x", frontmatter=fm, body="body")
|
||||
0
tests/unit/test_infra/test_ome/__init__.py
Normal file
0
tests/unit/test_infra/test_ome/__init__.py
Normal file
159
tests/unit/test_infra/test_ome/test_config.py
Normal file
159
tests/unit/test_infra/test_ome/test_config.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.infra.ome.config import (
|
||||
CounterOverride,
|
||||
OMEConfig,
|
||||
StrategyOverride,
|
||||
TomlRoot,
|
||||
)
|
||||
|
||||
|
||||
def test_ome_config_defaults() -> None:
|
||||
from everos.core.persistence.memory_root import MemoryRoot
|
||||
|
||||
c = OMEConfig()
|
||||
assert c.jobstore_path == MemoryRoot.default().ome_db
|
||||
assert c.aps_jobstore_path == MemoryRoot.default().ome_aps_db
|
||||
assert c.max_concurrent_runs == 20
|
||||
assert c.max_retries == 1
|
||||
assert c.max_records_per_strategy == 1000
|
||||
assert c.crash_recovery_timeout_seconds == 1800
|
||||
assert c.config_path is None
|
||||
assert c.config_watch is True
|
||||
assert c.config_watch_debounce_ms == 1600
|
||||
|
||||
|
||||
def test_aps_jobstore_path_derives_sibling_of_jobstore_path(tmp_path: object) -> None:
|
||||
"""When only ``jobstore_path`` is set, APS db lands next to it as
|
||||
``<stem>.aps.db`` so callers using a custom path (e.g. tests with
|
||||
tmp_path) get an isolated APS file rather than the global default."""
|
||||
from pathlib import Path
|
||||
|
||||
custom = Path(str(tmp_path)) / "custom_dir" / "my_ome.db"
|
||||
c = OMEConfig(jobstore_path=custom)
|
||||
assert c.aps_jobstore_path == custom.with_name("my_ome.aps.db")
|
||||
|
||||
|
||||
def test_aps_jobstore_path_respects_explicit_value(tmp_path: object) -> None:
|
||||
"""An explicitly passed ``aps_jobstore_path`` is honored verbatim and
|
||||
the derivation validator does not overwrite it."""
|
||||
from pathlib import Path
|
||||
|
||||
ome = Path(str(tmp_path)) / "ome.db"
|
||||
aps = Path(str(tmp_path)) / "elsewhere" / "scheduler.db"
|
||||
c = OMEConfig(jobstore_path=ome, aps_jobstore_path=aps)
|
||||
assert c.aps_jobstore_path == aps
|
||||
|
||||
|
||||
def test_ome_config_rejects_unknown_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OMEConfig(unknown_field=1) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_ome_config_rejects_zero_concurrency() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OMEConfig(max_concurrent_runs=0)
|
||||
|
||||
|
||||
def test_toml_root_parses_strategy_override() -> None:
|
||||
raw = """
|
||||
[strategies.cluster_memcells]
|
||||
enabled = true
|
||||
max_retries = 3
|
||||
|
||||
[strategies.cluster_memcells.gate]
|
||||
threshold = 10
|
||||
event_field = "user_id"
|
||||
"""
|
||||
import tomllib
|
||||
|
||||
parsed = tomllib.loads(raw)
|
||||
root = TomlRoot.model_validate(parsed)
|
||||
s = root.strategies["cluster_memcells"]
|
||||
assert isinstance(s, StrategyOverride)
|
||||
assert s.enabled is True
|
||||
assert s.max_retries == 3
|
||||
assert isinstance(s.gate, CounterOverride)
|
||||
assert s.gate.threshold == 10
|
||||
assert s.gate.event_field == "user_id"
|
||||
|
||||
|
||||
def test_toml_root_forbids_unknown_strategy_field() -> None:
|
||||
import tomllib
|
||||
|
||||
raw = """
|
||||
[strategies.x]
|
||||
unknown_key = 1
|
||||
"""
|
||||
parsed = tomllib.loads(raw)
|
||||
with pytest.raises(ValidationError):
|
||||
TomlRoot.model_validate(parsed)
|
||||
|
||||
|
||||
def test_strategy_override_accepts_cron_field() -> None:
|
||||
s = StrategyOverride(cron="0 3 * * *")
|
||||
assert s.cron == "0 3 * * *"
|
||||
|
||||
|
||||
def test_strategy_override_accepts_idle_seconds() -> None:
|
||||
s = StrategyOverride(idle_seconds=30)
|
||||
assert s.idle_seconds == 30
|
||||
|
||||
|
||||
def test_strategy_override_accepts_scan_interval_seconds() -> None:
|
||||
s = StrategyOverride(scan_interval_seconds=15)
|
||||
assert s.scan_interval_seconds == 15
|
||||
|
||||
|
||||
def test_strategy_override_rejects_zero_idle_seconds() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StrategyOverride(idle_seconds=0)
|
||||
|
||||
|
||||
def test_strategy_override_rejects_zero_scan_interval() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StrategyOverride(scan_interval_seconds=0)
|
||||
|
||||
|
||||
def test_strategy_override_defaults_are_none() -> None:
|
||||
s = StrategyOverride()
|
||||
assert s.cron is None
|
||||
assert s.idle_seconds is None
|
||||
assert s.scan_interval_seconds is None
|
||||
|
||||
|
||||
def test_counter_override_rejects_empty_event_field() -> None:
|
||||
with pytest.raises(ValidationError, match="event_field"):
|
||||
CounterOverride(event_field="")
|
||||
|
||||
|
||||
def test_strategy_override_rejects_invalid_cron_at_construction() -> None:
|
||||
"""cron is parsed by APS at construction time so TOML reload can't
|
||||
bring an invalid crontab into the system."""
|
||||
with pytest.raises(ValidationError, match="cron"):
|
||||
StrategyOverride(cron="not a cron")
|
||||
|
||||
|
||||
def test_strategy_override_rejects_inconsistent_idle_pair() -> None:
|
||||
"""When both idle_seconds and scan_interval_seconds are overridden in
|
||||
the same payload, scan_interval must be <= idle_seconds // 2 — mirror
|
||||
of the Idle trigger constraint."""
|
||||
with pytest.raises(ValidationError, match="scan_interval_seconds"):
|
||||
StrategyOverride(idle_seconds=30, scan_interval_seconds=20)
|
||||
|
||||
|
||||
def test_strategy_override_accepts_consistent_idle_pair() -> None:
|
||||
s = StrategyOverride(idle_seconds=60, scan_interval_seconds=30)
|
||||
assert s.idle_seconds == 60
|
||||
assert s.scan_interval_seconds == 30
|
||||
|
||||
|
||||
def test_strategy_override_accepts_single_idle_field() -> None:
|
||||
"""One-sided override is allowed; the cross-field check is deferred
|
||||
to post-merge time (in apply_overrides) when both are known."""
|
||||
s = StrategyOverride(scan_interval_seconds=999)
|
||||
assert s.scan_interval_seconds == 999
|
||||
assert s.idle_seconds is None
|
||||
407
tests/unit/test_infra/test_ome/test_config_reloader.py
Normal file
407
tests/unit/test_infra/test_ome/test_config_reloader.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""Tests for ConfigReloader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome._background.config_reloader import (
|
||||
ConfigReloader,
|
||||
apply_overrides,
|
||||
)
|
||||
from everos.infra.ome._dispatch.registry import StrategyRegistry
|
||||
from everos.infra.ome.config import CounterOverride, StrategyOverride, TomlRoot
|
||||
from everos.infra.ome.context import StrategyContext
|
||||
from everos.infra.ome.decorator import offline_strategy
|
||||
from everos.infra.ome.engine import OfflineEngine
|
||||
from everos.infra.ome.events import BaseEvent
|
||||
from everos.infra.ome.gates import Counter
|
||||
from everos.infra.ome.triggers import Cron, Idle, Immediate
|
||||
|
||||
|
||||
class _E(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
class _EventUid(BaseEvent):
|
||||
user_id: str
|
||||
|
||||
|
||||
def _make(name: str, **kw: Any) -> Any:
|
||||
@offline_strategy(name=name, trigger=Immediate(on=[_E]), emits=[], **kw)
|
||||
async def f(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _make_cron(name: str, expr: str = "0 3 * * *", **kw: Any) -> Any:
|
||||
@offline_strategy(name=name, trigger=Cron(expr=expr), emits=[], **kw)
|
||||
async def f(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _make_idle(name: str, **kw: Any) -> Any:
|
||||
@offline_strategy(
|
||||
name=name,
|
||||
trigger=Idle(
|
||||
on=[_EventUid],
|
||||
event_field="user_id",
|
||||
idle_seconds=30,
|
||||
scan_interval_seconds=10,
|
||||
),
|
||||
emits=[],
|
||||
**kw,
|
||||
)
|
||||
async def f(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
return f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_engine() -> MagicMock:
|
||||
"""Mock OfflineEngine; spec catches typos in mocked method names."""
|
||||
return MagicMock(spec=OfflineEngine)
|
||||
|
||||
|
||||
def test_apply_overrides_replaces_enabled(fake_engine: MagicMock) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s", enabled=True))
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(enabled=False)})
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
assert reg.get("s").enabled is False
|
||||
|
||||
|
||||
def test_apply_overrides_max_retries(fake_engine: MagicMock) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s", max_retries=1))
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(max_retries=5)})
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
assert reg.get("s").max_retries == 5
|
||||
|
||||
|
||||
def test_apply_overrides_counter_partial(fake_engine: MagicMock) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s", gate=Counter(threshold=3, event_field="user_id")))
|
||||
root = TomlRoot(
|
||||
strategies={"s": StrategyOverride(gate=CounterOverride(threshold=10))}
|
||||
)
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
g = reg.get("s").gate
|
||||
assert g.threshold == 10
|
||||
assert g.event_field == "user_id" # untouched
|
||||
|
||||
|
||||
def test_apply_overrides_unknown_strategy_ignored(fake_engine: MagicMock) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s"))
|
||||
root = TomlRoot(strategies={"unknown": StrategyOverride(enabled=False)})
|
||||
apply_overrides(reg, root, fake_engine) # must not raise
|
||||
|
||||
|
||||
def test_apply_overrides_updates_cron_expr(fake_engine: MagicMock) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s", "0 3 * * *"))
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(cron="*/5 * * * *")})
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert isinstance(reg.get("s").trigger, Cron)
|
||||
assert reg.get("s").trigger.expr == "*/5 * * * *"
|
||||
fake_engine.reschedule_cron_job.assert_called_once_with("s", "*/5 * * * *")
|
||||
|
||||
|
||||
def test_apply_overrides_skips_atomic_group_on_reschedule_failure(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Even though StrategyOverride.cron is now syntactically validated at
|
||||
parse time, reschedule_cron_job can still fail at runtime (APS internal
|
||||
error, scheduler stopped, etc.). The atomic-group rollback must hold
|
||||
against those failures too.
|
||||
"""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s", "0 3 * * *", enabled=True, max_retries=1))
|
||||
fake_engine.reschedule_cron_job.side_effect = RuntimeError("APS error")
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"s": StrategyOverride(enabled=False, cron="*/5 * * * *", max_retries=99)
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
# enabled applied independently
|
||||
assert reg.get("s").enabled is False
|
||||
# atomic group rolled back: cron unchanged, max_retries unchanged
|
||||
assert reg.get("s").trigger.expr == "0 3 * * *"
|
||||
assert reg.get("s").max_retries == 1
|
||||
fake_engine.reschedule_cron_job.assert_called_once_with("s", "*/5 * * * *")
|
||||
|
||||
|
||||
def test_apply_overrides_skips_atomic_group_on_cron_type_mismatch(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s", enabled=True)) # Immediate strategy
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(enabled=False, cron="0 3 * * *")})
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("s").enabled is False
|
||||
assert isinstance(reg.get("s").trigger, Immediate)
|
||||
fake_engine.reschedule_cron_job.assert_not_called()
|
||||
|
||||
|
||||
def test_apply_overrides_updates_idle_seconds_and_scan_interval(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_idle("s"))
|
||||
root = TomlRoot(
|
||||
strategies={"s": StrategyOverride(idle_seconds=120, scan_interval_seconds=15)}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
t = reg.get("s").trigger
|
||||
assert t.idle_seconds == 120
|
||||
assert t.scan_interval_seconds == 15
|
||||
fake_engine.reschedule_idle_job.assert_called_once_with(
|
||||
"s", scan_interval_seconds=15
|
||||
)
|
||||
|
||||
|
||||
def test_apply_overrides_updates_only_idle_seconds_does_not_reschedule_aps(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""idle_seconds is consumed by dispatcher / engine on each scan,
|
||||
not by APS IntervalTrigger, so changing only it must NOT trigger
|
||||
an APS reschedule (which would reset the pending tick).
|
||||
"""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_idle("s"))
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(idle_seconds=120)})
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("s").trigger.idle_seconds == 120
|
||||
fake_engine.reschedule_idle_job.assert_not_called()
|
||||
|
||||
|
||||
def test_apply_overrides_skips_atomic_group_on_idle_type_mismatch(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s")) # Cron strategy
|
||||
root = TomlRoot(strategies={"s": StrategyOverride(idle_seconds=60)})
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert isinstance(reg.get("s").trigger, Cron)
|
||||
fake_engine.reschedule_cron_job.assert_not_called()
|
||||
fake_engine.reschedule_idle_job.assert_not_called()
|
||||
|
||||
|
||||
def test_apply_overrides_rollback_on_aps_reschedule_failure(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
fake_engine.reschedule_cron_job.side_effect = RuntimeError("APS exploded")
|
||||
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s", "0 3 * * *", enabled=True, max_retries=1))
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"s": StrategyOverride(enabled=False, cron="*/5 * * * *", max_retries=99)
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
# enabled applied (Step 1, before atomic group)
|
||||
assert reg.get("s").enabled is False
|
||||
# atomic group rolled back: cron + max_retries unchanged
|
||||
assert reg.get("s").trigger.expr == "0 3 * * *"
|
||||
assert reg.get("s").max_retries == 1
|
||||
|
||||
|
||||
def test_apply_overrides_enabled_survives_reschedule_failure(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""enabled=false is emergency-stop semantics; must apply even when the
|
||||
paired cron update fails at reschedule time.
|
||||
"""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s", "0 3 * * *", enabled=True))
|
||||
fake_engine.reschedule_cron_job.side_effect = RuntimeError("APS error")
|
||||
root = TomlRoot(
|
||||
strategies={"s": StrategyOverride(enabled=False, cron="*/5 * * * *")}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("s").enabled is False
|
||||
assert reg.get("s").trigger.expr == "0 3 * * *"
|
||||
|
||||
|
||||
def test_apply_overrides_strategy_isolation(fake_engine: MagicMock) -> None:
|
||||
"""One strategy's atomic-group failure must not affect another."""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("a", "0 3 * * *"))
|
||||
reg.register(_make_cron("b", "0 4 * * *"))
|
||||
|
||||
def _reschedule(name: str, expr: str) -> None:
|
||||
if name == "b":
|
||||
raise RuntimeError("simulated APS failure for b")
|
||||
|
||||
fake_engine.reschedule_cron_job.side_effect = _reschedule
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"a": StrategyOverride(cron="*/5 * * * *"),
|
||||
"b": StrategyOverride(cron="*/7 * * * *"),
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("a").trigger.expr == "*/5 * * * *"
|
||||
assert reg.get("b").trigger.expr == "0 4 * * *"
|
||||
|
||||
|
||||
def test_apply_overrides_atomic_group_no_partial_application(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""A failure in the atomic group must roll back max_retries / gate too."""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(
|
||||
_make_cron(
|
||||
"s",
|
||||
"0 3 * * *",
|
||||
max_retries=1,
|
||||
gate=Counter(threshold=3, event_field="user_id"),
|
||||
)
|
||||
)
|
||||
fake_engine.reschedule_cron_job.side_effect = RuntimeError("APS error")
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"s": StrategyOverride(
|
||||
cron="*/5 * * * *",
|
||||
max_retries=99,
|
||||
gate=CounterOverride(threshold=100),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("s").trigger.expr == "0 3 * * *"
|
||||
assert reg.get("s").max_retries == 1
|
||||
assert reg.get("s").gate.threshold == 3
|
||||
|
||||
|
||||
def test_apply_overrides_succeeds_on_combined_enabled_and_trigger(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make_cron("s", "0 3 * * *", enabled=True))
|
||||
root = TomlRoot(
|
||||
strategies={"s": StrategyOverride(enabled=False, cron="*/5 * * * *")}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
assert reg.get("s").enabled is False
|
||||
assert reg.get("s").trigger.expr == "*/5 * * * *"
|
||||
fake_engine.reschedule_cron_job.assert_called_once_with("s", "*/5 * * * *")
|
||||
|
||||
|
||||
def test_atomic_group_skipped_when_introducing_gate_without_threshold(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""N5: TOML that introduces a gate via cooldown alone (no threshold)
|
||||
must be rejected, not silently defaulted to threshold=1 ('fire every event').
|
||||
"""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s")) # no gate
|
||||
assert reg.get("s").gate is None
|
||||
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"s": StrategyOverride(gate=CounterOverride(cooldown_seconds=60)),
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
# Atomic group rolled back: still no gate.
|
||||
assert reg.get("s").gate is None
|
||||
|
||||
|
||||
def test_atomic_group_accepts_introducing_gate_with_explicit_threshold(
|
||||
fake_engine: MagicMock,
|
||||
) -> None:
|
||||
"""N5 happy path: explicit threshold on a previously-gateless strategy
|
||||
is the user opt-in we require.
|
||||
"""
|
||||
reg = StrategyRegistry()
|
||||
reg.register(_make("s"))
|
||||
assert reg.get("s").gate is None
|
||||
|
||||
root = TomlRoot(
|
||||
strategies={
|
||||
"s": StrategyOverride(
|
||||
gate=CounterOverride(threshold=5, cooldown_seconds=60)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
apply_overrides(reg, root, fake_engine)
|
||||
|
||||
g = reg.get("s").gate
|
||||
assert g is not None
|
||||
assert g.threshold == 5
|
||||
assert g.cooldown_seconds == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_twice_raises(tmp_path: Path) -> None:
|
||||
"""N7: calling start() twice surfaces the caller bug instead of
|
||||
silently dropping the original task reference and racing two watchers.
|
||||
"""
|
||||
config_path = tmp_path / "ome.toml"
|
||||
config_path.write_text("")
|
||||
reloader = ConfigReloader(
|
||||
config_path=config_path,
|
||||
registry=StrategyRegistry(),
|
||||
engine=MagicMock(spec=OfflineEngine),
|
||||
)
|
||||
reloader.start()
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match=r"already started"):
|
||||
reloader.start()
|
||||
finally:
|
||||
await reloader.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_after_stop_is_allowed(tmp_path: Path) -> None:
|
||||
"""N7: idempotency check only fires while a task is live; once stopped,
|
||||
start() must work again so callers can restart the reloader.
|
||||
"""
|
||||
config_path = tmp_path / "ome.toml"
|
||||
config_path.write_text("")
|
||||
reloader = ConfigReloader(
|
||||
config_path=config_path,
|
||||
registry=StrategyRegistry(),
|
||||
engine=MagicMock(spec=OfflineEngine),
|
||||
)
|
||||
reloader.start()
|
||||
await reloader.stop()
|
||||
# Must not raise.
|
||||
reloader.start()
|
||||
await reloader.stop()
|
||||
24
tests/unit/test_infra/test_ome/test_context.py
Normal file
24
tests/unit/test_infra/test_ome/test_context.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import structlog
|
||||
|
||||
from everos.infra.ome.context import StrategyContext
|
||||
|
||||
|
||||
def test_strategy_context_is_protocol() -> None:
|
||||
assert issubclass(StrategyContext, Protocol) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_strategy_context_runtime_attributes() -> None:
|
||||
class _Impl:
|
||||
run_id = "r1"
|
||||
logger = structlog.get_logger("test")
|
||||
|
||||
async def emit(self, event: object) -> None:
|
||||
return None
|
||||
|
||||
ctx: StrategyContext = _Impl()
|
||||
assert ctx.run_id == "r1"
|
||||
assert callable(ctx.emit)
|
||||
111
tests/unit/test_infra/test_ome/test_counter_store.py
Normal file
111
tests/unit/test_infra/test_ome/test_counter_store.py
Normal file
@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome._stores.counter import CounterStore
|
||||
from everos.infra.ome._stores.storage import OMEStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def store(tmp_path: Path) -> CounterStore:
|
||||
storage = OMEStorage(db_path=tmp_path / "ome.db")
|
||||
await storage.init()
|
||||
return CounterStore(storage=storage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_until_threshold(store: CounterStore) -> None:
|
||||
for i in range(1, 5):
|
||||
passed, cur = await store.incr_and_check(
|
||||
"s",
|
||||
"u1",
|
||||
threshold=5,
|
||||
cooldown_seconds=0,
|
||||
)
|
||||
assert passed is False
|
||||
assert cur == i
|
||||
|
||||
passed, cur = await store.incr_and_check(
|
||||
"s",
|
||||
"u1",
|
||||
threshold=5,
|
||||
cooldown_seconds=0,
|
||||
)
|
||||
assert passed is True
|
||||
assert cur == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_after_pass(store: CounterStore) -> None:
|
||||
for _ in range(5):
|
||||
await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=0)
|
||||
passed, cur = await store.incr_and_check(
|
||||
"s",
|
||||
"u1",
|
||||
threshold=5,
|
||||
cooldown_seconds=0,
|
||||
)
|
||||
assert passed is False
|
||||
assert cur == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_blocks_pass(store: CounterStore) -> None:
|
||||
# First pass
|
||||
for _ in range(5):
|
||||
await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=10)
|
||||
# Threshold met again immediately, but cooldown blocks
|
||||
for _ in range(5):
|
||||
passed, _ = await store.incr_and_check(
|
||||
"s",
|
||||
"u1",
|
||||
threshold=5,
|
||||
cooldown_seconds=10,
|
||||
)
|
||||
assert passed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buckets_are_isolated(store: CounterStore) -> None:
|
||||
for _ in range(5):
|
||||
await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=0)
|
||||
passed, cur = await store.incr_and_check(
|
||||
"s",
|
||||
"u2",
|
||||
threshold=5,
|
||||
cooldown_seconds=0,
|
||||
)
|
||||
assert cur == 1
|
||||
assert passed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_query(store: CounterStore) -> None:
|
||||
await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=0)
|
||||
await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=0)
|
||||
cur = await store.get_progress("s", "u1")
|
||||
assert cur == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returned_counter_reflects_actual_value_when_threshold_lowered(
|
||||
store: CounterStore,
|
||||
) -> None:
|
||||
"""When threshold drops via hot-reload after counter accumulation,
|
||||
the returned counter must reflect the *actual* count at trigger
|
||||
moment, not the (lower) threshold. Diagnostics rely on this.
|
||||
"""
|
||||
# Accumulate 7 hits under a high threshold; none pass.
|
||||
for _ in range(7):
|
||||
passed, _ = await store.incr_and_check(
|
||||
"s", "u1", threshold=10, cooldown_seconds=0
|
||||
)
|
||||
assert passed is False
|
||||
|
||||
# Threshold is "lowered" to 5 (config hot-reload semantics).
|
||||
# Counter goes 7 -> 8, which is past the new threshold.
|
||||
passed, cur = await store.incr_and_check("s", "u1", threshold=5, cooldown_seconds=0)
|
||||
assert passed is True
|
||||
assert cur == 8 # actual count, not threshold (=5)
|
||||
149
tests/unit/test_infra/test_ome/test_crash_recovery.py
Normal file
149
tests/unit/test_infra/test_ome/test_crash_recovery.py
Normal file
@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.utils.datetime import get_now_with_timezone, to_iso_format
|
||||
from everos.infra.ome._background.crash_recovery import scan_and_resume
|
||||
from everos.infra.ome._stores.run_record import RunRecordStore
|
||||
from everos.infra.ome._stores.storage import OMEStorage
|
||||
from everos.infra.ome.records import RunStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def rec_store(tmp_path: Path) -> RunRecordStore:
|
||||
storage = OMEStorage(db_path=tmp_path / "ome.db")
|
||||
await storage.init()
|
||||
return RunRecordStore(storage=storage, max_records_per_strategy=1000)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_old_running_as_crashed(rec_store: RunRecordStore) -> None:
|
||||
await rec_store.mark_running(
|
||||
run_id="r_old",
|
||||
strategy_name="s",
|
||||
attempt=0,
|
||||
event_topic="x:E",
|
||||
event_payload="{}",
|
||||
max_retries_snapshot=1,
|
||||
)
|
||||
async with rec_store._storage.connect() as conn:
|
||||
rewind = to_iso_format(get_now_with_timezone() - timedelta(hours=2))
|
||||
await conn.execute(
|
||||
"UPDATE run_record SET started_at = ? WHERE run_id = ?",
|
||||
(rewind, "r_old"),
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
resumed: list = []
|
||||
|
||||
async def add_job_hook(name, run_id, event_topic, event_payload, max_retries):
|
||||
resumed.append((name, run_id, event_topic, event_payload, max_retries))
|
||||
|
||||
await scan_and_resume(
|
||||
run_record_store=rec_store,
|
||||
timeout_seconds=1800,
|
||||
add_job=add_job_hook,
|
||||
)
|
||||
|
||||
rec = await rec_store.get("r_old")
|
||||
assert rec.status == RunStatus.CRASHED
|
||||
assert len(resumed) == 1
|
||||
new_name, new_run_id, ec, ep, mr = resumed[0]
|
||||
assert new_name == "s"
|
||||
assert new_run_id != "r_old"
|
||||
assert ec == "x:E"
|
||||
assert ep == "{}"
|
||||
assert mr == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recent_running_skipped(rec_store: RunRecordStore) -> None:
|
||||
await rec_store.mark_running(
|
||||
run_id="r_fresh",
|
||||
strategy_name="s",
|
||||
attempt=0,
|
||||
event_topic="x:E",
|
||||
event_payload="{}",
|
||||
max_retries_snapshot=1,
|
||||
)
|
||||
resumed: list = []
|
||||
|
||||
async def add_job_hook(*args, **kw):
|
||||
resumed.append(args)
|
||||
|
||||
await scan_and_resume(
|
||||
run_record_store=rec_store,
|
||||
timeout_seconds=1800,
|
||||
add_job=add_job_hook,
|
||||
)
|
||||
rec = await rec_store.get("r_fresh")
|
||||
assert rec.status == RunStatus.RUNNING
|
||||
assert resumed == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_timeout", [0, -1])
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_and_resume_non_positive_timeout_raises(
|
||||
rec_store: RunRecordStore, bad_timeout: int
|
||||
) -> None:
|
||||
"""N6: non-positive timeout must fail fast rather than silently no-op."""
|
||||
|
||||
async def _noop_add_job(*_args: object, **_kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match=r"timeout_seconds must be > 0"):
|
||||
await scan_and_resume(
|
||||
run_record_store=rec_store,
|
||||
timeout_seconds=bad_timeout,
|
||||
add_job=_noop_add_job,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_job_failure_does_not_abort_loop(
|
||||
rec_store: RunRecordStore,
|
||||
) -> None:
|
||||
"""add_job raising on one row must not block sibling stale rows.
|
||||
|
||||
mark_crashed runs before add_job, so both rows end up CRASHED even
|
||||
when add_job fails for one. This pins the at-most-once contract
|
||||
documented in the module docstring.
|
||||
"""
|
||||
for run_id in ("r_old_1", "r_old_2"):
|
||||
await rec_store.mark_running(
|
||||
run_id=run_id,
|
||||
strategy_name="s",
|
||||
attempt=0,
|
||||
event_topic="x:E",
|
||||
event_payload="{}",
|
||||
max_retries_snapshot=1,
|
||||
)
|
||||
async with rec_store._storage.connect() as conn:
|
||||
rewind = to_iso_format(get_now_with_timezone() - timedelta(hours=2))
|
||||
await conn.execute(
|
||||
"UPDATE run_record SET started_at = ? WHERE run_id IN (?, ?)",
|
||||
(rewind, "r_old_1", "r_old_2"),
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
calls: list[tuple] = []
|
||||
|
||||
async def flaky_add_job(name, run_id, event_topic, event_payload, max_retries):
|
||||
calls.append((name, run_id, event_topic, event_payload, max_retries))
|
||||
if len(calls) == 1:
|
||||
raise RuntimeError("APS jobstore unavailable")
|
||||
|
||||
await scan_and_resume(
|
||||
run_record_store=rec_store,
|
||||
timeout_seconds=1800,
|
||||
add_job=flaky_add_job,
|
||||
)
|
||||
|
||||
rec1 = await rec_store.get("r_old_1")
|
||||
rec2 = await rec_store.get("r_old_2")
|
||||
assert rec1.status == RunStatus.CRASHED
|
||||
assert rec2.status == RunStatus.CRASHED
|
||||
assert len(calls) == 2
|
||||
81
tests/unit/test_infra/test_ome/test_decorator.py
Normal file
81
tests/unit/test_infra/test_ome/test_decorator.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome.context import StrategyContext
|
||||
from everos.infra.ome.decorator import StrategyMeta, offline_strategy
|
||||
from everos.infra.ome.events import BaseEvent
|
||||
from everos.infra.ome.gates import Counter
|
||||
from everos.infra.ome.triggers import Immediate
|
||||
|
||||
|
||||
class _E(BaseEvent):
|
||||
user_id: str
|
||||
|
||||
|
||||
def test_decorator_attaches_metadata() -> None:
|
||||
@offline_strategy(name="x", trigger=Immediate(on=[_E]), emits=[_E])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
meta: StrategyMeta = s._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "x"
|
||||
assert meta.emits == frozenset({_E})
|
||||
assert meta.gate is None
|
||||
assert meta.applies_to is None
|
||||
assert meta.max_retries is None
|
||||
assert meta.enabled is True
|
||||
assert meta.func is s
|
||||
|
||||
|
||||
def test_decorator_with_full_params() -> None:
|
||||
@offline_strategy(
|
||||
name="cluster",
|
||||
trigger=Immediate(on=[_E]),
|
||||
emits=[_E],
|
||||
applies_to="user_id",
|
||||
gate=Counter(threshold=5),
|
||||
max_retries=3,
|
||||
enabled=False,
|
||||
)
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
meta = s._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.applies_to == "user_id"
|
||||
assert meta.gate.threshold == 5
|
||||
assert meta.max_retries == 3
|
||||
assert meta.enabled is False
|
||||
|
||||
|
||||
def test_decorator_callable_applies_to() -> None:
|
||||
def is_paid(e: _E) -> bool:
|
||||
return e.user_id.startswith("paid_")
|
||||
|
||||
@offline_strategy(
|
||||
name="paid_only",
|
||||
trigger=Immediate(on=[_E]),
|
||||
emits=[_E],
|
||||
applies_to=is_paid,
|
||||
)
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
meta = s._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.applies_to is is_paid
|
||||
|
||||
|
||||
def test_decorator_rejects_blank_name() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@offline_strategy(name="", trigger=Immediate(on=[_E]), emits=[_E])
|
||||
async def _s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_decorator_rejects_non_async_function() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@offline_strategy(name="x", trigger=Immediate(on=[_E]), emits=[_E])
|
||||
def _s(event: _E, ctx: StrategyContext) -> None: # not async
|
||||
return None
|
||||
215
tests/unit/test_infra/test_ome/test_dispatcher.py
Normal file
215
tests/unit/test_infra/test_ome/test_dispatcher.py
Normal file
@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome._dispatch.dispatcher import EventDispatcher
|
||||
from everos.infra.ome._dispatch.registry import StrategyRegistry
|
||||
from everos.infra.ome._stores.counter import CounterStore
|
||||
from everos.infra.ome._stores.storage import OMEStorage
|
||||
from everos.infra.ome.context import StrategyContext
|
||||
from everos.infra.ome.decorator import offline_strategy
|
||||
from everos.infra.ome.events import BaseEvent, CronTick
|
||||
from everos.infra.ome.gates import Counter
|
||||
from everos.infra.ome.triggers import Cron, Immediate
|
||||
|
||||
|
||||
class _E(BaseEvent):
|
||||
user_id: str
|
||||
|
||||
|
||||
def _make_strategy(name: str, **kw):
|
||||
@offline_strategy(name=name, trigger=Immediate(on=[_E]), emits=[], **kw)
|
||||
async def _f(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dispatcher(tmp_path: Path) -> EventDispatcher:
|
||||
storage = OMEStorage(db_path=tmp_path / "ome.db")
|
||||
await storage.init()
|
||||
registry = StrategyRegistry()
|
||||
counter = CounterStore(storage=storage)
|
||||
return EventDispatcher(registry=registry, counter_store=counter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_passes_when_no_gate(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_strategy("s_pass"))
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
assert [m.name for m, _ in routes] == ["s_pass"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_skips_disabled(dispatcher: EventDispatcher) -> None:
|
||||
dispatcher._registry.register(_make_strategy("s_off", enabled=False))
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
assert routes == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_applies_to_string(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(
|
||||
_make_strategy("s", applies_to="user_id"),
|
||||
)
|
||||
routes_empty = await dispatcher.dispatch(_E(user_id=""))
|
||||
routes_set = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
assert routes_empty == []
|
||||
assert len(routes_set) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_applies_to_callable(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
def is_paid(e: _E) -> bool:
|
||||
return e.user_id.startswith("paid_")
|
||||
|
||||
dispatcher._registry.register(_make_strategy("s", applies_to=is_paid))
|
||||
assert await dispatcher.dispatch(_E(user_id="free_a")) == []
|
||||
assert len(await dispatcher.dispatch(_E(user_id="paid_a"))) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_counter_gate(dispatcher: EventDispatcher) -> None:
|
||||
dispatcher._registry.register(
|
||||
_make_strategy("s", gate=Counter(threshold=3, event_field="user_id"))
|
||||
)
|
||||
for _ in range(2):
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
assert routes == []
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
assert len(routes) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_returns_route_info(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(
|
||||
_make_strategy("s", gate=Counter(threshold=3, event_field="user_id"))
|
||||
)
|
||||
infos = await dispatcher.inspect(_E(user_id="u1"))
|
||||
assert len(infos) == 1
|
||||
assert infos[0].counter_progress == (1, 3)
|
||||
assert infos[0].will_run is False
|
||||
|
||||
|
||||
def _make_cron_strategy(name: str):
|
||||
@offline_strategy(name=name, trigger=Cron(expr="0 * * * *"), emits=[])
|
||||
async def _f(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_routes_engine_tick_to_named_strategy_only(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_cron_strategy("cron_a"))
|
||||
dispatcher._registry.register(_make_cron_strategy("cron_b"))
|
||||
routes = await dispatcher.dispatch(CronTick(strategy_name="cron_a"))
|
||||
assert [m.name for m, _ in routes] == ["cron_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_engine_tick_skips_non_target_strategy(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_cron_strategy("cron_a"))
|
||||
dispatcher._registry.register(_make_cron_strategy("cron_b"))
|
||||
infos = await dispatcher.inspect(CronTick(strategy_name="cron_b"))
|
||||
assert [i.strategy_name for i in infos] == ["cron_b"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_force_enabled_bypasses_enabled_gate(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_strategy("s_off", enabled=False))
|
||||
assert await dispatcher.dispatch(_E(user_id="u1")) == []
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"), force_enabled=True)
|
||||
assert [m.name for m, _ in routes] == ["s_off"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_force_enabled_still_applies_applies_to_and_counter(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(
|
||||
_make_strategy(
|
||||
"s",
|
||||
enabled=False,
|
||||
applies_to="user_id",
|
||||
gate=Counter(threshold=2, event_field="user_id"),
|
||||
),
|
||||
)
|
||||
assert await dispatcher.dispatch(_E(user_id=""), force_enabled=True) == []
|
||||
assert await dispatcher.dispatch(_E(user_id="u1"), force_enabled=True) == []
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"), force_enabled=True)
|
||||
assert len(routes) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_strategy_filter_scopes_to_single_strategy(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_strategy("s_a"))
|
||||
dispatcher._registry.register(_make_strategy("s_b"))
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"), strategy_filter="s_a")
|
||||
assert [m.name for m, _ in routes] == ["s_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_strategy_filter_unknown_raises(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
dispatcher._registry.register(_make_strategy("s_a"))
|
||||
with pytest.raises(KeyError):
|
||||
await dispatcher.dispatch(_E(user_id="u1"), strategy_filter="missing")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_isolates_faulty_applies_to_callable(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
"""A single strategy's buggy ``applies_to`` callable must not tank
|
||||
the fan-out for siblings subscribed to the same event class.
|
||||
"""
|
||||
|
||||
def _boom(_e: _E) -> bool:
|
||||
raise RuntimeError("applies_to is buggy")
|
||||
|
||||
dispatcher._registry.register(_make_strategy("s_buggy", applies_to=_boom))
|
||||
dispatcher._registry.register(_make_strategy("s_healthy"))
|
||||
|
||||
routes = await dispatcher.dispatch(_E(user_id="u1"))
|
||||
|
||||
# s_buggy is treated as not-applies; s_healthy still routes.
|
||||
assert [m.name for m, _ in routes] == ["s_healthy"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_isolates_faulty_applies_to_callable(
|
||||
dispatcher: EventDispatcher,
|
||||
) -> None:
|
||||
def _boom(_e: _E) -> bool:
|
||||
raise RuntimeError("applies_to is buggy")
|
||||
|
||||
dispatcher._registry.register(_make_strategy("s_buggy", applies_to=_boom))
|
||||
dispatcher._registry.register(_make_strategy("s_healthy"))
|
||||
|
||||
infos = await dispatcher.inspect(_E(user_id="u1"))
|
||||
|
||||
by_name = {i.strategy_name: i for i in infos}
|
||||
assert by_name["s_buggy"].applies_to_pass is False
|
||||
assert by_name["s_healthy"].applies_to_pass is True
|
||||
186
tests/unit/test_infra/test_ome/test_e2e.py
Normal file
186
tests/unit/test_infra/test_ome/test_e2e.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""End-to-end pipeline test exercising the chain emit semantics.
|
||||
|
||||
MemCellSaved -> atomic (leaf strategy)
|
||||
EpisodeSaved -> cluster -> ClusteringCompleted -> profile (Counter threshold=3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome import (
|
||||
BaseEvent,
|
||||
Counter,
|
||||
Cron,
|
||||
CronTick,
|
||||
Immediate,
|
||||
StrategyContext,
|
||||
offline_strategy,
|
||||
)
|
||||
from everos.infra.ome.engine import _cron_entry
|
||||
from everos.infra.ome.testing import StrategyTestHarness
|
||||
|
||||
|
||||
class MemCellSaved(BaseEvent):
|
||||
user_id: str
|
||||
cell_id: str
|
||||
|
||||
|
||||
class EpisodeSaved(BaseEvent):
|
||||
user_id: str
|
||||
episode_text: str
|
||||
|
||||
|
||||
class ClusteringCompleted(BaseEvent):
|
||||
user_id: str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_emit_without_counter_gate() -> None:
|
||||
"""Variant of the full-chain test without a Counter gate.
|
||||
|
||||
Profile fires once per ClusteringCompleted instead of once per N.
|
||||
"""
|
||||
log: list[tuple[str, str]] = []
|
||||
|
||||
@offline_strategy(
|
||||
name="cluster_e2e",
|
||||
trigger=Immediate(on=[EpisodeSaved]),
|
||||
emits=[ClusteringCompleted],
|
||||
)
|
||||
async def cluster(event: EpisodeSaved, ctx: StrategyContext) -> None:
|
||||
log.append(("cluster", event.user_id))
|
||||
await ctx.emit(ClusteringCompleted(user_id=event.user_id))
|
||||
|
||||
@offline_strategy(
|
||||
name="profile_e2e",
|
||||
trigger=Immediate(on=[ClusteringCompleted]),
|
||||
emits=[],
|
||||
)
|
||||
async def profile(event: ClusteringCompleted, ctx: StrategyContext) -> None:
|
||||
log.append(("profile", event.user_id))
|
||||
|
||||
async with StrategyTestHarness() as h:
|
||||
h.register(cluster)
|
||||
h.register(profile)
|
||||
await h.start()
|
||||
# Emit 3 episodes -> cluster runs 3x -> emits ClusteringCompleted 3x ->
|
||||
# profile runs 3x (no counter gate).
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t1"))
|
||||
await asyncio.sleep(0.15)
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t2"))
|
||||
await asyncio.sleep(0.15)
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t3"))
|
||||
await asyncio.sleep(0.2)
|
||||
await h.drain(timeout=15)
|
||||
|
||||
cluster_runs = await h.list_runs("cluster_e2e")
|
||||
profile_runs = await h.list_runs("profile_e2e")
|
||||
|
||||
cluster_calls = [c for c in log if c[0] == "cluster"]
|
||||
profile_calls = [c for c in log if c[0] == "profile"]
|
||||
assert len(cluster_calls) == 3, (
|
||||
f"Expected 3 cluster, got {len(cluster_calls)}: {log}"
|
||||
)
|
||||
assert len(profile_calls) == 3, (
|
||||
f"Expected 3 profile, got {len(profile_calls)}: {log}"
|
||||
)
|
||||
assert len(cluster_runs) == 3
|
||||
assert len(profile_runs) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_pipeline_runs_full_path() -> None:
|
||||
"""Full chain with atomic, cluster, and profile (Counter gated)."""
|
||||
log: list[tuple[str, str]] = []
|
||||
|
||||
@offline_strategy(name="atomic_e2e", trigger=Immediate(on=[MemCellSaved]), emits=[])
|
||||
async def atomic(event: MemCellSaved, ctx: StrategyContext) -> None:
|
||||
log.append(("atomic", event.cell_id))
|
||||
|
||||
@offline_strategy(
|
||||
name="cluster_e2e",
|
||||
trigger=Immediate(on=[EpisodeSaved]),
|
||||
emits=[ClusteringCompleted],
|
||||
)
|
||||
async def cluster(event: EpisodeSaved, ctx: StrategyContext) -> None:
|
||||
log.append(("cluster", event.user_id))
|
||||
await ctx.emit(ClusteringCompleted(user_id=event.user_id))
|
||||
|
||||
@offline_strategy(
|
||||
name="profile_e2e",
|
||||
trigger=Immediate(on=[ClusteringCompleted]),
|
||||
emits=[],
|
||||
gate=Counter(threshold=3, event_field="user_id"),
|
||||
)
|
||||
async def profile(event: ClusteringCompleted, ctx: StrategyContext) -> None:
|
||||
log.append(("profile", event.user_id))
|
||||
|
||||
async with StrategyTestHarness() as h:
|
||||
h.register(atomic)
|
||||
h.register(cluster)
|
||||
h.register(profile)
|
||||
await h.start()
|
||||
# Two memcells (each fires atomic).
|
||||
await h.emit(MemCellSaved(user_id="u1", cell_id="c1"))
|
||||
await asyncio.sleep(0.15)
|
||||
await h.emit(MemCellSaved(user_id="u1", cell_id="c2"))
|
||||
await asyncio.sleep(0.15)
|
||||
# Three episodes -> cluster runs 3x -> ClusteringCompleted 3x ->
|
||||
# profile Counter at threshold=3 fires once.
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t1"))
|
||||
await asyncio.sleep(0.15)
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t2"))
|
||||
await asyncio.sleep(0.15)
|
||||
await h.emit(EpisodeSaved(user_id="u1", episode_text="t3"))
|
||||
await asyncio.sleep(0.2)
|
||||
await h.drain(timeout=15)
|
||||
|
||||
# Validate using run records
|
||||
atomic_runs = await h.list_runs("atomic_e2e")
|
||||
cluster_runs = await h.list_runs("cluster_e2e")
|
||||
profile_runs = await h.list_runs("profile_e2e")
|
||||
|
||||
atomic_calls = [c for c in log if c[0] == "atomic"]
|
||||
cluster_calls = [c for c in log if c[0] == "cluster"]
|
||||
profile_calls = [c for c in log if c[0] == "profile"]
|
||||
assert len(atomic_calls) == 2, (
|
||||
f"Expected 2 atomic calls, got {len(atomic_calls)}: {log}"
|
||||
)
|
||||
assert len(cluster_calls) == 3, (
|
||||
f"Expected 3 cluster calls, got {len(cluster_calls)}: {log}"
|
||||
)
|
||||
assert len(profile_calls) == 1, (
|
||||
f"Expected 1 profile call, got {len(profile_calls)}: {log}"
|
||||
)
|
||||
assert len(atomic_runs) == 2
|
||||
assert len(cluster_runs) == 3
|
||||
assert len(profile_runs) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_strategy_executes_when_cron_entry_fires() -> None:
|
||||
"""Verify that the cron-trigger code path actually reaches the strategy.
|
||||
|
||||
APScheduler timing is mocked away — we directly call the module-level
|
||||
_cron_entry function that APS would invoke on schedule. This proves
|
||||
the registry/dispatcher/runner chain wires cron strategies correctly.
|
||||
"""
|
||||
seen: list[str] = []
|
||||
|
||||
@offline_strategy(name="cron_e2e", trigger=Cron(expr="0 * * * *"), emits=[])
|
||||
async def cron_job(event: CronTick, ctx: StrategyContext) -> None:
|
||||
seen.append(event.strategy_name)
|
||||
|
||||
async with StrategyTestHarness() as h:
|
||||
h.register(cron_job)
|
||||
await h.start()
|
||||
# Directly invoke what APS would call; bypass scheduler timing.
|
||||
await _cron_entry(h._engine._engine_id, "cron_e2e") # noqa: SLF001
|
||||
await h.drain(timeout=5)
|
||||
runs = await h.list_runs("cron_e2e")
|
||||
|
||||
assert seen == ["cron_e2e"]
|
||||
assert len(runs) == 1
|
||||
623
tests/unit/test_infra/test_ome/test_engine.py
Normal file
623
tests/unit/test_infra/test_ome/test_engine.py
Normal file
@ -0,0 +1,623 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.ome.config import OMEConfig
|
||||
from everos.infra.ome.context import StrategyContext
|
||||
from everos.infra.ome.decorator import offline_strategy
|
||||
from everos.infra.ome.engine import OfflineEngine
|
||||
from everos.infra.ome.events import BaseEvent
|
||||
from everos.infra.ome.exceptions import (
|
||||
EngineLockHeldError,
|
||||
OMEError,
|
||||
StartupValidationError,
|
||||
)
|
||||
from everos.infra.ome.records import RunStatus
|
||||
from everos.infra.ome.triggers import Cron, Idle, Immediate
|
||||
|
||||
|
||||
class _E(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
class _A(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
class _B(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(tmp_path: Path) -> OMEConfig:
|
||||
return OMEConfig(jobstore_path=tmp_path / "ome.db", config_watch=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_register_and_start(cfg: OMEConfig) -> None:
|
||||
@offline_strategy(name="s", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_register_after_start_raises(cfg: OMEConfig) -> None:
|
||||
engine = OfflineEngine(config=cfg)
|
||||
await engine.start()
|
||||
try:
|
||||
|
||||
@offline_strategy(name="s", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
with pytest.raises(OMEError):
|
||||
engine.register(s)
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_lock_prevents_double_open(cfg: OMEConfig) -> None:
|
||||
engine1 = OfflineEngine(config=cfg)
|
||||
await engine1.start()
|
||||
try:
|
||||
engine2 = OfflineEngine(config=cfg)
|
||||
with pytest.raises(EngineLockHeldError):
|
||||
await engine2.start()
|
||||
finally:
|
||||
await engine1.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_validates_dag_at_start(tmp_path: Path) -> None:
|
||||
cfg = OMEConfig(jobstore_path=tmp_path / "ome.db", config_watch=False)
|
||||
|
||||
@offline_strategy(name="s1", trigger=Immediate(on=[_A]), emits=[_B])
|
||||
async def _s1(e: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
@offline_strategy(name="s2", trigger=Immediate(on=[_B]), emits=[_A])
|
||||
async def _s2(e: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(_s1)
|
||||
engine.register(_s2)
|
||||
with pytest.raises(StartupValidationError, match=r"(?i)cycle"):
|
||||
await engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_emit_drives_strategy(cfg: OMEConfig) -> None:
|
||||
seen: list[_E] = []
|
||||
|
||||
@offline_strategy(name="collector", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
seen.append(event)
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_E())
|
||||
# Poll because APScheduler offers no completion signal; retry up to ~2.5s.
|
||||
for _ in range(50):
|
||||
if seen:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
finally:
|
||||
await engine.stop()
|
||||
assert len(seen) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_chain_emit_through_ctx(cfg: OMEConfig) -> None:
|
||||
seen_b: list = []
|
||||
|
||||
@offline_strategy(name="a_to_b", trigger=Immediate(on=[_A]), emits=[_B])
|
||||
async def s_a(event: _A, ctx: StrategyContext) -> None:
|
||||
await ctx.emit(_B())
|
||||
|
||||
@offline_strategy(name="b_collector", trigger=Immediate(on=[_B]), emits=[])
|
||||
async def s_b(event: _B, ctx: StrategyContext) -> None:
|
||||
seen_b.append(event)
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s_a)
|
||||
engine.register(s_b)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_A())
|
||||
for _ in range(50):
|
||||
if seen_b:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
finally:
|
||||
await engine.stop()
|
||||
assert len(seen_b) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_calling_engine_emit_directly_is_rejected(
|
||||
cfg: OMEConfig,
|
||||
) -> None:
|
||||
"""Strategy code must emit follow-up events through ctx.emit.
|
||||
|
||||
Calling engine.emit from inside a strategy raises
|
||||
EngineCallFromStrategyError (a StrategyContractError) so Runner
|
||||
short-circuits the retry budget and dead-letters on the very first
|
||||
attempt — re-running the same buggy code can't fix a programming bug.
|
||||
"""
|
||||
engine = OfflineEngine(config=cfg)
|
||||
|
||||
@offline_strategy(name="bad", trigger=Immediate(on=[_A]), emits=[_B])
|
||||
async def bad_strategy(event: _A, ctx: StrategyContext) -> None:
|
||||
# Captured engine reference is the common, intended pattern for
|
||||
# external triggers; using it from INSIDE a strategy is the
|
||||
# convention violation we want to catch.
|
||||
await engine.emit(_B())
|
||||
|
||||
engine.register(bad_strategy)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_A())
|
||||
for _ in range(50):
|
||||
runs = await engine.list_runs("bad")
|
||||
if runs and runs[0].status == RunStatus.DEAD_LETTER:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
runs = await engine.list_runs("bad")
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
assert runs, "expected at least one run record"
|
||||
# Permanent error → exactly one attempt, no retry.
|
||||
assert len(runs) == 1
|
||||
final = runs[0]
|
||||
assert final.status == RunStatus.DEAD_LETTER
|
||||
assert "EngineCallFromStrategyError" in (final.error or "")
|
||||
assert "emit" in (final.error or "")
|
||||
|
||||
|
||||
# Module-level singleton — proxies the "strategy reads engine via
|
||||
# globals/DI/import" pattern. Guard is contextvars-based so it catches
|
||||
# this path identically to the closure case.
|
||||
_MODULE_ENGINE: OfflineEngine | None = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_reaching_engine_via_module_global_is_rejected(
|
||||
cfg: OMEConfig,
|
||||
) -> None:
|
||||
"""The guard is contextvars-based: it doesn't matter how the strategy
|
||||
got the engine reference (closure, module singleton, DI container).
|
||||
"""
|
||||
global _MODULE_ENGINE
|
||||
_MODULE_ENGINE = OfflineEngine(config=cfg)
|
||||
|
||||
@offline_strategy(name="bad_global", trigger=Immediate(on=[_A]), emits=[_B])
|
||||
async def bad_strategy(event: _A, ctx: StrategyContext) -> None:
|
||||
assert _MODULE_ENGINE is not None
|
||||
await _MODULE_ENGINE.emit(_B())
|
||||
|
||||
_MODULE_ENGINE.register(bad_strategy)
|
||||
await _MODULE_ENGINE.start()
|
||||
try:
|
||||
await _MODULE_ENGINE.emit(_A())
|
||||
for _ in range(50):
|
||||
runs = await _MODULE_ENGINE.list_runs("bad_global")
|
||||
if runs and runs[0].status == RunStatus.DEAD_LETTER:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
runs = await _MODULE_ENGINE.list_runs("bad_global")
|
||||
finally:
|
||||
await _MODULE_ENGINE.stop()
|
||||
_MODULE_ENGINE = None
|
||||
|
||||
assert len(runs) == 1
|
||||
assert runs[0].status == RunStatus.DEAD_LETTER
|
||||
assert "EngineCallFromStrategyError" in (runs[0].error or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_calling_other_engine_methods_is_rejected(
|
||||
cfg: OMEConfig,
|
||||
) -> None:
|
||||
"""The guard covers every public engine method, not just emit —
|
||||
strategies must interact with the engine only via (event, ctx).
|
||||
"""
|
||||
engine = OfflineEngine(config=cfg)
|
||||
|
||||
@offline_strategy(name="bad_lookup", trigger=Immediate(on=[_A]), emits=[])
|
||||
async def bad_strategy(event: _A, ctx: StrategyContext) -> None:
|
||||
# trigger_manual is another public engine method that strategies
|
||||
# must not call directly.
|
||||
await engine.trigger_manual("bad_lookup")
|
||||
|
||||
engine.register(bad_strategy)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_A())
|
||||
for _ in range(50):
|
||||
runs = await engine.list_runs("bad_lookup")
|
||||
if runs and runs[0].status == RunStatus.DEAD_LETTER:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
runs = await engine.list_runs("bad_lookup")
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
assert len(runs) == 1
|
||||
assert runs[0].status == RunStatus.DEAD_LETTER
|
||||
assert "EngineCallFromStrategyError" in (runs[0].error or "")
|
||||
assert "trigger_manual" in (runs[0].error or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_manual_with_default_event_uses_manual_tick(
|
||||
cfg: OMEConfig,
|
||||
) -> None:
|
||||
seen: list = []
|
||||
|
||||
from everos.infra.ome.events import ManualTick
|
||||
|
||||
@offline_strategy(
|
||||
name="manual_only",
|
||||
trigger=Immediate(on=[ManualTick]),
|
||||
emits=[],
|
||||
)
|
||||
async def s(event: ManualTick, ctx: StrategyContext) -> None:
|
||||
seen.append(event)
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.trigger_manual("manual_only")
|
||||
for _ in range(50):
|
||||
if seen:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
finally:
|
||||
await engine.stop()
|
||||
assert len(seen) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_manual_force_bypasses_enabled(
|
||||
cfg: OMEConfig,
|
||||
) -> None:
|
||||
seen: list = []
|
||||
from everos.infra.ome.events import ManualTick
|
||||
|
||||
@offline_strategy(
|
||||
name="off",
|
||||
trigger=Immediate(on=[ManualTick]),
|
||||
emits=[],
|
||||
enabled=False,
|
||||
)
|
||||
async def s(event: ManualTick, ctx: StrategyContext) -> None:
|
||||
seen.append(event)
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.trigger_manual("off", force=True)
|
||||
for _ in range(50):
|
||||
if seen:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
finally:
|
||||
await engine.stop()
|
||||
assert len(seen) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_dead_letter_callback_invoked(cfg: OMEConfig) -> None:
|
||||
calls: list = []
|
||||
|
||||
@offline_strategy(
|
||||
name="bad_dl", trigger=Immediate(on=[_E]), emits=[], max_retries=0
|
||||
)
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
raise RuntimeError("always-fail")
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
engine.on_dead_letter(lambda rec: calls.append(rec.run_id))
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_E())
|
||||
for _ in range(50):
|
||||
if calls:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
finally:
|
||||
await engine.stop()
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_dispatch_returns_routes(cfg: OMEConfig) -> None:
|
||||
@offline_strategy(name="s_t24a", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
infos = await engine.inspect_dispatch(_E())
|
||||
assert len(infos) == 1
|
||||
assert infos[0].will_run is True
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_run_status_and_list(cfg: OMEConfig) -> None:
|
||||
@offline_strategy(name="s_t24b", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_E())
|
||||
# Poll because APScheduler offers no completion signal; up to ~2.5s.
|
||||
for _ in range(50):
|
||||
runs = await engine.list_runs("s_t24b")
|
||||
if runs and runs[0].status.value == "success":
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
runs = await engine.list_runs("s_t24b")
|
||||
assert len(runs) == 1
|
||||
rec = await engine.get_run_status(runs[0].run_id)
|
||||
assert rec is not None
|
||||
assert rec.status.value == "success"
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
class _EventWithUid(BaseEvent):
|
||||
user_id: str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reschedule_cron_job_updates_aps(cfg: OMEConfig) -> None:
|
||||
@offline_strategy(name="cron_s", trigger=Cron(expr="0 3 * * *"), emits=[])
|
||||
async def s(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
engine.reschedule_cron_job("cron_s", "*/5 * * * *")
|
||||
|
||||
job = engine._scheduler.get_job("cron::cron_s")
|
||||
assert isinstance(job.trigger, CronTrigger)
|
||||
# CronTrigger stores parsed crontab fields; minute step=5 means "*/5".
|
||||
minute_field = next(f for f in job.trigger.fields if f.name == "minute")
|
||||
assert str(minute_field) == "*/5"
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reschedule_idle_job_updates_interval(cfg: OMEConfig) -> None:
|
||||
@offline_strategy(
|
||||
name="idle_s",
|
||||
trigger=Idle(
|
||||
on=[_EventWithUid],
|
||||
event_field="user_id",
|
||||
idle_seconds=60,
|
||||
scan_interval_seconds=30,
|
||||
),
|
||||
emits=[],
|
||||
)
|
||||
async def s(event: Any, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
engine.reschedule_idle_job("idle_s", scan_interval_seconds=10)
|
||||
job = engine._scheduler.get_job("idle::idle_s")
|
||||
# IntervalTrigger.interval is a timedelta.
|
||||
assert job.trigger.interval.total_seconds() == 10
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
def test_reschedule_cron_job_before_start_raises(cfg: OMEConfig) -> None:
|
||||
engine = OfflineEngine(config=cfg)
|
||||
with pytest.raises(OMEError, match="engine not started"):
|
||||
engine.reschedule_cron_job("x", "* * * * *")
|
||||
|
||||
|
||||
def test_reschedule_idle_job_before_start_raises(cfg: OMEConfig) -> None:
|
||||
engine = OfflineEngine(config=cfg)
|
||||
with pytest.raises(OMEError, match="engine not started"):
|
||||
engine.reschedule_idle_job("x", scan_interval_seconds=30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_failure_cleans_up_engines_and_scheduler(
|
||||
cfg: OMEConfig, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A failure between scheduler start and ``_started = True`` must roll
|
||||
back: pop from the module-level ``_ENGINES`` registry, shut the
|
||||
scheduler thread down, and release the lock so a fresh ``OfflineEngine``
|
||||
can start on the same jobstore.
|
||||
"""
|
||||
from everos.infra.ome import engine as engine_mod
|
||||
|
||||
async def _boom(*args: Any, **kwargs: Any) -> None:
|
||||
raise RuntimeError("crash recovery exploded")
|
||||
|
||||
monkeypatch.setattr(engine_mod, "scan_and_resume", _boom)
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
with pytest.raises(RuntimeError, match="crash recovery exploded"):
|
||||
await engine.start()
|
||||
|
||||
assert engine._engine_id not in engine_mod._ENGINES
|
||||
assert engine._scheduler is None
|
||||
assert engine._started is False
|
||||
assert engine._lock_handle is None
|
||||
|
||||
monkeypatch.undo()
|
||||
engine2 = OfflineEngine(config=cfg)
|
||||
await engine2.start()
|
||||
await engine2.stop()
|
||||
|
||||
|
||||
# ── active_runs / wait_idle ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_idle_returns_true_when_no_runs(cfg: OMEConfig) -> None:
|
||||
"""Pre-emit idle: counter starts at 0, idle_event starts set."""
|
||||
engine = OfflineEngine(config=cfg)
|
||||
await engine.start()
|
||||
try:
|
||||
assert engine._active_runs == 0
|
||||
assert await engine.wait_idle(timeout=0.5) is True
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_idle_blocks_until_strategy_finishes(cfg: OMEConfig) -> None:
|
||||
"""A strategy mid-flight keeps active_runs > 0 and idle_event clear
|
||||
until it completes."""
|
||||
release = asyncio.Event()
|
||||
entered = asyncio.Event()
|
||||
|
||||
@offline_strategy(name="slow", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def slow(event: _E, ctx: StrategyContext) -> None:
|
||||
entered.set()
|
||||
await release.wait()
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(slow)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_E())
|
||||
await asyncio.wait_for(entered.wait(), timeout=2.0)
|
||||
# Strategy is now mid-flight.
|
||||
assert engine._active_runs >= 1
|
||||
assert await engine.wait_idle(timeout=0.2) is False
|
||||
# Release the strategy and verify wait_idle resolves.
|
||||
release.set()
|
||||
assert await engine.wait_idle(timeout=2.0) is True
|
||||
assert engine._active_runs == 0
|
||||
finally:
|
||||
release.set()
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_waits_for_in_flight_run_to_complete(cfg: OMEConfig) -> None:
|
||||
"""stop() must not cancel in-flight strategies. Pre-fix this used
|
||||
scheduler.shutdown(wait=True) which APS 3.x AsyncIOExecutor cancels
|
||||
silently; post-fix stop() drains through wait_idle first.
|
||||
"""
|
||||
completed: list[str] = []
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
@offline_strategy(name="slow_to_finish", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def slow(event: _E, ctx: StrategyContext) -> None:
|
||||
started.set()
|
||||
await release.wait()
|
||||
completed.append("done")
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(slow)
|
||||
await engine.start()
|
||||
await engine.emit(_E())
|
||||
await asyncio.wait_for(started.wait(), timeout=2.0)
|
||||
|
||||
# Stop concurrently with the in-flight strategy; release it after a
|
||||
# tick so stop() has to actually wait.
|
||||
stop_task = asyncio.create_task(engine.stop())
|
||||
await asyncio.sleep(0.05)
|
||||
assert not stop_task.done()
|
||||
release.set()
|
||||
await asyncio.wait_for(stop_task, timeout=5.0)
|
||||
assert completed == ["done"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_runs_decrements_on_strategy_exception(cfg: OMEConfig) -> None:
|
||||
"""A strategy that raises (and exhausts retries → DEAD_LETTER) must
|
||||
still release its counter — dispatch_run's finally guarantees -1.
|
||||
"""
|
||||
|
||||
@offline_strategy(name="boom", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def boom(event: _E, ctx: StrategyContext) -> None:
|
||||
raise RuntimeError("strategy boom")
|
||||
|
||||
cfg2 = OMEConfig(
|
||||
jobstore_path=cfg.jobstore_path,
|
||||
config_watch=False,
|
||||
max_retries=0,
|
||||
)
|
||||
engine = OfflineEngine(config=cfg2)
|
||||
engine.register(boom)
|
||||
await engine.start()
|
||||
try:
|
||||
await engine.emit(_E())
|
||||
assert await engine.wait_idle(timeout=2.0) is True
|
||||
runs = await engine.list_runs("boom")
|
||||
assert len(runs) == 1
|
||||
assert runs[0].status == RunStatus.DEAD_LETTER
|
||||
assert engine._active_runs == 0
|
||||
finally:
|
||||
await engine.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rolls_back_counter_on_add_job_failure(
|
||||
cfg: OMEConfig, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If APScheduler ``add_job`` raises, the matching dispatch_run never
|
||||
runs — _enqueue_run must roll back the pre-emptive +1 itself.
|
||||
"""
|
||||
|
||||
@offline_strategy(name="s", trigger=Immediate(on=[_E]), emits=[])
|
||||
async def s(event: _E, ctx: StrategyContext) -> None:
|
||||
return None
|
||||
|
||||
engine = OfflineEngine(config=cfg)
|
||||
engine.register(s)
|
||||
await engine.start()
|
||||
try:
|
||||
|
||||
def _boom(*args: Any, **kwargs: Any) -> None:
|
||||
raise RuntimeError("add_job exploded")
|
||||
|
||||
monkeypatch.setattr(engine._scheduler, "add_job", _boom)
|
||||
with pytest.raises(RuntimeError, match="add_job exploded"):
|
||||
await engine.emit(_E())
|
||||
assert engine._active_runs == 0
|
||||
assert await engine.wait_idle(timeout=0.5) is True
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
await engine.stop()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user