chore: initialize EverOS 1.0.0
md-first memory extraction framework for AI agents. Markdown is the single source of truth; SQLite holds state and LanceDB provides the rebuildable vector + BM25 + scalar index. The codebase follows a single-direction DDD layering (entrypoints -> service -> memory -> infra, with component / core / config cross-cutting) enforced by import-linter. Engineering surface: - Coding conventions in .claude/rules/ (path-scoped) and workflows in .claude/skills/ (/commit, /new-branch, /pr). - GitHub Actions CI runs make lint + test + integration; pre-commit mirrors the gates locally (ruff, hygiene hooks, gitlint commit-msg). - Commit messages follow Conventional Commits, enforced by gitlint. - make lint also enforces datetime two-zone discipline and OpenAPI drift.
This commit is contained in:
0
tests/unit/test_memory/__init__.py
Normal file
0
tests/unit/test_memory/__init__.py
Normal file
0
tests/unit/test_memory/test_cascade/__init__.py
Normal file
0
tests/unit/test_memory/test_cascade/__init__.py
Normal file
331
tests/unit/test_memory/test_cascade/test_handler_agent_skill.py
Normal file
331
tests/unit/test_memory/test_cascade/test_handler_agent_skill.py
Normal file
@ -0,0 +1,331 @@
|
||||
"""Tests for :class:`AgentSkillHandler` — whole-file skill reconcile.
|
||||
|
||||
Skill is the only kind that doesn't go through ``BaseDailyLogHandler``:
|
||||
no entries, no per-entry diff. The digest is ``content_sha256`` over
|
||||
the whole skill (name + description + body + references_content +
|
||||
confidence + maturity_score); the handler reads ``SKILL.md`` + every
|
||||
``references/*.md`` sibling and upserts one row per skill. These
|
||||
tests build the directory layout on disk and verify the resulting
|
||||
row + the delete path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.persistence.lancedb import AgentSkill
|
||||
from everos.infra.persistence.markdown import AgentSkillWriter
|
||||
from everos.memory.cascade.handlers import AgentSkillHandler, HandlerDeps
|
||||
|
||||
|
||||
class _StubTokenizer(Tokenizer):
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [tok for tok in text.split() if tok]
|
||||
|
||||
def tokenize_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(t) for t in texts]
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [await self.embed(t) for t in texts]
|
||||
|
||||
|
||||
class _FakeSkillRepo:
|
||||
def __init__(self) -> None:
|
||||
self.rows: dict[str, AgentSkill] = {}
|
||||
self.upserts: list[list[AgentSkill]] = []
|
||||
self.deletes: list[str] = []
|
||||
self.predicate_deletes: list[str] = []
|
||||
|
||||
async def get_by_id(self, row_id: str) -> AgentSkill | None:
|
||||
return self.rows.get(row_id)
|
||||
|
||||
async def upsert(self, rows: list[AgentSkill]) -> None:
|
||||
self.upserts.append(list(rows))
|
||||
for row in rows:
|
||||
self.rows[row.id] = row
|
||||
|
||||
async def delete_by_md_path(self, md_path: str) -> int:
|
||||
self.deletes.append(md_path)
|
||||
return 1
|
||||
|
||||
async def find_where(self, predicate: str, *, limit: int) -> list[AgentSkill]:
|
||||
"""In-memory equivalent — handles only the
|
||||
``md_path = '...' AND id != '...'`` shape the handler emits."""
|
||||
if "md_path = " in predicate and "id != " in predicate:
|
||||
md_lit = predicate.split("md_path = '")[1].split("'", 1)[0]
|
||||
id_lit = predicate.split("id != '")[1].split("'", 1)[0]
|
||||
return [
|
||||
r for r in self.rows.values() if r.md_path == md_lit and r.id != id_lit
|
||||
][:limit]
|
||||
raise NotImplementedError(f"fake repo doesn't handle {predicate!r}")
|
||||
|
||||
async def delete(self, predicate: str) -> None:
|
||||
self.predicate_deletes.append(predicate)
|
||||
if "md_path = " in predicate and "id != " in predicate:
|
||||
md_lit = predicate.split("md_path = '")[1].split("'", 1)[0]
|
||||
id_lit = predicate.split("id != '")[1].split("'", 1)[0]
|
||||
self.rows = {
|
||||
rid: row
|
||||
for rid, row in self.rows.items()
|
||||
if not (row.md_path == md_lit and row.id != id_lit)
|
||||
}
|
||||
return
|
||||
raise NotImplementedError(f"fake repo doesn't handle {predicate!r}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_repo(monkeypatch: pytest.MonkeyPatch) -> _FakeSkillRepo:
|
||||
"""Patch the module-level repo the handler references."""
|
||||
from everos.memory.cascade.handlers import agent_skill as skill_mod
|
||||
|
||||
repo = _FakeSkillRepo()
|
||||
monkeypatch.setattr(skill_mod, "agent_skill_repo", repo)
|
||||
return repo
|
||||
|
||||
|
||||
async def _write_skill(
|
||||
memory_root: MemoryRoot, agent_id: str, name: str, *, body: str
|
||||
) -> str:
|
||||
"""Create a SKILL.md via the real writer, return the relative md_path."""
|
||||
from everos.infra.persistence.markdown import AgentSkillFrontmatter
|
||||
|
||||
writer = AgentSkillWriter(memory_root)
|
||||
fm = AgentSkillFrontmatter(
|
||||
id=f"skill_{name}",
|
||||
agent_id=agent_id,
|
||||
name=name,
|
||||
description="Scan a contract draft for risk clauses.",
|
||||
confidence=0.8,
|
||||
maturity_score=0.6,
|
||||
source_case_ids=["ac_1", "ac_2"],
|
||||
)
|
||||
await writer.write_main(agent_id, name, frontmatter=fm, body=body)
|
||||
return f"default_app/default_project/agents/{agent_id}/skills/skill_{name}/SKILL.md"
|
||||
|
||||
|
||||
async def test_handle_added_or_modified_upserts_typed_row(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
md_path = await _write_skill(
|
||||
memory_root, "a1", "contract_scan", body="step one\nstep two\n"
|
||||
)
|
||||
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.deleted == 0
|
||||
row = fake_repo.upserts[0][0]
|
||||
assert row.id == "a1_contract_scan"
|
||||
assert row.owner_id == "a1"
|
||||
assert row.owner_type == "agent"
|
||||
assert row.name == "contract_scan"
|
||||
assert row.description.startswith("Scan a contract draft")
|
||||
assert row.description_tokens.startswith("Scan a contract draft")
|
||||
assert row.confidence == pytest.approx(0.8)
|
||||
assert row.maturity_score == pytest.approx(0.6)
|
||||
assert row.source_case_ids == ["ac_1", "ac_2"]
|
||||
assert row.md_path == md_path
|
||||
assert len(row.vector) == 1024
|
||||
# Body content lands in the ``content`` column.
|
||||
assert "step one" in row.content
|
||||
assert "step two" in row.content
|
||||
|
||||
|
||||
async def test_references_md_concatenated_into_content(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
"""references/*.md siblings are appended to ``content`` deterministically."""
|
||||
md_path = await _write_skill(memory_root, "a1", "skill_x", body="main body text")
|
||||
# Drop two reference files into the skill dir.
|
||||
refs_dir = (
|
||||
memory_root.root
|
||||
/ "default_app/default_project/agents/a1/skills/skill_skill_x/references"
|
||||
)
|
||||
refs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(refs_dir / "b.md").write_text("reference B content\n", encoding="utf-8")
|
||||
(refs_dir / "a.md").write_text("reference A content\n", encoding="utf-8")
|
||||
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
content = fake_repo.upserts[0][0].content
|
||||
|
||||
# Body comes first, references sorted by filename (a.md then b.md).
|
||||
assert content.index("main body text") < content.index("reference A content")
|
||||
assert content.index("reference A content") < content.index("reference B content")
|
||||
|
||||
|
||||
async def test_renaming_skill_via_frontmatter_clears_old_row(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
"""User edits SKILL.md frontmatter.name; the LanceDB row id changes.
|
||||
|
||||
skill_id is derived from ``frontmatter.name`` (``<owner_id>_<name>``).
|
||||
When the user edits the name in place — common when refining a skill
|
||||
title without moving the file — the new id differs from the old, so
|
||||
a plain ``upsert([new_row])`` would leave the old row behind and a
|
||||
subsequent search would return both. The handler must sweep the
|
||||
stale row by ``md_path = ? AND id != new_id`` before the upsert.
|
||||
"""
|
||||
# First pass: write the original SKILL.md and let cascade index it.
|
||||
md_path = await _write_skill(memory_root, "a1", "old_name", body="step one\n")
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
assert fake_repo.rows == {"a1_old_name": fake_repo.rows["a1_old_name"]}
|
||||
|
||||
# Second pass: simulate the user editing frontmatter.name in place
|
||||
# (md_path unchanged, only the name field flips).
|
||||
absolute = memory_root.root / md_path
|
||||
text = absolute.read_text(encoding="utf-8")
|
||||
absolute.write_text(text.replace("name: old_name", "name: new_name"))
|
||||
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.deleted == 1
|
||||
# Old id is gone, new id is present, exactly one row survives.
|
||||
assert list(fake_repo.rows.keys()) == ["a1_new_name"]
|
||||
# The sweep predicate references the *new* id with the same md_path.
|
||||
assert fake_repo.predicate_deletes == [
|
||||
f"md_path = '{md_path}' AND id != 'a1_new_name'"
|
||||
]
|
||||
|
||||
|
||||
async def test_first_create_does_not_call_orphan_sweep(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
"""First write of a SKILL.md issues an upsert but no orphan delete.
|
||||
|
||||
The sweep clause only kicks in when there's a prior row at the same
|
||||
md_path under a different id (the rename case). For a fresh skill
|
||||
we should not bother LanceDB with an empty delete predicate either.
|
||||
"""
|
||||
md_path = await _write_skill(memory_root, "a1", "fresh_skill", body="x")
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.deleted == 0
|
||||
# The handler does call find_where on first-pass (prior is None),
|
||||
# but the empty result short-circuits the delete.
|
||||
assert fake_repo.predicate_deletes == []
|
||||
|
||||
|
||||
async def test_content_edit_skips_orphan_lookup(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
"""When the name is unchanged (prior row exists under the same id),
|
||||
the handler must not pay for the orphan find — there can't be any.
|
||||
"""
|
||||
md_path = await _write_skill(memory_root, "a1", "stable_name", body="v1\n")
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
|
||||
# Edit the body so digest drifts (forces upsert path, not skip).
|
||||
absolute = memory_root.root / md_path
|
||||
absolute.write_text(
|
||||
absolute.read_text(encoding="utf-8").replace("v1", "v2"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.deleted == 0
|
||||
# Same id, no orphan sweep issued.
|
||||
assert fake_repo.predicate_deletes == []
|
||||
|
||||
|
||||
async def test_handle_deleted_calls_delete_by_md_path(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
outcome = await handler.handle_deleted("agents/a1/skills/skill_x/SKILL.md")
|
||||
assert outcome.deleted == 1
|
||||
assert outcome.upserted == 0
|
||||
assert fake_repo.deletes == ["agents/a1/skills/skill_x/SKILL.md"]
|
||||
|
||||
|
||||
async def test_missing_name_raises(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeSkillRepo
|
||||
) -> None:
|
||||
"""A SKILL.md whose frontmatter lacks ``name`` surfaces as ValueError."""
|
||||
# Hand-write a malformed SKILL.md (no `name`).
|
||||
skill_dir = memory_root.root / "agents/a1/skills/skill_broken"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"id: skill_broken\n"
|
||||
"type: agent_skill\n"
|
||||
"agent_id: a1\n"
|
||||
"track: agent\n"
|
||||
"description: x\n"
|
||||
"confidence: 0.5\n"
|
||||
"maturity_score: 0.5\n"
|
||||
"---\nbody\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
handler = AgentSkillHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
with pytest.raises(ValueError, match="name"):
|
||||
await handler.handle_added_or_modified("agents/a1/skills/skill_broken/SKILL.md")
|
||||
260
tests/unit/test_memory/test_cascade/test_handler_episode.py
Normal file
260
tests/unit/test_memory/test_cascade/test_handler_episode.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""Tests for :class:`EpisodeHandler` — md → LanceDB row reconcile.
|
||||
|
||||
Uses a real on-disk md file (via :class:`EpisodeWriter`) to exercise
|
||||
the parse → diff → upsert path. The lancedb repo is faked since the
|
||||
production singleton would need a live LanceDB connection; this keeps
|
||||
the test in-memory while still validating row construction and the
|
||||
3-way diff branch behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.persistence.lancedb import Episode
|
||||
from everos.infra.persistence.markdown import EpisodeWriter
|
||||
from everos.memory.cascade.handlers import HandlerDeps
|
||||
from everos.memory.cascade.handlers.episode import EpisodeHandler
|
||||
|
||||
|
||||
class _StubTokenizer(Tokenizer):
|
||||
"""Returns the input split on whitespace — deterministic for assertions."""
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [tok for tok in text.split() if tok]
|
||||
|
||||
def tokenize_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(t) for t in texts]
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
"""Returns a fixed 1024-dim vector; records call count."""
|
||||
|
||||
dim = 1024
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
self.calls += 1
|
||||
return [0.1] * self.dim
|
||||
|
||||
async def embed_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [await self.embed(t) for t in texts]
|
||||
|
||||
|
||||
class _FakeEpisodeRepo:
|
||||
"""Recording repo — captures upserts / deletes the handler issues."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.upserts: list[list[Episode]] = []
|
||||
self.deletes: list[str] = []
|
||||
self.rows: list[Episode] = []
|
||||
|
||||
async def find_where(self, where: str, *, limit: int = 100) -> list[Episode]:
|
||||
# Honour only the md_path = '...' filter the handler emits.
|
||||
prefix = "md_path = '"
|
||||
if where.startswith(prefix):
|
||||
md_path = where[len(prefix) :].rstrip("'")
|
||||
return [r for r in self.rows if r.md_path == md_path]
|
||||
return []
|
||||
|
||||
async def upsert(self, rows: list[Episode]) -> None:
|
||||
self.upserts.append(list(rows))
|
||||
# Reflect into ``self.rows`` so a follow-up find_where sees the state.
|
||||
by_id = {r.id: r for r in self.rows}
|
||||
for r in rows:
|
||||
by_id[r.id] = r
|
||||
self.rows = list(by_id.values())
|
||||
|
||||
async def delete(self, predicate: str) -> None:
|
||||
self.deletes.append(predicate)
|
||||
|
||||
async def delete_by_md_path(self, md_path: str) -> int:
|
||||
before = len(self.rows)
|
||||
self.rows = [r for r in self.rows if r.md_path != md_path]
|
||||
return before - len(self.rows)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_repo(monkeypatch: pytest.MonkeyPatch) -> _FakeEpisodeRepo:
|
||||
"""Swap the class-level ``lance_repo`` on EpisodeHandler.
|
||||
|
||||
After the BaseDailyLogHandler refactor, the repo binding is a
|
||||
ClassVar resolved at class-definition time; patching the module
|
||||
attribute would no longer reach the handler's call sites.
|
||||
"""
|
||||
from everos.memory.cascade.handlers.episode import EpisodeHandler
|
||||
|
||||
repo = _FakeEpisodeRepo()
|
||||
monkeypatch.setattr(EpisodeHandler, "lance_repo", repo)
|
||||
return repo
|
||||
|
||||
|
||||
async def _write_one_entry(writer: EpisodeWriter, owner_id: str, body: str) -> str:
|
||||
"""Append a single episode entry, return the md path (relative)."""
|
||||
today = _dt.date(2026, 5, 14)
|
||||
await writer.append_entry(
|
||||
owner_id,
|
||||
inline={
|
||||
"owner_id": owner_id,
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_type": "memcell",
|
||||
"parent_id": "mc_test_parent",
|
||||
"sender_ids": [owner_id],
|
||||
},
|
||||
sections={"Subject": "Test", "Summary": "Stub", "Content": body},
|
||||
date=today,
|
||||
)
|
||||
return (
|
||||
f"default_app/default_project/users/{owner_id}/episodes/episode-2026-05-14.md"
|
||||
)
|
||||
|
||||
|
||||
def _build_handler(
|
||||
memory_root: MemoryRoot,
|
||||
) -> tuple[EpisodeHandler, _StubEmbedder]:
|
||||
embedder = _StubEmbedder()
|
||||
deps = HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=embedder,
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
return EpisodeHandler(deps), embedder
|
||||
|
||||
|
||||
# ── happy path ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_added_entry_upserts_typed_row(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeEpisodeRepo
|
||||
) -> None:
|
||||
writer = EpisodeWriter(memory_root)
|
||||
rel = await _write_one_entry(writer, "u1", "hello world")
|
||||
|
||||
handler, embedder = _build_handler(memory_root)
|
||||
outcome = await handler.handle_added_or_modified(rel)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.deleted == 0
|
||||
assert outcome.skipped == 0
|
||||
assert embedder.calls == 1
|
||||
assert len(fake_repo.upserts) == 1
|
||||
row = fake_repo.upserts[0][0]
|
||||
assert row.owner_id == "u1"
|
||||
assert row.owner_type == "user"
|
||||
# Scope is parsed back from the md path's <app>/<project> prefix.
|
||||
assert row.app_id == "default"
|
||||
assert row.project_id == "default"
|
||||
assert row.session_id == "s1"
|
||||
assert row.parent_id == "mc_test_parent"
|
||||
assert row.parent_type == "memcell"
|
||||
assert row.episode == "hello world"
|
||||
assert row.episode_tokens == "hello world"
|
||||
assert row.subject == "Test"
|
||||
assert row.md_path == rel
|
||||
assert row.entry_id.startswith("ep_")
|
||||
assert row.id == f"u1_{row.entry_id}"
|
||||
assert len(row.vector) == 1024
|
||||
|
||||
|
||||
async def test_unchanged_entry_is_skipped_no_embed_call(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeEpisodeRepo
|
||||
) -> None:
|
||||
"""Second handle run with no md change → skipped + no embed call."""
|
||||
writer = EpisodeWriter(memory_root)
|
||||
rel = await _write_one_entry(writer, "u1", "hello world")
|
||||
|
||||
handler, embedder = _build_handler(memory_root)
|
||||
await handler.handle_added_or_modified(rel) # first pass populates fake repo
|
||||
fake_repo.upserts.clear()
|
||||
embedder.calls = 0
|
||||
|
||||
outcome = await handler.handle_added_or_modified(rel)
|
||||
assert outcome.skipped == 1
|
||||
assert outcome.upserted == 0
|
||||
assert embedder.calls == 0
|
||||
assert fake_repo.upserts == []
|
||||
|
||||
|
||||
async def test_modified_entry_reembeds(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeEpisodeRepo
|
||||
) -> None:
|
||||
"""Changing the entry body bumps the sha → re-embed + upsert."""
|
||||
writer = EpisodeWriter(memory_root)
|
||||
rel = await _write_one_entry(writer, "u1", "original content")
|
||||
|
||||
handler, embedder = _build_handler(memory_root)
|
||||
await handler.handle_added_or_modified(rel)
|
||||
# Tamper with the row's stored sha so the next pass sees a mismatch.
|
||||
fake_repo.rows[0] = fake_repo.rows[0].model_copy(
|
||||
update={"content_sha256": "0" * 64}
|
||||
)
|
||||
fake_repo.upserts.clear()
|
||||
embedder.calls = 0
|
||||
|
||||
outcome = await handler.handle_added_or_modified(rel)
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.skipped == 0
|
||||
assert embedder.calls == 1
|
||||
|
||||
|
||||
# ── deletion paths ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_handle_deleted_wipes_md_path_rows(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeEpisodeRepo
|
||||
) -> None:
|
||||
writer = EpisodeWriter(memory_root)
|
||||
rel = await _write_one_entry(writer, "u1", "hello")
|
||||
handler, _embedder = _build_handler(memory_root)
|
||||
await handler.handle_added_or_modified(rel)
|
||||
assert fake_repo.rows # populated
|
||||
|
||||
outcome = await handler.handle_deleted(rel)
|
||||
assert outcome.deleted == 1
|
||||
assert fake_repo.rows == []
|
||||
|
||||
|
||||
# ── error path ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_missing_timestamp_raises_value_error(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeEpisodeRepo
|
||||
) -> None:
|
||||
"""Malformed inline surfaces as ValueError — worker treats unrecoverable."""
|
||||
writer = EpisodeWriter(memory_root)
|
||||
# Manually bypass the writer to drop timestamp.
|
||||
today = _dt.date(2026, 5, 14)
|
||||
await writer.append_entry(
|
||||
"u1",
|
||||
inline={"owner_id": "u1", "session_id": "s1"}, # no timestamp
|
||||
sections={"Content": "x"},
|
||||
date=today,
|
||||
)
|
||||
rel = "default_app/default_project/users/u1/episodes/episode-2026-05-14.md"
|
||||
|
||||
handler, _embedder = _build_handler(memory_root)
|
||||
with pytest.raises(ValueError, match="timestamp"):
|
||||
await handler.handle_added_or_modified(rel)
|
||||
|
||||
|
||||
# ── unused noqa suppressor (keep imports tidy) ──────────────────────────
|
||||
|
||||
|
||||
_: Any = None
|
||||
260
tests/unit/test_memory/test_cascade/test_handler_user_profile.py
Normal file
260
tests/unit/test_memory/test_cascade/test_handler_user_profile.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""Tests for :class:`UserProfileHandler` — single-file profile reconcile.
|
||||
|
||||
UserProfile is the second single-file kind (after AgentSkill) — one
|
||||
``users/<user_id>/user.md`` per user, replaced wholesale on edit. The
|
||||
handler upserts one row per profile and skips when the
|
||||
content-bearing digest (summary + JSON buckets) is unchanged. These
|
||||
tests verify the upsert / skip path, the JSON encoding of
|
||||
``explicit_info`` / ``implicit_traits``, and the missing-user_id guard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.persistence.lancedb import UserProfile
|
||||
from everos.infra.persistence.markdown import ProfileWriter, UserProfileFrontmatter
|
||||
from everos.memory.cascade.handlers import HandlerDeps, UserProfileHandler
|
||||
|
||||
|
||||
class _StubTokenizer(Tokenizer):
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [tok for tok in text.split() if tok]
|
||||
|
||||
def tokenize_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(t) for t in texts]
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
"""Profile handler does not embed; the stub stays as a no-op so the
|
||||
shared :class:`HandlerDeps` shape is satisfied."""
|
||||
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, text: str) -> list[float]: # pragma: no cover
|
||||
raise AssertionError("UserProfileHandler must not call the embedder")
|
||||
|
||||
async def embed_batch( # pragma: no cover
|
||||
self,
|
||||
texts, # type: ignore[no-untyped-def]
|
||||
):
|
||||
raise AssertionError("UserProfileHandler must not call the embedder")
|
||||
|
||||
|
||||
class _FakeProfileRepo:
|
||||
def __init__(self) -> None:
|
||||
self.rows: dict[str, UserProfile] = {}
|
||||
self.upserts: list[list[UserProfile]] = []
|
||||
self.deletes: list[str] = []
|
||||
|
||||
async def get_by_id(self, row_id: str) -> UserProfile | None:
|
||||
return self.rows.get(row_id)
|
||||
|
||||
async def upsert(self, rows: list[UserProfile]) -> None:
|
||||
self.upserts.append(list(rows))
|
||||
for row in rows:
|
||||
self.rows[row.id] = row
|
||||
|
||||
async def delete_by_md_path(self, md_path: str) -> int:
|
||||
self.deletes.append(md_path)
|
||||
before = len(self.rows)
|
||||
self.rows = {rid: r for rid, r in self.rows.items() if r.md_path != md_path}
|
||||
return before - len(self.rows)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_root(tmp_path: Path) -> MemoryRoot:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return mr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_repo(monkeypatch: pytest.MonkeyPatch) -> _FakeProfileRepo:
|
||||
from everos.memory.cascade.handlers import user_profile as up_mod
|
||||
|
||||
repo = _FakeProfileRepo()
|
||||
monkeypatch.setattr(up_mod, "user_profile_repo", repo)
|
||||
return repo
|
||||
|
||||
|
||||
async def _write_profile(
|
||||
memory_root: MemoryRoot,
|
||||
user_id: str,
|
||||
*,
|
||||
summary: str,
|
||||
explicit_info: list,
|
||||
implicit_traits: list,
|
||||
profile_timestamp_ms: int = 1_700_000_000_000,
|
||||
) -> str:
|
||||
writer = ProfileWriter(memory_root)
|
||||
fm = UserProfileFrontmatter(
|
||||
id=f"user_profile_{user_id}",
|
||||
user_id=user_id,
|
||||
summary=summary,
|
||||
explicit_info=explicit_info,
|
||||
implicit_traits=implicit_traits,
|
||||
profile_timestamp_ms=profile_timestamp_ms,
|
||||
)
|
||||
await writer.write(user_id, frontmatter=fm, body="display text")
|
||||
return f"default_app/default_project/users/{user_id}/user.md"
|
||||
|
||||
|
||||
def _handler(memory_root: MemoryRoot) -> UserProfileHandler:
|
||||
return UserProfileHandler(
|
||||
HandlerDeps(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_first_pass_upserts_typed_row(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
md_path = await _write_profile(
|
||||
memory_root,
|
||||
"u_alice",
|
||||
summary="Alice likes long hikes and prefers oat milk.",
|
||||
explicit_info=[{"fact": "lives in tokyo"}, "renew passport"],
|
||||
implicit_traits=[{"trait": "introverted"}],
|
||||
)
|
||||
outcome = await _handler(memory_root).handle_added_or_modified(md_path)
|
||||
|
||||
assert outcome.upserted == 1
|
||||
assert outcome.skipped == 0
|
||||
row = fake_repo.upserts[0][0]
|
||||
assert row.id == "u_alice"
|
||||
assert row.owner_id == "u_alice"
|
||||
assert row.owner_type == "user"
|
||||
assert row.summary.startswith("Alice")
|
||||
assert row.md_path == md_path
|
||||
# Heterogeneous buckets land as canonical JSON strings.
|
||||
assert json.loads(row.explicit_info_json) == [
|
||||
{"fact": "lives in tokyo"},
|
||||
"renew passport",
|
||||
]
|
||||
assert json.loads(row.implicit_traits_json) == [{"trait": "introverted"}]
|
||||
assert row.profile_timestamp_ms == 1_700_000_000_000
|
||||
|
||||
|
||||
async def test_second_pass_with_same_content_skips(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
md_path = await _write_profile(
|
||||
memory_root,
|
||||
"u_alice",
|
||||
summary="Stable summary.",
|
||||
explicit_info=["a"],
|
||||
implicit_traits=["b"],
|
||||
)
|
||||
handler = _handler(memory_root)
|
||||
first = await handler.handle_added_or_modified(md_path)
|
||||
assert first.upserted == 1
|
||||
|
||||
# Re-run with no edits — digest matches, handler must skip.
|
||||
second = await handler.handle_added_or_modified(md_path)
|
||||
assert second.upserted == 0
|
||||
assert second.skipped == 1
|
||||
# Only the first pass touched the repo.
|
||||
assert len(fake_repo.upserts) == 1
|
||||
|
||||
|
||||
async def test_timestamp_only_drift_skips(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
"""Re-synthesis bumps ``profile_timestamp_ms`` even when the content
|
||||
is byte-identical; the digest excludes the timestamp so cascade
|
||||
skips re-upsert and avoids a wasted index write."""
|
||||
md_path = await _write_profile(
|
||||
memory_root,
|
||||
"u_alice",
|
||||
summary="Same summary.",
|
||||
explicit_info=["x"],
|
||||
implicit_traits=["y"],
|
||||
profile_timestamp_ms=1_700_000_000_000,
|
||||
)
|
||||
handler = _handler(memory_root)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
|
||||
# Bump only profile_timestamp_ms.
|
||||
absolute = memory_root.root / md_path
|
||||
absolute.write_text(
|
||||
absolute.read_text(encoding="utf-8").replace("1700000000000", "1800000000000"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
assert outcome.upserted == 0
|
||||
assert outcome.skipped == 1
|
||||
|
||||
|
||||
async def test_summary_edit_triggers_upsert(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
md_path = await _write_profile(
|
||||
memory_root,
|
||||
"u_alice",
|
||||
summary="Original summary.",
|
||||
explicit_info=[],
|
||||
implicit_traits=[],
|
||||
)
|
||||
handler = _handler(memory_root)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
assert len(fake_repo.upserts) == 1
|
||||
|
||||
absolute = memory_root.root / md_path
|
||||
absolute.write_text(
|
||||
absolute.read_text(encoding="utf-8").replace(
|
||||
"Original summary.", "New shiny summary."
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
outcome = await handler.handle_added_or_modified(md_path)
|
||||
assert outcome.upserted == 1
|
||||
assert fake_repo.upserts[1][0].summary == "New shiny summary."
|
||||
|
||||
|
||||
async def test_missing_user_id_raises(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
bad_dir = memory_root.root / "users" / "u_x"
|
||||
bad_dir.mkdir(parents=True, exist_ok=True)
|
||||
(bad_dir / "user.md").write_text(
|
||||
"---\n"
|
||||
"id: user_profile_u_x\n"
|
||||
"type: user_profile\n"
|
||||
"track: user\n"
|
||||
"summary: x\n"
|
||||
"---\nbody\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="user_id"):
|
||||
await _handler(memory_root).handle_added_or_modified("users/u_x/user.md")
|
||||
|
||||
|
||||
async def test_handle_deleted_drops_row(
|
||||
memory_root: MemoryRoot, fake_repo: _FakeProfileRepo
|
||||
) -> None:
|
||||
md_path = await _write_profile(
|
||||
memory_root,
|
||||
"u_alice",
|
||||
summary="bye",
|
||||
explicit_info=[],
|
||||
implicit_traits=[],
|
||||
)
|
||||
handler = _handler(memory_root)
|
||||
await handler.handle_added_or_modified(md_path)
|
||||
assert "u_alice" in fake_repo.rows
|
||||
|
||||
outcome = await handler.handle_deleted(md_path)
|
||||
assert outcome.deleted == 1
|
||||
assert fake_repo.deletes == [md_path]
|
||||
assert "u_alice" not in fake_repo.rows
|
||||
@ -0,0 +1,261 @@
|
||||
"""Per-kind ``_build_row`` mapping for the 3 non-Episode daily-log handlers.
|
||||
|
||||
The diff loop (read → sha256 → 3-way diff → upsert/delete) lives on
|
||||
:class:`BaseDailyLogHandler` and is exercised by
|
||||
``test_handler_episode.py``. These tests focus on the kind-specific
|
||||
:meth:`_build_row` mapping — given a synthesised ``ParsedEntry``, do
|
||||
the right LanceDB columns get populated?
|
||||
|
||||
Each kind gets one happy-path test (all fields present) plus a
|
||||
focused error-path test (missing required inline field). Sharing one
|
||||
file avoids 3 nearly-identical fixture stacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.core.persistence import MemoryRoot, StructuredEntry
|
||||
from everos.memory.cascade.handlers import (
|
||||
AgentCaseHandler,
|
||||
AtomicFactHandler,
|
||||
ForesightHandler,
|
||||
HandlerDeps,
|
||||
)
|
||||
from everos.memory.cascade.handlers._daily_log_base import ParsedEntry
|
||||
|
||||
|
||||
class _StubTokenizer(Tokenizer):
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [tok for tok in text.split() if tok]
|
||||
|
||||
def tokenize_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(t) for t in texts]
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [await self.embed(t) for t in texts]
|
||||
|
||||
|
||||
def _deps(tmp_path) -> HandlerDeps: # type: ignore[no-untyped-def]
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
return HandlerDeps(
|
||||
memory_root=mr,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
|
||||
|
||||
def _entry(
|
||||
entry_id: str,
|
||||
inline: dict[str, str],
|
||||
sections: dict[str, str],
|
||||
*,
|
||||
sha: str = "f" * 64,
|
||||
) -> ParsedEntry:
|
||||
return ParsedEntry(
|
||||
entry_id=entry_id,
|
||||
structured=StructuredEntry(
|
||||
id=entry_id,
|
||||
body="",
|
||||
start=0,
|
||||
end=0,
|
||||
header=None,
|
||||
inline=inline,
|
||||
sections=sections,
|
||||
),
|
||||
content_sha256=sha,
|
||||
)
|
||||
|
||||
|
||||
# ── AtomicFact ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_atomic_fact_build_row_maps_inline_and_section(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = AtomicFactHandler(_deps(tmp_path))
|
||||
row = await handler._build_row(
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
md_path="users/u1/.atomic_facts/atomic_fact-2026-05-14.md",
|
||||
entry=_entry(
|
||||
"af_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"sender_ids": "[u1, u2]",
|
||||
},
|
||||
sections={"Fact": "the user prefers dark mode"},
|
||||
),
|
||||
)
|
||||
assert row.id == "u1_af_20260514_0001"
|
||||
assert row.fact == "the user prefers dark mode"
|
||||
assert row.fact_tokens == "the user prefers dark mode"
|
||||
assert row.parent_id == "mc_1"
|
||||
assert row.sender_ids == ["u1", "u2"]
|
||||
assert row.timestamp == _dt.datetime(2026, 5, 14, 10, 0, tzinfo=_dt.UTC)
|
||||
assert row.md_path.endswith("atomic_fact-2026-05-14.md")
|
||||
assert len(row.vector) == 1024
|
||||
|
||||
|
||||
async def test_atomic_fact_missing_timestamp_raises(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = AtomicFactHandler(_deps(tmp_path))
|
||||
with pytest.raises(ValueError, match="timestamp"):
|
||||
await handler._build_row(
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
md_path="x.md",
|
||||
entry=_entry(
|
||||
"af_20260514_0001",
|
||||
inline={"owner_id": "u1", "session_id": "s1"},
|
||||
sections={"Fact": "x"},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ── Foresight ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_foresight_build_row_with_evidence(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = ForesightHandler(_deps(tmp_path))
|
||||
row = await handler._build_row(
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
md_path="users/u1/.foresights/foresight-2026-05-14.md",
|
||||
entry=_entry(
|
||||
"fs_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"start_time": "2026-05-14T11:00:00+00:00",
|
||||
"end_time": "2026-05-14T13:00:00+00:00",
|
||||
"duration_days": "2",
|
||||
},
|
||||
sections={
|
||||
"Foresight": "user will book lunch",
|
||||
"Evidence": "calendar invite mentions 12pm",
|
||||
},
|
||||
),
|
||||
)
|
||||
assert row.foresight == "user will book lunch"
|
||||
assert row.foresight_tokens == "user will book lunch"
|
||||
assert row.evidence == "calendar invite mentions 12pm"
|
||||
assert row.evidence_tokens == "calendar invite mentions 12pm"
|
||||
assert row.start_time == _dt.datetime(2026, 5, 14, 11, 0, tzinfo=_dt.UTC)
|
||||
assert row.end_time == _dt.datetime(2026, 5, 14, 13, 0, tzinfo=_dt.UTC)
|
||||
assert row.duration_days == 2
|
||||
|
||||
|
||||
async def test_foresight_optional_evidence_left_none(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = ForesightHandler(_deps(tmp_path))
|
||||
row = await handler._build_row(
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
md_path="x.md",
|
||||
entry=_entry(
|
||||
"fs_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "u1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
},
|
||||
sections={"Foresight": "user will book lunch"},
|
||||
),
|
||||
)
|
||||
assert row.evidence is None
|
||||
assert row.evidence_tokens is None
|
||||
assert row.start_time is None
|
||||
assert row.end_time is None
|
||||
assert row.duration_days is None
|
||||
|
||||
|
||||
# ── AgentCase ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agent_case_build_row_maps_intent_approach_insight(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = AgentCaseHandler(_deps(tmp_path))
|
||||
row = await handler._build_row(
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
md_path="agents/a1/.cases/agent_case-2026-05-14.md",
|
||||
entry=_entry(
|
||||
"ac_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "a1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"quality_score": "0.87",
|
||||
},
|
||||
sections={
|
||||
"TaskIntent": "scan contract for risk clauses",
|
||||
"Approach": "1. read pages 1-5; 2. flag indemnity",
|
||||
"KeyInsight": "indemnity cap missing",
|
||||
},
|
||||
),
|
||||
)
|
||||
assert row.task_intent == "scan contract for risk clauses"
|
||||
assert row.task_intent_tokens == "scan contract for risk clauses"
|
||||
assert row.approach.startswith("1. read pages")
|
||||
assert row.key_insight == "indemnity cap missing"
|
||||
assert row.quality_score == pytest.approx(0.87)
|
||||
assert row.owner_type == "agent"
|
||||
|
||||
|
||||
async def test_agent_case_optional_insight_left_none(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = AgentCaseHandler(_deps(tmp_path))
|
||||
row = await handler._build_row(
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
md_path="x.md",
|
||||
entry=_entry(
|
||||
"ac_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "a1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
"quality_score": "0.5",
|
||||
},
|
||||
sections={
|
||||
"TaskIntent": "x",
|
||||
"Approach": "y",
|
||||
},
|
||||
),
|
||||
)
|
||||
assert row.key_insight is None
|
||||
|
||||
|
||||
async def test_agent_case_missing_quality_score_raises(tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
handler = AgentCaseHandler(_deps(tmp_path))
|
||||
with pytest.raises(ValueError, match="quality_score"):
|
||||
await handler._build_row(
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
md_path="x.md",
|
||||
entry=_entry(
|
||||
"ac_20260514_0001",
|
||||
inline={
|
||||
"owner_id": "a1",
|
||||
"session_id": "s1",
|
||||
"timestamp": "2026-05-14T10:00:00+00:00",
|
||||
"parent_id": "mc_1",
|
||||
},
|
||||
sections={"TaskIntent": "x", "Approach": "y"},
|
||||
),
|
||||
)
|
||||
106
tests/unit/test_memory/test_cascade/test_orchestrator.py
Normal file
106
tests/unit/test_memory/test_cascade/test_orchestrator.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""``CascadeOrchestrator`` — idempotent start/stop, queue_summary forwards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import build_tokenizer
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.persistence.lancedb import (
|
||||
dispose_connection,
|
||||
ensure_business_indexes,
|
||||
)
|
||||
from everos.infra.persistence.sqlite import dispose_engine, get_engine
|
||||
from everos.memory.cascade import CascadeConfig, CascadeOrchestrator
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [[0.0] * self.dim for _ in texts]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def runtime(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> AsyncIterator[MemoryRoot]:
|
||||
"""Boot sqlite + lancedb against a tmp memory_root."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__MODEL", "stub-model")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__BASE_URL", "http://stub.invalid/v1")
|
||||
monkeypatch.setenv("EVEROS_EMBEDDING__API_KEY", "stub-key")
|
||||
|
||||
await dispose_connection()
|
||||
await dispose_engine()
|
||||
engine = get_engine()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await ensure_business_indexes()
|
||||
yield MemoryRoot.default()
|
||||
await dispose_connection()
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
def _make_orchestrator(memory_root: MemoryRoot) -> CascadeOrchestrator:
|
||||
return CascadeOrchestrator(
|
||||
memory_root=memory_root,
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=build_tokenizer(),
|
||||
config=CascadeConfig(
|
||||
scan_interval_seconds=60.0,
|
||||
worker_batch_size=10,
|
||||
worker_max_retry=1,
|
||||
worker_poll_interval_seconds=0.05,
|
||||
worker_retry_backoff_seconds=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_double_start_is_idempotent(runtime: MemoryRoot) -> None:
|
||||
"""Calling start twice does not relaunch tasks."""
|
||||
orch = _make_orchestrator(runtime)
|
||||
await orch.start()
|
||||
# Capture watcher identity to verify the second start doesn't replace it.
|
||||
first_watcher = orch._watcher
|
||||
await orch.start()
|
||||
assert orch._watcher is first_watcher
|
||||
await orch.stop()
|
||||
|
||||
|
||||
async def test_stop_before_start_is_noop(runtime: MemoryRoot) -> None:
|
||||
orch = _make_orchestrator(runtime)
|
||||
await orch.stop() # must not raise; nothing to do
|
||||
|
||||
|
||||
async def test_double_stop_is_idempotent(runtime: MemoryRoot) -> None:
|
||||
orch = _make_orchestrator(runtime)
|
||||
await orch.start()
|
||||
await orch.stop()
|
||||
await orch.stop() # second stop is a no-op
|
||||
|
||||
|
||||
async def test_queue_summary_returns_empty_on_fresh_runtime(
|
||||
runtime: MemoryRoot,
|
||||
) -> None:
|
||||
orch = _make_orchestrator(runtime)
|
||||
summary = await orch.queue_summary()
|
||||
assert summary.pending == 0
|
||||
assert summary.done == 0
|
||||
assert summary.failed_retryable == 0
|
||||
assert summary.failed_permanent == 0
|
||||
|
||||
|
||||
async def test_drain_once_returns_zero_on_empty_queue(
|
||||
runtime: MemoryRoot,
|
||||
) -> None:
|
||||
orch = _make_orchestrator(runtime)
|
||||
assert await orch.drain_once() == 0
|
||||
137
tests/unit/test_memory/test_cascade/test_reconciler.py
Normal file
137
tests/unit/test_memory/test_cascade/test_reconciler.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Tests for :func:`reconcile` — pure scan vs state diff.
|
||||
|
||||
The reconciler is pure (no IO), so each scenario is just a few
|
||||
dataclass instances in / decisions out. Covers the 4 cases:
|
||||
``added`` / ``modified`` / ``deleted`` / ``no-op``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from everos.memory.cascade.reconciler import PriorState, reconcile
|
||||
from everos.memory.cascade.types import ScanInput
|
||||
|
||||
|
||||
def _scan(path: str, mtime: float = 1.0, kind: str = "episode") -> ScanInput:
|
||||
return ScanInput(md_path=path, mtime=mtime, kind=kind)
|
||||
|
||||
|
||||
def _state(
|
||||
path: str,
|
||||
*,
|
||||
mtime: float = 1.0,
|
||||
kind: str = "episode",
|
||||
status: str = "done",
|
||||
change_type: str = "modified",
|
||||
) -> PriorState:
|
||||
return PriorState(
|
||||
md_path=path,
|
||||
kind=kind,
|
||||
mtime=mtime,
|
||||
status=status,
|
||||
change_type=change_type,
|
||||
)
|
||||
|
||||
|
||||
def test_added_path_emits_added_decision() -> None:
|
||||
decisions = reconcile([_scan("a.md")], state={})
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "added")]
|
||||
|
||||
|
||||
def test_modified_mtime_emits_modified_decision() -> None:
|
||||
decisions = reconcile(
|
||||
[_scan("a.md", mtime=2.0)],
|
||||
state={"a.md": _state("a.md", mtime=1.0)},
|
||||
)
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "modified")]
|
||||
|
||||
|
||||
def test_done_state_with_matching_mtime_is_skipped() -> None:
|
||||
"""Quiet sweeps must stay quiet — no upsert churn."""
|
||||
decisions = reconcile(
|
||||
[_scan("a.md", mtime=1.0)],
|
||||
state={"a.md": _state("a.md", mtime=1.0, status="done")},
|
||||
)
|
||||
assert decisions == []
|
||||
|
||||
|
||||
def test_pending_state_with_matching_mtime_still_emits_modified() -> None:
|
||||
"""Pending / failed states are NOT terminal — re-emit so worker re-runs."""
|
||||
decisions = reconcile(
|
||||
[_scan("a.md", mtime=1.0)],
|
||||
state={"a.md": _state("a.md", mtime=1.0, status="pending")},
|
||||
)
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "modified")]
|
||||
|
||||
|
||||
def test_deleted_path_emits_deleted_decision() -> None:
|
||||
decisions = reconcile(
|
||||
[],
|
||||
state={"a.md": _state("a.md", status="pending")},
|
||||
)
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "deleted")]
|
||||
|
||||
|
||||
def test_deleted_path_already_done_as_delete_is_skipped() -> None:
|
||||
"""A done row that is itself a successful delete cycle — don't re-emit."""
|
||||
decisions = reconcile(
|
||||
[],
|
||||
state={
|
||||
"a.md": _state("a.md", status="done", change_type="deleted"),
|
||||
},
|
||||
)
|
||||
assert decisions == []
|
||||
|
||||
|
||||
def test_done_added_row_with_missing_path_is_recovered_as_deleted() -> None:
|
||||
"""Watcher missed an unlink (e.g. fseventsd drop / daemon restart).
|
||||
|
||||
The state row is ``status='done'`` from the previous add cycle, but
|
||||
the file is gone from disk. The scanner MUST re-emit a 'deleted'
|
||||
decision — otherwise LanceDB keeps stale rows for the orphan path
|
||||
until something else triggers an enqueue.
|
||||
"""
|
||||
decisions = reconcile(
|
||||
[],
|
||||
state={
|
||||
"a.md": _state("a.md", status="done", change_type="added"),
|
||||
},
|
||||
)
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "deleted")]
|
||||
|
||||
|
||||
def test_done_modified_row_with_missing_path_is_recovered_as_deleted() -> None:
|
||||
"""Same as the added variant, but the prior cycle was a modification."""
|
||||
decisions = reconcile(
|
||||
[],
|
||||
state={
|
||||
"a.md": _state("a.md", status="done", change_type="modified"),
|
||||
},
|
||||
)
|
||||
assert [(d.md_path, d.change_type) for d in decisions] == [("a.md", "deleted")]
|
||||
|
||||
|
||||
def test_mixed_scenario_preserves_order() -> None:
|
||||
decisions = reconcile(
|
||||
[
|
||||
_scan("new.md"),
|
||||
_scan("changed.md", mtime=2.0),
|
||||
_scan("unchanged.md", mtime=1.0),
|
||||
],
|
||||
state={
|
||||
"changed.md": _state(
|
||||
"changed.md", mtime=1.0, status="done", change_type="modified"
|
||||
),
|
||||
"unchanged.md": _state(
|
||||
"unchanged.md", mtime=1.0, status="done", change_type="modified"
|
||||
),
|
||||
"gone.md": _state("gone.md", status="pending", change_type="modified"),
|
||||
},
|
||||
)
|
||||
by_path = {d.md_path: d.change_type for d in decisions}
|
||||
assert by_path == {
|
||||
"new.md": "added",
|
||||
"changed.md": "modified",
|
||||
"gone.md": "deleted",
|
||||
}
|
||||
# Order: added/modified in scan order, deleted at the tail.
|
||||
assert decisions[-1].md_path == "gone.md"
|
||||
83
tests/unit/test_memory/test_cascade/test_registry.py
Normal file
83
tests/unit/test_memory/test_cascade/test_registry.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Tests for the cascade kind registry.
|
||||
|
||||
Verify the 5 registered kinds' globs match the right paths and reject
|
||||
noise (random ``.md``, swp files, profile-style paths). ``match_kind``
|
||||
must walk the registry in declared order and pick the first matching
|
||||
spec.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.cascade import KIND_REGISTRY, match_kind
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path", "expected_kind"),
|
||||
[
|
||||
(
|
||||
"default_app/default_project/users/u1/episodes/episode-2026-05-14.md",
|
||||
"episode",
|
||||
),
|
||||
("claude_code/oss/users/u_jason/episodes/episode-2026-01-01.md", "episode"),
|
||||
(
|
||||
"default_app/default_project/users/u1/.atomic_facts/atomic_fact-2026-05-14.md",
|
||||
"atomic_fact",
|
||||
),
|
||||
(
|
||||
"default_app/default_project/users/u1/.foresights/foresight-2026-05-14.md",
|
||||
"foresight",
|
||||
),
|
||||
(
|
||||
"default_app/default_project/agents/a1/.cases/agent_case-2026-05-14.md",
|
||||
"agent_case",
|
||||
),
|
||||
(
|
||||
"default_app/default_project/agents/a1/skills/skill_contract_risk_scan/SKILL.md",
|
||||
"agent_skill",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_kind_recognises_registered_paths(path: str, expected_kind: str) -> None:
|
||||
spec = match_kind(path)
|
||||
assert spec is not None
|
||||
assert spec.name == expected_kind
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"users/u1/profile/user.md",
|
||||
"users/u1/random.md",
|
||||
"users/u1/episodes/draft.txt", # wrong extension
|
||||
".cache/foo.md",
|
||||
"users/u1/episodes/episode-2026-05-14.md.swp", # swap file
|
||||
"agents/a1/skills/skill_x/references/notes.md", # reference, not main
|
||||
# Valid episode shape but MISSING the <app>/<project> prefix — must be
|
||||
# rejected so a prefix-less path can never silently match (the scanner
|
||||
# would otherwise find nothing while the watcher matched, a split brain).
|
||||
"users/u1/episodes/episode-2026-05-14.md",
|
||||
],
|
||||
)
|
||||
def test_match_kind_rejects_unregistered_paths(path: str) -> None:
|
||||
assert match_kind(path) is None
|
||||
|
||||
|
||||
def test_registry_has_exactly_six_kinds() -> None:
|
||||
"""The registry pins the cascade surface — no silent registration."""
|
||||
names = [s.name for s in KIND_REGISTRY]
|
||||
assert names == [
|
||||
"episode",
|
||||
"atomic_fact",
|
||||
"foresight",
|
||||
"agent_case",
|
||||
"agent_skill",
|
||||
"user_profile",
|
||||
]
|
||||
|
||||
|
||||
def test_kind_spec_path_glob_reads_off_schema() -> None:
|
||||
"""Path glob is owned by the frontmatter schema, not duplicated here."""
|
||||
for spec in KIND_REGISTRY:
|
||||
assert spec.path_glob() == spec.frontmatter_schema.path_glob()
|
||||
127
tests/unit/test_memory/test_cascade/test_scanner_unit.py
Normal file
127
tests/unit/test_memory/test_cascade/test_scanner_unit.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Unit tests for :class:`CascadeScanner` lifecycle + ``_collect_scan_inputs``.
|
||||
|
||||
The reconcile-against-state flow is integration territory; this file
|
||||
covers the no-real-DB-needed pieces: idempotent start/stop and the
|
||||
sync-thread walker's resilience to broken files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.memory.cascade.scanner import CascadeScanner, _collect_scan_inputs
|
||||
|
||||
|
||||
async def test_double_start_is_idempotent(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
scanner = CascadeScanner(mr, scan_interval_seconds=60.0)
|
||||
await scanner.start()
|
||||
first_task = scanner._task
|
||||
await scanner.start() # second start: no-op
|
||||
assert scanner._task is first_task
|
||||
await scanner.stop()
|
||||
|
||||
|
||||
async def test_stop_before_start_is_noop(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
scanner = CascadeScanner(mr, scan_interval_seconds=60.0)
|
||||
await scanner.stop() # must not raise
|
||||
|
||||
|
||||
async def test_double_stop_is_idempotent(tmp_path: Path) -> None:
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
scanner = CascadeScanner(mr, scan_interval_seconds=60.0)
|
||||
await scanner.start()
|
||||
await scanner.stop()
|
||||
await scanner.stop() # second stop: no-op
|
||||
|
||||
|
||||
def test_collect_scan_inputs_skips_dangling_symlinks(tmp_path: Path) -> None:
|
||||
"""A symlink whose target was deleted yields ``stat`` OSError → skipped."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
# Build a real .md under a registered kind path (with the <app>/<project>
|
||||
# scope prefix the glob requires), then add a broken symlink next to it to
|
||||
# exercise the OSError branch.
|
||||
user_dir = (
|
||||
tmp_path / "default_app" / "default_project" / "users" / "u1" / "episodes"
|
||||
)
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
real = user_dir / "episode-2026-01-01.md"
|
||||
real.write_text("ok")
|
||||
broken = user_dir / "episode-2026-01-02.md"
|
||||
target = tmp_path / "deleted-target"
|
||||
target.write_text("temp")
|
||||
broken.symlink_to(target)
|
||||
target.unlink() # Now ``broken`` is a dangling symlink.
|
||||
|
||||
inputs = _collect_scan_inputs(tmp_path)
|
||||
paths = {i.md_path for i in inputs}
|
||||
assert real.relative_to(tmp_path).as_posix() in paths
|
||||
# Dangling symlink was silently skipped.
|
||||
assert broken.relative_to(tmp_path).as_posix() not in paths
|
||||
|
||||
|
||||
def test_collect_scan_inputs_raises_on_transient_oserror(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Non-ENOENT stat errors (EMFILE / EACCES / EIO) must propagate.
|
||||
|
||||
Regression guard for the 2026-05-28 incident where FD exhaustion
|
||||
during a scan made every healthy md look "deleted" to reconcile().
|
||||
The fix in ``_collect_scan_inputs`` swallows only ``FileNotFoundError``
|
||||
and re-raises any other ``OSError`` so the reconciler never sees a
|
||||
partial scan.
|
||||
"""
|
||||
user_dir = (
|
||||
tmp_path / "default_app" / "default_project" / "users" / "u1" / "episodes"
|
||||
)
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
real = user_dir / "episode-2026-01-01.md"
|
||||
real.write_text("ok")
|
||||
|
||||
real_stat = Path.stat
|
||||
|
||||
def boom_stat(self: Path, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
# Only fail on the .md file — let glob / directory walks succeed.
|
||||
if self.suffix == ".md":
|
||||
raise OSError(24, "Too many open files")
|
||||
return real_stat(self, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(Path, "stat", boom_stat)
|
||||
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
_collect_scan_inputs(tmp_path)
|
||||
# errno 24 = EMFILE on every POSIX system we care about.
|
||||
assert exc_info.value.errno == 24
|
||||
|
||||
|
||||
async def test_run_loop_swallows_scan_exception(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A failure in ``scan_once`` is logged but the loop keeps going."""
|
||||
mr = MemoryRoot(tmp_path)
|
||||
mr.ensure()
|
||||
scanner = CascadeScanner(mr, scan_interval_seconds=0.05)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_scan() -> list: # type: ignore[type-arg]
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise RuntimeError("simulated scanner failure")
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(scanner, "scan_once", fake_scan)
|
||||
await scanner.start()
|
||||
# Let the loop iterate at least twice (interval is 50ms).
|
||||
await asyncio.sleep(0.2)
|
||||
await scanner.stop()
|
||||
assert call_count["n"] >= 2 # second call ran despite first throwing
|
||||
36
tests/unit/test_memory/test_cascade/test_watcher_helpers.py
Normal file
36
tests/unit/test_memory/test_cascade/test_watcher_helpers.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Unit tests for the pure helpers in :mod:`everos.memory.cascade.watcher`.
|
||||
|
||||
The :class:`CascadeWatcher` itself needs a running event loop + real
|
||||
filesystem to test end-to-end (see ``tests/integration/``). The pure
|
||||
helpers can be exercised in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from everos.memory.cascade.watcher import _relative_to_root, _safe_mtime
|
||||
|
||||
|
||||
def test_relative_to_root_within(tmp_path: Path) -> None:
|
||||
target = tmp_path / "users" / "u1" / "x.md"
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text("x")
|
||||
assert _relative_to_root(tmp_path, str(target)) == "users/u1/x.md"
|
||||
|
||||
|
||||
def test_relative_to_root_outside(tmp_path: Path) -> None:
|
||||
"""A path outside the memory root returns ``None``."""
|
||||
outside = tmp_path.parent / "completely-different" / "y.md"
|
||||
assert _relative_to_root(tmp_path, str(outside)) is None
|
||||
|
||||
|
||||
def test_safe_mtime_missing_path_returns_zero(tmp_path: Path) -> None:
|
||||
missing = tmp_path / "does-not-exist.md"
|
||||
assert _safe_mtime(str(missing)) == 0.0
|
||||
|
||||
|
||||
def test_safe_mtime_existing_path_returns_positive(tmp_path: Path) -> None:
|
||||
f = tmp_path / "f.md"
|
||||
f.write_text("ok")
|
||||
assert _safe_mtime(str(f)) > 0
|
||||
573
tests/unit/test_memory/test_cascade/test_worker.py
Normal file
573
tests/unit/test_memory/test_cascade/test_worker.py
Normal file
@ -0,0 +1,573 @@
|
||||
"""Tests for :class:`CascadeWorker` retry classification + optimize scheduler.
|
||||
|
||||
The pure-function pieces (registry / reconciler) get coverage in
|
||||
their own files. Here we focus on the worker's branch behaviour
|
||||
without touching the real handler / lancedb stack:
|
||||
|
||||
- ``RecoverableError`` retries up to ``max_retry`` and then marks
|
||||
``retryable=TRUE``.
|
||||
- Any other exception marks ``retryable=FALSE`` immediately.
|
||||
- Successful handler ⇒ ``mark_done``.
|
||||
- Unknown kind ⇒ ``mark_failed(retryable=False)``.
|
||||
|
||||
A second group covers the per-kind throttle + trailing-edge
|
||||
optimize scheduler that fires LanceDB ``optimize()`` outside the
|
||||
drain loop — coalescing under burst writes, re-running when dirty
|
||||
is re-raised mid-optimize, and flushing on drain-until-empty / stop.
|
||||
|
||||
The repo singleton is monkey-patched onto a recording fake so the
|
||||
test stays in-memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import time
|
||||
import unittest.mock as mock
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.cascade.errors import RecoverableError, UnrecoverableError
|
||||
from everos.memory.cascade.handlers import Handler, HandlerDeps
|
||||
from everos.memory.cascade.types import HandlerOutcome
|
||||
from everos.memory.cascade.worker import CascadeWorker
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Row:
|
||||
"""Minimal MdChangeState shape the worker reads off."""
|
||||
|
||||
md_path: str
|
||||
kind: str = "episode"
|
||||
change_type: str = "added"
|
||||
retry_count: int = 0
|
||||
|
||||
|
||||
class _FakeRepo:
|
||||
"""Records every state-machine transition the worker drives."""
|
||||
|
||||
def __init__(self, batch: list[_Row]) -> None:
|
||||
self.batch = list(batch)
|
||||
self.done: list[str] = []
|
||||
self.failed: list[tuple[str, bool, str, int]] = []
|
||||
|
||||
async def claim_pending_batch(self, _limit: int) -> list[_Row]:
|
||||
items, self.batch = self.batch, []
|
||||
return items
|
||||
|
||||
async def mark_done(self, md_path: str) -> None:
|
||||
self.done.append(md_path)
|
||||
|
||||
async def mark_failed(
|
||||
self,
|
||||
md_path: str,
|
||||
*,
|
||||
retryable: bool,
|
||||
error: str,
|
||||
new_retry_count: int,
|
||||
) -> None:
|
||||
self.failed.append((md_path, retryable, error, new_retry_count))
|
||||
|
||||
|
||||
class _OkHandler(Handler):
|
||||
def __init__(self) -> None: # noqa: D401 — no deps needed
|
||||
pass
|
||||
|
||||
async def handle_added_or_modified(self, md_path: str) -> HandlerOutcome:
|
||||
return HandlerOutcome(
|
||||
md_path=md_path, kind="episode", upserted=1, deleted=0, skipped=0
|
||||
)
|
||||
|
||||
async def handle_deleted(self, md_path: str) -> HandlerOutcome:
|
||||
return HandlerOutcome(
|
||||
md_path=md_path, kind="episode", upserted=0, deleted=1, skipped=0
|
||||
)
|
||||
|
||||
|
||||
class _RecoverableHandler(_OkHandler):
|
||||
"""Always raises RecoverableError."""
|
||||
|
||||
async def handle_added_or_modified(self, md_path: str) -> HandlerOutcome:
|
||||
raise RecoverableError("embedding 503")
|
||||
|
||||
|
||||
class _UnrecoverableHandler(_OkHandler):
|
||||
async def handle_added_or_modified(self, md_path: str) -> HandlerOutcome:
|
||||
raise UnrecoverableError("YAML parse error")
|
||||
|
||||
|
||||
class _BareExceptionHandler(_OkHandler):
|
||||
async def handle_added_or_modified(self, md_path: str) -> HandlerOutcome:
|
||||
raise RuntimeError("unexpected boom")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_repo(monkeypatch: pytest.MonkeyPatch) -> _FakeRepo:
|
||||
"""Drop a fake repo onto the module the worker imports."""
|
||||
from everos.memory.cascade import worker as worker_mod
|
||||
|
||||
repo = _FakeRepo(batch=[])
|
||||
monkeypatch.setattr(worker_mod, "md_change_state_repo", repo)
|
||||
return repo
|
||||
|
||||
|
||||
async def test_ok_handler_marks_done(patched_repo: _FakeRepo) -> None:
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker({"episode": _OkHandler()}, retry_backoff_seconds=0)
|
||||
await w.drain_once()
|
||||
assert patched_repo.done == ["a.md"]
|
||||
assert patched_repo.failed == []
|
||||
|
||||
|
||||
async def test_recoverable_handler_marks_retryable_after_max_retry(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker(
|
||||
{"episode": _RecoverableHandler()}, max_retry=2, retry_backoff_seconds=0
|
||||
)
|
||||
await w.drain_once()
|
||||
assert patched_repo.done == []
|
||||
assert len(patched_repo.failed) == 1
|
||||
path, retryable, _err, retry_count = patched_repo.failed[0]
|
||||
assert path == "a.md"
|
||||
assert retryable is True
|
||||
assert retry_count == 2 # 2 retries after the initial attempt
|
||||
|
||||
|
||||
async def test_unrecoverable_handler_marks_permanent(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker({"episode": _UnrecoverableHandler()}, retry_backoff_seconds=0)
|
||||
await w.drain_once()
|
||||
_path, retryable, err, _retry = patched_repo.failed[0]
|
||||
assert retryable is False
|
||||
assert "UnrecoverableError" in err or "YAML parse error" in err
|
||||
|
||||
|
||||
async def test_bare_exception_marked_permanent(patched_repo: _FakeRepo) -> None:
|
||||
"""Anything that isn't RecoverableError counts as unrecoverable."""
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker({"episode": _BareExceptionHandler()}, retry_backoff_seconds=0)
|
||||
await w.drain_once()
|
||||
_path, retryable, _err, _retry = patched_repo.failed[0]
|
||||
assert retryable is False
|
||||
|
||||
|
||||
async def test_unknown_kind_marks_permanent_without_handler(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
patched_repo.batch = [_Row(md_path="a.md", kind="mystery")]
|
||||
w = CascadeWorker({"episode": _OkHandler()}, retry_backoff_seconds=0)
|
||||
await w.drain_once()
|
||||
assert patched_repo.failed[0][1] is False
|
||||
assert "no handler" in patched_repo.failed[0][2]
|
||||
|
||||
|
||||
async def test_drain_until_empty_loops_until_no_batch(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""Worker keeps draining until claim returns an empty list."""
|
||||
|
||||
rows = [_Row(md_path=f"a{i}.md") for i in range(3)]
|
||||
|
||||
class _ChunkedRepo(_FakeRepo):
|
||||
async def claim_pending_batch(self, _limit: int) -> list[_Row]:
|
||||
if not self.batch:
|
||||
return []
|
||||
head, self.batch = self.batch[:1], self.batch[1:]
|
||||
return head
|
||||
|
||||
chunked = _ChunkedRepo(rows)
|
||||
from everos.memory.cascade import worker as worker_mod
|
||||
|
||||
with mock.patch.object(worker_mod, "md_change_state_repo", chunked):
|
||||
w = CascadeWorker({"episode": _OkHandler()}, retry_backoff_seconds=0)
|
||||
total = await w.drain_until_empty()
|
||||
assert total == 3
|
||||
assert len(chunked.done) == 3
|
||||
|
||||
|
||||
def test_worker_handler_deps_construct_with_real_classes() -> None:
|
||||
"""Sanity: HandlerDeps accepts the real provider Protocols."""
|
||||
# No instantiation needed — just verifies the dataclass shape.
|
||||
assert {"memory_root", "embedder", "tokenizer"} == {
|
||||
f.name for f in HandlerDeps.__dataclass_fields__.values()
|
||||
}
|
||||
|
||||
|
||||
# ── Optimize scheduler tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeLanceRepo:
|
||||
"""Records every optimize() / rebuild_indexes() call.
|
||||
|
||||
``optimize_delay`` / ``rebuild_delay`` simulate slow operations.
|
||||
``rebuild_raises`` makes ``rebuild_indexes`` raise (for crash-safety tests).
|
||||
Each ``optimize`` call's ``cleanup_older_than`` is preserved so
|
||||
prune-cadence tests can assert which calls took the heavy path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
optimize_delay: float = 0.0,
|
||||
rebuild_delay: float = 0.0,
|
||||
rebuild_raises: bool = False,
|
||||
) -> None:
|
||||
self.optimize_calls: list[float] = []
|
||||
self.optimize_cleanup_args: list[dt.timedelta | None] = []
|
||||
self.rebuild_calls: list[float] = []
|
||||
self.optimize_delay = optimize_delay
|
||||
self.rebuild_delay = rebuild_delay
|
||||
self.rebuild_raises = rebuild_raises
|
||||
|
||||
async def optimize(self, *, cleanup_older_than: dt.timedelta | None = None) -> None:
|
||||
if self.optimize_delay > 0:
|
||||
await asyncio.sleep(self.optimize_delay)
|
||||
self.optimize_calls.append(time.monotonic())
|
||||
self.optimize_cleanup_args.append(cleanup_older_than)
|
||||
|
||||
async def rebuild_indexes(self) -> None:
|
||||
if self.rebuild_delay > 0:
|
||||
await asyncio.sleep(self.rebuild_delay)
|
||||
if self.rebuild_raises:
|
||||
raise RuntimeError("rebuild boom")
|
||||
self.rebuild_calls.append(time.monotonic())
|
||||
|
||||
|
||||
class _OkHandlerWithRepo(_OkHandler):
|
||||
"""OK handler exposing a fake ``lance_repo`` for scheduler tests."""
|
||||
|
||||
def __init__(self, repo: _FakeLanceRepo) -> None:
|
||||
super().__init__()
|
||||
self.lance_repo = repo
|
||||
|
||||
|
||||
async def test_schedule_optimize_noop_when_handler_has_no_lance_repo(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""Test stubs without ``lance_repo`` should not even register state."""
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandler()},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.05,
|
||||
)
|
||||
w._schedule_optimize("episode")
|
||||
assert "episode" not in w._optimizer_states
|
||||
|
||||
|
||||
async def test_schedule_optimize_collapses_burst_within_throttle_window(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""A burst of synchronous schedules creates at most one in-flight task.
|
||||
|
||||
The first call starts the optimize; subsequent calls during the
|
||||
same window only flip ``dirty``. With no time advance between
|
||||
schedules, the runner sees ``dirty=False`` after the first run
|
||||
and exits — total optimize() calls collapse to one.
|
||||
"""
|
||||
fake = _FakeLanceRepo()
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.05,
|
||||
)
|
||||
for _ in range(10):
|
||||
w._schedule_optimize("episode")
|
||||
await w._flush_optimizers()
|
||||
assert fake.optimize_calls, "expected at least one optimize"
|
||||
assert len(fake.optimize_calls) == 1, (
|
||||
f"burst should collapse, got {len(fake.optimize_calls)} calls"
|
||||
)
|
||||
|
||||
|
||||
async def test_schedule_optimize_reruns_when_dirty_set_during_optimize(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""A write that lands mid-optimize re-raises ``dirty`` and triggers a re-run.
|
||||
|
||||
Uses an artificially slow optimize so the second schedule fires
|
||||
while the first run is still in flight. Trailing-edge semantics
|
||||
guarantee the second run happens after the throttle interval.
|
||||
"""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.05)
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
)
|
||||
w._schedule_optimize("episode")
|
||||
await asyncio.sleep(0.01) # ensure first task is mid-optimize
|
||||
w._schedule_optimize("episode")
|
||||
await w._flush_optimizers()
|
||||
assert len(fake.optimize_calls) == 2
|
||||
|
||||
|
||||
async def test_concurrent_schedules_keep_one_task_per_kind(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""LanceDB manifest contention guard: per-kind in-flight task is unique."""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.05)
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
)
|
||||
w._schedule_optimize("episode")
|
||||
first_task = w._optimizer_states["episode"].task
|
||||
# Re-schedule while first task is still in flight; slot must not
|
||||
# be replaced.
|
||||
for _ in range(5):
|
||||
w._schedule_optimize("episode")
|
||||
assert w._optimizer_states["episode"].task is first_task
|
||||
await w._flush_optimizers()
|
||||
|
||||
|
||||
async def test_flush_optimizers_awaits_pending_task(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""flush_optimizers blocks until in-flight optimize commits and clears slot."""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.05)
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
)
|
||||
w._schedule_optimize("episode")
|
||||
assert w._optimizer_states["episode"].task is not None
|
||||
await w._flush_optimizers()
|
||||
assert fake.optimize_calls, "flush should not return before optimize ran"
|
||||
assert w._optimizer_states["episode"].task is None
|
||||
|
||||
|
||||
async def test_drain_until_empty_flushes_optimizers_before_returning(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""CLI ``cascade sync`` expects FTS to be current when the call returns."""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.03)
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
)
|
||||
await w.drain_until_empty()
|
||||
assert patched_repo.done == ["a.md"]
|
||||
assert len(fake.optimize_calls) == 1
|
||||
assert w._optimizer_states["episode"].task is None
|
||||
|
||||
|
||||
async def test_drain_once_does_not_block_on_optimize(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""drain_once is fire-and-forget — it must return before optimize commits."""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.2)
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
)
|
||||
started = time.monotonic()
|
||||
await w.drain_once()
|
||||
drain_elapsed = time.monotonic() - started
|
||||
# drain returned long before the 0.2s optimize would finish
|
||||
assert drain_elapsed < 0.1, f"drain blocked on optimize: {drain_elapsed:.3f}s"
|
||||
assert not fake.optimize_calls, "optimize should still be in flight"
|
||||
await w._flush_optimizers()
|
||||
assert len(fake.optimize_calls) == 1
|
||||
|
||||
|
||||
async def test_stop_waits_for_in_flight_optimize(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""stop() must give an in-flight optimize a chance to commit cleanly."""
|
||||
fake = _FakeLanceRepo(optimize_delay=0.05)
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
optimize_heartbeat_seconds=10.0,
|
||||
# Park rebuild interval — startup sweep still fires but we wait
|
||||
# for it before testing optimize semantics.
|
||||
optimize_rebuild_interval_seconds=10.0,
|
||||
)
|
||||
await w.start()
|
||||
# Let the startup rebuild sweep complete (instant for the fake repo)
|
||||
# before scheduling optimize — otherwise optimize would queue behind it.
|
||||
await asyncio.sleep(0.02)
|
||||
assert fake.rebuild_calls, "startup rebuild should have fired by now"
|
||||
w._schedule_optimize("episode")
|
||||
await asyncio.sleep(0.01) # let optimize start
|
||||
await w.stop()
|
||||
assert len(fake.optimize_calls) == 1
|
||||
|
||||
|
||||
async def test_optimize_failure_does_not_crash_drain_loop(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""Repo.optimize() raising should be logged but never propagate."""
|
||||
|
||||
class _FailingRepo:
|
||||
async def optimize(self) -> None:
|
||||
raise RuntimeError("simulated lancedb manifest conflict")
|
||||
|
||||
class _HandlerWithFailingRepo(_OkHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lance_repo = _FailingRepo()
|
||||
|
||||
patched_repo.batch = [_Row(md_path="a.md")]
|
||||
w = CascadeWorker(
|
||||
{"episode": _HandlerWithFailingRepo()},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.02,
|
||||
)
|
||||
# If the failure propagated, drain_until_empty would raise.
|
||||
await w.drain_until_empty()
|
||||
assert patched_repo.done == ["a.md"]
|
||||
assert patched_repo.failed == []
|
||||
|
||||
|
||||
async def test_heartbeat_schedules_every_handler_kind(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""The heartbeat sweeps all kinds, even ones nobody wrote to.
|
||||
|
||||
Drives the heartbeat manually via a short interval and asserts
|
||||
that ``optimize`` ran for both kinds at least once.
|
||||
"""
|
||||
fake_a = _FakeLanceRepo()
|
||||
fake_b = _FakeLanceRepo()
|
||||
w = CascadeWorker(
|
||||
{
|
||||
"episode": _OkHandlerWithRepo(fake_a),
|
||||
"atomic_fact": _OkHandlerWithRepo(fake_b),
|
||||
},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
optimize_heartbeat_seconds=0.05,
|
||||
)
|
||||
await w.start()
|
||||
# Let at least one heartbeat tick happen.
|
||||
await asyncio.sleep(0.12)
|
||||
await w.stop()
|
||||
assert fake_a.optimize_calls, "heartbeat should have scheduled episode"
|
||||
assert fake_b.optimize_calls, "heartbeat should have scheduled atomic_fact"
|
||||
|
||||
|
||||
async def test_optimize_prunes_on_first_call_then_throttles(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""First optimize() per kind passes ``cleanup_older_than``; subsequent
|
||||
calls within ``optimize_prune_interval_seconds`` do not.
|
||||
|
||||
Rationale lives in ``DEFAULT_OPTIMIZE_PRUNE_INTERVAL_SECONDS``:
|
||||
LanceDB ``optimize()`` without ``cleanup_older_than`` leaves stale
|
||||
physical files on disk; passing it on every 1-second optimize tick
|
||||
is wasteful, but never passing it leaks files until FDs exhaust.
|
||||
A separate cadence — prune ≪ optimize — balances the two.
|
||||
"""
|
||||
fake = _FakeLanceRepo()
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
optimize_prune_interval_seconds=10.0, # long — second call should NOT prune
|
||||
)
|
||||
# First call: state has never pruned, must include cleanup_older_than.
|
||||
w._schedule_optimize("episode")
|
||||
await w._flush_optimizers()
|
||||
assert len(fake.optimize_calls) == 1
|
||||
assert fake.optimize_cleanup_args[0] is not None, (
|
||||
"first optimize must prune to catch up from prior session"
|
||||
)
|
||||
assert fake.optimize_cleanup_args[0] == dt.timedelta(seconds=10.0)
|
||||
|
||||
# Second call within the prune window: light path (no cleanup).
|
||||
await asyncio.sleep(0.02) # exceed optimize throttle (0.01), not prune (10)
|
||||
w._schedule_optimize("episode")
|
||||
await w._flush_optimizers()
|
||||
assert len(fake.optimize_calls) == 2
|
||||
assert fake.optimize_cleanup_args[1] is None, (
|
||||
"second optimize within prune window should skip cleanup_older_than"
|
||||
)
|
||||
|
||||
|
||||
# ── Rebuild scheduler tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_rebuild_runs_on_startup_for_every_kind(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""The first rebuild sweep fires on worker start, before any interval.
|
||||
|
||||
Otherwise a daemon that restarts more often than the rebuild
|
||||
interval would never bound accumulated UUIDs.
|
||||
"""
|
||||
fake_a = _FakeLanceRepo()
|
||||
fake_b = _FakeLanceRepo()
|
||||
w = CascadeWorker(
|
||||
{
|
||||
"episode": _OkHandlerWithRepo(fake_a),
|
||||
"atomic_fact": _OkHandlerWithRepo(fake_b),
|
||||
},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
optimize_heartbeat_seconds=10.0, # park heartbeat
|
||||
optimize_rebuild_interval_seconds=10.0, # only the startup sweep should fire
|
||||
)
|
||||
await w.start()
|
||||
# Allow the startup sweep to complete; the next tick is 10s away.
|
||||
await asyncio.sleep(0.1)
|
||||
await w.stop()
|
||||
# Exactly one rebuild per kind: the startup sweep. Next interval is 10s.
|
||||
assert len(fake_a.rebuild_calls) == 1
|
||||
assert len(fake_b.rebuild_calls) == 1
|
||||
|
||||
|
||||
async def test_rebuild_runs_periodically(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""After the startup sweep, rebuild repeats every interval."""
|
||||
fake = _FakeLanceRepo()
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
optimize_heartbeat_seconds=10.0,
|
||||
optimize_rebuild_interval_seconds=0.05, # ~tick every 50ms in this test
|
||||
)
|
||||
await w.start()
|
||||
await asyncio.sleep(0.2) # ~4 ticks plus startup sweep
|
||||
await w.stop()
|
||||
# Startup sweep + at least 2 interval-driven sweeps.
|
||||
assert len(fake.rebuild_calls) >= 3, (
|
||||
f"expected ≥3 rebuilds (1 startup + ≥2 periodic), got {len(fake.rebuild_calls)}"
|
||||
)
|
||||
|
||||
|
||||
async def test_rebuild_failure_does_not_crash_daemon(
|
||||
patched_repo: _FakeRepo,
|
||||
) -> None:
|
||||
"""A throwing rebuild is logged and absorbed; the worker keeps running."""
|
||||
fake = _FakeLanceRepo(rebuild_raises=True)
|
||||
w = CascadeWorker(
|
||||
{"episode": _OkHandlerWithRepo(fake)},
|
||||
retry_backoff_seconds=0,
|
||||
optimize_min_interval_seconds=0.01,
|
||||
optimize_heartbeat_seconds=0.05,
|
||||
optimize_rebuild_interval_seconds=10.0,
|
||||
)
|
||||
await w.start()
|
||||
# Give startup rebuild a chance to throw, then heartbeat to keep optimizing.
|
||||
await asyncio.sleep(0.12)
|
||||
# Optimize should still progress despite rebuild errors.
|
||||
assert fake.optimize_calls, "heartbeat optimize should run even when rebuild fails"
|
||||
await w.stop()
|
||||
# Worker is still alive (stop() returned cleanly).
|
||||
assert w._task is None
|
||||
85
tests/unit/test_memory/test_events.py
Normal file
85
tests/unit/test_memory/test_events.py
Normal file
@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from everalgo.types import ChatMessage, MemCell
|
||||
|
||||
from everos.memory.events import AgentPipelineStarted, UserPipelineStarted
|
||||
|
||||
|
||||
def _sample_memcell() -> MemCell:
|
||||
return MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hello",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u1",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="assistant",
|
||||
content="hi back",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="agent",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_001_000,
|
||||
)
|
||||
|
||||
|
||||
def test_user_pipeline_started_topic_is_module_qualified() -> None:
|
||||
assert UserPipelineStarted.topic() == "everos.memory.events:UserPipelineStarted"
|
||||
|
||||
|
||||
def test_agent_pipeline_started_topic_is_module_qualified() -> None:
|
||||
assert AgentPipelineStarted.topic() == "everos.memory.events:AgentPipelineStarted"
|
||||
|
||||
|
||||
def test_user_pipeline_started_roundtrip_json() -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a", session_id="s1", memcell=_sample_memcell()
|
||||
)
|
||||
restored = UserPipelineStarted.model_validate_json(event.model_dump_json())
|
||||
assert restored.memcell_id == "mc_a"
|
||||
assert restored.session_id == "s1"
|
||||
|
||||
|
||||
def test_user_pipeline_started_is_frozen_and_extra_forbid() -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=_sample_memcell(),
|
||||
)
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
UserPipelineStarted( # type: ignore[call-arg]
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=_sample_memcell(),
|
||||
extra_field=1,
|
||||
)
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
event.memcell_id = "mc_b" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_user_pipeline_started_carries_memcell() -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=_sample_memcell(),
|
||||
)
|
||||
assert event.memcell.items[0].content == "hello"
|
||||
assert event.memcell.items[1].sender_id == "agent"
|
||||
|
||||
|
||||
def test_user_pipeline_started_nested_roundtrip_json() -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=_sample_memcell(),
|
||||
)
|
||||
restored = UserPipelineStarted.model_validate_json(event.model_dump_json())
|
||||
assert restored.memcell.items[0].id == "m1"
|
||||
assert restored.memcell.items[1].content == "hi back"
|
||||
assert restored.memcell.timestamp == 1_700_000_001_000
|
||||
0
tests/unit/test_memory/test_extract/__init__.py
Normal file
0
tests/unit/test_memory/test_extract/__init__.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Tests for ingest content coercion + text derivation (tagged rendering)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from everos.memory.extract.ingest.multimodal import (
|
||||
coerce_items,
|
||||
derive_text,
|
||||
normalise_content,
|
||||
)
|
||||
|
||||
|
||||
def test_coerce_str_to_text_item() -> None:
|
||||
assert coerce_items("hi") == [{"type": "text", "text": "hi"}]
|
||||
|
||||
|
||||
def test_derive_text_renders_parsed_nontext_as_tag() -> None:
|
||||
items = [
|
||||
{"type": "text", "text": "before"},
|
||||
{"type": "image", "name": "p.png", "parsed_content": "OCR TEXT"},
|
||||
{"type": "text", "text": "after"},
|
||||
]
|
||||
text, non_text = derive_text(items)
|
||||
|
||||
assert "[IMAGE: p.png]\nOCR TEXT" in text
|
||||
assert text.startswith("before")
|
||||
assert text.endswith("after")
|
||||
assert non_text == 0
|
||||
|
||||
|
||||
def test_derive_text_counts_unparsed_nontext() -> None:
|
||||
text, non_text = derive_text([{"type": "image", "uri": "x"}])
|
||||
assert text == ""
|
||||
assert non_text == 1
|
||||
|
||||
|
||||
def test_derive_text_tag_without_name() -> None:
|
||||
text, _ = derive_text([{"type": "pdf", "parsed_content": "DOC"}])
|
||||
assert text == "[PDF]\nDOC"
|
||||
|
||||
|
||||
def test_normalise_content_text_only_unchanged() -> None:
|
||||
items, text, non_text = normalise_content("hello")
|
||||
assert items == [{"type": "text", "text": "hello"}]
|
||||
assert text == "hello"
|
||||
assert non_text == 0
|
||||
@ -0,0 +1,38 @@
|
||||
"""Tests for the multimodal capability guard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.core.errors import MultimodalNotEnabledError
|
||||
from everos.memory.extract.parser import availability
|
||||
|
||||
|
||||
def test_has_unparsed_multimodal_true_for_unparsed_nontext() -> None:
|
||||
items = [{"type": "text", "text": "hi"}, {"type": "image", "uri": "x"}]
|
||||
assert availability.has_unparsed_multimodal(items) is True
|
||||
|
||||
|
||||
def test_has_unparsed_multimodal_false_when_all_text() -> None:
|
||||
items = [{"type": "text", "text": "hi"}]
|
||||
assert availability.has_unparsed_multimodal(items) is False
|
||||
|
||||
|
||||
def test_has_unparsed_multimodal_false_when_already_parsed() -> None:
|
||||
items = [{"type": "image", "uri": "x", "parsed_content": "ocr"}]
|
||||
assert availability.has_unparsed_multimodal(items) is False
|
||||
|
||||
|
||||
def test_require_multimodal_raises_when_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(availability, "multimodal_available", lambda: False)
|
||||
with pytest.raises(MultimodalNotEnabledError):
|
||||
availability.require_multimodal()
|
||||
|
||||
|
||||
def test_require_multimodal_ok_when_available(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(availability, "multimodal_available", lambda: True)
|
||||
availability.require_multimodal() # must not raise
|
||||
183
tests/unit/test_memory/test_extract/test_parser/test_enrich.py
Normal file
183
tests/unit/test_memory/test_extract/test_parser/test_enrich.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Tests for enrich_content_items (everalgo.parser.aparse is monkeypatched)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
# ``everalgo.parser`` ships under the ``[multimodal]`` extra (see
|
||||
# pyproject.toml). CI doesn't install that extra by default, and these
|
||||
# tests monkeypatch ``everalgo.parser.aparse`` — which requires the
|
||||
# module to actually be importable, otherwise ``monkeypatch.setattr``
|
||||
# fails at resolve-time. Skip the whole module when the optional
|
||||
# dependency isn't present; we still run when ``multimodal`` is installed.
|
||||
pytest.importorskip("everalgo.parser")
|
||||
|
||||
from everalgo.llm import LLMError # noqa: E402
|
||||
from everalgo.types import ParsedContent # noqa: E402
|
||||
|
||||
from everos.core.errors import UnsupportedModalityError # noqa: E402
|
||||
from everos.memory.extract.parser import enrich_content_items # noqa: E402
|
||||
|
||||
|
||||
def _img_item() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "image",
|
||||
"base64": base64.b64encode(b"\x89PNG").decode(),
|
||||
"ext": "png",
|
||||
}
|
||||
|
||||
|
||||
def _html_b64_item() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "html",
|
||||
"base64": base64.b64encode(b"<html><body>v9.9.9</body></html>").decode(),
|
||||
"ext": "html",
|
||||
}
|
||||
|
||||
|
||||
def _html_uri_item() -> dict[str, Any]:
|
||||
return {"type": "html", "uri": "https://example.com/page.html"}
|
||||
|
||||
|
||||
async def test_enrich_backfills_parsed_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
return ParsedContent(text="OCR RESULT")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
items: list[dict[str, Any]] = [{"type": "text", "text": "hi"}, _img_item()]
|
||||
await enrich_content_items(items, llm=object(), max_concurrency=2)
|
||||
|
||||
assert items[1]["parsed_content"] == "OCR RESULT"
|
||||
assert items[1]["parse_status"] == "success"
|
||||
assert "parsed_content" not in items[0] # text item untouched
|
||||
|
||||
|
||||
async def test_enrich_unsupported_modality_raises(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
raise NotImplementedError("video deferred")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
with pytest.raises(UnsupportedModalityError):
|
||||
await enrich_content_items([_img_item()], llm=object())
|
||||
|
||||
|
||||
async def test_enrich_transient_llm_error_degrades(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
raise LLMError("provider down")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
items = [_img_item()]
|
||||
await enrich_content_items(items, llm=object()) # must not raise
|
||||
|
||||
assert items[0]["parse_status"] == "failed"
|
||||
assert "parsed_content" not in items[0]
|
||||
|
||||
|
||||
async def test_enrich_html_base64_routes_as_html_bytes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A type=html base64 item reaches the parser as html-extension bytes.
|
||||
|
||||
Locks the "normal HTML file call" contract: base64 + ext=html maps to
|
||||
a RawFile the parser dispatches as HTML (vs the 415 that a text-only
|
||||
html item produces — see test_ingest for that negative path).
|
||||
"""
|
||||
seen: dict[str, Any] = {}
|
||||
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
seen["extension"] = raw_file.extension
|
||||
seen["content"] = raw_file.content
|
||||
return ParsedContent(text="HTML PARSED")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
items = [_html_b64_item()]
|
||||
await enrich_content_items(items, llm=object())
|
||||
|
||||
assert items[0]["parsed_content"] == "HTML PARSED"
|
||||
assert items[0]["parse_status"] == "success"
|
||||
assert seen["extension"] == "html"
|
||||
assert b"v9.9.9" in seen["content"]
|
||||
|
||||
|
||||
async def test_enrich_http_uri_routes_as_uri(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""An http(s) uri item reaches the parser as a uri RawFile (no bytes).
|
||||
|
||||
Proves everos forwards uri-backed items to the parser, which is what
|
||||
drives everalgo's URL-fetch dispatch path (http/https only; file:// is
|
||||
rejected downstream).
|
||||
"""
|
||||
seen: dict[str, Any] = {}
|
||||
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
seen["uri"] = raw_file.uri
|
||||
seen["content"] = raw_file.content
|
||||
return ParsedContent(text="URL PARSED")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
items = [_html_uri_item()]
|
||||
await enrich_content_items(items, llm=object())
|
||||
|
||||
assert items[0]["parsed_content"] == "URL PARSED"
|
||||
assert items[0]["parse_status"] == "success"
|
||||
assert seen["uri"] == "https://example.com/page.html"
|
||||
assert seen["content"] == b""
|
||||
|
||||
|
||||
async def test_enrich_html_text_only_raises_unsupported(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""type=html carrying only ``text`` (no uri/base64) is undispatchable.
|
||||
|
||||
Any non-text item is routed to the parser, which needs a fetchable or
|
||||
decodable payload; a bare ``text`` has neither, so it surfaces as a
|
||||
MultimodalError (the route maps it to HTTP 415). To inline HTML *as
|
||||
text*, callers must use ``type="text"`` instead.
|
||||
"""
|
||||
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
return ParsedContent(text="should-not-be-reached")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
with pytest.raises(UnsupportedModalityError):
|
||||
await enrich_content_items(
|
||||
[{"type": "html", "text": "<p>hi</p>"}], llm=object()
|
||||
)
|
||||
|
||||
|
||||
async def test_enrich_file_uri_hydrates_and_parses(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Any,
|
||||
) -> None:
|
||||
"""A ``file://`` item is read locally and handed to the parser as bytes.
|
||||
|
||||
Proves EverOS hydrates the file (everalgo never sees the path / fs) — the
|
||||
parser receives ``content`` bytes, not a uri.
|
||||
"""
|
||||
seen: dict[str, Any] = {}
|
||||
|
||||
async def fake_aparse(raw_file: Any, *, llm: Any) -> ParsedContent:
|
||||
seen["content"] = raw_file.content
|
||||
seen["uri"] = raw_file.uri
|
||||
return ParsedContent(text="FILE PARSED")
|
||||
|
||||
monkeypatch.setattr("everalgo.parser.aparse", fake_aparse)
|
||||
f = tmp_path / "doc.html"
|
||||
f.write_bytes(b"<html>hello</html>")
|
||||
items = [{"type": "html", "uri": f"file://{f}"}]
|
||||
await enrich_content_items(items, llm=object())
|
||||
|
||||
assert items[0]["parsed_content"] == "FILE PARSED"
|
||||
assert items[0]["parse_status"] == "success"
|
||||
assert seen["content"] == b"<html>hello</html>" # hydrated, not a pointer
|
||||
assert seen["uri"] == ""
|
||||
105
tests/unit/test_memory/test_extract/test_parser/test_mapping.py
Normal file
105
tests/unit/test_memory/test_extract/test_parser/test_mapping.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Tests for ContentItem -> everalgo RawFile mapping + file:// hydration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.config import load_settings
|
||||
from everos.memory.extract.parser.mapping import build_raw_file, to_raw_file
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_settings_cache():
|
||||
"""file:// guardrails read settings; keep the lru_cache from leaking
|
||||
env overrides across tests."""
|
||||
load_settings.cache_clear()
|
||||
yield
|
||||
load_settings.cache_clear()
|
||||
|
||||
|
||||
def test_uri_item_maps_to_rawfile_uri() -> None:
|
||||
rf = to_raw_file({"type": "image", "uri": "https://x/y.png"})
|
||||
assert rf.uri == "https://x/y.png"
|
||||
assert rf.content == b""
|
||||
|
||||
|
||||
def test_base64_item_decodes_and_lowercases_extension() -> None:
|
||||
raw = b"\x89PNG\r\n"
|
||||
rf = to_raw_file(
|
||||
{"type": "image", "base64": base64.b64encode(raw).decode(), "ext": ".PNG"}
|
||||
)
|
||||
assert rf.content == raw
|
||||
assert rf.extension == "png"
|
||||
|
||||
|
||||
def test_item_without_uri_or_base64_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_raw_file({"type": "image"})
|
||||
|
||||
|
||||
# ── build_raw_file: file:// hydration + guardrails ──────────────────────
|
||||
|
||||
|
||||
async def test_build_raw_file_delegates_http_uri() -> None:
|
||||
"""http(s) uris stay in uri form (everalgo fetches), not hydrated."""
|
||||
rf = await build_raw_file({"type": "html", "uri": "https://example.com"})
|
||||
assert rf.uri == "https://example.com"
|
||||
assert rf.content == b""
|
||||
|
||||
|
||||
async def test_build_raw_file_hydrates_file_uri(tmp_path: Path) -> None:
|
||||
"""file:// is read locally into a hydrated RawFile (content + ext)."""
|
||||
f = tmp_path / "notes.html"
|
||||
f.write_bytes(b"<html><body>v9.9.9</body></html>")
|
||||
rf = await build_raw_file({"type": "html", "uri": f"file://{f}"})
|
||||
assert rf.content == b"<html><body>v9.9.9</body></html>"
|
||||
assert rf.extension == "html"
|
||||
assert rf.uri == "" # hydrated, not a pointer
|
||||
|
||||
|
||||
async def test_build_raw_file_file_uri_ext_hint_wins(tmp_path: Path) -> None:
|
||||
f = tmp_path / "blob" # no suffix
|
||||
f.write_bytes(b"%PDF-1.4 ...")
|
||||
rf = await build_raw_file({"type": "pdf", "uri": f"file://{f}", "ext": "pdf"})
|
||||
assert rf.extension == "pdf"
|
||||
|
||||
|
||||
async def test_build_raw_file_missing_file_raises(tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await build_raw_file({"type": "pdf", "uri": f"file://{tmp_path}/nope.pdf"})
|
||||
|
||||
|
||||
async def test_build_raw_file_oversize_raises(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
f = tmp_path / "big.html"
|
||||
f.write_bytes(b"x" * 100)
|
||||
monkeypatch.setenv("EVEROS_MULTIMODAL__FILE_URI_MAX_BYTES", "10")
|
||||
load_settings.cache_clear()
|
||||
with pytest.raises(ValueError, match="too large"):
|
||||
await build_raw_file({"type": "html", "uri": f"file://{f}"})
|
||||
|
||||
|
||||
async def test_build_raw_file_outside_allowlist_raises(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
f = tmp_path / "secret.html"
|
||||
f.write_bytes(b"<html></html>")
|
||||
monkeypatch.setenv("EVEROS_MULTIMODAL__FILE_URI_ALLOW_DIRS", '["/some/other/root"]')
|
||||
load_settings.cache_clear()
|
||||
with pytest.raises(ValueError, match="outside the allowed roots"):
|
||||
await build_raw_file({"type": "html", "uri": f"file://{f}"})
|
||||
|
||||
|
||||
async def test_build_raw_file_inside_allowlist_ok(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
f = tmp_path / "ok.html"
|
||||
f.write_bytes(b"<html>ok</html>")
|
||||
monkeypatch.setenv("EVEROS_MULTIMODAL__FILE_URI_ALLOW_DIRS", f'["{tmp_path}"]')
|
||||
load_settings.cache_clear()
|
||||
rf = await build_raw_file({"type": "html", "uri": f"file://{f}"})
|
||||
assert rf.content == b"<html>ok</html>"
|
||||
@ -0,0 +1,61 @@
|
||||
"""``AgentMemoryPipeline.run`` — empty short-circuit + per-cell event emit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from everalgo.types import ChatMessage, MemCell
|
||||
|
||||
from everos.memory import IngestResult
|
||||
from everos.memory.events import AgentPipelineStarted
|
||||
from everos.memory.extract.pipeline.agent_memory import AgentMemoryPipeline
|
||||
|
||||
|
||||
class _FakeEngine:
|
||||
"""Captures emitted events; mirrors ``OfflineEngine.emit`` async signature."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[AgentPipelineStarted] = []
|
||||
|
||||
async def emit(self, event: AgentPipelineStarted) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
def _make_cell(n_items: int, ts: int = 1_700_000_000_000) -> MemCell:
|
||||
items = [
|
||||
ChatMessage(
|
||||
id=f"m{i}",
|
||||
role="user",
|
||||
sender_id="u1",
|
||||
sender_name="u",
|
||||
content="hi",
|
||||
timestamp=ts,
|
||||
)
|
||||
for i in range(n_items)
|
||||
]
|
||||
return MemCell(items=items, timestamp=ts)
|
||||
|
||||
|
||||
async def test_empty_cells_short_circuit() -> None:
|
||||
engine = _FakeEngine()
|
||||
pipeline = AgentMemoryPipeline(engine) # type: ignore[arg-type]
|
||||
ingested = IngestResult(session_id="s1", messages=[])
|
||||
out = await pipeline.run(ingested, cells=[], memcell_ids=[])
|
||||
assert out.track == "agent_memory"
|
||||
assert out.status == "accumulated"
|
||||
assert out.message_count == 0
|
||||
assert engine.events == []
|
||||
|
||||
|
||||
async def test_emits_one_event_per_cell() -> None:
|
||||
engine = _FakeEngine()
|
||||
pipeline = AgentMemoryPipeline(engine) # type: ignore[arg-type]
|
||||
ingested = IngestResult(session_id="s1", messages=[])
|
||||
cells = [_make_cell(n_items=2), _make_cell(n_items=3)]
|
||||
memcell_ids = ["mc_a", "mc_b"]
|
||||
out = await pipeline.run(ingested, cells=cells, memcell_ids=memcell_ids)
|
||||
|
||||
assert out.track == "agent_memory"
|
||||
assert out.status == "extracted"
|
||||
assert out.message_count == 5 # 2 + 3
|
||||
assert [e.memcell_id for e in engine.events] == ["mc_a", "mc_b"]
|
||||
assert all(e.session_id == "s1" for e in engine.events)
|
||||
assert all(isinstance(e, AgentPipelineStarted) for e in engine.events)
|
||||
@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from everalgo.types import ChatMessage, MemCell
|
||||
from everalgo.types import Episode as AlgoEpisode
|
||||
|
||||
from everos.core.persistence import EntryId
|
||||
from everos.memory import IngestResult
|
||||
from everos.memory.events import EpisodeExtracted, UserPipelineStarted
|
||||
from everos.memory.extract.pipeline.user_memory import UserMemoryPipeline
|
||||
from everos.memory.models import CanonicalMessage
|
||||
|
||||
|
||||
def _sample_memcell() -> MemCell:
|
||||
return MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hello",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u1",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
|
||||
|
||||
class _CapturingEngine:
|
||||
def __init__(self) -> None:
|
||||
self.emitted: list[object] = []
|
||||
|
||||
async def emit(self, event: object) -> None:
|
||||
self.emitted.append(event)
|
||||
|
||||
|
||||
async def test_emit_pipeline_started_routes_through_engine() -> None:
|
||||
engine = _CapturingEngine()
|
||||
pipeline = UserMemoryPipeline(
|
||||
episode_writer=MagicMock(),
|
||||
prompt_loader=MagicMock(),
|
||||
llm_client=MagicMock(),
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
cell = _sample_memcell()
|
||||
await pipeline._emit_pipeline_started( # noqa: SLF001 — test introspection
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
app_id="claude_code",
|
||||
project_id="oss",
|
||||
cell=cell,
|
||||
)
|
||||
|
||||
started = [e for e in engine.emitted if isinstance(e, UserPipelineStarted)]
|
||||
assert len(started) == 1
|
||||
assert started[0].memcell_id == "mc_a"
|
||||
assert started[0].session_id == "s1"
|
||||
assert started[0].app_id == "claude_code"
|
||||
assert started[0].project_id == "oss"
|
||||
assert started[0].memcell is cell
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_episode_extracted_after_md_write() -> None:
|
||||
"""Each per-sender Episode write emits EpisodeExtracted with the md entry id."""
|
||||
engine = _CapturingEngine()
|
||||
episode_writer = MagicMock()
|
||||
episode_writer.append_entry = AsyncMock(
|
||||
return_value=EntryId(prefix="ep", date=_dt.date(2026, 5, 17), seq=1)
|
||||
)
|
||||
episode_writer.path_for = MagicMock(
|
||||
return_value="users/u1/episodes/episode-2026-05-17.md"
|
||||
)
|
||||
prompt_loader = MagicMock()
|
||||
prompt_loader.load = MagicMock(return_value="<prompt>")
|
||||
llm_client = MagicMock()
|
||||
|
||||
pipeline = UserMemoryPipeline(
|
||||
episode_writer=episode_writer,
|
||||
prompt_loader=prompt_loader,
|
||||
llm_client=llm_client,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
cell = _sample_memcell()
|
||||
ingested = IngestResult(
|
||||
session_id="s1",
|
||||
messages=[
|
||||
CanonicalMessage(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
sender_id="u1",
|
||||
role="user",
|
||||
timestamp=_dt.datetime.fromtimestamp(1_700_000_000, tz=_dt.UTC),
|
||||
text="hello",
|
||||
)
|
||||
],
|
||||
)
|
||||
algo_ep = AlgoEpisode(
|
||||
owner_id="u1", episode="they said hello", timestamp=1_700_000_000_000
|
||||
)
|
||||
with patch.object( # noqa: SLF001
|
||||
pipeline._ep_ext, "aextract", new=AsyncMock(return_value=algo_ep)
|
||||
):
|
||||
outcome = await pipeline.run(
|
||||
ingested=ingested,
|
||||
cells=[cell],
|
||||
memcell_ids=["mc_a"],
|
||||
per_cell_all_senders=[["u1"]],
|
||||
)
|
||||
|
||||
assert outcome.status == "extracted"
|
||||
extracted = [e for e in engine.emitted if isinstance(e, EpisodeExtracted)]
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0].memcell_id == "mc_a"
|
||||
assert extracted[0].episode_entry_id == "ep_20260517_00000001"
|
||||
assert extracted[0].episode_text == "they said hello"
|
||||
assert extracted[0].episode_timestamp_ms == 1_700_000_000_000
|
||||
assert extracted[0].owner_id == "u1"
|
||||
0
tests/unit/test_memory/test_get/__init__.py
Normal file
0
tests/unit/test_memory/test_get/__init__.py
Normal file
177
tests/unit/test_memory/test_get/test_dto.py
Normal file
177
tests/unit/test_memory/test_get/test_dto.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Tests for ``memory.get.dto``.
|
||||
|
||||
Pydantic-side guarantees the manager / route can rely on:
|
||||
|
||||
* ``GetRequest`` defaults match the wiki spec (``page=1`` /
|
||||
``page_size=20`` / ``sort_by="timestamp"`` / ``sort_order="desc"``)
|
||||
* ``page_size`` upper bound (1–100)
|
||||
* ``owner_type`` × ``memory_type`` strict pairing
|
||||
* Unknown fields on the request are rejected (``extra="forbid"``)
|
||||
|
||||
Filter DSL coverage lives in ``test_memory/test_search/test_filters.py``
|
||||
since ``/get`` shares :class:`everos.memory.search.FilterNode`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.memory.get.dto import (
|
||||
GetMemoryType,
|
||||
GetRequest,
|
||||
)
|
||||
|
||||
# ── GetRequest defaults / shape ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_request_defaults_match_wiki() -> None:
|
||||
"""``page`` / ``page_size`` / ``sort_by`` / ``sort_order`` come from the wiki."""
|
||||
req = GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
)
|
||||
assert req.page == 1
|
||||
assert req.page_size == 20
|
||||
assert req.sort_by == "timestamp"
|
||||
assert req.sort_order == "desc"
|
||||
assert req.filters is None
|
||||
|
||||
|
||||
def test_get_request_page_size_upper_bound() -> None:
|
||||
"""101 → ValidationError (wiki cap is 100)."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
page_size=101,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_page_size_lower_bound() -> None:
|
||||
"""0 → ValidationError (page_size ≥ 1)."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
page_size=0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_page_lower_bound() -> None:
|
||||
"""0 → ValidationError (page ≥ 1; 1-indexed)."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
page=0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_unknown_field() -> None:
|
||||
"""``extra='forbid'`` — typos surface as a 422, not silent drops."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
unknown_extra=True, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_empty_user_id() -> None:
|
||||
"""``user_id`` carries ``min_length=1`` — empty string is 422."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(
|
||||
user_id="",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_missing_memory_type() -> None:
|
||||
"""``memory_type`` is required — omission is 422."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest( # type: ignore[call-arg]
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_missing_owner_identity() -> None:
|
||||
"""Neither ``user_id`` nor ``agent_id`` → xor validator rejects."""
|
||||
with pytest.raises(ValidationError, match="exactly one of"):
|
||||
GetRequest( # type: ignore[call-arg]
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_both_user_and_agent_id() -> None:
|
||||
"""Both ``user_id`` and ``agent_id`` set → xor validator rejects."""
|
||||
with pytest.raises(ValidationError, match="exactly one of"):
|
||||
GetRequest(
|
||||
user_id="u1",
|
||||
agent_id="agent_x",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_invalid_memory_type_value() -> None:
|
||||
"""A value outside the four-kind enum is 422."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest.model_validate(
|
||||
{
|
||||
"user_id": "u1",
|
||||
"memory_type": "atomic_fact", # not a top-level kind
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_get_request_rejects_invalid_sort_order() -> None:
|
||||
"""``sort_order`` is a tight Literal — typos / casing variants are 422."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest.model_validate(
|
||||
{
|
||||
"user_id": "u1",
|
||||
"memory_type": "episode",
|
||||
"sort_order": "DESC", # must be lowercase
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ── owner_type × memory_type pairing ─────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"id_field, memory_type",
|
||||
[
|
||||
("user_id", GetMemoryType.EPISODE),
|
||||
("user_id", GetMemoryType.PROFILE),
|
||||
("agent_id", GetMemoryType.AGENT_CASE),
|
||||
("agent_id", GetMemoryType.AGENT_SKILL),
|
||||
],
|
||||
)
|
||||
def test_get_request_allows_valid_owner_memory_pair(
|
||||
id_field: str,
|
||||
memory_type: GetMemoryType,
|
||||
) -> None:
|
||||
"""The four valid (owner-kind, memory_type) combinations."""
|
||||
req = GetRequest(**{id_field: "u1"}, memory_type=memory_type)
|
||||
assert req.memory_type is memory_type
|
||||
expected_owner_type = "user" if id_field == "user_id" else "agent"
|
||||
assert req.owner_type == expected_owner_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"id_field, memory_type",
|
||||
[
|
||||
("user_id", GetMemoryType.AGENT_CASE),
|
||||
("user_id", GetMemoryType.AGENT_SKILL),
|
||||
("agent_id", GetMemoryType.EPISODE),
|
||||
("agent_id", GetMemoryType.PROFILE),
|
||||
],
|
||||
)
|
||||
def test_get_request_rejects_cross_owner_memory_pair(
|
||||
id_field: str,
|
||||
memory_type: GetMemoryType,
|
||||
) -> None:
|
||||
"""Cross-pairs (user_id+agent_case etc.) are 422 at the DTO layer."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetRequest(**{id_field: "u1"}, memory_type=memory_type)
|
||||
212
tests/unit/test_memory/test_get/test_filters_adapter.py
Normal file
212
tests/unit/test_memory/test_get/test_filters_adapter.py
Normal file
@ -0,0 +1,212 @@
|
||||
"""Tests for ``memory.get.filters_adapter.compile_filters_for_get``.
|
||||
|
||||
The adapter is a thin wrapper around
|
||||
:func:`everos.memory.search.compile_filters` — these tests pin the
|
||||
behaviour /get callers depend on:
|
||||
|
||||
* base clause shape (``owner_id = '...' AND owner_type = '...'``)
|
||||
* flat multi-field → implicit ``AND``
|
||||
* reserved field (``owner_id`` / ``owner_type`` inside ``filters``)
|
||||
→ :class:`FilterError`
|
||||
* unknown field → :class:`FilterError`
|
||||
* top-level ``AND`` / ``OR`` combinators are accepted (parity with
|
||||
``/search`` — the wiki §附录 C restriction was dropped 2026-05-16)
|
||||
* ``timestamp`` range (multi-op map) renders ``AND``-folded clauses
|
||||
* ``sender_id`` is an array column → ``array_has(...)`` rendering
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.get.filters_adapter import compile_filters_for_get
|
||||
from everos.memory.search import FilterError, FilterNode
|
||||
|
||||
|
||||
def test_no_filters_emits_base_clause() -> None:
|
||||
"""``filters=None`` → owner + app/project scope clauses AND-joined."""
|
||||
where = compile_filters_for_get(None, owner_id="u1", owner_type="user")
|
||||
assert where == (
|
||||
"owner_id = 'u1' AND owner_type = 'user' "
|
||||
"AND app_id = 'default' AND project_id = 'default'"
|
||||
)
|
||||
|
||||
|
||||
def test_owner_id_quote_is_escaped() -> None:
|
||||
"""SQL-standard double-quote escape on ``owner_id``."""
|
||||
where = compile_filters_for_get(None, owner_id="o'reilly", owner_type="user")
|
||||
assert where == (
|
||||
"owner_id = 'o''reilly' AND owner_type = 'user' "
|
||||
"AND app_id = 'default' AND project_id = 'default'"
|
||||
)
|
||||
|
||||
|
||||
def test_flat_multi_field_renders_implicit_and() -> None:
|
||||
"""Multiple top-level fields → implicit ``AND`` between predicates."""
|
||||
node = FilterNode.model_validate({"session_id": "sess_a", "parent_id": "mc_x"})
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
# Field iteration order follows insertion order, so both are present.
|
||||
assert "owner_id = 'u1'" in where
|
||||
assert "owner_type = 'user'" in where
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "parent_id = 'mc_x'" in where
|
||||
# 4 base scope clauses + 2 filter fields = 6 clauses → 5 ' AND ' joins.
|
||||
assert where.count(" AND ") == 5
|
||||
|
||||
|
||||
def test_reserved_owner_id_in_filters_raises() -> None:
|
||||
"""``owner_id`` inside ``filters`` is a hard error (must be top level)."""
|
||||
node = FilterNode.model_validate({"owner_id": "u1"})
|
||||
with pytest.raises(FilterError, match="reserved"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_reserved_owner_type_in_filters_raises() -> None:
|
||||
"""``owner_type`` inside ``filters`` is also reserved."""
|
||||
node = FilterNode.model_validate({"owner_type": "user"})
|
||||
with pytest.raises(FilterError, match="reserved"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_unsupported_field_raises() -> None:
|
||||
"""Any field outside the shared allow-list → :class:`FilterError`."""
|
||||
node = FilterNode.model_validate({"random_attr": "x"})
|
||||
with pytest.raises(FilterError, match="unsupported"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_timestamp_range_renders_and_folded() -> None:
|
||||
"""Multi-op map on one field folds with ``AND`` (reused from /search)."""
|
||||
node = FilterNode.model_validate(
|
||||
{"timestamp": {"gte": 1704067200000, "lt": 1735689600000}}
|
||||
)
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "timestamp >= TIMESTAMP '" in where
|
||||
assert "timestamp < TIMESTAMP '" in where
|
||||
# The two clauses are AND-joined inside one parenthesised group.
|
||||
assert "(timestamp >= TIMESTAMP" in where
|
||||
assert " AND timestamp < TIMESTAMP" in where
|
||||
|
||||
|
||||
def test_sender_id_in_list_renders_array_has() -> None:
|
||||
"""``sender_id`` is an array column — ``in`` → ``array_has(...) OR ...``."""
|
||||
node = FilterNode.model_validate({"sender_id": {"in": ["alice", "bob"]}})
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "array_has(sender_ids, 'alice')" in where
|
||||
assert "array_has(sender_ids, 'bob')" in where
|
||||
|
||||
|
||||
def test_sender_id_eq_shorthand_renders_array_has() -> None:
|
||||
"""Equality shorthand on an array column → single ``array_has``."""
|
||||
node = FilterNode.model_validate({"sender_id": "alice"})
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "array_has(sender_ids, 'alice')" in where
|
||||
|
||||
|
||||
def test_parent_id_eq_shorthand_renders_scalar_eq() -> None:
|
||||
"""``parent_id`` is a scalar string column → plain ``=``."""
|
||||
node = FilterNode.model_validate({"parent_id": "mc_42"})
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "parent_id = 'mc_42'" in where
|
||||
|
||||
|
||||
def test_top_level_and_renders_grouped_clause() -> None:
|
||||
"""``AND`` combinator compiles like /search — parens-grouped fragments."""
|
||||
node = FilterNode.model_validate(
|
||||
{"AND": [{"session_id": "sess_a"}, {"parent_id": "mc_x"}]}
|
||||
)
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
# Base clause is always first; combinator output appended.
|
||||
assert where.startswith("owner_id = 'u1' AND owner_type = 'user' AND ")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "parent_id = 'mc_x'" in where
|
||||
|
||||
|
||||
def test_top_level_or_renders_grouped_clause() -> None:
|
||||
"""``OR`` combinator emits parens-grouped ``OR`` between sibling preds."""
|
||||
node = FilterNode.model_validate(
|
||||
{"OR": [{"session_id": "sess_a"}, {"session_id": "sess_b"}]}
|
||||
)
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "session_id = 'sess_b'" in where
|
||||
assert " OR " in where
|
||||
|
||||
|
||||
def test_ne_operator_renders_not_equal() -> None:
|
||||
"""``ne`` op compiles to ``!=`` on str fields."""
|
||||
node = FilterNode.model_validate({"session_id": {"ne": "sess_internal"}})
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "session_id != 'sess_internal'" in where
|
||||
|
||||
|
||||
def test_timestamp_iso_string_renders_literal() -> None:
|
||||
"""ISO 8601 string is accepted as a timestamp literal (alongside epoch ms)."""
|
||||
node = FilterNode.model_validate(
|
||||
{"timestamp": {"gte": "2026-01-04T00:00:00+00:00"}}
|
||||
)
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "timestamp >= TIMESTAMP '2026-01-04T00:00:00+00:00'" in where
|
||||
|
||||
|
||||
def test_nested_and_inside_or() -> None:
|
||||
"""``AND`` nested inside ``OR`` — combinators compose recursively."""
|
||||
node = FilterNode.model_validate(
|
||||
{
|
||||
"OR": [
|
||||
{"AND": [{"session_id": "sess_a"}, {"parent_id": "mc_x"}]},
|
||||
{"session_id": "sess_b"},
|
||||
]
|
||||
}
|
||||
)
|
||||
where = compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "parent_id = 'mc_x'" in where
|
||||
assert "session_id = 'sess_b'" in where
|
||||
assert " OR " in where
|
||||
assert " AND " in where
|
||||
|
||||
|
||||
# ── Malformed value shapes ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_in_op_with_non_list_rejected() -> None:
|
||||
"""``in`` requires a non-empty list — a scalar is a hard error."""
|
||||
node = FilterNode.model_validate({"session_id": {"in": "not_a_list"}})
|
||||
with pytest.raises(FilterError, match="non-empty list"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_in_op_with_empty_list_rejected() -> None:
|
||||
"""``in: []`` is invalid — must contain at least one value."""
|
||||
node = FilterNode.model_validate({"session_id": {"in": []}})
|
||||
with pytest.raises(FilterError, match="non-empty list"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_empty_operator_map_rejected() -> None:
|
||||
"""``{}`` as a field value (no op) is a hard error."""
|
||||
node = FilterNode.model_validate({"timestamp": {}})
|
||||
with pytest.raises(FilterError, match="empty operator map"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_unknown_op_rejected() -> None:
|
||||
"""``between`` / other non-allow-listed ops surface as :class:`FilterError`."""
|
||||
node = FilterNode.model_validate({"timestamp": {"between": [1, 2]}})
|
||||
with pytest.raises(FilterError, match="operator"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_sender_id_gt_rejected() -> None:
|
||||
"""``gt`` on an ``array_str`` column is not supported (semantics unclear)."""
|
||||
node = FilterNode.model_validate({"sender_id": {"gt": "alice"}})
|
||||
with pytest.raises(FilterError, match="not supported on array"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
|
||||
|
||||
def test_non_string_in_str_field_rejected() -> None:
|
||||
"""``session_id`` is a str field — passing an int is a typed error."""
|
||||
node = FilterNode.model_validate({"session_id": {"in": [1, 2]}})
|
||||
with pytest.raises(FilterError, match="must be a string"):
|
||||
compile_filters_for_get(node, owner_id="u1", owner_type="user")
|
||||
350
tests/unit/test_memory/test_get/test_manager.py
Normal file
350
tests/unit/test_memory/test_get/test_manager.py
Normal file
@ -0,0 +1,350 @@
|
||||
"""Unit tests for :class:`GetManager` with in-memory stub repos.
|
||||
|
||||
These tests exercise the dispatch / shape / sort-override logic without
|
||||
LanceDB. Each repo is replaced by a minimal stub that records the call
|
||||
and returns canned rows; the manager's job is to:
|
||||
|
||||
* dispatch on ``memory_type`` to the matching repo,
|
||||
* compile filters once and pass the same ``where`` to the repo,
|
||||
* shape rows into the correct ``GetItem`` (lossless except score),
|
||||
* silently override ``sort_by`` to ``updated_at`` for ``agent_skill``
|
||||
(the table has no ``timestamp`` column),
|
||||
* fetch the owner's single profile row (KV-by-owner) and shape it into
|
||||
``GetProfileItem``, or return ``[]`` on a cold-start miss.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.persistence.lancedb import (
|
||||
AgentCase,
|
||||
AgentSkill,
|
||||
Episode,
|
||||
UserProfile,
|
||||
)
|
||||
from everos.memory.get import (
|
||||
GetManager,
|
||||
GetMemoryType,
|
||||
GetRequest,
|
||||
)
|
||||
from everos.memory.search import FilterNode
|
||||
|
||||
# ── Stub repos ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CallRecord:
|
||||
where: str = ""
|
||||
sort_by: str = ""
|
||||
descending: bool = True
|
||||
page: int = 0
|
||||
page_size: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubRepo:
|
||||
"""Records the call and returns ``(rows, total)`` verbatim."""
|
||||
|
||||
rows: list[Any] = field(default_factory=list)
|
||||
total: int = 0
|
||||
last: _CallRecord = field(default_factory=_CallRecord)
|
||||
|
||||
async def find_where_paginated(
|
||||
self,
|
||||
where: str,
|
||||
*,
|
||||
sort_by: str,
|
||||
descending: bool = True,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
max_fetch: int = 20000,
|
||||
) -> tuple[list[Any], int]:
|
||||
self.last = _CallRecord(
|
||||
where=where,
|
||||
sort_by=sort_by,
|
||||
descending=descending,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return list(self.rows), self.total
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ProfileStubRepo:
|
||||
"""Stub ``user_profile_repo`` — returns its configured row by id."""
|
||||
|
||||
row: Any = None
|
||||
last_id: str | None = None
|
||||
|
||||
async def get_by_id(self, id_: str) -> Any:
|
||||
self.last_id = id_
|
||||
return self.row
|
||||
|
||||
|
||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts(day: int = 1) -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, day, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _episode_row(entry: str) -> Episode:
|
||||
return Episode(
|
||||
id=f"u1_{entry}",
|
||||
entry_id=entry,
|
||||
owner_id="u1",
|
||||
owner_type="user",
|
||||
session_id="sess_a",
|
||||
timestamp=_ts(),
|
||||
parent_type="memcell",
|
||||
parent_id="mc_1",
|
||||
sender_ids=["u1", "assistant"],
|
||||
subject=f"subj {entry}",
|
||||
summary=f"summary {entry}",
|
||||
episode=f"body of {entry}",
|
||||
episode_tokens=f"body of {entry}",
|
||||
md_path=f"users/u1/episodes/{entry}.md",
|
||||
content_sha256="abc",
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
def _agent_case_row(entry: str) -> AgentCase:
|
||||
return AgentCase(
|
||||
id=f"a1_{entry}",
|
||||
entry_id=entry,
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
session_id="sess_x",
|
||||
timestamp=_ts(),
|
||||
parent_type="memcell",
|
||||
parent_id="mc_99",
|
||||
quality_score=0.8,
|
||||
task_intent=f"intent {entry}",
|
||||
task_intent_tokens=f"intent {entry}",
|
||||
approach=f"approach {entry}",
|
||||
approach_tokens=f"approach {entry}",
|
||||
key_insight=None,
|
||||
md_path=f"agents/a1/cases/{entry}.md",
|
||||
content_sha256="abc",
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
def _agent_skill_row(name: str) -> AgentSkill:
|
||||
return AgentSkill(
|
||||
id=f"a1_{name}",
|
||||
owner_id="a1",
|
||||
owner_type="agent",
|
||||
name=name,
|
||||
description=f"desc {name}",
|
||||
description_tokens=f"desc {name}",
|
||||
content=f"content {name}",
|
||||
content_tokens=f"content {name}",
|
||||
confidence=0.9,
|
||||
maturity_score=0.7,
|
||||
source_case_ids=["a1_ac_1"],
|
||||
md_path=f"agents/a1/skills/{name}/SKILL.md",
|
||||
content_sha256="abc",
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
def _user_profile_row(owner: str = "u1") -> UserProfile:
|
||||
return UserProfile(
|
||||
id=owner,
|
||||
owner_id=owner,
|
||||
owner_type="user",
|
||||
app_id="default",
|
||||
project_id="default",
|
||||
summary=f"{owner} loves climbing in Yosemite",
|
||||
explicit_info_json='[{"category": "Hobby", "description": "climbing"}]',
|
||||
implicit_traits_json='[{"trait": "Outdoorsy"}]',
|
||||
profile_timestamp_ms=1780304400000,
|
||||
md_path=f"users/{owner}/user.md",
|
||||
content_sha256="abc",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profile_repo() -> _ProfileStubRepo:
|
||||
return _ProfileStubRepo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(
|
||||
profile_repo: _ProfileStubRepo,
|
||||
) -> tuple[GetManager, _StubRepo, _StubRepo, _StubRepo]:
|
||||
ep = _StubRepo()
|
||||
ac = _StubRepo()
|
||||
sk = _StubRepo()
|
||||
mgr = GetManager(
|
||||
episode_repo=ep, # type: ignore[arg-type]
|
||||
agent_case_repo=ac, # type: ignore[arg-type]
|
||||
agent_skill_repo=sk, # type: ignore[arg-type]
|
||||
user_profile_repo=profile_repo, # type: ignore[arg-type]
|
||||
)
|
||||
return mgr, ep, ac, sk
|
||||
|
||||
|
||||
# ── Episode dispatch ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_episodic_memory_populates_episodes_and_counts(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
mgr, ep, _, _ = manager
|
||||
ep.rows = [_episode_row("ep_1"), _episode_row("ep_2")]
|
||||
ep.total = 17 # filtered total may exceed the page
|
||||
req = GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
)
|
||||
resp = await mgr.get(req)
|
||||
|
||||
assert len(resp.request_id) == 32 and all(
|
||||
c in "0123456789abcdef" for c in resp.request_id
|
||||
)
|
||||
assert resp.data.total_count == 17
|
||||
assert resp.data.count == 2
|
||||
assert [item.id for item in resp.data.episodes] == ["u1_ep_1", "u1_ep_2"]
|
||||
assert resp.data.profiles == []
|
||||
assert resp.data.agent_cases == []
|
||||
assert resp.data.agent_skills == []
|
||||
# The shaper maps the lance row's owner_id onto the item's user_id field.
|
||||
assert all(item.user_id == "u1" for item in resp.data.episodes)
|
||||
|
||||
|
||||
async def test_episodic_memory_passes_where_and_sort_to_repo(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
"""The compiled ``where`` must include owner_id + filter clauses."""
|
||||
mgr, ep, _, _ = manager
|
||||
req = GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.EPISODE,
|
||||
sort_by="timestamp",
|
||||
sort_order="asc",
|
||||
page=2,
|
||||
page_size=10,
|
||||
filters=FilterNode.model_validate({"session_id": "sess_a"}),
|
||||
)
|
||||
await mgr.get(req)
|
||||
assert "owner_id = 'u1'" in ep.last.where
|
||||
assert "owner_type = 'user'" in ep.last.where
|
||||
assert "session_id = 'sess_a'" in ep.last.where
|
||||
assert ep.last.sort_by == "timestamp"
|
||||
assert ep.last.descending is False # asc
|
||||
assert ep.last.page == 2
|
||||
assert ep.last.page_size == 10
|
||||
|
||||
|
||||
# ── Profile dispatch ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_profile_miss_returns_empty(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
"""Cold start (no profile row yet) → empty list + total_count=0."""
|
||||
mgr, ep, ac, sk = manager # profile_repo.row defaults to None
|
||||
req = GetRequest(
|
||||
user_id="u1",
|
||||
memory_type=GetMemoryType.PROFILE,
|
||||
)
|
||||
resp = await mgr.get(req)
|
||||
assert resp.data.profiles == []
|
||||
assert resp.data.total_count == 0
|
||||
assert resp.data.count == 0
|
||||
# The profile path never touches the paginated (episode/case/skill) repos.
|
||||
assert ep.last.where == ""
|
||||
assert ac.last.where == ""
|
||||
assert sk.last.where == ""
|
||||
|
||||
|
||||
async def test_profile_hit_shapes_row_into_item(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
profile_repo: _ProfileStubRepo,
|
||||
) -> None:
|
||||
"""A present profile row is fetched by owner and shaped + json-decoded."""
|
||||
mgr, *_ = manager
|
||||
profile_repo.row = _user_profile_row("u1")
|
||||
req = GetRequest(user_id="u1", memory_type=GetMemoryType.PROFILE)
|
||||
resp = await mgr.get(req)
|
||||
|
||||
assert resp.data.total_count == 1
|
||||
assert resp.data.count == 1
|
||||
assert len(resp.data.profiles) == 1
|
||||
item = resp.data.profiles[0]
|
||||
assert item.id == "u1"
|
||||
assert item.user_id == "u1"
|
||||
# KV fetch keys on owner_id.
|
||||
assert profile_repo.last_id == "u1"
|
||||
# json buckets are decoded back into structured profile_data.
|
||||
assert item.profile_data["summary"] == "u1 loves climbing in Yosemite"
|
||||
assert item.profile_data["explicit_info"] == [
|
||||
{"category": "Hobby", "description": "climbing"}
|
||||
]
|
||||
assert item.profile_data["implicit_traits"] == [{"trait": "Outdoorsy"}]
|
||||
assert item.profile_data["profile_timestamp_ms"] == 1780304400000
|
||||
|
||||
|
||||
# ── Agent case dispatch ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agent_case_populates_agent_cases(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
mgr, _, ac, _ = manager
|
||||
ac.rows = [_agent_case_row("ac_1"), _agent_case_row("ac_2")]
|
||||
ac.total = 2
|
||||
req = GetRequest(
|
||||
agent_id="a1",
|
||||
memory_type=GetMemoryType.AGENT_CASE,
|
||||
)
|
||||
resp = await mgr.get(req)
|
||||
assert resp.data.total_count == 2
|
||||
assert resp.data.count == 2
|
||||
assert [item.id for item in resp.data.agent_cases] == ["a1_ac_1", "a1_ac_2"]
|
||||
assert resp.data.episodes == []
|
||||
assert resp.data.agent_skills == []
|
||||
|
||||
|
||||
# ── Agent skill dispatch — sort_by silent override ──────────────────────
|
||||
|
||||
|
||||
async def test_agent_skill_sort_by_silently_overridden_to_updated_at(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
"""``agent_skill`` always sorts by ``updated_at`` (no ``timestamp`` column)."""
|
||||
mgr, _, _, sk = manager
|
||||
sk.rows = [_agent_skill_row("planner")]
|
||||
sk.total = 1
|
||||
req = GetRequest(
|
||||
agent_id="a1",
|
||||
memory_type=GetMemoryType.AGENT_SKILL,
|
||||
# User passes the default — should be silently downgraded.
|
||||
sort_by="timestamp",
|
||||
)
|
||||
resp = await mgr.get(req)
|
||||
assert sk.last.sort_by == "updated_at"
|
||||
assert resp.data.total_count == 1
|
||||
assert resp.data.agent_skills[0].name == "planner"
|
||||
|
||||
|
||||
async def test_agent_skill_explicit_updated_at_is_respected(
|
||||
manager: tuple[GetManager, _StubRepo, _StubRepo, _StubRepo],
|
||||
) -> None:
|
||||
"""``updated_at`` passes through unchanged (no double-override surprise)."""
|
||||
mgr, _, _, sk = manager
|
||||
req = GetRequest(
|
||||
agent_id="a1",
|
||||
memory_type=GetMemoryType.AGENT_SKILL,
|
||||
sort_by="updated_at",
|
||||
)
|
||||
await mgr.get(req)
|
||||
assert sk.last.sort_by == "updated_at"
|
||||
196
tests/unit/test_memory/test_models.py
Normal file
196
tests/unit/test_memory/test_models.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""Unit tests for memory domain models — focused on ``from_algo`` factories.
|
||||
|
||||
The factories carry the load-bearing contract: algo's emitted business
|
||||
fields survive, everos's engineering metadata (session_id / sender_ids
|
||||
/ parent_id) gets injected, and any algo-side ``parent_id`` (smuggled
|
||||
through ``extra='allow'``) is dropped in favour of the caller's value.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from everalgo.types import (
|
||||
AgentCase as AlgoAgentCase,
|
||||
)
|
||||
from everalgo.types import (
|
||||
AtomicFact as AlgoAtomicFact,
|
||||
)
|
||||
from everalgo.types import (
|
||||
Episode as AlgoEpisode,
|
||||
)
|
||||
from everalgo.types import (
|
||||
Foresight as AlgoForesight,
|
||||
)
|
||||
|
||||
from everos.memory.models import AgentCase, AtomicFact, Episode, Foresight
|
||||
|
||||
|
||||
def test_atomic_fact_from_algo_carries_business_fields_and_metadata() -> None:
|
||||
algo = AlgoAtomicFact(
|
||||
owner_id="u_alice",
|
||||
content="alice likes hiking",
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
fact = AtomicFact.from_algo(
|
||||
algo,
|
||||
owner_id="u_alice",
|
||||
session_id="s_42",
|
||||
parent_id="mc_abc",
|
||||
)
|
||||
assert fact.owner_id == "u_alice"
|
||||
assert fact.fact == "alice likes hiking"
|
||||
assert fact.timestamp == 1_700_000_000_000
|
||||
assert fact.session_id == "s_42"
|
||||
assert fact.parent_id == "mc_abc"
|
||||
assert not hasattr(fact, "sender_ids")
|
||||
|
||||
|
||||
def test_atomic_fact_from_algo_drops_algo_side_parent_id() -> None:
|
||||
# Smuggle a parent_id through extra='allow' on the algo side.
|
||||
algo = AlgoAtomicFact.model_validate(
|
||||
{
|
||||
"owner_id": "u_alice",
|
||||
"content": "x",
|
||||
"timestamp": 1_700_000_000_000,
|
||||
"parent_id": "ALGO_STALE",
|
||||
}
|
||||
)
|
||||
fact = AtomicFact.from_algo(
|
||||
algo, owner_id="u_alice", session_id="s1", parent_id="mc_real"
|
||||
)
|
||||
# Caller-supplied parent_id wins; algo-side value is discarded.
|
||||
assert fact.parent_id == "mc_real"
|
||||
|
||||
|
||||
def test_atomic_fact_from_algo_owner_id_override_for_fan_out() -> None:
|
||||
"""One LLM template fans out to many owners — caller's owner_id wins."""
|
||||
algo = AlgoAtomicFact(
|
||||
owner_id="PLACEHOLDER", # subject-agnostic prompt placeholder
|
||||
content="likes hiking",
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
fact_alice = AtomicFact.from_algo(
|
||||
algo, owner_id="u_alice", session_id="s1", parent_id="mc_a"
|
||||
)
|
||||
fact_bob = AtomicFact.from_algo(
|
||||
algo, owner_id="u_bob", session_id="s1", parent_id="mc_a"
|
||||
)
|
||||
assert fact_alice.owner_id == "u_alice"
|
||||
assert fact_bob.owner_id == "u_bob"
|
||||
# Same source template body survives the fan-out.
|
||||
assert fact_alice.fact == fact_bob.fact == "likes hiking"
|
||||
|
||||
|
||||
def test_foresight_from_algo_preserves_optional_time_window() -> None:
|
||||
algo = AlgoForesight(
|
||||
owner_id="u_alice",
|
||||
foresight="plans trip to tokyo",
|
||||
evidence="said so",
|
||||
timestamp=1_700_000_000_000,
|
||||
start_time="2026-06-01",
|
||||
duration_days=7,
|
||||
)
|
||||
fs = Foresight.from_algo(algo, session_id="s1", parent_id="mc_a")
|
||||
assert fs.foresight == "plans trip to tokyo"
|
||||
assert fs.evidence == "said so"
|
||||
assert fs.start_time == "2026-06-01"
|
||||
assert fs.duration_days == 7
|
||||
assert fs.end_time is None
|
||||
assert fs.parent_id == "mc_a"
|
||||
assert not hasattr(fs, "sender_ids")
|
||||
|
||||
|
||||
def test_foresight_from_algo_drops_algo_side_parent_id() -> None:
|
||||
algo = AlgoForesight.model_validate(
|
||||
{
|
||||
"owner_id": "u_alice",
|
||||
"foresight": "x",
|
||||
"evidence": "y",
|
||||
"timestamp": 1_700_000_000_000,
|
||||
"parent_id": "ALGO_STALE",
|
||||
}
|
||||
)
|
||||
fs = Foresight.from_algo(algo, session_id="s1", parent_id="mc_real")
|
||||
assert fs.parent_id == "mc_real"
|
||||
|
||||
|
||||
def test_foresight_from_algo_preserves_algo_owner_id() -> None:
|
||||
"""Per-sender extraction: algo emits the correct owner_id."""
|
||||
algo = AlgoForesight(
|
||||
owner_id="u_bob",
|
||||
foresight="trip to tokyo",
|
||||
evidence="said so",
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
fs = Foresight.from_algo(algo, session_id="s1", parent_id="mc_a")
|
||||
assert fs.owner_id == "u_bob"
|
||||
|
||||
|
||||
def test_agent_case_from_algo_injects_owner_and_drops_algo_id() -> None:
|
||||
"""Algo emits a uuid `id` + no owner; everos injects agent_id, drops uuid."""
|
||||
algo = AlgoAgentCase(
|
||||
id=uuid.uuid4().hex,
|
||||
timestamp=1_700_000_000_000,
|
||||
task_intent="summarise doc",
|
||||
approach="read + condense",
|
||||
quality_score=0.75,
|
||||
key_insight="batch-then-summarise",
|
||||
)
|
||||
case = AgentCase.from_algo(
|
||||
algo, owner_id="agent_42", session_id="s1", parent_id="mc_a"
|
||||
)
|
||||
assert case.owner_id == "agent_42"
|
||||
assert case.task_intent == "summarise doc"
|
||||
assert case.approach == "read + condense"
|
||||
assert case.quality_score == 0.75
|
||||
assert case.key_insight == "batch-then-summarise"
|
||||
assert case.session_id == "s1"
|
||||
assert case.parent_id == "mc_a"
|
||||
# algo's uuid `id` is not surfaced on the domain model.
|
||||
assert not hasattr(case, "id") or case.id != algo.id # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_agent_case_from_algo_normalises_empty_key_insight_to_none() -> None:
|
||||
"""algo emits `""` when there's nothing to insight; domain normalises to None."""
|
||||
algo = AlgoAgentCase(
|
||||
id=uuid.uuid4().hex,
|
||||
timestamp=1_700_000_000_000,
|
||||
task_intent="ti",
|
||||
approach="ap",
|
||||
quality_score=0.5,
|
||||
key_insight="",
|
||||
)
|
||||
case = AgentCase.from_algo(
|
||||
algo, owner_id="agent_42", session_id="s1", parent_id="mc_a"
|
||||
)
|
||||
assert case.key_insight is None
|
||||
|
||||
|
||||
def test_episode_from_algo_owner_id_caller_supplied() -> None:
|
||||
"""Caller supplies ``owner_id``; algo's value (None or otherwise) is dropped.
|
||||
|
||||
The pipeline runs the algo once with ``sender_id=None`` (generic
|
||||
EPISODE_GENERATION_PROMPT) and then fans the same algo Episode out
|
||||
to one domain Episode per user sender, each rooted at its own owner.
|
||||
"""
|
||||
algo = AlgoEpisode(owner_id=None, episode="hello", timestamp=1_700_000_000_000)
|
||||
ep_alice = Episode.from_algo(
|
||||
algo,
|
||||
owner_id="u_alice",
|
||||
session_id="s1",
|
||||
sender_ids=["u_alice", "u_bob"],
|
||||
parent_id="mc_a",
|
||||
)
|
||||
ep_bob = Episode.from_algo(
|
||||
algo,
|
||||
owner_id="u_bob",
|
||||
session_id="s1",
|
||||
sender_ids=["u_alice", "u_bob"],
|
||||
parent_id="mc_a",
|
||||
)
|
||||
assert ep_alice.owner_id == "u_alice"
|
||||
assert ep_bob.owner_id == "u_bob"
|
||||
assert ep_alice.episode == ep_bob.episode == "hello"
|
||||
assert ep_alice.parent_id == ep_bob.parent_id == "mc_a"
|
||||
assert ep_alice.session_id == ep_bob.session_id == "s1"
|
||||
61
tests/unit/test_memory/test_prompt_slots/test_loader.py
Normal file
61
tests/unit/test_memory/test_prompt_slots/test_loader.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""``PromptLoader.load`` — returns template iff enabled + non-empty."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from everos.memory.prompt_slots.loader import PromptLoader
|
||||
|
||||
|
||||
def _write_slot(root: Path, name: str, content: str) -> None:
|
||||
slot_dir = root / "prompt_slots"
|
||||
slot_dir.mkdir(parents=True, exist_ok=True)
|
||||
(slot_dir / f"{name}.yaml").write_text(content)
|
||||
|
||||
|
||||
def test_returns_none_when_file_missing(tmp_path: Path) -> None:
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("boundary_detection") is None
|
||||
|
||||
|
||||
def test_returns_none_when_disabled(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "enabled: false\ntemplate: 'hello'\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
|
||||
|
||||
def test_returns_none_when_enabled_key_absent(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "template: 'hello'\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
|
||||
|
||||
def test_returns_none_when_template_empty(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "enabled: true\ntemplate: ''\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
|
||||
|
||||
def test_returns_none_when_template_whitespace(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "enabled: true\ntemplate: ' '\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
|
||||
|
||||
def test_returns_none_when_template_missing(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "enabled: true\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
|
||||
|
||||
def test_returns_template_when_enabled_and_non_empty(tmp_path: Path) -> None:
|
||||
_write_slot(tmp_path, "x", "enabled: true\ntemplate: 'detect now'\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") == "detect now"
|
||||
|
||||
|
||||
def test_template_must_be_string(tmp_path: Path) -> None:
|
||||
"""Non-string ``template`` (e.g. accidental int) is treated as None."""
|
||||
_write_slot(tmp_path, "x", "enabled: true\ntemplate: 42\n")
|
||||
loader = PromptLoader(tmp_path)
|
||||
assert loader.load("x") is None
|
||||
0
tests/unit/test_memory/test_search/__init__.py
Normal file
0
tests/unit/test_memory/test_search/__init__.py
Normal file
27
tests/unit/test_memory/test_search/conftest.py
Normal file
27
tests/unit/test_memory/test_search/conftest.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Shared fixtures for ``memory.search`` unit tests.
|
||||
|
||||
The project default is ``EVEROS_SEARCH__VECTOR_STRATEGY=maxsim_atomic`` —
|
||||
that path queries both the ``atomic_fact`` table and the ``episode`` table
|
||||
to do MaxSim. The existing VECTOR-route tests in ``test_manager.py`` were
|
||||
written against the legacy single-vector ``episode`` path and stub only the
|
||||
episode recaller (atomic_fact recaller is a no-data stub).
|
||||
|
||||
Force the legacy ``episode`` strategy by default for these tests so they
|
||||
keep asserting against the dense-recall path they were designed to cover.
|
||||
MaxSim-specific tests opt back into ``maxsim_atomic`` by overriding the env
|
||||
var inside their own body.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.config.settings import load_settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _force_episode_vector_strategy(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EVEROS_SEARCH__VECTOR_STRATEGY", "episode")
|
||||
load_settings.cache_clear()
|
||||
yield
|
||||
load_settings.cache_clear()
|
||||
59
tests/unit/test_memory/test_search/test_adapter.py
Normal file
59
tests/unit/test_memory/test_search/test_adapter.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Unit tests for ``memory.search.adapter.resolve_pipeline``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.search.adapter import resolve_pipeline
|
||||
from everos.memory.search.dto import SearchMethod
|
||||
|
||||
|
||||
def test_keyword_skips_everalgo() -> None:
|
||||
fm, cfg = resolve_pipeline(SearchMethod.KEYWORD, "episode")
|
||||
assert fm is None
|
||||
assert cfg is None
|
||||
|
||||
|
||||
def test_vector_skips_everalgo() -> None:
|
||||
fm, cfg = resolve_pipeline(SearchMethod.VECTOR, "episode")
|
||||
assert fm is None
|
||||
assert cfg is None
|
||||
|
||||
|
||||
def test_hybrid_episode_picks_hierarchy() -> None:
|
||||
fm, cfg = resolve_pipeline(SearchMethod.HYBRID, "episode")
|
||||
assert fm == "hierarchy"
|
||||
assert cfg is None
|
||||
|
||||
|
||||
def test_hybrid_atomic_fact_picks_hierarchy() -> None:
|
||||
fm, _cfg = resolve_pipeline(SearchMethod.HYBRID, "atomic_fact")
|
||||
assert fm == "hierarchy"
|
||||
|
||||
|
||||
def test_hybrid_case_picks_vector_anchored() -> None:
|
||||
fm, cfg = resolve_pipeline(SearchMethod.HYBRID, "agent_case")
|
||||
assert fm == "vector_anchored"
|
||||
assert cfg is None
|
||||
|
||||
|
||||
def test_hybrid_skill_picks_skill_hybrid() -> None:
|
||||
fm, _cfg = resolve_pipeline(SearchMethod.HYBRID, "agent_skill")
|
||||
assert fm == "skill_hybrid"
|
||||
|
||||
|
||||
def test_agentic_method_raises_value_error() -> None:
|
||||
"""AGENTIC (a valid enum member) raises ValueError from resolve_pipeline.
|
||||
|
||||
Distinct from ``test_unsupported_method_raises`` which passes an arbitrary
|
||||
non-enum string. This test verifies the manager's contract: AGENTIC must be
|
||||
intercepted before resolve_pipeline is called, and resolve_pipeline defends
|
||||
against it with a ValueError even for the known enum member.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="unsupported method"):
|
||||
resolve_pipeline(SearchMethod.AGENTIC, "episode")
|
||||
|
||||
|
||||
def test_unsupported_method_raises() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported method"):
|
||||
resolve_pipeline("not-a-method", "episode") # type: ignore[arg-type]
|
||||
338
tests/unit/test_memory/test_search/test_agentic.py
Normal file
338
tests/unit/test_memory/test_search/test_agentic.py
Normal file
@ -0,0 +1,338 @@
|
||||
"""Unit tests for ``memory.search.agentic.search_episodes_agentic``.
|
||||
|
||||
White-box: patches ``aagentic_retrieve`` to assert benchmark hyperparameters
|
||||
are wired correctly, plus a shaping test to verify id remapping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from everalgo.clustering import Cluster
|
||||
from everalgo.rank.protocols import AgenticDecision
|
||||
from everalgo.testing.fake_llm import FakeLLMClient
|
||||
from everalgo.types import Candidate
|
||||
|
||||
from everos.component.utils.datetime import from_timestamp
|
||||
from everos.memory.search.agentic import (
|
||||
_restore_shaper_metadata,
|
||||
_to_everalgo_doc_metadata,
|
||||
search_episodes_agentic,
|
||||
)
|
||||
from everos.memory.search.dto import SearchEpisodeItem
|
||||
|
||||
# ── Stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _mc_candidate(mc_id: str, ep_id: str, score: float = 0.8) -> Candidate:
|
||||
"""Candidate keyed by memcell_id (as returned by amaxsim/fetch_all_for_owner)."""
|
||||
return Candidate(
|
||||
id=mc_id,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"episode_id": ep_id,
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess_a",
|
||||
"timestamp": _ts(),
|
||||
"sender_ids": ["alice"],
|
||||
"subject": "Alice eats oat milk",
|
||||
"summary": "Alice food preferences",
|
||||
"episode": "Alice prefers oat milk in her coffee",
|
||||
"parent_id": mc_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _StubEpisodeRecaller:
|
||||
kind: ClassVar[str] = "episode"
|
||||
everalgo_memory_type: ClassVar[str] = "episodic"
|
||||
text_field: ClassVar[str] = "episode"
|
||||
|
||||
def __init__(
|
||||
self, all_docs: list[Candidate], by_parent: dict[str, Candidate]
|
||||
) -> None:
|
||||
self._all_docs = all_docs
|
||||
self._by_parent = by_parent
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return []
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._all_docs)
|
||||
|
||||
async def fetch_by_parent_ids(
|
||||
self, parent_ids: Sequence[str], where: str
|
||||
) -> list[Candidate]:
|
||||
"""Returns Candidate with id=episode_id (real LanceDB id)."""
|
||||
return [self._by_parent[p] for p in parent_ids if p in self._by_parent]
|
||||
|
||||
async def fetch_all_for_owner(self, where: str) -> list[Candidate]:
|
||||
"""Returns Candidate with id=memcell_id and metadata['episode_id']."""
|
||||
return list(self._all_docs)
|
||||
|
||||
|
||||
class _StubFactRecaller:
|
||||
kind: ClassVar[str] = "atomic_fact"
|
||||
everalgo_memory_type: ClassVar[str] = "episodic"
|
||||
text_field: ClassVar[str] = "fact"
|
||||
|
||||
def __init__(self, facts: list[Candidate]) -> None:
|
||||
self._facts = facts
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._facts)
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._facts)
|
||||
|
||||
|
||||
class _StubReranker:
|
||||
async def rerank(
|
||||
self, query: str, passages: list[str], *, instruction: str | None = None
|
||||
) -> list[Any]:
|
||||
class _R:
|
||||
def __init__(self, idx: int) -> None:
|
||||
self.index = idx
|
||||
self.score = 1.0 - idx * 0.1
|
||||
|
||||
return [_R(i) for i in range(len(passages))]
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mc_cand() -> Candidate:
|
||||
return _mc_candidate("mc_1", "ep_1")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def ep_recaller(mc_cand: Candidate) -> _StubEpisodeRecaller:
|
||||
ep_raw = Candidate(
|
||||
id="ep_1",
|
||||
score=0.0,
|
||||
source="vector",
|
||||
metadata=mc_cand.metadata,
|
||||
)
|
||||
return _StubEpisodeRecaller(
|
||||
all_docs=[mc_cand],
|
||||
by_parent={"mc_1": ep_raw},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fact_cand() -> Candidate:
|
||||
return Candidate(
|
||||
id="f_1",
|
||||
score=0.9,
|
||||
source="vector",
|
||||
metadata={"parent_id": "mc_1", "fact": "Alice prefers oat milk"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fact_recaller(fact_cand: Candidate) -> _StubFactRecaller:
|
||||
return _StubFactRecaller([fact_cand])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def clusters() -> list[Cluster]:
|
||||
# ``cluster_repo.list_for_owner`` is mocked in every test, so cluster
|
||||
# contents are never exercised by everalgo; we only need a valid instance
|
||||
# that satisfies the everalgo ``Cluster`` schema (ndarray centroid + last_ts).
|
||||
return [
|
||||
Cluster(
|
||||
id="cl_1",
|
||||
members=["mc_1"],
|
||||
centroid=np.zeros(4, dtype=np.float32),
|
||||
last_ts=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agentic_search_wires_benchmark_hyperparams(
|
||||
ep_recaller: _StubEpisodeRecaller,
|
||||
fact_recaller: _StubFactRecaller,
|
||||
clusters: list[Cluster],
|
||||
) -> None:
|
||||
"""aagentic_retrieve must be called with the exact benchmark hyperparams."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_aagentic(
|
||||
query: str,
|
||||
*,
|
||||
base_retrieve: Any,
|
||||
llm: Any,
|
||||
rerank_fn: Any,
|
||||
round2_retrieve: Any,
|
||||
round2_cap: int,
|
||||
top_n: int,
|
||||
round1_top_n: int,
|
||||
round1_rerank_top_n: int,
|
||||
refinement_strategy: str,
|
||||
multi_query_count: int,
|
||||
rrf_k: int,
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
captured.update(
|
||||
top_n=top_n,
|
||||
round1_top_n=round1_top_n,
|
||||
round1_rerank_top_n=round1_rerank_top_n,
|
||||
round2_cap=round2_cap,
|
||||
multi_query_count=multi_query_count,
|
||||
rrf_k=rrf_k,
|
||||
refinement_strategy=refinement_strategy,
|
||||
has_round2=round2_retrieve is not None,
|
||||
)
|
||||
return [], AgenticDecision(is_multi_round=False)
|
||||
|
||||
async def fake_embed(q: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with (
|
||||
patch("everos.memory.search.agentic.aagentic_retrieve", fake_aagentic),
|
||||
patch(
|
||||
"everos.memory.search.agentic.cluster_repo.list_for_owner",
|
||||
AsyncMock(return_value=clusters),
|
||||
),
|
||||
):
|
||||
await search_episodes_agentic(
|
||||
"What did Alice eat?",
|
||||
owner_id="alice",
|
||||
where="owner_id = 'alice' AND owner_type = 'user'",
|
||||
episode_recaller=ep_recaller,
|
||||
atomic_fact_recaller=fact_recaller,
|
||||
embed_query_fn=fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert captured["top_n"] == 10
|
||||
assert captured["round1_top_n"] == 50
|
||||
assert captured["round1_rerank_top_n"] == 10
|
||||
assert captured["round2_cap"] == 40
|
||||
assert captured["multi_query_count"] == 3
|
||||
assert captured["rrf_k"] == 40
|
||||
assert captured["refinement_strategy"] == "multi_query"
|
||||
assert captured["has_round2"] is True
|
||||
|
||||
|
||||
async def test_agentic_search_loads_user_memory_clusters(
|
||||
ep_recaller: _StubEpisodeRecaller,
|
||||
fact_recaller: _StubFactRecaller,
|
||||
) -> None:
|
||||
"""cluster_repo.list_for_owner must be called with kind='user_memory'."""
|
||||
mock_list = AsyncMock(return_value=[])
|
||||
|
||||
async def fake_embed(q: str) -> list[float]:
|
||||
return [0.1] * 4
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.search.agentic.aagentic_retrieve",
|
||||
AsyncMock(return_value=([], AgenticDecision(is_multi_round=False))),
|
||||
),
|
||||
patch("everos.memory.search.agentic.cluster_repo.list_for_owner", mock_list),
|
||||
):
|
||||
await search_episodes_agentic(
|
||||
"q",
|
||||
owner_id="alice",
|
||||
where="owner_id = 'alice' AND owner_type = 'user'",
|
||||
episode_recaller=ep_recaller,
|
||||
atomic_fact_recaller=fact_recaller,
|
||||
embed_query_fn=fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with("alice", "user_memory")
|
||||
|
||||
|
||||
async def test_agentic_search_shapes_candidates_with_episode_id(
|
||||
ep_recaller: _StubEpisodeRecaller,
|
||||
fact_recaller: _StubFactRecaller,
|
||||
clusters: list[Cluster],
|
||||
mc_cand: Candidate,
|
||||
) -> None:
|
||||
"""SearchEpisodeItem.id must be episode_id (not memcell_id) after retrieve."""
|
||||
|
||||
async def fake_aagentic(
|
||||
*_: Any, **__: Any
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
return [mc_cand], AgenticDecision(is_multi_round=False)
|
||||
|
||||
async def fake_embed(q: str) -> list[float]:
|
||||
return [0.1] * 4
|
||||
|
||||
with (
|
||||
patch("everos.memory.search.agentic.aagentic_retrieve", fake_aagentic),
|
||||
patch(
|
||||
"everos.memory.search.agentic.cluster_repo.list_for_owner",
|
||||
AsyncMock(return_value=clusters),
|
||||
),
|
||||
):
|
||||
result = await search_episodes_agentic(
|
||||
"What did Alice eat?",
|
||||
owner_id="alice",
|
||||
where="owner_id = 'alice' AND owner_type = 'user'",
|
||||
episode_recaller=ep_recaller,
|
||||
atomic_fact_recaller=fact_recaller,
|
||||
embed_query_fn=fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SearchEpisodeItem)
|
||||
assert result[0].id == "ep_1", (
|
||||
f"Expected episode_id='ep_1' but got {result[0].id!r}. "
|
||||
"Shaper must remap from memcell_id via metadata['episode_id']."
|
||||
)
|
||||
|
||||
|
||||
# ── Metadata bridge to the everalgo _format_docs contract ──────────────────
|
||||
|
||||
|
||||
def test_to_everalgo_doc_metadata_injects_text_and_ms_timestamp() -> None:
|
||||
"""Bridge adds `text` (episode body) + ms-epoch `timestamp` for _format_docs.
|
||||
|
||||
Without this the sufficiency / multi-query LLM prompt falls back to the
|
||||
memcell id as the doc body and renders the date as "N/A". ``episode`` is
|
||||
left untouched so the reranker / shaper (both expecting a str) keep working.
|
||||
"""
|
||||
original = _ts()
|
||||
md = {
|
||||
"episode": "Alice prefers oat milk",
|
||||
"timestamp": original,
|
||||
"subject": "Alice eats oat milk",
|
||||
}
|
||||
out = _to_everalgo_doc_metadata(md)
|
||||
assert out["text"] == "Alice prefers oat milk"
|
||||
assert out["episode"] == "Alice prefers oat milk" # untouched for rerank/shaper
|
||||
assert isinstance(out["timestamp"], int)
|
||||
assert from_timestamp(out["timestamp"]) == original
|
||||
|
||||
|
||||
def test_restore_shaper_metadata_reverts_ms_timestamp_to_datetime() -> None:
|
||||
"""The ms-epoch timestamp is reverted to the datetime the shaper requires."""
|
||||
original = _ts()
|
||||
bridged = _to_everalgo_doc_metadata({"episode": "x", "timestamp": original})
|
||||
restored = _restore_shaper_metadata(bridged)
|
||||
assert isinstance(restored["timestamp"], _dt.datetime)
|
||||
assert restored["timestamp"] == original
|
||||
272
tests/unit/test_memory/test_search/test_agentic_agent.py
Normal file
272
tests/unit/test_memory/test_search/test_agentic_agent.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Unit tests for ``memory.search.agentic_agent``.
|
||||
|
||||
White-box: patches ``aagentic_retrieve`` to assert benchmark hyperparameters
|
||||
are wired correctly, plus a shaping test to verify DTOs are built correctly.
|
||||
|
||||
The skill verify step has been removed from production code; this test
|
||||
module covers the agentic retrieve flow only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from typing import Any, ClassVar
|
||||
from unittest.mock import patch
|
||||
|
||||
from everalgo.rank.protocols import AgenticDecision
|
||||
from everalgo.testing.fake_llm import FakeLLMClient
|
||||
from everalgo.types import Candidate
|
||||
|
||||
from everos.memory.search.agentic_agent import (
|
||||
search_agent_cases_agentic,
|
||||
search_agent_skills_agentic,
|
||||
)
|
||||
from everos.memory.search.dto import SearchAgentCaseItem, SearchAgentSkillItem
|
||||
|
||||
# ── Stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _case_candidate(cid: str, score: float = 0.8) -> Candidate:
|
||||
return Candidate(
|
||||
id=cid,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"session_id": "sess_b",
|
||||
"timestamp": _ts(),
|
||||
"task_intent": f"intent {cid}",
|
||||
"approach": f"approach {cid}",
|
||||
"quality_score": 0.8,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _skill_candidate(sid: str, score: float = 0.75) -> Candidate:
|
||||
return Candidate(
|
||||
id=sid,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"name": f"skill_{sid}",
|
||||
"description": f"desc {sid}",
|
||||
"content": f"content {sid}",
|
||||
"confidence": 0.9,
|
||||
"maturity_score": 0.6,
|
||||
"source_case_ids": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _StubCaseRecaller:
|
||||
kind: ClassVar[str] = "agent_case"
|
||||
everalgo_memory_type: ClassVar[str] = "case"
|
||||
text_field: ClassVar[str] = "task_intent"
|
||||
|
||||
def __init__(self, dense: list[Candidate]) -> None:
|
||||
self._dense = dense
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
|
||||
class _StubSkillRecaller:
|
||||
kind: ClassVar[str] = "agent_skill"
|
||||
everalgo_memory_type: ClassVar[str] = "skill"
|
||||
text_field: ClassVar[str] = "description"
|
||||
|
||||
def __init__(self, dense: list[Candidate]) -> None:
|
||||
self._dense = dense
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
|
||||
class _StubReranker:
|
||||
async def rerank(self, query: str, passages: list[str]) -> list[Any]:
|
||||
class _R:
|
||||
def __init__(self, idx: int) -> None:
|
||||
self.index = idx
|
||||
self.score = 1.0 - idx * 0.1
|
||||
|
||||
return [_R(i) for i in range(len(passages))]
|
||||
|
||||
|
||||
async def _fake_embed(q: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_search_agent_cases_agentic_calls_aagentic_retrieve_with_benchmark_params() -> ( # noqa: E501
|
||||
None
|
||||
):
|
||||
"""Verify aagentic_retrieve called with benchmark hyperparams for agent_case."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_aagentic(
|
||||
query: str,
|
||||
*,
|
||||
base_retrieve: Any,
|
||||
llm: Any,
|
||||
rerank_fn: Any,
|
||||
round2_retrieve: Any,
|
||||
round2_cap: Any,
|
||||
top_n: int,
|
||||
round1_top_n: int,
|
||||
round1_rerank_top_n: int,
|
||||
refinement_strategy: str,
|
||||
multi_query_count: int,
|
||||
rrf_k: int,
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
captured.update(
|
||||
top_n=top_n,
|
||||
round1_top_n=round1_top_n,
|
||||
round1_rerank_top_n=round1_rerank_top_n,
|
||||
round2_cap=round2_cap,
|
||||
round2_retrieve_is_none=round2_retrieve is None,
|
||||
multi_query_count=multi_query_count,
|
||||
rrf_k=rrf_k,
|
||||
refinement_strategy=refinement_strategy,
|
||||
)
|
||||
return [], AgenticDecision(is_multi_round=False)
|
||||
|
||||
with patch("everos.memory.search.agentic_agent.aagentic_retrieve", fake_aagentic):
|
||||
await search_agent_cases_agentic(
|
||||
"How did agent handle login failure?",
|
||||
where="owner_id = 'agent_a' AND owner_type = 'agent'",
|
||||
case_recaller=_StubCaseRecaller([]),
|
||||
embed_query_fn=_fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert captured["top_n"] == 10
|
||||
assert captured["round1_top_n"] == 20
|
||||
assert captured["round1_rerank_top_n"] == 10
|
||||
assert captured["round2_cap"] == 40
|
||||
assert captured["round2_retrieve_is_none"] is True
|
||||
assert captured["multi_query_count"] == 3
|
||||
assert captured["rrf_k"] == 60
|
||||
assert captured["refinement_strategy"] == "multi_query"
|
||||
|
||||
|
||||
async def test_search_agent_skills_agentic_calls_aagentic_retrieve_with_benchmark_params() -> ( # noqa: E501
|
||||
None
|
||||
):
|
||||
"""Verify aagentic_retrieve called with benchmark hyperparams for agent_skill."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_aagentic(
|
||||
query: str,
|
||||
*,
|
||||
base_retrieve: Any,
|
||||
llm: Any,
|
||||
rerank_fn: Any,
|
||||
round2_retrieve: Any,
|
||||
round2_cap: Any,
|
||||
top_n: int,
|
||||
round1_top_n: int,
|
||||
round1_rerank_top_n: int,
|
||||
refinement_strategy: str,
|
||||
multi_query_count: int,
|
||||
rrf_k: int,
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
captured.update(
|
||||
top_n=top_n,
|
||||
round1_top_n=round1_top_n,
|
||||
round1_rerank_top_n=round1_rerank_top_n,
|
||||
round2_cap=round2_cap,
|
||||
round2_retrieve_is_none=round2_retrieve is None,
|
||||
multi_query_count=multi_query_count,
|
||||
rrf_k=rrf_k,
|
||||
refinement_strategy=refinement_strategy,
|
||||
)
|
||||
return [], AgenticDecision(is_multi_round=False)
|
||||
|
||||
with patch("everos.memory.search.agentic_agent.aagentic_retrieve", fake_aagentic):
|
||||
await search_agent_skills_agentic(
|
||||
"What skill handles auth token refresh?",
|
||||
where="owner_id = 'agent_a' AND owner_type = 'agent'",
|
||||
skill_recaller=_StubSkillRecaller([]),
|
||||
embed_query_fn=_fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
assert captured["top_n"] == 5
|
||||
assert captured["round1_top_n"] == 20
|
||||
assert captured["round1_rerank_top_n"] == 10
|
||||
assert captured["round2_cap"] == 40
|
||||
assert captured["round2_retrieve_is_none"] is True
|
||||
assert captured["multi_query_count"] == 3
|
||||
assert captured["rrf_k"] == 60
|
||||
assert captured["refinement_strategy"] == "multi_query"
|
||||
|
||||
|
||||
async def test_search_agent_cases_agentic_shapes_result() -> None:
|
||||
"""Output must be list[SearchAgentCaseItem] built from aagentic_retrieve results."""
|
||||
cand = _case_candidate("c_1")
|
||||
|
||||
async def fake_aagentic(
|
||||
*_: Any, **__: Any
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
return [cand], AgenticDecision(is_multi_round=False)
|
||||
|
||||
with patch("everos.memory.search.agentic_agent.aagentic_retrieve", fake_aagentic):
|
||||
result = await search_agent_cases_agentic(
|
||||
"intent query",
|
||||
where="owner_id = 'agent_a' AND owner_type = 'agent'",
|
||||
case_recaller=_StubCaseRecaller([cand]),
|
||||
embed_query_fn=_fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SearchAgentCaseItem)
|
||||
assert result[0].id == "c_1"
|
||||
assert result[0].task_intent == "intent c_1"
|
||||
|
||||
|
||||
async def test_search_agent_skills_agentic_shapes_result() -> None:
|
||||
"""Output must be list[SearchAgentSkillItem] from aagentic_retrieve results."""
|
||||
cand = _skill_candidate("s_1")
|
||||
|
||||
async def fake_aagentic(
|
||||
*_: Any, **__: Any
|
||||
) -> tuple[list[Candidate], AgenticDecision]:
|
||||
return [cand], AgenticDecision(is_multi_round=False)
|
||||
|
||||
with patch("everos.memory.search.agentic_agent.aagentic_retrieve", fake_aagentic):
|
||||
result = await search_agent_skills_agentic(
|
||||
"skill query",
|
||||
where="owner_id = 'agent_a' AND owner_type = 'agent'",
|
||||
skill_recaller=_StubSkillRecaller([cand]),
|
||||
embed_query_fn=_fake_embed,
|
||||
reranker=_StubReranker(),
|
||||
llm=FakeLLMClient(responses=[]),
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SearchAgentSkillItem)
|
||||
assert result[0].id == "s_1"
|
||||
assert result[0].name == "skill_s_1"
|
||||
163
tests/unit/test_memory/test_search/test_callbacks.py
Normal file
163
tests/unit/test_memory/test_search/test_callbacks.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""Unit tests for ``memory.search.callbacks``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from everalgo.types import Candidate
|
||||
|
||||
from everos.memory.search.callbacks import (
|
||||
_SKILL_RERANK_INSTRUCTION,
|
||||
build_rerank_fn,
|
||||
build_skill_rerank_fn,
|
||||
)
|
||||
|
||||
|
||||
class _StubReranker:
|
||||
"""Returns candidates in original order with scores 1.0, 0.9, 0.8, ...
|
||||
|
||||
Records the ``instruction`` and ``passages`` from the most recent call so
|
||||
tests can assert that callback factories forward the right arguments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.last_instruction: str | None = None
|
||||
self.last_passages: list[str] | None = None
|
||||
|
||||
async def rerank(
|
||||
self, query: str, passages: list[str], *, instruction: str | None = None
|
||||
) -> list[Any]:
|
||||
self.last_instruction = instruction
|
||||
self.last_passages = list(passages)
|
||||
|
||||
class _R:
|
||||
def __init__(self, index: int, score: float) -> None:
|
||||
self.index = index
|
||||
self.score = score
|
||||
|
||||
return [_R(i, 1.0 - i * 0.1) for i in range(len(passages))]
|
||||
|
||||
|
||||
def _cand(cid: str, episode_text: str = "body") -> Candidate:
|
||||
return Candidate(
|
||||
id=cid,
|
||||
score=0.5,
|
||||
source="vector",
|
||||
metadata={"episode": episode_text},
|
||||
)
|
||||
|
||||
|
||||
async def test_build_rerank_fn_returns_two_arg_callable() -> None:
|
||||
"""build_rerank_fn must return a 2-arg async callable matching RerankFn."""
|
||||
rerank_fn = build_rerank_fn(_StubReranker(), text_field="episode")
|
||||
sig = inspect.signature(rerank_fn)
|
||||
params = list(sig.parameters)
|
||||
assert params == ["query", "candidates"], f"Expected 2-arg fn, got params: {params}"
|
||||
|
||||
|
||||
async def test_build_rerank_fn_returns_all_candidates_without_truncation() -> None:
|
||||
"""rerank_fn must return ALL reranked candidates; caller slices."""
|
||||
rerank_fn = build_rerank_fn(_StubReranker(), text_field="episode")
|
||||
cands = [_cand(f"c{i}") for i in range(5)]
|
||||
result = await rerank_fn("what did Alice eat?", cands)
|
||||
assert len(result) == 5
|
||||
|
||||
|
||||
async def test_build_rerank_fn_attaches_scores_from_provider() -> None:
|
||||
"""rerank_fn updates Candidate.score from RerankProvider results."""
|
||||
rerank_fn = build_rerank_fn(_StubReranker(), text_field="episode")
|
||||
cands = [_cand("a"), _cand("b")]
|
||||
result = await rerank_fn("q", cands)
|
||||
assert all(isinstance(c.score, float) for c in result)
|
||||
assert result[0].score == pytest.approx(1.0)
|
||||
assert result[1].score == pytest.approx(0.9)
|
||||
|
||||
|
||||
async def test_build_rerank_fn_handles_empty_candidates() -> None:
|
||||
"""Empty candidate list returns empty list without calling the provider."""
|
||||
rerank_fn = build_rerank_fn(_StubReranker(), text_field="episode")
|
||||
result = await rerank_fn("q", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
async def test_build_rerank_fn_forwards_instruction() -> None:
|
||||
"""The task instruction is forwarded verbatim to the provider."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_rerank_fn(stub, text_field="episode", instruction="find facts")
|
||||
await rerank_fn("q", [_cand("a")])
|
||||
assert stub.last_instruction == "find facts"
|
||||
|
||||
|
||||
# ── build_skill_rerank_fn ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _skill_cand(cid: str, *, name: str = "", description: str = "") -> Candidate:
|
||||
return Candidate(
|
||||
id=cid,
|
||||
score=0.5,
|
||||
source="vector",
|
||||
metadata={"name": name, "description": description},
|
||||
)
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_emits_shaped_passage() -> None:
|
||||
"""Passage = ``"Agent Skill: {name} - {description}"`` when both present."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
await rerank_fn(
|
||||
"q",
|
||||
[_skill_cand("s1", name="refactor_auth", description="split provider lookup")],
|
||||
)
|
||||
assert stub.last_passages == ["Agent Skill: refactor_auth - split provider lookup"]
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_omits_dash_when_description_missing() -> None:
|
||||
"""When description is empty, drop ``" - {description}"`` suffix."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
await rerank_fn("q", [_skill_cand("s1", name="refactor_auth", description="")])
|
||||
assert stub.last_passages == ["Agent Skill: refactor_auth"]
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_falls_back_when_name_missing() -> None:
|
||||
"""When name is empty, passage degrades to bare description."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
await rerank_fn("q", [_skill_cand("s1", name="", description="just text")])
|
||||
assert stub.last_passages == ["just text"]
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_forwards_skill_instruction() -> None:
|
||||
"""The skill-specific instruction is hard-wired into the call."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
await rerank_fn("q", [_skill_cand("s1", name="x", description="y")])
|
||||
assert stub.last_instruction == _SKILL_RERANK_INSTRUCTION
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_handles_empty_candidates() -> None:
|
||||
"""Empty candidate list skips the provider call entirely."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
result = await rerank_fn("q", [])
|
||||
assert result == []
|
||||
assert stub.last_passages is None # provider never called
|
||||
|
||||
|
||||
async def test_build_skill_rerank_fn_attaches_scores_and_preserves_metadata() -> None:
|
||||
"""Reranked candidates carry the provider's score and original metadata."""
|
||||
stub = _StubReranker()
|
||||
rerank_fn = build_skill_rerank_fn(stub)
|
||||
cands = [
|
||||
_skill_cand("a", name="alpha", description="d-a"),
|
||||
_skill_cand("b", name="beta", description="d-b"),
|
||||
]
|
||||
result = await rerank_fn("q", cands)
|
||||
assert [c.id for c in result] == ["a", "b"]
|
||||
assert result[0].score == pytest.approx(1.0)
|
||||
assert result[1].score == pytest.approx(0.9)
|
||||
# metadata round-trips intact — the shape function only reads it, never mutates.
|
||||
assert result[0].metadata["name"] == "alpha"
|
||||
assert result[1].metadata["description"] == "d-b"
|
||||
135
tests/unit/test_memory/test_search/test_dto.py
Normal file
135
tests/unit/test_memory/test_search/test_dto.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Unit tests for ``memory.search.dto`` validation rules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from everos.memory.search import (
|
||||
SearchData,
|
||||
SearchMethod,
|
||||
SearchRequest,
|
||||
SearchResponse,
|
||||
)
|
||||
|
||||
|
||||
def _minimal_request_kwargs() -> dict:
|
||||
return {
|
||||
"user_id": "alice",
|
||||
"query": "hello",
|
||||
}
|
||||
|
||||
|
||||
def test_enable_llm_rerank_defaults_to_false() -> None:
|
||||
"""HYBRID should NOT auto-trigger LLM Phase-5 rerank by default.
|
||||
|
||||
The caller opts in explicitly when they want the extra LLM pass;
|
||||
leaving it off keeps a default HYBRID call cheap (no LLM ``chat``).
|
||||
"""
|
||||
req = SearchRequest(**_minimal_request_kwargs())
|
||||
assert req.enable_llm_rerank is False
|
||||
|
||||
|
||||
def test_enable_llm_rerank_accepts_true() -> None:
|
||||
req = SearchRequest(**_minimal_request_kwargs(), enable_llm_rerank=True)
|
||||
assert req.enable_llm_rerank is True
|
||||
|
||||
|
||||
def test_minimal_request_uses_hybrid_default() -> None:
|
||||
req = SearchRequest(**_minimal_request_kwargs())
|
||||
assert req.method == SearchMethod.HYBRID
|
||||
assert req.top_k == -1
|
||||
assert req.include_profile is False
|
||||
assert req.filters is None
|
||||
assert req.radius is None
|
||||
|
||||
|
||||
def test_top_k_zero_rejected() -> None:
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
SearchRequest(**_minimal_request_kwargs(), top_k=0)
|
||||
assert "top_k" in str(exc.value)
|
||||
|
||||
|
||||
def test_top_k_above_100_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(**_minimal_request_kwargs(), top_k=101)
|
||||
|
||||
|
||||
def test_top_k_below_minus_one_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(**_minimal_request_kwargs(), top_k=-2)
|
||||
|
||||
|
||||
def test_top_k_minus_one_accepted() -> None:
|
||||
req = SearchRequest(**_minimal_request_kwargs(), top_k=-1)
|
||||
assert req.top_k == -1
|
||||
|
||||
|
||||
def test_top_k_in_range_accepted() -> None:
|
||||
req = SearchRequest(**_minimal_request_kwargs(), top_k=50)
|
||||
assert req.top_k == 50
|
||||
|
||||
|
||||
def test_radius_out_of_range_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(**_minimal_request_kwargs(), radius=1.5)
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(**_minimal_request_kwargs(), radius=-0.1)
|
||||
|
||||
|
||||
def test_neither_user_id_nor_agent_id_rejected() -> None:
|
||||
"""The xor validator requires exactly one of user_id / agent_id."""
|
||||
with pytest.raises(ValidationError, match="exactly one of"):
|
||||
SearchRequest(query="hello") # neither set
|
||||
|
||||
|
||||
def test_both_user_id_and_agent_id_rejected() -> None:
|
||||
"""The xor validator rejects ambiguous owner identity."""
|
||||
with pytest.raises(ValidationError, match="exactly one of"):
|
||||
SearchRequest(user_id="alice", agent_id="agent_x", query="hello")
|
||||
|
||||
|
||||
def test_empty_query_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(user_id="alice", query="")
|
||||
|
||||
|
||||
def test_empty_user_id_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(user_id="", query="hello")
|
||||
|
||||
|
||||
def test_extra_top_level_field_rejected() -> None:
|
||||
"""``extra='forbid'`` keeps the contract tight."""
|
||||
with pytest.raises(ValidationError):
|
||||
SearchRequest(
|
||||
**_minimal_request_kwargs(),
|
||||
unexpected_field="x", # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
|
||||
def test_filters_extra_keys_allowed() -> None:
|
||||
"""FilterNode is open-shape; safety is enforced in the compiler."""
|
||||
req = SearchRequest(
|
||||
**_minimal_request_kwargs(),
|
||||
filters={"session_id": "sess_a", "AND": [{"timestamp": {"gte": 1}}]},
|
||||
)
|
||||
assert req.filters is not None
|
||||
dumped = req.filters.model_dump(exclude_none=True)
|
||||
assert dumped["session_id"] == "sess_a"
|
||||
assert dumped["AND"][0]["timestamp"]["gte"] == 1
|
||||
|
||||
|
||||
def test_response_default_arrays_present() -> None:
|
||||
"""Every ``data.*`` array must exist so callers can iterate unconditionally."""
|
||||
resp = SearchResponse(request_id="0" * 32, data=SearchData())
|
||||
assert resp.data.episodes == []
|
||||
assert resp.data.profiles == []
|
||||
assert resp.data.agent_cases == []
|
||||
assert resp.data.agent_skills == []
|
||||
|
||||
|
||||
def test_method_enum_serialises_to_lowercase() -> None:
|
||||
req = SearchRequest(**_minimal_request_kwargs(), method="agentic") # type: ignore[arg-type]
|
||||
assert req.method == SearchMethod.AGENTIC
|
||||
assert req.method.value == "agentic"
|
||||
244
tests/unit/test_memory/test_search/test_filters.py
Normal file
244
tests/unit/test_memory/test_search/test_filters.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""Unit tests for the Filters DSL compiler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.search import (
|
||||
FilterError,
|
||||
FilterNode,
|
||||
compile_filters,
|
||||
)
|
||||
|
||||
# ── Base injection ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_filters_emits_base_clause() -> None:
|
||||
where = compile_filters(None, owner_id="alice", owner_type="user")
|
||||
assert where == (
|
||||
"owner_id = 'alice' AND owner_type = 'user' "
|
||||
"AND app_id = 'default' AND project_id = 'default'"
|
||||
)
|
||||
|
||||
|
||||
def test_owner_type_agent_pinned() -> None:
|
||||
where = compile_filters(None, owner_id="alice", owner_type="agent")
|
||||
assert "owner_type = 'agent'" in where
|
||||
|
||||
|
||||
def test_app_project_scope_pinned() -> None:
|
||||
where = compile_filters(
|
||||
None,
|
||||
owner_id="alice",
|
||||
owner_type="user",
|
||||
app_id="claude_code",
|
||||
project_id="oss",
|
||||
)
|
||||
assert "app_id = 'claude_code'" in where
|
||||
assert "project_id = 'oss'" in where
|
||||
|
||||
|
||||
def test_owner_id_with_quote_is_escaped() -> None:
|
||||
where = compile_filters(None, owner_id="al'ice", owner_type="user")
|
||||
assert "owner_id = 'al''ice'" in where
|
||||
|
||||
|
||||
# ── Equality / shorthand ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_flat_equality_shorthand() -> None:
|
||||
node = FilterNode(session_id="sess_a") # type: ignore[call-arg]
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
|
||||
|
||||
def test_multiple_flat_fields_join_with_and() -> None:
|
||||
node = FilterNode.model_validate({"session_id": "sess_a", "parent_type": "memcell"})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "parent_type = 'memcell'" in where
|
||||
|
||||
|
||||
# ── Operators ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_timestamp_gte_renders_timestamp_literal() -> None:
|
||||
node = FilterNode.model_validate({"timestamp": {"gte": 1704067200000}})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "timestamp >= TIMESTAMP '" in where
|
||||
|
||||
|
||||
def test_timestamp_range_folds_with_and() -> None:
|
||||
node = FilterNode.model_validate(
|
||||
{"timestamp": {"gte": 1704067200000, "lt": 1740614399000}}
|
||||
)
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "timestamp >= TIMESTAMP '" in where
|
||||
assert "timestamp < TIMESTAMP '" in where
|
||||
# Operators on the same field are wrapped in a single group.
|
||||
assert " AND " in where
|
||||
|
||||
|
||||
def test_in_operator_string_field() -> None:
|
||||
node = FilterNode.model_validate({"parent_type": {"in": ["memcell", "episode"]}})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "parent_type IN ('memcell', 'episode')" in where
|
||||
|
||||
|
||||
def test_in_operator_requires_non_empty_list() -> None:
|
||||
node = FilterNode.model_validate({"parent_type": {"in": []}})
|
||||
with pytest.raises(FilterError):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_invalid_operator_rejected() -> None:
|
||||
node = FilterNode.model_validate({"timestamp": {"between": [1, 2]}})
|
||||
with pytest.raises(FilterError, match="operator"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
# ── Combinators ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_and_combinator() -> None:
|
||||
node = FilterNode.model_validate(
|
||||
{
|
||||
"AND": [
|
||||
{"timestamp": {"gte": 1704067200000}},
|
||||
{"timestamp": {"lt": 1740614399000}},
|
||||
]
|
||||
}
|
||||
)
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "timestamp >= TIMESTAMP '" in where
|
||||
assert "timestamp < TIMESTAMP '" in where
|
||||
assert " AND " in where
|
||||
|
||||
|
||||
def test_or_combinator() -> None:
|
||||
node = FilterNode.model_validate(
|
||||
{
|
||||
"OR": [
|
||||
{"parent_type": "memcell"},
|
||||
{"parent_type": "episode"},
|
||||
]
|
||||
}
|
||||
)
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert " OR " in where
|
||||
assert "parent_type = 'memcell'" in where
|
||||
assert "parent_type = 'episode'" in where
|
||||
|
||||
|
||||
def test_nested_and_inside_or() -> None:
|
||||
node = FilterNode.model_validate(
|
||||
{
|
||||
"OR": [
|
||||
{"AND": [{"parent_type": "memcell"}, {"session_id": "sa"}]},
|
||||
{"parent_type": "episode"},
|
||||
]
|
||||
}
|
||||
)
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "parent_type = 'memcell'" in where
|
||||
assert "session_id = 'sa'" in where
|
||||
assert "parent_type = 'episode'" in where
|
||||
assert " OR " in where
|
||||
assert " AND " in where
|
||||
|
||||
|
||||
def test_flat_field_alongside_and_combinator() -> None:
|
||||
node = FilterNode.model_validate(
|
||||
{
|
||||
"session_id": "sess_a",
|
||||
"AND": [{"timestamp": {"gte": 1}}],
|
||||
}
|
||||
)
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "session_id = 'sess_a'" in where
|
||||
assert "timestamp >= TIMESTAMP '" in where
|
||||
|
||||
|
||||
# ── Array field (sender_id → sender_ids) ────────────────────────────────
|
||||
|
||||
|
||||
def test_sender_id_eq_uses_array_has() -> None:
|
||||
node = FilterNode.model_validate({"sender_id": "u_jason"})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "array_has(sender_ids, 'u_jason')" in where
|
||||
|
||||
|
||||
def test_sender_id_in_expands_to_or_array_has() -> None:
|
||||
node = FilterNode.model_validate({"sender_id": {"in": ["u_a", "u_b"]}})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "array_has(sender_ids, 'u_a')" in where
|
||||
assert "array_has(sender_ids, 'u_b')" in where
|
||||
assert " OR " in where
|
||||
|
||||
|
||||
def test_sender_id_gt_rejected() -> None:
|
||||
node = FilterNode.model_validate({"sender_id": {"gt": "x"}})
|
||||
with pytest.raises(FilterError, match="not supported on array"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
# ── Safety ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_field_rejected() -> None:
|
||||
node = FilterNode.model_validate({"secret_field": "x"})
|
||||
with pytest.raises(FilterError, match="unsupported filter field"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_owner_id_in_filters_rejected() -> None:
|
||||
node = FilterNode.model_validate({"owner_id": "mallory"})
|
||||
with pytest.raises(FilterError, match="reserved"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_owner_type_in_filters_rejected() -> None:
|
||||
node = FilterNode.model_validate({"owner_type": "agent"})
|
||||
with pytest.raises(FilterError, match="reserved"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_string_with_single_quote_escaped() -> None:
|
||||
node = FilterNode.model_validate({"session_id": "ses's"})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert "session_id = 'ses''s'" in where
|
||||
|
||||
|
||||
def test_timestamp_string_with_quote_rejected() -> None:
|
||||
"""ISO strings with embedded quotes can break the literal — reject loudly."""
|
||||
node = FilterNode.model_validate({"timestamp": {"gte": "2024-01'-01T00:00:00"}})
|
||||
with pytest.raises(FilterError, match="contains a quote"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_in_value_type_check() -> None:
|
||||
node = FilterNode.model_validate({"parent_type": {"in": [1, 2]}})
|
||||
with pytest.raises(FilterError, match="must be a string"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_bool_for_timestamp_rejected() -> None:
|
||||
node = FilterNode.model_validate({"timestamp": {"gte": True}})
|
||||
with pytest.raises(FilterError, match="timestamp value"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_empty_operator_map_rejected() -> None:
|
||||
node = FilterNode.model_validate({"timestamp": {}})
|
||||
with pytest.raises(FilterError, match="empty operator map"):
|
||||
compile_filters(node, owner_id="alice", owner_type="user")
|
||||
|
||||
|
||||
def test_empty_and_array_skips_combinator() -> None:
|
||||
"""Empty AND/OR arrays compile to no clauses — only the base remains."""
|
||||
node = FilterNode.model_validate({"AND": []})
|
||||
where = compile_filters(node, owner_id="alice", owner_type="user")
|
||||
assert where == (
|
||||
"owner_id = 'alice' AND owner_type = 'user' "
|
||||
"AND app_id = 'default' AND project_id = 'default'"
|
||||
)
|
||||
278
tests/unit/test_memory/test_search/test_hierarchy.py
Normal file
278
tests/unit/test_memory/test_search/test_hierarchy.py
Normal file
@ -0,0 +1,278 @@
|
||||
"""Unit tests for ``memory.search.hierarchy``.
|
||||
|
||||
White-box surfaces accessed:
|
||||
- ``_hierarchy_eviction_pass`` (internal, tested directly for unit coverage)
|
||||
- ``hierarchy_retrieve_episodes`` (public function, tested with stubbed I/O)
|
||||
|
||||
All I/O (fact_recaller, episode_recaller) is injected via AsyncMock stubs.
|
||||
No LanceDB or network calls are made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from everalgo.types import Candidate, FactCandidate
|
||||
|
||||
from everos.memory.search.hierarchy import (
|
||||
_hierarchy_eviction_pass,
|
||||
hierarchy_retrieve_episodes,
|
||||
)
|
||||
|
||||
# ── Fixtures / helpers ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _episode_candidate(
|
||||
*,
|
||||
ep_id: str = "ep-1",
|
||||
score: float = 0.7,
|
||||
memcell_id: str = "mc-1",
|
||||
) -> Candidate:
|
||||
return Candidate(
|
||||
id=ep_id,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"parent_id": memcell_id,
|
||||
"owner_id": "u1",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess-1",
|
||||
"timestamp": _ts(),
|
||||
"episode": "Some episode text.",
|
||||
"sender_ids": ["u1"],
|
||||
"subject": "Test subject",
|
||||
"summary": "Test summary",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _fact_candidate(
|
||||
*,
|
||||
fact_id: str = "fact-1",
|
||||
parent_episode_id: str = "ep-1",
|
||||
score: float = 0.9,
|
||||
) -> FactCandidate:
|
||||
return FactCandidate(
|
||||
id=fact_id,
|
||||
parent_episode_id=parent_episode_id,
|
||||
score=score,
|
||||
metadata={"fact": "Some fact text."},
|
||||
)
|
||||
|
||||
|
||||
def _make_recallers(
|
||||
*,
|
||||
dense_facts: list[Candidate] | None = None,
|
||||
fetched_episodes: list[Candidate] | None = None,
|
||||
facts_for_episodes: dict[str, list[FactCandidate]] | None = None,
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
"""Build stubbed fact_recaller and episode_recaller."""
|
||||
fact_recaller = MagicMock()
|
||||
fact_recaller.dense_recall = AsyncMock(return_value=dense_facts or [])
|
||||
fact_recaller.facts_for_episodes = AsyncMock(return_value=facts_for_episodes or {})
|
||||
|
||||
episode_recaller = MagicMock()
|
||||
episode_recaller.fetch_by_parent_ids = AsyncMock(
|
||||
return_value=fetched_episodes or []
|
||||
)
|
||||
|
||||
return fact_recaller, episode_recaller
|
||||
|
||||
|
||||
# ── _hierarchy_eviction_pass unit tests ─────────────────────────────────
|
||||
|
||||
|
||||
class TestHierarchyEvictionPass:
|
||||
def test_fact_wins_emits_atomic_fact_scored_item(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.5)
|
||||
fact = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.9)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [fact]})
|
||||
|
||||
assert len(result) == 1
|
||||
item = result[0]
|
||||
assert item.item_type == "atomic_fact"
|
||||
assert item.id == "fact-1"
|
||||
assert item.score == pytest.approx(0.9)
|
||||
|
||||
def test_episode_wins_emits_episode_scored_item(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.8)
|
||||
fact = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.6)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [fact]})
|
||||
|
||||
assert len(result) == 1
|
||||
item = result[0]
|
||||
assert item.item_type == "episode"
|
||||
assert item.id == "ep-1"
|
||||
assert item.score == pytest.approx(0.8)
|
||||
|
||||
def test_no_facts_emits_episode(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.7)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {})
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].item_type == "episode"
|
||||
assert result[0].id == "ep-1"
|
||||
|
||||
def test_ordering_preserved_matches_input_order(self) -> None:
|
||||
ep_a = _episode_candidate(ep_id="ep-a", score=0.9, memcell_id="mc-a")
|
||||
ep_b = _episode_candidate(ep_id="ep-b", score=0.8, memcell_id="mc-b")
|
||||
ep_c = _episode_candidate(ep_id="ep-c", score=0.7, memcell_id="mc-c")
|
||||
merged = [ep_a, ep_b, ep_c]
|
||||
|
||||
result = _hierarchy_eviction_pass(merged, {})
|
||||
|
||||
assert [r.id for r in result] == ["ep-a", "ep-b", "ep-c"]
|
||||
|
||||
def test_parent_episode_id_set_on_evicted_fact(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.4)
|
||||
fact = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.8)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [fact]})
|
||||
|
||||
assert result[0].parent_episode_id == "ep-1"
|
||||
|
||||
def test_episode_wins_parent_episode_id_is_none(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.9)
|
||||
fact = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.5)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [fact]})
|
||||
|
||||
assert result[0].parent_episode_id is None
|
||||
|
||||
def test_multiple_episodes_mixed_eviction(self) -> None:
|
||||
ep1 = _episode_candidate(ep_id="ep-1", score=0.5, memcell_id="mc-1")
|
||||
ep2 = _episode_candidate(ep_id="ep-2", score=0.8, memcell_id="mc-2")
|
||||
ep3 = _episode_candidate(ep_id="ep-3", score=0.6, memcell_id="mc-3")
|
||||
fact1 = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.9)
|
||||
fact2 = _fact_candidate(fact_id="fact-2", parent_episode_id="ep-2", score=0.4)
|
||||
|
||||
result = _hierarchy_eviction_pass(
|
||||
[ep1, ep2, ep3],
|
||||
{"ep-1": [fact1], "ep-2": [fact2]},
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0].item_type == "atomic_fact"
|
||||
assert result[0].id == "fact-1"
|
||||
assert result[1].item_type == "episode"
|
||||
assert result[1].id == "ep-2"
|
||||
assert result[2].item_type == "episode"
|
||||
assert result[2].id == "ep-3"
|
||||
|
||||
def test_best_fact_is_first_element_used_for_comparison(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.7)
|
||||
best_fact = _fact_candidate(
|
||||
fact_id="fact-best", parent_episode_id="ep-1", score=0.8
|
||||
)
|
||||
second_fact = _fact_candidate(
|
||||
fact_id="fact-second", parent_episode_id="ep-1", score=0.3
|
||||
)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [best_fact, second_fact]})
|
||||
|
||||
assert result[0].item_type == "atomic_fact"
|
||||
assert result[0].id == "fact-best"
|
||||
|
||||
def test_fact_score_equal_to_episode_score_episode_wins(self) -> None:
|
||||
episode = _episode_candidate(ep_id="ep-1", score=0.7)
|
||||
fact = _fact_candidate(fact_id="fact-1", parent_episode_id="ep-1", score=0.7)
|
||||
|
||||
result = _hierarchy_eviction_pass([episode], {"ep-1": [fact]})
|
||||
|
||||
assert result[0].item_type == "episode"
|
||||
|
||||
|
||||
# ── hierarchy_retrieve_episodes integration-style unit tests ─────────────
|
||||
|
||||
|
||||
class TestHierarchyRetrieveEpisodes:
|
||||
"""Integration-style unit tests with fully stubbed I/O.
|
||||
|
||||
amaxsim_retrieve and rrf are exercised with real implementations but
|
||||
all LanceDB / network calls are replaced by AsyncMock.
|
||||
"""
|
||||
|
||||
async def test_empty_sparse_dense_returns_empty_list(self) -> None:
|
||||
fact_recaller, episode_recaller = _make_recallers()
|
||||
|
||||
result = await hierarchy_retrieve_episodes(
|
||||
query="test query",
|
||||
sparse=[],
|
||||
dense=[],
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
fact_recaller=fact_recaller,
|
||||
episode_recaller=episode_recaller,
|
||||
where="owner_id = 'u1'",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_happy_path_episode_wins_no_nested_facts(self) -> None:
|
||||
ep = _episode_candidate(ep_id="ep-1", score=0.8, memcell_id="mc-1")
|
||||
|
||||
fact_recaller, episode_recaller = _make_recallers(
|
||||
dense_facts=[],
|
||||
fetched_episodes=[],
|
||||
facts_for_episodes={},
|
||||
)
|
||||
|
||||
result = await hierarchy_retrieve_episodes(
|
||||
query="test query",
|
||||
sparse=[ep],
|
||||
dense=[ep],
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
fact_recaller=fact_recaller,
|
||||
episode_recaller=episode_recaller,
|
||||
where="owner_id = 'u1'",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
episode_item = result[0]
|
||||
assert episode_item.id == "ep-1"
|
||||
assert episode_item.atomic_facts == []
|
||||
|
||||
async def test_happy_path_fact_evicts_episode_nested_in_result(self) -> None:
|
||||
ep = _episode_candidate(ep_id="ep-2", score=0.6, memcell_id="mc-2")
|
||||
fact = _fact_candidate(fact_id="fact-2", parent_episode_id="ep-2", score=0.95)
|
||||
|
||||
fact_recaller, episode_recaller = _make_recallers(
|
||||
dense_facts=[
|
||||
Candidate(
|
||||
id="fact-2",
|
||||
score=0.95,
|
||||
source="vector",
|
||||
metadata={"parent_id": "mc-2"},
|
||||
)
|
||||
],
|
||||
fetched_episodes=[ep],
|
||||
facts_for_episodes={"ep-2": [fact]},
|
||||
)
|
||||
|
||||
result = await hierarchy_retrieve_episodes(
|
||||
query="test query",
|
||||
sparse=[ep],
|
||||
dense=[ep],
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
fact_recaller=fact_recaller,
|
||||
episode_recaller=episode_recaller,
|
||||
where="owner_id = 'u1'",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
episode_item = result[0]
|
||||
assert episode_item.atomic_facts != []
|
||||
nested_fact = episode_item.atomic_facts[0]
|
||||
assert nested_fact.id == "fact-2"
|
||||
assert nested_fact.score == pytest.approx(0.95)
|
||||
930
tests/unit/test_memory/test_search/test_manager.py
Normal file
930
tests/unit/test_memory/test_search/test_manager.py
Normal file
@ -0,0 +1,930 @@
|
||||
"""Unit tests for ``SearchManager`` with in-memory stub recallers.
|
||||
|
||||
These tests exercise the orchestration without touching LanceDB. Every
|
||||
recaller is replaced by a hand-rolled stub that returns a small
|
||||
candidate list; the manager's job is to:
|
||||
|
||||
* honour the ``owner_type`` hard partition,
|
||||
* run KEYWORD as sparse-only and leave ``atomic_facts`` empty,
|
||||
* run VECTOR as dense-only (and refuse when no embedding is wired),
|
||||
* let HYBRID run without an LLM by default; require LLM only when the
|
||||
caller sets ``enable_llm_rerank=True``,
|
||||
* refuse AGENTIC when reranker / LLM prerequisites are missing,
|
||||
* delegate AGENTIC to ``search_episodes_agentic`` and return its result.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import pytest
|
||||
from everalgo.types import Candidate, FactCandidate
|
||||
|
||||
from everos.memory.search.dto import SearchMethod, SearchRequest
|
||||
from everos.memory.search.manager import SearchManager
|
||||
|
||||
# ── Stubs ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _episode_row(
|
||||
eid: str, score: float = 0.8, memcell_id: str | None = None
|
||||
) -> Candidate:
|
||||
return Candidate(
|
||||
id=eid,
|
||||
score=score,
|
||||
source="keyword",
|
||||
metadata={
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess_a",
|
||||
"timestamp": _ts(),
|
||||
"sender_ids": ["alice"],
|
||||
"subject": f"subj {eid}",
|
||||
"summary": f"summary {eid}",
|
||||
"episode": f"body {eid}",
|
||||
"parent_id": memcell_id if memcell_id is not None else f"mc_{eid}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _case_row(cid: str) -> Candidate:
|
||||
return Candidate(
|
||||
id=cid,
|
||||
score=0.7,
|
||||
source="keyword",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"session_id": "sess_b",
|
||||
"timestamp": _ts(),
|
||||
"task_intent": f"intent {cid}",
|
||||
"approach": f"approach {cid}",
|
||||
"quality_score": 0.8,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _skill_row(sid: str) -> Candidate:
|
||||
return Candidate(
|
||||
id=sid,
|
||||
score=0.65,
|
||||
source="keyword",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"name": f"skill_{sid}",
|
||||
"description": f"desc {sid}",
|
||||
"content": f"content {sid}",
|
||||
"confidence": 0.9,
|
||||
"maturity_score": 0.6,
|
||||
"source_case_ids": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _StubEpisodeRecaller:
|
||||
kind: ClassVar[str] = "episode"
|
||||
everalgo_memory_type: ClassVar[str] = "episodic"
|
||||
text_field: ClassVar[str] = "episode"
|
||||
|
||||
def __init__(self, sparse: list[Candidate], dense: list[Candidate]) -> None:
|
||||
self._sparse = sparse
|
||||
self._dense = dense
|
||||
self.last_where: str | None = None
|
||||
|
||||
async def sparse_recall(
|
||||
self, query: str, where: str, *, limit: int
|
||||
) -> list[Candidate]:
|
||||
self.last_where = where
|
||||
return list(self._sparse[:limit])
|
||||
|
||||
async def dense_recall(
|
||||
self, vector: Sequence[float], where: str, *, limit: int
|
||||
) -> list[Candidate]:
|
||||
self.last_where = where
|
||||
return list(self._dense[:limit])
|
||||
|
||||
async def fetch_by_parent_ids(
|
||||
self, parent_ids: Sequence[str], where: str
|
||||
) -> list[Candidate]:
|
||||
# Index dense rows by their parent_id (memcell id) so the maxsim
|
||||
# path's reverse-resolve has something to return.
|
||||
by_parent = {str(c.metadata.get("parent_id", "")): c for c in self._dense}
|
||||
return [by_parent[p] for p in parent_ids if p in by_parent]
|
||||
|
||||
|
||||
class _StubAtomicFactRecaller:
|
||||
kind: ClassVar[str] = "atomic_fact"
|
||||
everalgo_memory_type: ClassVar[str] = "episodic"
|
||||
text_field: ClassVar[str] = "fact"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
facts_map: dict[str, list[FactCandidate]] | None = None,
|
||||
dense: list[Candidate] | None = None,
|
||||
) -> None:
|
||||
self._facts_map = facts_map or {}
|
||||
self._dense = dense or []
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return []
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
async def facts_for_episodes(
|
||||
self,
|
||||
ep_to_memcell: Mapping[str, str],
|
||||
where: str,
|
||||
*,
|
||||
per_episode: int,
|
||||
query_vector: Any = None,
|
||||
) -> dict[str, list[FactCandidate]]:
|
||||
# ``query_vector`` accepted to match the real recaller signature
|
||||
# Accepted to match the real recaller signature; stub doesn't use it.
|
||||
return {
|
||||
eid: self._facts_map.get(eid, [])[:per_episode] for eid in ep_to_memcell
|
||||
}
|
||||
|
||||
|
||||
class _StubAgentCaseRecaller:
|
||||
kind: ClassVar[str] = "agent_case"
|
||||
everalgo_memory_type: ClassVar[str] = "case"
|
||||
text_field: ClassVar[str] = "task_intent"
|
||||
|
||||
def __init__(self, sparse: list[Candidate], dense: list[Candidate]) -> None:
|
||||
self._sparse = sparse
|
||||
self._dense = dense
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._sparse)
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
|
||||
class _StubAgentSkillRecaller:
|
||||
kind: ClassVar[str] = "agent_skill"
|
||||
everalgo_memory_type: ClassVar[str] = "skill"
|
||||
text_field: ClassVar[str] = "description"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparse: list[Candidate],
|
||||
dense: list[Candidate],
|
||||
by_case: list[Candidate] | None = None,
|
||||
) -> None:
|
||||
self._sparse = sparse
|
||||
self._dense = dense
|
||||
# Bridge recall fixture: reverse-resolved skills (``fetch_by_case_ids``).
|
||||
# Default empty — only the bridge tests populate this.
|
||||
self._by_case = by_case or []
|
||||
|
||||
async def sparse_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._sparse)
|
||||
|
||||
async def dense_recall(self, *_: Any, **__: Any) -> list[Candidate]:
|
||||
return list(self._dense)
|
||||
|
||||
async def fetch_by_case_ids(
|
||||
self, case_ids: Sequence[str], where: str, *, limit: int
|
||||
) -> list[Candidate]:
|
||||
return list(self._by_case)
|
||||
|
||||
|
||||
class _StubProfileRecaller:
|
||||
async def fetch(self, owner_id: str) -> list:
|
||||
return []
|
||||
|
||||
|
||||
class _StubEmbedding:
|
||||
def __init__(self, dim: int = 4) -> None:
|
||||
self.dim = dim
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
||||
return [[0.0] * self.dim for _ in texts]
|
||||
|
||||
|
||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_manager(
|
||||
*,
|
||||
episode_sparse: list[Candidate] | None = None,
|
||||
episode_dense: list[Candidate] | None = None,
|
||||
case_sparse: list[Candidate] | None = None,
|
||||
case_dense: list[Candidate] | None = None,
|
||||
skill_sparse: list[Candidate] | None = None,
|
||||
skill_dense: list[Candidate] | None = None,
|
||||
skill_by_case: list[Candidate] | None = None,
|
||||
facts_map: dict[str, list[FactCandidate]] | None = None,
|
||||
atomic_fact_dense: list[Candidate] | None = None,
|
||||
embedding: _StubEmbedding | None = None,
|
||||
reranker: Any = None,
|
||||
llm_client: Any = None,
|
||||
) -> SearchManager:
|
||||
ep_recaller = _StubEpisodeRecaller(episode_sparse or [], episode_dense or [])
|
||||
return SearchManager(
|
||||
episode_recaller=ep_recaller,
|
||||
atomic_fact_recaller=_StubAtomicFactRecaller(facts_map, atomic_fact_dense),
|
||||
agent_case_recaller=_StubAgentCaseRecaller(case_sparse or [], case_dense or []),
|
||||
agent_skill_recaller=_StubAgentSkillRecaller(
|
||||
skill_sparse or [], skill_dense or [], skill_by_case
|
||||
),
|
||||
profile_recaller=_StubProfileRecaller(),
|
||||
embedding=embedding,
|
||||
reranker=reranker,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
|
||||
def _user_req(
|
||||
method: SearchMethod = SearchMethod.KEYWORD, **kwargs: Any
|
||||
) -> SearchRequest:
|
||||
return SearchRequest(user_id="alice", query="hi", method=method, **kwargs)
|
||||
|
||||
|
||||
def _agent_req(
|
||||
method: SearchMethod = SearchMethod.KEYWORD, **kwargs: Any
|
||||
) -> SearchRequest:
|
||||
return SearchRequest(agent_id="agent_a", query="hi", method=method, **kwargs)
|
||||
|
||||
|
||||
# ── KEYWORD: user owner ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_user_keyword_returns_episodes_only() -> None:
|
||||
mgr = _build_manager(episode_sparse=[_episode_row("ep_1")])
|
||||
resp = await mgr.search(_user_req())
|
||||
assert len(resp.request_id) == 32 and all(
|
||||
c in "0123456789abcdef" for c in resp.request_id
|
||||
)
|
||||
assert len(resp.data.episodes) == 1
|
||||
assert resp.data.episodes[0].id == "ep_1"
|
||||
assert resp.data.episodes[0].user_id == "alice"
|
||||
assert resp.data.episodes[0].type == "Conversation"
|
||||
# Agent paths stay empty.
|
||||
assert resp.data.agent_cases == []
|
||||
assert resp.data.agent_skills == []
|
||||
assert resp.data.profiles == []
|
||||
|
||||
|
||||
async def test_user_keyword_leaves_atomic_facts_empty() -> None:
|
||||
"""KEYWORD never back-fills facts — only HYBRID produces relevance-scored facts.
|
||||
|
||||
Even if the facts repository would return rows for the matched
|
||||
episode, the keyword path must leave ``atomic_facts=[]``: there is
|
||||
no per-query score for those facts, so emitting them would muddy
|
||||
the contract (mirrors enterprise where event_log is a separate
|
||||
memory_type, not auto-attached to episodic results).
|
||||
"""
|
||||
fact = FactCandidate(
|
||||
id="f1",
|
||||
parent_episode_id="ep_1",
|
||||
score=0.0,
|
||||
metadata={"fact": "Alice prefers oat milk"},
|
||||
)
|
||||
mgr = _build_manager(
|
||||
episode_sparse=[_episode_row("ep_1")],
|
||||
facts_map={"ep_1": [fact]},
|
||||
)
|
||||
resp = await mgr.search(_user_req())
|
||||
ep = resp.data.episodes[0]
|
||||
assert ep.atomic_facts == []
|
||||
|
||||
|
||||
async def test_user_keyword_no_results() -> None:
|
||||
resp = await _build_manager().search(_user_req())
|
||||
assert resp.data.episodes == []
|
||||
|
||||
|
||||
async def test_user_keyword_filters_compile_pinned_owner() -> None:
|
||||
"""``compile_filters`` should pin owner_id / owner_type on the where."""
|
||||
recaller = _StubEpisodeRecaller([_episode_row("ep_1")], [])
|
||||
mgr = SearchManager(
|
||||
episode_recaller=recaller,
|
||||
atomic_fact_recaller=_StubAtomicFactRecaller(),
|
||||
agent_case_recaller=_StubAgentCaseRecaller([], []),
|
||||
agent_skill_recaller=_StubAgentSkillRecaller([], []),
|
||||
profile_recaller=_StubProfileRecaller(),
|
||||
embedding=None,
|
||||
reranker=None,
|
||||
llm_client=None,
|
||||
)
|
||||
await mgr.search(_user_req())
|
||||
assert recaller.last_where is not None
|
||||
assert "owner_id = 'alice'" in recaller.last_where
|
||||
assert "owner_type = 'user'" in recaller.last_where
|
||||
|
||||
|
||||
# ── VECTOR: requires embedding ────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_vector_method_requires_embedding() -> None:
|
||||
mgr = _build_manager() # embedding=None by default
|
||||
with pytest.raises(RuntimeError, match="embedding"):
|
||||
await mgr.search(_user_req(method=SearchMethod.VECTOR))
|
||||
|
||||
|
||||
async def test_vector_method_runs_dense_only_with_embedding() -> None:
|
||||
mgr = _build_manager(
|
||||
episode_sparse=[_episode_row("should_not_appear")],
|
||||
episode_dense=[_episode_row("ep_dense")],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR))
|
||||
assert [e.id for e in resp.data.episodes] == ["ep_dense"]
|
||||
|
||||
|
||||
async def test_vector_radius_filter_drops_below_threshold() -> None:
|
||||
mgr = _build_manager(
|
||||
episode_dense=[
|
||||
_episode_row("ep_low", score=0.3),
|
||||
_episode_row("ep_high", score=0.9),
|
||||
],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, radius=0.5))
|
||||
assert [e.id for e in resp.data.episodes] == ["ep_high"]
|
||||
|
||||
|
||||
async def test_unlimited_mode_applies_default_radius_for_vector() -> None:
|
||||
"""``top_k=-1`` without an explicit radius gets the project default 0.5.
|
||||
|
||||
Mirrors enterprise's auto-floor behaviour — unlimited mode must not
|
||||
return arbitrarily low-similarity tail.
|
||||
"""
|
||||
mgr = _build_manager(
|
||||
episode_dense=[
|
||||
_episode_row("ep_low", score=0.3), # below default 0.5 → dropped
|
||||
_episode_row("ep_mid", score=0.55), # above default → kept
|
||||
_episode_row("ep_high", score=0.9),
|
||||
],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, top_k=-1))
|
||||
assert [e.id for e in resp.data.episodes] == ["ep_mid", "ep_high"]
|
||||
|
||||
|
||||
async def test_unlimited_mode_explicit_radius_overrides_default() -> None:
|
||||
"""Caller-supplied radius (even ``0.0``) wins over the unlimited default."""
|
||||
mgr = _build_manager(
|
||||
episode_dense=[
|
||||
_episode_row("ep_low", score=0.2),
|
||||
_episode_row("ep_high", score=0.9),
|
||||
],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, top_k=-1, radius=0.1))
|
||||
# 0.1 threshold keeps both rows (the default 0.5 would have dropped ep_low).
|
||||
assert {e.id for e in resp.data.episodes} == {"ep_low", "ep_high"}
|
||||
|
||||
|
||||
async def test_normal_mode_keeps_full_pool_when_no_radius() -> None:
|
||||
"""``top_k > 0`` without a radius applies no threshold — truncation handles tail."""
|
||||
mgr = _build_manager(
|
||||
episode_dense=[
|
||||
_episode_row("ep_low", score=0.2),
|
||||
_episode_row("ep_high", score=0.9),
|
||||
],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, top_k=10))
|
||||
# No radius default in normal mode → both kept.
|
||||
assert {e.id for e in resp.data.episodes} == {"ep_low", "ep_high"}
|
||||
|
||||
|
||||
# ── VECTOR + maxsim_atomic strategy ─────────────────────────────────────
|
||||
|
||||
|
||||
def _atomic_fact_row(fid: str, *, parent_id: str, score: float) -> Candidate:
|
||||
"""Atomic-fact candidate emitted by ``AtomicFactRecaller.dense_recall``."""
|
||||
return Candidate(
|
||||
id=fid,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess_a",
|
||||
"timestamp": _ts(),
|
||||
"sender_ids": ["alice"],
|
||||
"parent_id": parent_id,
|
||||
"fact": f"fact {fid}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_vector_maxsim_atomic_max_pools_facts_to_episodes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``vector_strategy=maxsim_atomic`` should ANN atomic_facts → max-pool by
|
||||
memcell parent → reverse-resolve to episode, ordering episodes by the
|
||||
per-memcell maximum fact score."""
|
||||
from everos.config.settings import load_settings
|
||||
|
||||
monkeypatch.setenv("EVEROS_SEARCH__VECTOR_STRATEGY", "maxsim_atomic")
|
||||
load_settings.cache_clear()
|
||||
# Two episodes; each has two atomic facts under it. The max fact score
|
||||
# per memcell is what should end up as the episode's score.
|
||||
mgr = _build_manager(
|
||||
episode_dense=[
|
||||
_episode_row("ep_A", memcell_id="mc_A"),
|
||||
_episode_row("ep_B", memcell_id="mc_B"),
|
||||
],
|
||||
atomic_fact_dense=[
|
||||
_atomic_fact_row("f_A1", parent_id="mc_A", score=0.95),
|
||||
_atomic_fact_row("f_A2", parent_id="mc_A", score=0.40),
|
||||
_atomic_fact_row("f_B1", parent_id="mc_B", score=0.75),
|
||||
],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, top_k=5))
|
||||
eps = resp.data.episodes
|
||||
# Both episodes returned, ordered by max-pool score desc.
|
||||
assert [e.id for e in eps] == ["ep_A", "ep_B"]
|
||||
assert eps[0].score == pytest.approx(0.95) # max(0.95, 0.40)
|
||||
assert eps[1].score == pytest.approx(0.75)
|
||||
|
||||
|
||||
async def test_vector_maxsim_atomic_returns_empty_when_no_facts(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""No fact recall → no memcells to score → empty episode list."""
|
||||
from everos.config.settings import load_settings
|
||||
|
||||
monkeypatch.setenv("EVEROS_SEARCH__VECTOR_STRATEGY", "maxsim_atomic")
|
||||
load_settings.cache_clear()
|
||||
mgr = _build_manager(
|
||||
episode_dense=[_episode_row("ep_A", memcell_id="mc_A")],
|
||||
atomic_fact_dense=[],
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.VECTOR, top_k=5))
|
||||
assert resp.data.episodes == []
|
||||
|
||||
|
||||
# ── HYBRID / AGENTIC: prerequisite errors ──────────────────────────────
|
||||
|
||||
|
||||
async def test_hybrid_requires_embedding() -> None:
|
||||
mgr = _build_manager()
|
||||
with pytest.raises(RuntimeError, match="embedding"):
|
||||
await mgr.search(_user_req(method=SearchMethod.HYBRID))
|
||||
|
||||
|
||||
async def test_hybrid_does_not_require_llm_by_default() -> None:
|
||||
"""HYBRID no longer auto-pulls LLM. With enable_llm_rerank=False the
|
||||
fusion-only path (RRF / LR) should run without an LLM client."""
|
||||
mgr = _build_manager(embedding=_StubEmbedding())
|
||||
# Should not raise: no LLM needed when caller opts out of Phase-5 rerank.
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.HYBRID))
|
||||
assert resp.data.episodes == [] # empty stub recallers → empty result
|
||||
|
||||
|
||||
async def test_hybrid_requires_llm_when_enable_llm_rerank_true() -> None:
|
||||
"""Setting ``enable_llm_rerank=True`` makes the LLM mandatory."""
|
||||
mgr = _build_manager(embedding=_StubEmbedding())
|
||||
with pytest.raises(RuntimeError, match="enable_llm_rerank"):
|
||||
await mgr.search(_user_req(method=SearchMethod.HYBRID, enable_llm_rerank=True))
|
||||
|
||||
|
||||
async def test_user_hybrid_episode_fuses_and_evicts_facts() -> None:
|
||||
"""HYBRID episode path: hierarchy pipeline (RRF -> MaxSim -> merge -> eviction).
|
||||
|
||||
ep_1 has a fact scoring higher than the RRF score -> fact evicts episode.
|
||||
ep_2 has no facts -> episode emitted as-is.
|
||||
"""
|
||||
ep1 = _episode_row("ep_1", score=0.8, memcell_id="mc_1")
|
||||
ep2 = _episode_row("ep_2", score=0.7, memcell_id="mc_2")
|
||||
fact1 = FactCandidate(
|
||||
id="f1",
|
||||
parent_episode_id="ep_1",
|
||||
score=0.95,
|
||||
metadata={"fact": "Alice prefers oat milk"},
|
||||
)
|
||||
mgr = _build_manager(
|
||||
episode_sparse=[ep1, ep2],
|
||||
episode_dense=[ep1, ep2],
|
||||
facts_map={"ep_1": [fact1]},
|
||||
embedding=_StubEmbedding(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.HYBRID, top_k=10))
|
||||
eps = resp.data.episodes
|
||||
assert len(eps) >= 1
|
||||
ep1_result = next((e for e in eps if e.id == "ep_1"), None)
|
||||
assert ep1_result is not None
|
||||
assert len(ep1_result.atomic_facts) == 1
|
||||
assert ep1_result.atomic_facts[0].id == "f1"
|
||||
|
||||
|
||||
async def test_agentic_requires_reranker_and_llm() -> None:
|
||||
mgr = _build_manager(embedding=_StubEmbedding())
|
||||
with pytest.raises(RuntimeError, match="rerank provider"):
|
||||
await mgr.search(_user_req(method=SearchMethod.AGENTIC))
|
||||
|
||||
|
||||
async def test_agent_hybrid_requires_reranker_without_llm_rerank() -> None:
|
||||
"""``owner_type='agent'`` + HYBRID + ``enable_llm_rerank=False`` reaches
|
||||
the skill cross-encoder lane (``skill_hybrid``: rrf → cross-encoder),
|
||||
so a missing rerank provider must fail-fast with a config hint rather
|
||||
than crash deep inside the rerank callback.
|
||||
"""
|
||||
mgr = _build_manager(embedding=_StubEmbedding())
|
||||
with pytest.raises(RuntimeError, match="rerank provider"):
|
||||
await mgr.search(_agent_req(method=SearchMethod.HYBRID))
|
||||
|
||||
|
||||
async def test_agent_hybrid_with_llm_rerank_does_not_need_reranker() -> None:
|
||||
"""The LLM-rerank lane skips the cross-encoder and dispatches through
|
||||
``arank`` instead, so a missing reranker is fine as long as the LLM
|
||||
client is configured. Empty stub recallers → empty result; the call
|
||||
must not raise on the reranker-absence path.
|
||||
"""
|
||||
mgr = _build_manager(embedding=_StubEmbedding(), llm_client=_StubLLM())
|
||||
resp = await mgr.search(
|
||||
_agent_req(method=SearchMethod.HYBRID, enable_llm_rerank=True)
|
||||
)
|
||||
assert resp.data.agent_skills == []
|
||||
assert resp.data.agent_cases == []
|
||||
|
||||
|
||||
class _StubReranker:
|
||||
"""Minimal reranker stub — returns trivial scores."""
|
||||
|
||||
async def rerank(self, query: str, documents: Sequence[str]) -> list[Any]:
|
||||
from everos.component.rerank.protocol import RerankResult
|
||||
|
||||
return [RerankResult(index=i, score=1.0) for i in range(len(documents))]
|
||||
|
||||
|
||||
class _StubLLM:
|
||||
"""Minimal LLM stub — satisfies protocol without making real calls."""
|
||||
|
||||
async def chat(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return ""
|
||||
|
||||
|
||||
async def test_agentic_episode_delegates_to_search_episodes_agentic(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""AGENTIC method delegates to search_episodes_agentic and returns its result."""
|
||||
import datetime as _dt
|
||||
|
||||
from everos.memory.search.dto import SearchEpisodeItem
|
||||
|
||||
fake_result = [
|
||||
SearchEpisodeItem(
|
||||
id="ep_1",
|
||||
score=0.9,
|
||||
session_id="s",
|
||||
user_id="alice",
|
||||
timestamp=_dt.datetime(2026, 1, 1, tzinfo=_dt.UTC),
|
||||
sender_ids=["alice"],
|
||||
subject="s",
|
||||
summary="s",
|
||||
episode="body",
|
||||
type="Conversation",
|
||||
atomic_facts=[],
|
||||
)
|
||||
]
|
||||
|
||||
async def _fake_agentic(*args: Any, **kwargs: Any) -> list[SearchEpisodeItem]:
|
||||
return fake_result
|
||||
|
||||
monkeypatch.setattr(
|
||||
"everos.memory.search.manager.search_episodes_agentic", _fake_agentic
|
||||
)
|
||||
|
||||
mgr = _build_manager(
|
||||
embedding=_StubEmbedding(),
|
||||
reranker=_StubReranker(),
|
||||
llm_client=_StubLLM(),
|
||||
)
|
||||
resp = await mgr.search(_user_req(method=SearchMethod.AGENTIC))
|
||||
assert resp.data.episodes == fake_result
|
||||
|
||||
|
||||
# ── AGENT owner hard partition ─────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agent_keyword_returns_cases_and_skills_only() -> None:
|
||||
mgr = _build_manager(
|
||||
case_sparse=[_case_row("c_1")],
|
||||
skill_sparse=[_skill_row("s_1")],
|
||||
)
|
||||
resp = await mgr.search(_agent_req())
|
||||
assert resp.data.episodes == []
|
||||
assert resp.data.profiles == []
|
||||
assert [c.id for c in resp.data.agent_cases] == ["c_1"]
|
||||
assert [s.id for s in resp.data.agent_skills] == ["s_1"]
|
||||
|
||||
|
||||
async def test_agent_owner_ignores_include_profile() -> None:
|
||||
"""Profile is user-only at this revision."""
|
||||
mgr = _build_manager()
|
||||
resp = await mgr.search(_agent_req(include_profile=True))
|
||||
assert resp.data.profiles == []
|
||||
|
||||
|
||||
# ── Top-k behaviour ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_top_k_truncates_results() -> None:
|
||||
rows = [_episode_row(f"ep_{i}", score=1.0 - i * 0.01) for i in range(10)]
|
||||
mgr = _build_manager(episode_sparse=rows)
|
||||
resp = await mgr.search(_user_req(top_k=3))
|
||||
assert [e.id for e in resp.data.episodes] == ["ep_0", "ep_1", "ep_2"]
|
||||
|
||||
|
||||
async def test_top_k_minus_one_caps_at_100() -> None:
|
||||
rows = [_episode_row(f"ep_{i}") for i in range(120)]
|
||||
mgr = _build_manager(episode_sparse=rows)
|
||||
resp = await mgr.search(_user_req(top_k=-1))
|
||||
assert len(resp.data.episodes) == 100
|
||||
|
||||
|
||||
# ── AGENTIC agent_case / agent_skill delegation ───────────────────────────
|
||||
|
||||
|
||||
async def test_agentic_agent_cases_delegates_to_search_agent_cases_agentic(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""AGENTIC method for agent owner delegates to search_agent_cases_agentic."""
|
||||
import datetime as _dt
|
||||
|
||||
from everos.memory.search.dto import SearchAgentCaseItem
|
||||
|
||||
fake_cases = [
|
||||
SearchAgentCaseItem(
|
||||
id="c_1",
|
||||
agent_id="agent_a",
|
||||
session_id="sess_b",
|
||||
timestamp=_dt.datetime(2026, 1, 1, tzinfo=_dt.UTC),
|
||||
task_intent="handle login",
|
||||
approach="retry with backoff",
|
||||
quality_score=0.9,
|
||||
score=0.85,
|
||||
)
|
||||
]
|
||||
|
||||
async def _fake_cases_agentic(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> list[SearchAgentCaseItem]:
|
||||
return fake_cases
|
||||
|
||||
monkeypatch.setattr(
|
||||
"everos.memory.search.manager.search_agent_cases_agentic",
|
||||
_fake_cases_agentic,
|
||||
)
|
||||
|
||||
mgr = _build_manager(
|
||||
embedding=_StubEmbedding(),
|
||||
reranker=_StubReranker(),
|
||||
llm_client=_StubLLM(),
|
||||
)
|
||||
resp = await mgr.search(_agent_req(method=SearchMethod.AGENTIC))
|
||||
assert resp.data.agent_cases == fake_cases
|
||||
|
||||
|
||||
async def test_agentic_agent_skills_delegates_to_search_agent_skills_agentic(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""AGENTIC method for agent owner delegates to search_agent_skills_agentic."""
|
||||
|
||||
from everos.memory.search.dto import SearchAgentSkillItem
|
||||
|
||||
fake_skills = [
|
||||
SearchAgentSkillItem(
|
||||
id="s_1",
|
||||
agent_id="agent_a",
|
||||
name="auth_refresh",
|
||||
description="Refreshes auth tokens",
|
||||
content="Retry with new token",
|
||||
confidence=0.9,
|
||||
maturity_score=0.7,
|
||||
source_case_ids=[],
|
||||
score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
async def _fake_skills_agentic(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> list[SearchAgentSkillItem]:
|
||||
return fake_skills
|
||||
|
||||
monkeypatch.setattr(
|
||||
"everos.memory.search.manager.search_agent_skills_agentic",
|
||||
_fake_skills_agentic,
|
||||
)
|
||||
|
||||
mgr = _build_manager(
|
||||
embedding=_StubEmbedding(),
|
||||
reranker=_StubReranker(),
|
||||
llm_client=_StubLLM(),
|
||||
)
|
||||
resp = await mgr.search(_agent_req(method=SearchMethod.AGENTIC))
|
||||
assert resp.data.agent_skills == fake_skills
|
||||
|
||||
|
||||
# ── _merge_by_id_max / _case_bridged_skills helpers ──────────────────────
|
||||
|
||||
|
||||
def test_merge_by_id_max_keeps_higher_score_on_collision() -> None:
|
||||
"""Same-id collision → keep the higher score; non-colliding rows are
|
||||
unioned. Used to fold bridge candidates into the direct dense pool.
|
||||
"""
|
||||
from everos.memory.search.manager import _merge_by_id_max
|
||||
|
||||
primary = [
|
||||
Candidate(id="s1", score=0.5, source="vector", metadata={"src": "primary"}),
|
||||
Candidate(id="s2", score=0.7, source="vector", metadata={"src": "primary"}),
|
||||
]
|
||||
extra = [
|
||||
Candidate(id="s1", score=0.9, source="vector", metadata={"src": "bridge"}),
|
||||
Candidate(id="s2", score=0.3, source="vector", metadata={"src": "bridge"}),
|
||||
Candidate(id="s3", score=0.6, source="vector", metadata={"src": "bridge"}),
|
||||
]
|
||||
merged = {c.id: c for c in _merge_by_id_max(primary, extra)}
|
||||
# s1 collision → bridge wins (0.9 > 0.5); s2 collision → primary wins
|
||||
# (0.7 > 0.3); s3 fresh-from-bridge is added.
|
||||
assert merged["s1"].score == 0.9
|
||||
assert merged["s1"].metadata["src"] == "bridge"
|
||||
assert merged["s2"].score == 0.7
|
||||
assert merged["s2"].metadata["src"] == "primary"
|
||||
assert merged["s3"].score == 0.6
|
||||
|
||||
|
||||
async def test_case_bridged_skills_max_pools_score_across_source_cases() -> None:
|
||||
"""Each bridged skill inherits the highest score among its matched
|
||||
source cases (mirrors the ``maxsim_atomic`` fact→episode pooling).
|
||||
Source cases not present in the bridge pool are ignored.
|
||||
"""
|
||||
skill_row = Candidate(
|
||||
id="agent_a_skill_x",
|
||||
score=0.0, # bridge ignores the recaller-side score
|
||||
source="vector",
|
||||
metadata={"source_case_ids": ["c1", "c2", "c3"], "name": "x"},
|
||||
)
|
||||
mgr = _build_manager(skill_by_case=[skill_row])
|
||||
bridge_cases = [
|
||||
Candidate(id="c1", score=0.4, source="vector", metadata={}),
|
||||
Candidate(id="c2", score=0.9, source="vector", metadata={}), # max wins
|
||||
Candidate(id="c_other", score=0.7, source="vector", metadata={}),
|
||||
]
|
||||
bridged = await mgr._case_bridged_skills(bridge_cases, where="", top_k=5)
|
||||
assert len(bridged) == 1
|
||||
assert bridged[0].id == "agent_a_skill_x"
|
||||
# c1=0.4 and c2=0.9 are in the bridge pool; c3 is not → max-pool == 0.9.
|
||||
assert bridged[0].score == pytest.approx(0.9)
|
||||
# Metadata (incl. ``source_case_ids``) rides through so downstream
|
||||
# shaping doesn't need a second fetch.
|
||||
assert bridged[0].metadata["source_case_ids"] == ["c1", "c2", "c3"]
|
||||
|
||||
|
||||
async def test_case_bridged_skills_returns_empty_for_none_or_empty_input() -> None:
|
||||
"""No bridge cases ⇒ no bridge recall (skip the reverse fetch entirely).
|
||||
This is the cross-encoder lane / KEYWORD / VECTOR contract.
|
||||
"""
|
||||
mgr = _build_manager(skill_by_case=[_skill_row("s1")]) # noise the stub
|
||||
assert await mgr._case_bridged_skills(None, where="", top_k=5) == []
|
||||
assert await mgr._case_bridged_skills([], where="", top_k=5) == []
|
||||
|
||||
|
||||
# ── Agent HYBRID lane selection ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_agent_hybrid_no_llm_rerank_runs_cross_encoder_lane(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``enable_llm_rerank=False`` for agent HYBRID must dispatch through
|
||||
``search_agent_skills_hybrid`` (rrf → cross-encoder lane) with the
|
||||
configured reranker, not through generic ``arank``.
|
||||
"""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_hybrid(
|
||||
query: str,
|
||||
*,
|
||||
sparse: list[Candidate],
|
||||
dense: list[Candidate],
|
||||
reranker: Any,
|
||||
top_k: int,
|
||||
) -> list:
|
||||
captured.update(
|
||||
query=query, sparse=sparse, dense=dense, reranker=reranker, top_k=top_k
|
||||
)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"everos.memory.search.manager.search_agent_skills_hybrid", _fake_hybrid
|
||||
)
|
||||
stub_reranker = _StubReranker()
|
||||
mgr = _build_manager(embedding=_StubEmbedding(), reranker=stub_reranker)
|
||||
await mgr.search(_agent_req(method=SearchMethod.HYBRID))
|
||||
|
||||
assert captured["query"] == "hi"
|
||||
# Manager forwards its configured reranker to the cross-encoder lane.
|
||||
assert captured["reranker"] is stub_reranker
|
||||
# Agent kinds cap unlimited-mode top_k at _AGENT_TOP_K_CAP (10).
|
||||
assert captured["top_k"] == 10
|
||||
|
||||
|
||||
async def test_agent_hybrid_llm_rerank_dispatches_arank_for_case_then_skill(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""LLM rerank lane: ``_search_cases_and_skills`` runs serially —
|
||||
``arank`` is called once with ``memory_type="case"`` and once with
|
||||
``memory_type="skill"``, both with ``enable_rerank=True`` + the LLM
|
||||
client. Order matters: the case call must precede the skill call so
|
||||
its results can feed the bridge.
|
||||
"""
|
||||
from everalgo.types import RankOutput
|
||||
|
||||
calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
async def _fake_arank(rank_input: Any, **kwargs: Any) -> RankOutput:
|
||||
calls.append((rank_input.memory_type, kwargs))
|
||||
return RankOutput(items=[], metadata={})
|
||||
|
||||
monkeypatch.setattr("everos.memory.search.manager.arank", _fake_arank)
|
||||
mgr = _build_manager(embedding=_StubEmbedding(), llm_client=_StubLLM())
|
||||
await mgr.search(_agent_req(method=SearchMethod.HYBRID, enable_llm_rerank=True))
|
||||
|
||||
# Two dispatches in the documented serial order.
|
||||
assert [c[0] for c in calls] == ["case", "skill"]
|
||||
# Both runs opt into rerank with the LLM client wired in.
|
||||
for _mt, kw in calls:
|
||||
assert kw["enable_rerank"] is True
|
||||
assert kw["llm"] is mgr._llm
|
||||
assert kw["rerank_top_k"] == 10 # _AGENT_TOP_K_CAP
|
||||
|
||||
|
||||
async def test_agent_hybrid_llm_rerank_merges_bridged_skills_into_dense_pool(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""The bridge must surface into the skill dispatch: skills resolved
|
||||
by ``fetch_by_case_ids`` are max-pooled into the dense candidates that
|
||||
``arank`` sees on the second call, while the direct skill recall pool
|
||||
is preserved.
|
||||
"""
|
||||
from everalgo.types import RankOutput, ScoredItem
|
||||
|
||||
case_result = ScoredItem(
|
||||
id="agent_a_c1",
|
||||
score=0.85,
|
||||
item_type="case",
|
||||
# Shaper requires owner_type="agent" + timestamp + intent/approach;
|
||||
# otherwise the case is dropped and bridge_cases comes back empty.
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"session_id": "sess_b",
|
||||
"timestamp": _ts(),
|
||||
"task_intent": "intent c1",
|
||||
"approach": "approach c1",
|
||||
"quality_score": 0.8,
|
||||
},
|
||||
)
|
||||
skill_direct = _skill_row("s_direct")
|
||||
skill_bridged = Candidate(
|
||||
id="s_bridged",
|
||||
score=0.0,
|
||||
source="vector",
|
||||
metadata={"source_case_ids": ["agent_a_c1"], "name": "s_bridged"},
|
||||
)
|
||||
|
||||
seen_skill_dense: dict[str, list[Candidate]] = {}
|
||||
|
||||
async def _fake_arank(rank_input: Any, **_: Any) -> RankOutput:
|
||||
if rank_input.memory_type == "case":
|
||||
return RankOutput(items=[case_result], metadata={})
|
||||
# skill call — capture the merged dense pool the manager built.
|
||||
seen_skill_dense["dense"] = list(rank_input.dense_candidates)
|
||||
return RankOutput(items=[], metadata={})
|
||||
|
||||
monkeypatch.setattr("everos.memory.search.manager.arank", _fake_arank)
|
||||
mgr = _build_manager(
|
||||
embedding=_StubEmbedding(),
|
||||
llm_client=_StubLLM(),
|
||||
skill_sparse=[],
|
||||
skill_dense=[skill_direct],
|
||||
skill_by_case=[skill_bridged],
|
||||
)
|
||||
await mgr.search(_agent_req(method=SearchMethod.HYBRID, enable_llm_rerank=True))
|
||||
|
||||
dense_ids = {c.id for c in seen_skill_dense["dense"]}
|
||||
# Direct dense recall is preserved AND the case-bridged skill is unioned.
|
||||
assert dense_ids == {"s_direct", "s_bridged"}
|
||||
# The bridged skill inherits the matched case's score (0.85 from c1).
|
||||
by_id = {c.id: c for c in seen_skill_dense["dense"]}
|
||||
assert by_id["s_bridged"].score == pytest.approx(0.85)
|
||||
145
tests/unit/test_memory/test_search/test_recall_agent_skill.py
Normal file
145
tests/unit/test_memory/test_search/test_recall_agent_skill.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Real-LanceDB tests for ``AgentSkillRecaller.fetch_by_case_ids``.
|
||||
|
||||
The case→skill bridge reverse-resolves skills by ``source_case_ids``
|
||||
membership using DataFusion's ``array_has`` on a ``list<utf8>`` column.
|
||||
These tests exercise the actual SQL ``where`` predicate (no recaller
|
||||
stubs):
|
||||
|
||||
* OR-composition over multiple case ids,
|
||||
* hits respect the partition filter (``where`` passed by the caller),
|
||||
* empty case-id input short-circuits without a LanceDB call,
|
||||
* case ids containing single quotes round-trip safely via the ``_q``
|
||||
escaper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.infra.persistence.lancedb import (
|
||||
AgentSkill as LanceAgentSkill,
|
||||
)
|
||||
from everos.infra.persistence.lancedb import (
|
||||
agent_skill_repo,
|
||||
lancedb_manager,
|
||||
)
|
||||
from everos.memory.search.recall.agent_skill import AgentSkillRecaller
|
||||
from everos.memory.search.recall.base import RecallerDeps
|
||||
|
||||
|
||||
class _WhitespaceTokenizer(Tokenizer):
|
||||
"""Bridge reverse-fetch never tokenises; satisfy the deps contract."""
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return text.split()
|
||||
|
||||
|
||||
def _skill_row(
|
||||
*,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
source_case_ids: list[str],
|
||||
) -> LanceAgentSkill:
|
||||
return LanceAgentSkill(
|
||||
id=f"{owner_id}_{name}",
|
||||
owner_id=owner_id,
|
||||
owner_type="agent",
|
||||
name=name,
|
||||
description=f"desc {name}",
|
||||
description_tokens=f"desc {name}",
|
||||
content=f"body of {name}",
|
||||
content_tokens=f"body of {name}",
|
||||
confidence=0.7,
|
||||
maturity_score=0.6,
|
||||
source_case_ids=source_case_ids,
|
||||
cluster_id=None,
|
||||
md_path=f"agents/{owner_id}/skills/{name}/SKILL.md",
|
||||
content_sha256="x" * 64,
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Isolate LanceDB under tmp memory root per test."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
def _recaller() -> AgentSkillRecaller:
|
||||
return AgentSkillRecaller(RecallerDeps(tokenizer=_WhitespaceTokenizer()))
|
||||
|
||||
|
||||
_OWNER_WHERE = "owner_id = 'agt' AND owner_type = 'agent'"
|
||||
|
||||
|
||||
async def test_fetch_by_case_ids_matches_any_lineage_case() -> None:
|
||||
"""OR over case ids: a skill surfaces when its ``source_case_ids``
|
||||
contains at least one queried case."""
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="s1", owner_id="agt", source_case_ids=["c_a", "c_b"]),
|
||||
_skill_row(name="s2", owner_id="agt", source_case_ids=["c_c"]),
|
||||
_skill_row(name="s3", owner_id="agt", source_case_ids=["c_d"]),
|
||||
]
|
||||
)
|
||||
|
||||
got = await _recaller().fetch_by_case_ids(["c_a", "c_c"], _OWNER_WHERE, limit=10)
|
||||
|
||||
assert sorted(c.id for c in got) == ["agt_s1", "agt_s2"]
|
||||
|
||||
|
||||
async def test_fetch_by_case_ids_respects_owner_partition() -> None:
|
||||
"""The ``where`` clause is AND-composed with ``array_has(...)`` — a
|
||||
skill in a different owner partition must not leak through."""
|
||||
await agent_skill_repo.upsert(
|
||||
[
|
||||
_skill_row(name="s1", owner_id="agt", source_case_ids=["c_a"]),
|
||||
_skill_row(name="s1", owner_id="other", source_case_ids=["c_a"]),
|
||||
]
|
||||
)
|
||||
|
||||
got = await _recaller().fetch_by_case_ids(["c_a"], _OWNER_WHERE, limit=10)
|
||||
|
||||
assert [c.id for c in got] == ["agt_s1"]
|
||||
|
||||
|
||||
async def test_fetch_by_case_ids_returns_empty_for_no_ids() -> None:
|
||||
"""Empty input short-circuits — no LanceDB query is issued."""
|
||||
got = await _recaller().fetch_by_case_ids([], _OWNER_WHERE, limit=10)
|
||||
assert got == []
|
||||
|
||||
|
||||
async def test_fetch_by_case_ids_escapes_single_quotes() -> None:
|
||||
"""A case id with a single quote must not break the SQL literal.
|
||||
|
||||
The ``_q`` escaper turns ``'`` into ``''`` (SQL standard); without it
|
||||
the where-clause would close the string literal prematurely.
|
||||
"""
|
||||
quoted_id = "ac_o'brien_0001"
|
||||
await agent_skill_repo.upsert(
|
||||
[_skill_row(name="s1", owner_id="agt", source_case_ids=[quoted_id])]
|
||||
)
|
||||
|
||||
got = await _recaller().fetch_by_case_ids([quoted_id], _OWNER_WHERE, limit=10)
|
||||
|
||||
assert [c.id for c in got] == ["agt_s1"]
|
||||
|
||||
|
||||
async def test_fetch_by_case_ids_carries_source_case_ids_in_metadata() -> None:
|
||||
"""The full ``source_case_ids`` list must ride back in metadata so the
|
||||
manager's max-pool can score against the caller's case_score map."""
|
||||
await agent_skill_repo.upsert(
|
||||
[_skill_row(name="s1", owner_id="agt", source_case_ids=["c_a", "c_b", "c_c"])]
|
||||
)
|
||||
|
||||
got = await _recaller().fetch_by_case_ids(["c_a"], _OWNER_WHERE, limit=10)
|
||||
|
||||
assert len(got) == 1
|
||||
assert sorted(got[0].metadata["source_case_ids"]) == ["c_a", "c_b", "c_c"]
|
||||
264
tests/unit/test_memory/test_search/test_recall_atomic_fact.py
Normal file
264
tests/unit/test_memory/test_search/test_recall_atomic_fact.py
Normal file
@ -0,0 +1,264 @@
|
||||
"""Real-LanceDB tests for ``AtomicFactRecaller.facts_for_episodes``.
|
||||
|
||||
The MRAG bridge is the only path that links facts back to episodes, and
|
||||
the previous ``parent_type='episode' AND parent_id IN (episode_ids)``
|
||||
query never matched: cascade writes facts with
|
||||
``parent_type='memcell'``, ``parent_id=memcell_id``. The fixed version
|
||||
takes an ``episode → memcell`` map from the caller, queries by the
|
||||
deduped memcell set, and re-buckets results under every episode that
|
||||
shares each memcell.
|
||||
|
||||
These tests exercise the real LanceDB query path (no recaller stubs):
|
||||
- shared memcell → fact appears under both episodes,
|
||||
- distinct memcells → facts bucket exclusively to their owning episode,
|
||||
- empty / unknown memcells → empty result, no LanceDB call surprise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.infra.persistence.lancedb import (
|
||||
AtomicFact,
|
||||
ParentType,
|
||||
atomic_fact_repo,
|
||||
lancedb_manager,
|
||||
)
|
||||
from everos.memory.search.recall.atomic_fact import AtomicFactRecaller
|
||||
from everos.memory.search.recall.base import RecallerDeps
|
||||
|
||||
|
||||
class _WhitespaceTokenizer(Tokenizer):
|
||||
"""Trivial tokenizer — the bridge doesn't touch text tokenisation."""
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return text.split()
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _fact_row(
|
||||
*,
|
||||
fid: str,
|
||||
memcell_id: str,
|
||||
fact: str,
|
||||
owner_id: str = "alice",
|
||||
) -> AtomicFact:
|
||||
return AtomicFact(
|
||||
id=fid,
|
||||
entry_id=fid.split("_", 1)[1] if "_" in fid else fid,
|
||||
owner_id=owner_id,
|
||||
owner_type="user",
|
||||
session_id="sess_1",
|
||||
timestamp=_ts(),
|
||||
parent_type=ParentType.MEMCELL.value,
|
||||
parent_id=memcell_id,
|
||||
sender_ids=[owner_id],
|
||||
fact=fact,
|
||||
fact_tokens=fact,
|
||||
md_path=f"users/{owner_id}/.atomic_facts/atomic_fact-2026-01-01.md",
|
||||
content_sha256="x" * 64,
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Isolate LanceDB to a tmp memory root per test."""
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
def _recaller() -> AtomicFactRecaller:
|
||||
return AtomicFactRecaller(RecallerDeps(tokenizer=_WhitespaceTokenizer()))
|
||||
|
||||
|
||||
async def test_facts_for_episodes_buckets_by_shared_memcell() -> None:
|
||||
"""Two episodes sharing one memcell both see the same fact pool.
|
||||
|
||||
Episode-level fan-out (Episode pipeline runs once per cell but emits
|
||||
one Episode per user sender) gives multiple LanceDB episode rows
|
||||
pointing at the same memcell. The bridge must surface every fact
|
||||
that hangs off that memcell under both episode ids.
|
||||
"""
|
||||
await atomic_fact_repo.upsert(
|
||||
[
|
||||
_fact_row(fid="alice_af_1", memcell_id="mc_shared", fact="likes hiking"),
|
||||
_fact_row(fid="alice_af_2", memcell_id="mc_shared", fact="lives in tokyo"),
|
||||
_fact_row(fid="alice_af_3", memcell_id="mc_other", fact="prefers oat milk"),
|
||||
]
|
||||
)
|
||||
|
||||
ep_to_memcell = {
|
||||
"alice_ep_a": "mc_shared",
|
||||
"alice_ep_b": "mc_shared",
|
||||
"alice_ep_c": "mc_other",
|
||||
}
|
||||
where = "owner_id = 'alice' AND owner_type = 'user'"
|
||||
out = await _recaller().facts_for_episodes(ep_to_memcell, where, per_episode=10)
|
||||
|
||||
assert sorted(out.keys()) == ["alice_ep_a", "alice_ep_b", "alice_ep_c"]
|
||||
assert sorted(f.id for f in out["alice_ep_a"]) == ["alice_af_1", "alice_af_2"]
|
||||
assert sorted(f.id for f in out["alice_ep_b"]) == ["alice_af_1", "alice_af_2"]
|
||||
assert [f.id for f in out["alice_ep_c"]] == ["alice_af_3"]
|
||||
# parent_episode_id is the *bucket* episode, not the underlying memcell:
|
||||
# the same fact_1 surfaces twice with different parent_episode_id values.
|
||||
fact1_in_a = next(f for f in out["alice_ep_a"] if f.id == "alice_af_1")
|
||||
fact1_in_b = next(f for f in out["alice_ep_b"] if f.id == "alice_af_1")
|
||||
assert fact1_in_a.parent_episode_id == "alice_ep_a"
|
||||
assert fact1_in_b.parent_episode_id == "alice_ep_b"
|
||||
|
||||
|
||||
async def test_facts_for_episodes_returns_empty_for_no_episodes() -> None:
|
||||
out = await _recaller().facts_for_episodes({}, "owner_id = 'alice'", per_episode=10)
|
||||
assert out == {}
|
||||
|
||||
|
||||
async def test_facts_for_episodes_skips_unknown_memcells() -> None:
|
||||
"""Episodes whose memcell has no facts simply don't appear in the result."""
|
||||
await atomic_fact_repo.upsert(
|
||||
[_fact_row(fid="alice_af_1", memcell_id="mc_a", fact="hello")]
|
||||
)
|
||||
|
||||
out = await _recaller().facts_for_episodes(
|
||||
{"alice_ep_a": "mc_a", "alice_ep_b": "mc_missing"},
|
||||
"owner_id = 'alice' AND owner_type = 'user'",
|
||||
per_episode=10,
|
||||
)
|
||||
assert "alice_ep_a" in out
|
||||
assert "alice_ep_b" not in out
|
||||
assert [f.id for f in out["alice_ep_a"]] == ["alice_af_1"]
|
||||
|
||||
|
||||
async def test_facts_for_episodes_filters_by_where_clause() -> None:
|
||||
"""The caller's where clause is preserved (e.g. owner pinning)."""
|
||||
await atomic_fact_repo.upsert(
|
||||
[
|
||||
_fact_row(
|
||||
fid="alice_af_1",
|
||||
memcell_id="mc_a",
|
||||
fact="alice fact",
|
||||
owner_id="alice",
|
||||
),
|
||||
_fact_row(
|
||||
fid="bob_af_1",
|
||||
memcell_id="mc_a",
|
||||
fact="bob fact",
|
||||
owner_id="bob",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
out = await _recaller().facts_for_episodes(
|
||||
{"alice_ep_a": "mc_a"},
|
||||
"owner_id = 'alice' AND owner_type = 'user'",
|
||||
per_episode=10,
|
||||
)
|
||||
assert [f.id for f in out["alice_ep_a"]] == ["alice_af_1"]
|
||||
|
||||
|
||||
async def test_facts_for_episodes_drops_empty_memcell_ids() -> None:
|
||||
"""Episodes whose parent_id is missing (empty string) are dropped silently.
|
||||
|
||||
Real-world cause: a candidate row that lost its ``parent_id`` (data
|
||||
corruption, manual edit). The bridge must not crash and must not
|
||||
emit ``parent_id IN ('')`` — which would match every empty-string
|
||||
row in the table.
|
||||
"""
|
||||
await atomic_fact_repo.upsert(
|
||||
[_fact_row(fid="alice_af_1", memcell_id="", fact="orphan fact")]
|
||||
)
|
||||
|
||||
out = await _recaller().facts_for_episodes(
|
||||
{"alice_ep_a": ""},
|
||||
"owner_id = 'alice' AND owner_type = 'user'",
|
||||
per_episode=10,
|
||||
)
|
||||
assert out == {}
|
||||
|
||||
|
||||
# ── MRAG fact-level scoring (regression for query_vector handling) ─────
|
||||
|
||||
|
||||
def _unit_vector(direction: int, dim: int = 1024) -> list[float]:
|
||||
"""Return a unit vector with 1.0 at ``direction`` axis, 0 elsewhere.
|
||||
|
||||
Used to build deterministic cosine relationships in the tests below:
|
||||
same direction → distance 0 (score 1.0); orthogonal → distance 1
|
||||
(score 0.0). The ``vector`` field on AtomicFact requires 1024-dim,
|
||||
so any test that goes through ``.nearest_to`` needs full-width.
|
||||
"""
|
||||
out = [0.0] * dim
|
||||
out[direction] = 1.0
|
||||
return out
|
||||
|
||||
|
||||
async def test_facts_for_episodes_assigns_real_cosine_score_with_query_vector() -> None:
|
||||
"""Regression: ``query_vector`` triggers cosine ANN, not flat scan.
|
||||
|
||||
Pre-fix, ``facts_for_episodes`` only ran ``where parent_id IN (...)``
|
||||
and emitted every fact with ``score=0.0`` — the MRAG fact-level
|
||||
ranking collapsed to insertion order. Post-fix, ``query_vector``
|
||||
flows into ``.nearest_to(...).distance_type('cosine')`` and each
|
||||
fact lands with its real query↔fact relevance score.
|
||||
|
||||
Setup:
|
||||
- fact A's vector = unit on axis 0 (same direction as the query) →
|
||||
cosine distance 0 → score ≈ 1.0.
|
||||
- fact B's vector = unit on axis 1 (orthogonal to the query) →
|
||||
cosine distance 1 → score ≈ 0.0.
|
||||
|
||||
Assertion: A ranks first AND its score > B's score AND both are
|
||||
non-zero-distinguishable (catches the old hardcoded ``0.0`` bug).
|
||||
"""
|
||||
row_a = _fact_row(fid="alice_af_1", memcell_id="mc_shared", fact="close fact")
|
||||
row_a.vector = _unit_vector(0)
|
||||
row_b = _fact_row(fid="alice_af_2", memcell_id="mc_shared", fact="far fact")
|
||||
row_b.vector = _unit_vector(1)
|
||||
await atomic_fact_repo.upsert([row_a, row_b])
|
||||
|
||||
out = await _recaller().facts_for_episodes(
|
||||
{"alice_ep_a": "mc_shared"},
|
||||
"owner_id = 'alice' AND owner_type = 'user'",
|
||||
per_episode=10,
|
||||
query_vector=_unit_vector(0),
|
||||
)
|
||||
|
||||
facts = out["alice_ep_a"]
|
||||
assert [f.id for f in facts] == ["alice_af_1", "alice_af_2"], (
|
||||
"facts must be ordered by cosine distance ascending (closest first)"
|
||||
)
|
||||
assert facts[0].score > facts[1].score, "real cosine scoring must differentiate"
|
||||
assert facts[0].score > 0.5, "near-identical vectors should score close to 1"
|
||||
assert facts[1].score < 0.5, "orthogonal vectors should score close to 0"
|
||||
|
||||
|
||||
async def test_facts_for_episodes_score_zero_without_query_vector() -> None:
|
||||
"""Backward-compat: omitting ``query_vector`` falls back to flat scan.
|
||||
|
||||
Callers that don't need fact-level relevance (e.g. KV-style fetch
|
||||
where the parent ranking already encodes the signal) keep the old
|
||||
``score=0.0`` semantics. Documents the explicit contract so the
|
||||
fallback path is intentional, not an oversight.
|
||||
"""
|
||||
row = _fact_row(fid="alice_af_1", memcell_id="mc_a", fact="anything")
|
||||
row.vector = _unit_vector(0)
|
||||
await atomic_fact_repo.upsert([row])
|
||||
|
||||
out = await _recaller().facts_for_episodes(
|
||||
{"alice_ep_a": "mc_a"},
|
||||
"owner_id = 'alice' AND owner_type = 'user'",
|
||||
per_episode=10,
|
||||
# no query_vector
|
||||
)
|
||||
|
||||
assert out["alice_ep_a"][0].score == 0.0
|
||||
108
tests/unit/test_memory/test_search/test_recall_episode.py
Normal file
108
tests/unit/test_memory/test_search/test_recall_episode.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Unit tests for ``EpisodeRecaller.fetch_all_for_owner``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.memory.search.recall.base import RecallerDeps
|
||||
from everos.memory.search.recall.episode import EpisodeRecaller
|
||||
|
||||
|
||||
def _make_row(ep_id: str, mc_id: str) -> dict[str, Any]:
|
||||
"""Build a minimal episode LanceDB row dict for test fixtures."""
|
||||
return {
|
||||
"id": ep_id,
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess_1",
|
||||
"timestamp": 1000000,
|
||||
"sender_ids": ["alice"],
|
||||
"subject": f"subj {ep_id}",
|
||||
"summary": f"summary {ep_id}",
|
||||
"episode": f"body {ep_id}",
|
||||
"parent_id": mc_id,
|
||||
}
|
||||
|
||||
|
||||
def _mock_table(rows: list[dict[str, Any]]) -> MagicMock:
|
||||
tbl = MagicMock()
|
||||
tbl.query.return_value.where.return_value.to_list = AsyncMock(return_value=rows)
|
||||
return tbl
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def recaller() -> EpisodeRecaller:
|
||||
tok = MagicMock(spec=Tokenizer)
|
||||
tok.tokenize.return_value = ["hi"]
|
||||
return EpisodeRecaller(RecallerDeps(tokenizer=tok))
|
||||
|
||||
|
||||
async def test_fetch_all_for_owner_returns_memcell_keyed_candidates(
|
||||
recaller: EpisodeRecaller,
|
||||
) -> None:
|
||||
"""id must equal parent_id (memcell_id) so acluster_retrieve membership works."""
|
||||
rows = [
|
||||
_make_row("ep_1", "mc_1"),
|
||||
_make_row("ep_2", "mc_2"),
|
||||
]
|
||||
with patch(
|
||||
"everos.memory.search.recall.episode.get_table",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_table(rows),
|
||||
):
|
||||
result = await recaller.fetch_all_for_owner("owner_id = 'alice'")
|
||||
|
||||
assert len(result) == 2
|
||||
ids = {c.id for c in result}
|
||||
assert ids == {"mc_1", "mc_2"}, "id must be memcell_id, not episode_id"
|
||||
|
||||
|
||||
async def test_fetch_all_for_owner_stores_episode_id_in_metadata(
|
||||
recaller: EpisodeRecaller,
|
||||
) -> None:
|
||||
"""metadata['episode_id'] carries the real LanceDB episode id for final shaping."""
|
||||
rows = [_make_row("ep_abc", "mc_xyz")]
|
||||
with patch(
|
||||
"everos.memory.search.recall.episode.get_table",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_table(rows),
|
||||
):
|
||||
result = await recaller.fetch_all_for_owner("owner_id = 'alice'")
|
||||
|
||||
assert result[0].metadata["episode_id"] == "ep_abc"
|
||||
assert result[0].metadata["parent_id"] == "mc_xyz"
|
||||
|
||||
|
||||
async def test_fetch_all_for_owner_skips_rows_without_parent_id(
|
||||
recaller: EpisodeRecaller,
|
||||
) -> None:
|
||||
"""Rows without parent_id are silently skipped.
|
||||
|
||||
They are incomplete episode records.
|
||||
"""
|
||||
rows = [
|
||||
{
|
||||
"id": "ep_bad",
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "s",
|
||||
"timestamp": 1,
|
||||
"sender_ids": [],
|
||||
"subject": "",
|
||||
"summary": "",
|
||||
"episode": "",
|
||||
# no parent_id key
|
||||
},
|
||||
]
|
||||
with patch(
|
||||
"everos.memory.search.recall.episode.get_table",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_table(rows),
|
||||
):
|
||||
result = await recaller.fetch_all_for_owner("owner_id = 'alice'")
|
||||
|
||||
assert result == []
|
||||
189
tests/unit/test_memory/test_search/test_recall_or_semantics.py
Normal file
189
tests/unit/test_memory/test_search/test_recall_or_semantics.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Real-LanceDB regression: OR-mode BooleanQuery sparse recall.
|
||||
|
||||
Locks the fix for the tantivy implicit-AND poison: when a query
|
||||
contains an IDF≈0 token (typically the partition owner's own name on
|
||||
an owner-scoped corpus), the entire query used to return 0 hits. The
|
||||
fixed path wraps each token in a ``BooleanQuery`` with ``SHOULD``
|
||||
clauses (mirrors enterprise ES ``bool.should + minimum_should_match=1``)
|
||||
so other tokens can carry the query.
|
||||
|
||||
These tests build a tiny in-memory corpus where one term is 100% DF
|
||||
(the "poison" term) and verify that mixing it with informative
|
||||
content tokens still surfaces results.
|
||||
|
||||
White-box surfaces:
|
||||
- LanceDB ``episode`` table (real, per-test tmp root)
|
||||
- ``EpisodeRecaller.sparse_recall``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.infra.persistence.lancedb import (
|
||||
Episode,
|
||||
ParentType,
|
||||
episode_repo,
|
||||
lancedb_manager,
|
||||
)
|
||||
from everos.memory.search.recall.base import RecallerDeps, build_or_query
|
||||
from everos.memory.search.recall.episode import EpisodeRecaller
|
||||
|
||||
|
||||
class _WhitespaceTokenizer(Tokenizer):
|
||||
"""Split-on-whitespace tokenizer, lowercased.
|
||||
|
||||
The OR-semantics fix is independent of jieba's behaviour, so a
|
||||
trivial tokenizer keeps the test focused.
|
||||
"""
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [tok for tok in text.lower().split() if tok]
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _episode_row(
|
||||
*,
|
||||
eid: str,
|
||||
owner_id: str,
|
||||
body_tokens: str,
|
||||
) -> Episode:
|
||||
"""Build an Episode row with ``body_tokens`` indexed as ``episode_tokens``."""
|
||||
return Episode(
|
||||
id=f"{owner_id}_{eid}",
|
||||
entry_id=eid,
|
||||
owner_id=owner_id,
|
||||
owner_type="user",
|
||||
session_id="sess_1",
|
||||
timestamp=_ts(),
|
||||
parent_type=ParentType.MEMCELL.value,
|
||||
parent_id="mc_test",
|
||||
sender_ids=[owner_id],
|
||||
episode=body_tokens,
|
||||
episode_tokens=body_tokens,
|
||||
md_path=f"users/{owner_id}/episodes/episode-2026-01-01.md",
|
||||
content_sha256="x" * 64,
|
||||
vector=[0.0] * 1024,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
def _recaller() -> EpisodeRecaller:
|
||||
return EpisodeRecaller(RecallerDeps(tokenizer=_WhitespaceTokenizer()))
|
||||
|
||||
|
||||
# ── build_or_query helper unit-level checks ────────────────────────────
|
||||
|
||||
|
||||
def test_build_or_query_empty_returns_none() -> None:
|
||||
"""Empty / whitespace-only query → ``None`` (caller must short-circuit)."""
|
||||
tk = _WhitespaceTokenizer()
|
||||
assert build_or_query(tk, "", column="episode_tokens") is None
|
||||
assert build_or_query(tk, " ", column="episode_tokens") is None
|
||||
|
||||
|
||||
def test_build_or_query_single_token_returns_match_query() -> None:
|
||||
"""One token → bare MatchQuery (no boolean-wrapper overhead)."""
|
||||
from lancedb.query import MatchQuery
|
||||
|
||||
q = build_or_query(_WhitespaceTokenizer(), "hello", column="episode_tokens")
|
||||
assert isinstance(q, MatchQuery)
|
||||
|
||||
|
||||
def test_build_or_query_multi_token_returns_boolean_query() -> None:
|
||||
"""≥2 tokens → BooleanQuery with one SHOULD clause per token."""
|
||||
from lancedb.query import BooleanQuery
|
||||
|
||||
q = build_or_query(
|
||||
_WhitespaceTokenizer(), "alice support group", column="episode_tokens"
|
||||
)
|
||||
assert isinstance(q, BooleanQuery)
|
||||
|
||||
|
||||
# ── Live recall: poison token + informative token must surface results ──
|
||||
|
||||
|
||||
async def test_or_semantics_poison_token_does_not_kill_query() -> None:
|
||||
"""Two episodes, owner name in every doc (DF=100%), plus distinct content.
|
||||
|
||||
Pre-fix, querying ``"alice support group"`` against owner=alice would
|
||||
return 0 hits — the ``alice`` token (DF=100% → IDF≈0) poisoned the
|
||||
implicit-AND query parser and dragged the score-conjunction to zero.
|
||||
Post-fix, ``BooleanQuery + SHOULD`` lets ``support`` / ``group`` carry
|
||||
the query on their own.
|
||||
"""
|
||||
await episode_repo.upsert(
|
||||
[
|
||||
_episode_row(
|
||||
eid="ep_1",
|
||||
owner_id="alice",
|
||||
body_tokens="alice attended lgbtq support group last tuesday",
|
||||
),
|
||||
_episode_row(
|
||||
eid="ep_2",
|
||||
owner_id="alice",
|
||||
body_tokens="alice tried watercolor painting on saturday morning",
|
||||
),
|
||||
]
|
||||
)
|
||||
# LanceDB FTS only sees data merged into the index after optimize().
|
||||
# Tests treat that as part of "the corpus is ready to query".
|
||||
from everos.infra.persistence.lancedb import get_table
|
||||
|
||||
tbl = await get_table(Episode.TABLE_NAME, Episode)
|
||||
await tbl.optimize()
|
||||
|
||||
where = "owner_id = 'alice' AND owner_type = 'user'"
|
||||
cands = await _recaller().sparse_recall("alice support group", where, limit=10)
|
||||
assert cands, "alice + support + group should recall ep_1 via SHOULD"
|
||||
# ep_1 is the support-group episode; should rank above ep_2 (no support).
|
||||
assert cands[0].id == "alice_ep_1"
|
||||
assert cands[0].score > 0.0
|
||||
|
||||
|
||||
async def test_or_semantics_single_informative_token() -> None:
|
||||
"""Single non-poison token still recalls (regression for ``painting``)."""
|
||||
await episode_repo.upsert(
|
||||
[
|
||||
_episode_row(
|
||||
eid="ep_1",
|
||||
owner_id="alice",
|
||||
body_tokens="alice attended lgbtq support group",
|
||||
),
|
||||
_episode_row(
|
||||
eid="ep_2",
|
||||
owner_id="alice",
|
||||
body_tokens="alice tried watercolor painting on saturday",
|
||||
),
|
||||
]
|
||||
)
|
||||
from everos.infra.persistence.lancedb import get_table
|
||||
|
||||
tbl = await get_table(Episode.TABLE_NAME, Episode)
|
||||
await tbl.optimize()
|
||||
|
||||
where = "owner_id = 'alice' AND owner_type = 'user'"
|
||||
cands = await _recaller().sparse_recall("painting", where, limit=10)
|
||||
assert cands, "single informative token must recall the matching episode"
|
||||
assert cands[0].id == "alice_ep_2"
|
||||
|
||||
|
||||
async def test_or_semantics_empty_query_returns_empty() -> None:
|
||||
"""Tokenisation yields nothing → recall returns ``[]`` without hitting LanceDB."""
|
||||
cands = await _recaller().sparse_recall(" ", "owner_id = 'alice'", limit=10)
|
||||
assert cands == []
|
||||
128
tests/unit/test_memory/test_search/test_recall_profile.py
Normal file
128
tests/unit/test_memory/test_search/test_recall_profile.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Real-LanceDB tests for ``ProfileRecaller`` — KV-by-owner fetch.
|
||||
|
||||
Profile recall has no query / no ranking: ``fetch(owner_id)`` returns
|
||||
the at-most-one row keyed by ``id = owner_id``. These tests exercise
|
||||
the LanceDB path (no stubs) and the JSON unpacking that turns the
|
||||
``*_json`` columns back into the DTO's ``profile_data`` mapping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.infra.persistence.lancedb import (
|
||||
UserProfile,
|
||||
lancedb_manager,
|
||||
user_profile_repo,
|
||||
)
|
||||
from everos.memory.search.recall.profile import ProfileRecaller
|
||||
|
||||
|
||||
def _profile_row(
|
||||
*,
|
||||
owner_id: str,
|
||||
summary: str = "summary text",
|
||||
explicit_info: list | None = None,
|
||||
implicit_traits: list | None = None,
|
||||
profile_timestamp_ms: int = 1_700_000_000_000,
|
||||
) -> UserProfile:
|
||||
return UserProfile(
|
||||
id=owner_id,
|
||||
owner_id=owner_id,
|
||||
owner_type="user",
|
||||
summary=summary,
|
||||
explicit_info_json=json.dumps(explicit_info or [], ensure_ascii=False),
|
||||
implicit_traits_json=json.dumps(implicit_traits or [], ensure_ascii=False),
|
||||
profile_timestamp_ms=profile_timestamp_ms,
|
||||
md_path=f"users/{owner_id}/user.md",
|
||||
content_sha256="x" * 64,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _reset(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("EVEROS_MEMORY__ROOT", str(tmp_path))
|
||||
lancedb_manager._conn = None
|
||||
lancedb_manager._tables.clear()
|
||||
yield
|
||||
await lancedb_manager.dispose_connection()
|
||||
|
||||
|
||||
async def test_fetch_returns_dto_when_row_exists() -> None:
|
||||
await user_profile_repo.upsert(
|
||||
[
|
||||
_profile_row(
|
||||
owner_id="u_alice",
|
||||
summary="Alice likes long hikes.",
|
||||
explicit_info=[{"fact": "lives in tokyo"}],
|
||||
implicit_traits=[{"trait": "introverted"}],
|
||||
profile_timestamp_ms=1_700_000_001_000,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
items = await ProfileRecaller().fetch("u_alice")
|
||||
assert len(items) == 1
|
||||
item = items[0]
|
||||
assert item.id == "u_alice"
|
||||
assert item.user_id == "u_alice"
|
||||
assert item.score is None
|
||||
# JSON columns are decoded back to live Python on the way out.
|
||||
assert item.profile_data["summary"] == "Alice likes long hikes."
|
||||
assert item.profile_data["explicit_info"] == [{"fact": "lives in tokyo"}]
|
||||
assert item.profile_data["implicit_traits"] == [{"trait": "introverted"}]
|
||||
assert item.profile_data["profile_timestamp_ms"] == 1_700_000_001_000
|
||||
|
||||
|
||||
async def test_fetch_returns_empty_when_row_missing() -> None:
|
||||
items = await ProfileRecaller().fetch("u_cold_start")
|
||||
assert items == []
|
||||
|
||||
|
||||
async def test_fetch_returns_empty_for_blank_owner() -> None:
|
||||
"""Blank ``owner_id`` short-circuits — never hit LanceDB with an
|
||||
empty-string PK (which would otherwise return any row whose id was
|
||||
persisted as the empty string)."""
|
||||
items = await ProfileRecaller().fetch("")
|
||||
assert items == []
|
||||
|
||||
|
||||
async def test_fetch_isolates_by_owner() -> None:
|
||||
await user_profile_repo.upsert(
|
||||
[
|
||||
_profile_row(owner_id="u_alice", summary="Alice"),
|
||||
_profile_row(owner_id="u_bob", summary="Bob"),
|
||||
]
|
||||
)
|
||||
bob_items = await ProfileRecaller().fetch("u_bob")
|
||||
assert len(bob_items) == 1
|
||||
assert bob_items[0].profile_data["summary"] == "Bob"
|
||||
|
||||
|
||||
async def test_fetch_tolerates_malformed_json_columns() -> None:
|
||||
"""A column with corrupted JSON should not blow up the recall path —
|
||||
the bucket falls back to ``[]`` and the rest of the DTO survives."""
|
||||
await user_profile_repo.upsert(
|
||||
[
|
||||
UserProfile(
|
||||
id="u_broken",
|
||||
owner_id="u_broken",
|
||||
owner_type="user",
|
||||
summary="ok",
|
||||
explicit_info_json="{not valid json",
|
||||
implicit_traits_json="[]",
|
||||
profile_timestamp_ms=0,
|
||||
md_path="users/u_broken/user.md",
|
||||
content_sha256="y" * 64,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
items = await ProfileRecaller().fetch("u_broken")
|
||||
assert len(items) == 1
|
||||
assert items[0].profile_data["explicit_info"] == []
|
||||
assert items[0].profile_data["implicit_traits"] == []
|
||||
assert items[0].profile_data["summary"] == "ok"
|
||||
214
tests/unit/test_memory/test_search/test_shaper.py
Normal file
214
tests/unit/test_memory/test_search/test_shaper.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Unit tests for ``memory.search.shaper``.
|
||||
|
||||
Tests are pure: no LanceDB, no everalgo, just dataclass-in / DTO-out.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
|
||||
from everalgo.types import Candidate, ScoredItem
|
||||
|
||||
from everos.memory.search.shaper import (
|
||||
reshape_hybrid_output,
|
||||
shape_agent_case_from_candidate,
|
||||
shape_agent_skill_from_candidate,
|
||||
shape_atomic_fact_from_candidate,
|
||||
shape_episode_from_candidate,
|
||||
)
|
||||
|
||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts(year: int = 2026) -> _dt.datetime:
|
||||
return _dt.datetime(year, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _episode_candidate(*, id: str = "alice_ep_1", score: float = 0.9) -> Candidate:
|
||||
return Candidate(
|
||||
id=id,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "sess_a",
|
||||
"timestamp": _ts(),
|
||||
"sender_ids": ["alice", "assistant_1"],
|
||||
"subject": "Coffee chat",
|
||||
"summary": "Discussed coffee preferences.",
|
||||
"episode": "Alice said she prefers oat milk.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _agent_case_candidate() -> Candidate:
|
||||
return Candidate(
|
||||
id="agent_a_case_1",
|
||||
score=0.8,
|
||||
source="keyword",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"session_id": "sess_a",
|
||||
"timestamp": _ts(),
|
||||
"task_intent": "Draft a follow-up email",
|
||||
"approach": "1. summarise...",
|
||||
"quality_score": 0.92,
|
||||
"key_insight": "User prefers brief tone",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _agent_skill_candidate() -> Candidate:
|
||||
return Candidate(
|
||||
id="agent_a_skill_1",
|
||||
score=0.7,
|
||||
source="keyword",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"name": "contract_redline",
|
||||
"description": "Spot risky clauses",
|
||||
"content": "Step 1: ...",
|
||||
"confidence": 0.9,
|
||||
"maturity_score": 0.5,
|
||||
"source_case_ids": ["agent_a_case_1"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── Episode shaping ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shape_episode_basic() -> None:
|
||||
item = shape_episode_from_candidate(_episode_candidate())
|
||||
assert item is not None
|
||||
assert item.id == "alice_ep_1"
|
||||
assert item.user_id == "alice"
|
||||
assert item.type == "Conversation"
|
||||
assert item.score == 0.9
|
||||
assert item.atomic_facts == []
|
||||
assert item.sender_ids == ["alice", "assistant_1"]
|
||||
|
||||
|
||||
def test_shape_episode_drops_when_owner_type_wrong() -> None:
|
||||
cand = _episode_candidate()
|
||||
cand.metadata["owner_type"] = "agent"
|
||||
assert shape_episode_from_candidate(cand) is None
|
||||
|
||||
|
||||
def test_shape_episode_drops_when_timestamp_missing() -> None:
|
||||
cand = _episode_candidate()
|
||||
del cand.metadata["timestamp"]
|
||||
assert shape_episode_from_candidate(cand) is None
|
||||
|
||||
|
||||
def test_shape_episode_attaches_facts() -> None:
|
||||
facts = [
|
||||
shape_atomic_fact_from_candidate(
|
||||
Candidate(
|
||||
id="f1",
|
||||
score=0.5,
|
||||
source="other",
|
||||
metadata={"fact": "Alice prefers oat milk"},
|
||||
)
|
||||
)
|
||||
]
|
||||
item = shape_episode_from_candidate(_episode_candidate(), atomic_facts=facts)
|
||||
assert item is not None
|
||||
assert len(item.atomic_facts) == 1
|
||||
assert item.atomic_facts[0].content == "Alice prefers oat milk"
|
||||
|
||||
|
||||
# ── Agent case / skill shaping ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shape_agent_case_basic() -> None:
|
||||
item = shape_agent_case_from_candidate(_agent_case_candidate())
|
||||
assert item is not None
|
||||
assert item.agent_id == "agent_a"
|
||||
assert item.task_intent == "Draft a follow-up email"
|
||||
assert item.quality_score == 0.92
|
||||
assert item.key_insight == "User prefers brief tone"
|
||||
|
||||
|
||||
def test_shape_agent_case_drops_when_owner_type_wrong() -> None:
|
||||
cand = _agent_case_candidate()
|
||||
cand.metadata["owner_type"] = "user"
|
||||
assert shape_agent_case_from_candidate(cand) is None
|
||||
|
||||
|
||||
def test_shape_agent_skill_basic() -> None:
|
||||
item = shape_agent_skill_from_candidate(_agent_skill_candidate())
|
||||
assert item is not None
|
||||
assert item.name == "contract_redline"
|
||||
assert item.maturity_score == 0.5
|
||||
assert item.source_case_ids == ["agent_a_case_1"]
|
||||
|
||||
|
||||
# ── Hybrid reshape ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _scored_episode(eid: str, score: float) -> ScoredItem:
|
||||
return ScoredItem(
|
||||
id=eid,
|
||||
score=score,
|
||||
item_type="episode",
|
||||
metadata={
|
||||
"owner_id": "alice",
|
||||
"owner_type": "user",
|
||||
"session_id": "s1",
|
||||
"timestamp": _ts(),
|
||||
"sender_ids": ["alice"],
|
||||
"subject": "subj",
|
||||
"summary": "summ",
|
||||
"episode": "body",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _scored_fact(fid: str, parent: str, score: float) -> ScoredItem:
|
||||
return ScoredItem(
|
||||
id=fid,
|
||||
score=score,
|
||||
item_type="atomic_fact",
|
||||
parent_episode_id=parent,
|
||||
metadata={"fact": f"fact text {fid}"},
|
||||
)
|
||||
|
||||
|
||||
def test_reshape_hybrid_nests_facts_under_kept_episode() -> None:
|
||||
scored = [
|
||||
_scored_episode("ep_1", 0.9),
|
||||
_scored_fact("f_1", "ep_1", 0.95),
|
||||
_scored_fact("f_2", "ep_1", 0.85),
|
||||
]
|
||||
out = reshape_hybrid_output(scored, episode_pool={})
|
||||
assert len(out) == 1
|
||||
assert out[0].id == "ep_1"
|
||||
# Facts sorted descending by score.
|
||||
assert [f.id for f in out[0].atomic_facts] == ["f_1", "f_2"]
|
||||
|
||||
|
||||
def test_reshape_hybrid_backfills_evicted_episode_from_pool() -> None:
|
||||
# Episode ep_2 was evicted (only facts present),
|
||||
# but it is in episode_pool — should be restored as a result.
|
||||
scored = [
|
||||
_scored_episode("ep_1", 0.7),
|
||||
_scored_fact("f_a", "ep_2", 0.95),
|
||||
]
|
||||
pool_episode = _episode_candidate(id="ep_2", score=0.0)
|
||||
out = reshape_hybrid_output(scored, episode_pool={"ep_2": pool_episode})
|
||||
assert len(out) == 2
|
||||
# Output sorted by score descending — ep_2 takes fact's max score (0.95).
|
||||
assert out[0].id == "ep_2"
|
||||
assert out[0].score == 0.95
|
||||
assert len(out[0].atomic_facts) == 1
|
||||
assert out[1].id == "ep_1"
|
||||
|
||||
|
||||
def test_reshape_hybrid_drops_orphan_facts_with_no_pool_parent() -> None:
|
||||
scored = [_scored_fact("f_x", "ep_missing", 0.5)]
|
||||
out = reshape_hybrid_output(scored, episode_pool={})
|
||||
assert out == []
|
||||
154
tests/unit/test_memory/test_search/test_skill_hybrid.py
Normal file
154
tests/unit/test_memory/test_search/test_skill_hybrid.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Unit tests for ``memory.search.skill_hybrid``.
|
||||
|
||||
skill_hybrid is the **cross-encoder lane** for skill HYBRID retrieval.
|
||||
The LLM-rerank lane lives in ``SearchManager._search_agent_skills`` and
|
||||
goes through ``everalgo.rank.skill.arank`` directly — covered by
|
||||
``test_manager`` tests instead.
|
||||
|
||||
Covered surfaces:
|
||||
- ``search_agent_skills_hybrid`` (public function, MagicMock stubs)
|
||||
- ``_fuse``, ``_cross_encoder_rerank``, ``_shape_results``
|
||||
(via integration through the public function)
|
||||
|
||||
All I/O (reranker) is injected via MagicMock / stub objects. No LanceDB
|
||||
or network calls are made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from everalgo.types import Candidate
|
||||
|
||||
from everos.memory.search.callbacks import _SKILL_RERANK_INSTRUCTION
|
||||
from everos.memory.search.dto import SearchAgentSkillItem
|
||||
from everos.memory.search.skill_hybrid import search_agent_skills_hybrid
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts() -> _dt.datetime:
|
||||
return _dt.datetime(2026, 1, 1, tzinfo=_dt.UTC)
|
||||
|
||||
|
||||
def _skill_candidate(
|
||||
sid: str,
|
||||
score: float = 0.8,
|
||||
name: str | None = None,
|
||||
) -> Candidate:
|
||||
label = name or f"skill_{sid}"
|
||||
return Candidate(
|
||||
id=sid,
|
||||
score=score,
|
||||
source="vector",
|
||||
metadata={
|
||||
"owner_id": "agent_a",
|
||||
"owner_type": "agent",
|
||||
"name": label,
|
||||
"description": f"desc {sid}",
|
||||
"content": f"content {sid}",
|
||||
"confidence": 0.9,
|
||||
"maturity_score": 0.6,
|
||||
"source_case_ids": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_reranker(candidates: list[Candidate]) -> MagicMock:
|
||||
"""Stub reranker that returns identity-reranked results in the same order."""
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, index: int, score: float) -> None:
|
||||
self.index = index
|
||||
self.score = score
|
||||
|
||||
reranker = MagicMock()
|
||||
# provider.rerank returns a list of result objects with index + score
|
||||
reranker.rerank = AsyncMock(
|
||||
return_value=[_FakeResult(i, c.score) for i, c in enumerate(candidates)]
|
||||
)
|
||||
return reranker
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchAgentSkillsHybridRerank:
|
||||
"""Cross-encoder rerank path."""
|
||||
|
||||
async def test_returns_shaped_items_up_to_top_k(self) -> None:
|
||||
"""rrf + rerank produces at most top_k SearchAgentSkillItem objects."""
|
||||
c1 = _skill_candidate("s1", score=0.9)
|
||||
c2 = _skill_candidate("s2", score=0.8)
|
||||
c3 = _skill_candidate("s3", score=0.7)
|
||||
|
||||
reranker = _make_reranker([c1, c2, c3])
|
||||
|
||||
result = await search_agent_skills_hybrid(
|
||||
"what skill handles auth?",
|
||||
sparse=[c1, c2, c3],
|
||||
dense=[c1, c2, c3],
|
||||
reranker=reranker,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(item, SearchAgentSkillItem) for item in result)
|
||||
assert result[0].id == "s1"
|
||||
assert result[1].id == "s2"
|
||||
|
||||
async def test_reranker_receives_skill_instruction_and_shaped_passages(
|
||||
self,
|
||||
) -> None:
|
||||
"""Reranker must see the skill-specific instruction and
|
||||
``"Agent Skill: {name} - {description}"`` passage shape — matches
|
||||
the everosos-opensource contract for skill rerank.
|
||||
"""
|
||||
c1 = _skill_candidate("s1", name="auth_middleware_refactor")
|
||||
c2 = _skill_candidate("s2", name="provider_lookup_split")
|
||||
|
||||
reranker = _make_reranker([c1, c2])
|
||||
|
||||
await search_agent_skills_hybrid(
|
||||
"how to split auth?",
|
||||
sparse=[c1],
|
||||
dense=[c1, c2],
|
||||
reranker=reranker,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
reranker.rerank.assert_awaited_once()
|
||||
call = reranker.rerank.await_args
|
||||
assert call is not None
|
||||
positional = call.args
|
||||
kw = call.kwargs
|
||||
# Signature: rerank(query, passages, *, instruction=...)
|
||||
assert positional[0] == "how to split auth?"
|
||||
passages = positional[1]
|
||||
assert passages == [
|
||||
"Agent Skill: auth_middleware_refactor - desc s1",
|
||||
"Agent Skill: provider_lookup_split - desc s2",
|
||||
]
|
||||
assert kw["instruction"] == _SKILL_RERANK_INSTRUCTION
|
||||
|
||||
|
||||
class TestSearchAgentSkillsHybridEmpty:
|
||||
"""Empty input / degenerate cases."""
|
||||
|
||||
async def test_empty_sparse_and_dense_returns_empty_list(self) -> None:
|
||||
"""No candidates → no items, no errors."""
|
||||
reranker = MagicMock()
|
||||
reranker.rerank = AsyncMock(return_value=[])
|
||||
|
||||
result = await search_agent_skills_hybrid(
|
||||
"query",
|
||||
sparse=[],
|
||||
dense=[],
|
||||
reranker=reranker,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
# reranker.rerank must not be called when fused list is empty
|
||||
reranker.rerank.assert_not_called()
|
||||
0
tests/unit/test_memory/test_strategies/__init__.py
Normal file
0
tests/unit/test_memory/test_strategies/__init__.py
Normal file
@ -0,0 +1,323 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
import importlib
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import structlog.testing
|
||||
from everalgo.types import AgentCase, ChatMessage, MemCell
|
||||
|
||||
from everos.core.persistence import EntryId
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import AgentCaseExtracted, AgentPipelineStarted
|
||||
from everos.memory.strategies.extract_agent_case import extract_agent_case
|
||||
|
||||
|
||||
def _fake_eid() -> EntryId:
|
||||
return EntryId(prefix="ac", date=_dt.date(2026, 5, 17), seq=1)
|
||||
|
||||
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_agent_case")
|
||||
|
||||
|
||||
def _agent_memcell() -> MemCell:
|
||||
return MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="please summarise the doc",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="assistant",
|
||||
content="here's the summary ...",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="agent_42",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_001_000,
|
||||
)
|
||||
|
||||
|
||||
def _event() -> AgentPipelineStarted:
|
||||
return AgentPipelineStarted(
|
||||
memcell_id="mc_a", session_id="s1", memcell=_agent_memcell()
|
||||
)
|
||||
|
||||
|
||||
def _algo_case(
|
||||
*,
|
||||
task_intent: str = "summarise doc",
|
||||
approach: str = "read + condense",
|
||||
quality_score: float = 0.8,
|
||||
key_insight: str = "",
|
||||
) -> AgentCase:
|
||||
return AgentCase(
|
||||
id=uuid.uuid4().hex,
|
||||
timestamp=1_700_000_001_000,
|
||||
task_intent=task_intent,
|
||||
approach=approach,
|
||||
quality_score=quality_score,
|
||||
key_insight=key_insight,
|
||||
)
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = extract_agent_case._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "extract_agent_case"
|
||||
assert AgentPipelineStarted in meta.trigger.on
|
||||
assert meta.emits == frozenset({AgentCaseExtracted})
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
async def test_writes_md_when_algo_returns_a_case(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
case = _algo_case(quality_score=0.9, key_insight="batch-then-summarise")
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[case])
|
||||
mock_wcls.return_value.append_entry = AsyncMock(return_value=_fake_eid())
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
await extract_agent_case(_event(), ctx)
|
||||
|
||||
assert mock_cls.return_value.aextract.await_count == 1
|
||||
assert mock_wcls.return_value.append_entry.call_count == 1
|
||||
_, kwargs = mock_wcls.return_value.append_entry.call_args
|
||||
assert kwargs["inline"]["owner_id"] == "agent_42"
|
||||
assert kwargs["inline"]["session_id"] == "s1"
|
||||
assert kwargs["inline"]["parent_type"] == "memcell"
|
||||
assert kwargs["inline"]["parent_id"] == "mc_a"
|
||||
assert kwargs["inline"]["quality_score"] == 0.9
|
||||
assert kwargs["sections"] == {
|
||||
"TaskIntent": "summarise doc",
|
||||
"Approach": "read + condense",
|
||||
"KeyInsight": "batch-then-summarise",
|
||||
}
|
||||
# Chain emit: AgentCaseExtracted fires after the md write.
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, AgentCaseExtracted)]
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].memcell_id == "mc_a"
|
||||
assert emitted[0].case_entry_id == _fake_eid().format()
|
||||
assert emitted[0].task_intent == "summarise doc"
|
||||
assert emitted[0].quality_score == 0.9
|
||||
assert emitted[0].case_timestamp_ms == 1_700_000_001_000
|
||||
assert emitted[0].agent_id == "agent_42"
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "agent_case_extracted"]
|
||||
assert matching, "expected agent_case_extracted log line"
|
||||
assert matching[0]["owner_ids"] == ["agent_42"]
|
||||
assert matching[0]["fanout"] == 1
|
||||
assert matching[0]["quality_score"] == 0.9
|
||||
|
||||
|
||||
async def test_fans_out_per_assistant_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""One LLM call, then md write + emit per distinct assistant sender.
|
||||
|
||||
Case text is third-person (``the agent did X``) so the same body
|
||||
is a valid reference experience for every assistant sender that
|
||||
participated in the trajectory. Verifies: aextract is called
|
||||
exactly once, md is written once per agent, and an
|
||||
``AgentCaseExtracted`` event fires per agent so the downstream
|
||||
skill clustering chain runs in each agent's own scope.
|
||||
"""
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
multi_agent_cell = MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="please dispatch",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="assistant",
|
||||
content="dispatching to specialist",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="agent_lead",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m3",
|
||||
role="assistant",
|
||||
content="here is the answer",
|
||||
timestamp=1_700_000_002_000,
|
||||
sender_id="agent_specialist",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_002_000,
|
||||
)
|
||||
event = AgentPipelineStarted(
|
||||
memcell_id="mc_multi", session_id="s_multi", memcell=multi_agent_cell
|
||||
)
|
||||
case = _algo_case(quality_score=0.85)
|
||||
|
||||
# writer.append_entry returns a different entry_id per call so the
|
||||
# emitted events carry per-agent entry_ids (cascade keys off owner+entry).
|
||||
eids = [
|
||||
EntryId(prefix="ac", date=_dt.date(2026, 5, 17), seq=1),
|
||||
EntryId(prefix="ac", date=_dt.date(2026, 5, 17), seq=2),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[case])
|
||||
mock_wcls.return_value.append_entry = AsyncMock(side_effect=eids)
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
await extract_agent_case(event, ctx)
|
||||
|
||||
# Exactly one LLM call regardless of agent count.
|
||||
assert mock_cls.return_value.aextract.await_count == 1
|
||||
|
||||
# Two md writes (one per distinct assistant sender), in first-seen order.
|
||||
assert mock_wcls.return_value.append_entry.call_count == 2
|
||||
owners_written = [
|
||||
call.kwargs["inline"]["owner_id"]
|
||||
for call in mock_wcls.return_value.append_entry.call_args_list
|
||||
]
|
||||
assert owners_written == ["agent_lead", "agent_specialist"]
|
||||
|
||||
# Two emits, each tagged with its own agent_id + per-agent entry_id.
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, AgentCaseExtracted)]
|
||||
assert len(emitted) == 2
|
||||
assert [e.agent_id for e in emitted] == ["agent_lead", "agent_specialist"]
|
||||
assert [e.case_entry_id for e in emitted] == [eids[0].format(), eids[1].format()]
|
||||
# Same task body / quality across the fan-out (broadcast semantics).
|
||||
assert {e.task_intent for e in emitted} == {"summarise doc"}
|
||||
assert {e.quality_score for e in emitted} == {0.85}
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "agent_case_extracted"]
|
||||
assert matching, "expected agent_case_extracted log line"
|
||||
assert matching[0]["owner_ids"] == ["agent_lead", "agent_specialist"]
|
||||
assert matching[0]["fanout"] == 2
|
||||
|
||||
|
||||
async def test_omits_key_insight_section_when_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
case = _algo_case(key_insight="")
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseWriter"
|
||||
) as mock_wcls,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[case])
|
||||
mock_wcls.return_value.append_entry = AsyncMock(return_value=_fake_eid())
|
||||
await extract_agent_case(_event(), FakeStrategyContext())
|
||||
|
||||
_, kwargs = mock_wcls.return_value.append_entry.call_args
|
||||
assert "KeyInsight" not in kwargs["sections"]
|
||||
assert kwargs["sections"]["TaskIntent"] == "summarise doc"
|
||||
|
||||
|
||||
async def test_skips_when_algo_returns_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Algo pre-filter rejected the cell — no md written, log a noop line."""
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[])
|
||||
mock_wcls.return_value.append_entry = AsyncMock(return_value=_fake_eid())
|
||||
await extract_agent_case(_event(), FakeStrategyContext())
|
||||
|
||||
mock_wcls.return_value.append_entry.assert_not_called()
|
||||
matching = [e for e in captured if e.get("event") == "agent_case_skipped_by_algo"]
|
||||
assert matching, "expected agent_case_skipped_by_algo log line"
|
||||
|
||||
|
||||
async def test_skips_when_no_assistant_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""No assistant in the cell → no agent_id can be inferred; algo not called."""
|
||||
user_only = MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hi",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
event = AgentPipelineStarted(memcell_id="mc_b", session_id="s1", memcell=user_only)
|
||||
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[])
|
||||
mock_wcls.return_value.append_entry = AsyncMock(return_value=_fake_eid())
|
||||
await extract_agent_case(event, FakeStrategyContext())
|
||||
|
||||
# Algo extractor must not be invoked at all when there's no agent.
|
||||
mock_cls.return_value.aextract.assert_not_called()
|
||||
mock_wcls.return_value.append_entry.assert_not_called()
|
||||
matching = [
|
||||
e for e in captured if e.get("event") == "agent_case_skipped_no_assistant"
|
||||
]
|
||||
assert matching, "expected agent_case_skipped_no_assistant log line"
|
||||
@ -0,0 +1,584 @@
|
||||
"""Tests for :func:`extract_agent_skill`.
|
||||
|
||||
Mocked seams: ``cluster_repo`` (sqlite), ``agent_case_repo`` /
|
||||
``agent_skill_repo`` (LanceDB), ``get_embedder`` (component),
|
||||
``AgentSkillExtractor`` (algo), ``AgentSkillWriter`` (md). Each
|
||||
retry-class exception (cluster missing / case-not-indexed) bubbles up so
|
||||
OME's ``max_retries`` machinery catches the race instead of the strategy
|
||||
implementing its own backoff loop.
|
||||
|
||||
LanceDB repo behaviour itself (predicate isolation, cosine ranking,
|
||||
``_distance`` stripping) lives under
|
||||
``tests/unit/test_infra/test_lancedb/test_repos/``; strategy tests only
|
||||
verify routing decisions and orchestration glue.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as _dt
|
||||
import importlib
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from everalgo.clustering import Cluster as AlgoCluster
|
||||
from everalgo.types import AgentSkill as AlgoAgentSkill
|
||||
|
||||
from everos.component.embedding import (
|
||||
EmbeddingError,
|
||||
EmbeddingNotConfiguredError,
|
||||
)
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import SkillClusterUpdated
|
||||
from everos.memory.strategies._partition_locks import _reset_for_tests
|
||||
from everos.memory.strategies.extract_agent_skill import (
|
||||
MAX_SKILLS_IN_PROMPT,
|
||||
MAX_SUPPORTING_CASES,
|
||||
_CaseNotYetIndexedError,
|
||||
_ClusterMissingError,
|
||||
_collect_supporting_entry_ids,
|
||||
_resolve_query_vector,
|
||||
_select_existing_skills,
|
||||
_select_supporting_cases,
|
||||
extract_agent_skill,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_partition_locks() -> None:
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def _event(
|
||||
*,
|
||||
cluster_id: str = "cl_xxxxxxxxxxx1",
|
||||
case_entry_id: str = "ac_20260517_0001",
|
||||
agent_id: str = "agent_42",
|
||||
) -> SkillClusterUpdated:
|
||||
return SkillClusterUpdated(
|
||||
case_entry_id=case_entry_id,
|
||||
cluster_id=cluster_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
|
||||
def _algo_cluster(
|
||||
*,
|
||||
cluster_id: str = "cl_xxxxxxxxxxx1",
|
||||
members: list[str] | None = None,
|
||||
) -> AlgoCluster:
|
||||
return AlgoCluster(
|
||||
id=cluster_id,
|
||||
centroid=np.zeros(1024, dtype=np.float32),
|
||||
count=len(members or ["ac_20260517_0001"]),
|
||||
last_ts=1_700_000_000_000,
|
||||
preview=[],
|
||||
members=members or ["ac_20260517_0001"],
|
||||
)
|
||||
|
||||
|
||||
def _lance_case(
|
||||
entry_id: str,
|
||||
*,
|
||||
quality_score: float = 0.8,
|
||||
timestamp: _dt.datetime | None = None,
|
||||
vector: list[float] | None = None,
|
||||
task_intent: str | None = None,
|
||||
) -> MagicMock:
|
||||
"""Stand-in for a LanceDB AgentCase row (only fields the strategy reads)."""
|
||||
case = MagicMock()
|
||||
case.entry_id = entry_id
|
||||
case.timestamp = timestamp or _dt.datetime(2026, 5, 17, tzinfo=_dt.UTC)
|
||||
case.task_intent = (
|
||||
task_intent if task_intent is not None else f"intent of {entry_id}"
|
||||
)
|
||||
case.approach = f"approach of {entry_id}"
|
||||
case.quality_score = quality_score
|
||||
case.key_insight = ""
|
||||
case.vector = vector or []
|
||||
return case
|
||||
|
||||
|
||||
def _lance_skill(
|
||||
*,
|
||||
name: str = "old_skill",
|
||||
cluster_id: str = "cl_xxxxxxxxxxx1",
|
||||
source_case_ids: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
skill = MagicMock()
|
||||
skill.id = f"agent_42_{name}"
|
||||
skill.cluster_id = cluster_id
|
||||
skill.name = name
|
||||
skill.description = f"desc {name}"
|
||||
skill.content = f"content {name}"
|
||||
skill.confidence = 0.5
|
||||
skill.maturity_score = 0.5
|
||||
skill.source_case_ids = source_case_ids or []
|
||||
return skill
|
||||
|
||||
|
||||
def _algo_skill(name: str = "summarise_doc") -> AlgoAgentSkill:
|
||||
return AlgoAgentSkill(
|
||||
id="dummyuuid",
|
||||
cluster_id="", # caller will post-stamp
|
||||
name=name,
|
||||
description=f"how to {name}",
|
||||
content="full body of the skill",
|
||||
confidence=0.7,
|
||||
maturity_score=0.5,
|
||||
source_case_ids=["ac_20260517_0001"],
|
||||
)
|
||||
|
||||
|
||||
# ── strategy meta + retry-class errors ───────────────────────────────────
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = extract_agent_skill._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "extract_agent_skill"
|
||||
assert SkillClusterUpdated in meta.trigger.on
|
||||
assert meta.emits == frozenset()
|
||||
assert meta.max_retries == 3
|
||||
|
||||
|
||||
async def test_raises_when_cluster_missing_for_retry() -> None:
|
||||
"""No cluster row yet — OME will retry the run."""
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.cluster_repo"
|
||||
) as mock_repo:
|
||||
mock_repo.get_with_members = AsyncMock(return_value=None)
|
||||
with pytest.raises(_ClusterMissingError):
|
||||
await extract_agent_skill(_event(), FakeStrategyContext())
|
||||
|
||||
|
||||
async def test_raises_when_target_case_not_yet_in_lancedb() -> None:
|
||||
"""LanceDB has not yet indexed the freshly-written case — let OME retry."""
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo,
|
||||
):
|
||||
mock_cluster_repo.get_with_members = AsyncMock(return_value=_algo_cluster())
|
||||
mock_case_repo.find_by_owner_entry = AsyncMock(return_value=None)
|
||||
with pytest.raises(_CaseNotYetIndexedError):
|
||||
await extract_agent_skill(_event(), FakeStrategyContext())
|
||||
|
||||
|
||||
# ── end-to-end orchestration (mocked) ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_and_persists_with_cluster_id_stamped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""End-to-end (mocked): extractor emits skills → writer stamps cluster_id."""
|
||||
target = _lance_case("ac_20260517_0001", vector=[0.1] * 1024)
|
||||
supporting = [_lance_case("ac_20260517_0000")]
|
||||
existing = [_lance_skill(name="old_skill", source_case_ids=["ac_20260517_0000"])]
|
||||
emitted = [_algo_skill(name="summarise_doc"), _algo_skill(name="batch_then_synth")]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_skill_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.AgentSkillExtractor"
|
||||
) as mock_extractor_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.AgentSkillWriter"
|
||||
) as mock_writer_cls,
|
||||
):
|
||||
mock_cluster_repo.get_with_members = AsyncMock(
|
||||
return_value=_algo_cluster(members=["ac_20260517_0000", "ac_20260517_0001"])
|
||||
)
|
||||
mock_case_repo.find_by_owner_entry = AsyncMock(return_value=target)
|
||||
mock_case_repo.find_by_owner_entries = AsyncMock(return_value=supporting)
|
||||
# Small cluster path: count ≤ K → scalar fetch returns existing.
|
||||
mock_skill_repo.count_in_cluster = AsyncMock(return_value=len(existing))
|
||||
mock_skill_repo.find_in_cluster = AsyncMock(return_value=existing)
|
||||
mock_extractor_cls.return_value.aextract = AsyncMock(return_value=emitted)
|
||||
mock_writer_cls.return_value.write_main = AsyncMock(return_value=None)
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_agent_skill")
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
|
||||
await extract_agent_skill(_event(), FakeStrategyContext())
|
||||
|
||||
extractor_call = mock_extractor_cls.return_value.aextract.call_args
|
||||
target_arg = extractor_call.args[0]
|
||||
assert target_arg.id == "ac_20260517_0001"
|
||||
assert target_arg.task_intent == "intent of ac_20260517_0001"
|
||||
assert [s.name for s in extractor_call.kwargs["existing_relevant_skills"]] == [
|
||||
"old_skill"
|
||||
]
|
||||
assert [c.id for c in extractor_call.kwargs["supporting_cases"]] == [
|
||||
"ac_20260517_0000"
|
||||
]
|
||||
|
||||
write_calls = mock_writer_cls.return_value.write_main.call_args_list
|
||||
assert len(write_calls) == 2
|
||||
for call, expected in zip(write_calls, emitted, strict=True):
|
||||
agent_id_arg, skill_name_arg = call.args
|
||||
fm = call.kwargs["frontmatter"]
|
||||
assert agent_id_arg == "agent_42"
|
||||
assert skill_name_arg == expected.name
|
||||
assert fm.cluster_id == "cl_xxxxxxxxxxx1"
|
||||
assert fm.name == expected.name
|
||||
assert fm.confidence == expected.confidence
|
||||
assert call.kwargs["body"] == expected.content
|
||||
|
||||
|
||||
# ── _select_existing_skills routing (cluster size × vector availability) ─
|
||||
|
||||
|
||||
async def test_select_existing_skills_small_cluster_uses_scalar_fetch() -> None:
|
||||
"""``total ≤ K`` short-circuits — no ranking needed for fully-inclusive set."""
|
||||
target = _lance_case("ac_001", vector=[0.5] * 1024)
|
||||
skills = [_lance_skill(name=f"s{i}") for i in range(3)]
|
||||
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_repo:
|
||||
mock_repo.count_in_cluster = AsyncMock(return_value=3)
|
||||
mock_repo.find_in_cluster = AsyncMock(return_value=skills)
|
||||
mock_repo.find_topk_relevant_in_cluster = AsyncMock()
|
||||
|
||||
got = await _select_existing_skills(
|
||||
agent_id="a", cluster_id="cl_x", target=target
|
||||
)
|
||||
|
||||
assert got == skills
|
||||
mock_repo.find_topk_relevant_in_cluster.assert_not_awaited()
|
||||
mock_repo.find_in_cluster.assert_awaited_once_with(
|
||||
owner_id="a", cluster_id="cl_x", limit=MAX_SKILLS_IN_PROMPT
|
||||
)
|
||||
|
||||
|
||||
async def test_select_existing_skills_large_cluster_with_vector_uses_topk() -> None:
|
||||
"""``total > K`` and target carries vector → cosine top-K path."""
|
||||
target = _lance_case("ac_001", vector=[0.5] * 1024)
|
||||
topk_skills = [_lance_skill(name=f"s{i}") for i in range(MAX_SKILLS_IN_PROMPT)]
|
||||
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_repo:
|
||||
mock_repo.count_in_cluster = AsyncMock(return_value=MAX_SKILLS_IN_PROMPT + 5)
|
||||
mock_repo.find_topk_relevant_in_cluster = AsyncMock(return_value=topk_skills)
|
||||
mock_repo.find_in_cluster = AsyncMock()
|
||||
|
||||
got = await _select_existing_skills(
|
||||
agent_id="a", cluster_id="cl_x", target=target
|
||||
)
|
||||
|
||||
assert got == topk_skills
|
||||
mock_repo.find_in_cluster.assert_not_awaited()
|
||||
call_kwargs = mock_repo.find_topk_relevant_in_cluster.await_args.kwargs
|
||||
assert call_kwargs["query_vector"] == [0.5] * 1024
|
||||
assert call_kwargs["top_k"] == MAX_SKILLS_IN_PROMPT
|
||||
|
||||
|
||||
async def test_select_existing_skills_large_cluster_recomputes_embedding() -> None:
|
||||
"""``total > K`` but case has no vector → re-embed ``task_intent`` on the fly."""
|
||||
target = _lance_case("ac_001", vector=[], task_intent="how to summarise docs")
|
||||
topk_skills = [_lance_skill(name=f"s{i}") for i in range(MAX_SKILLS_IN_PROMPT)]
|
||||
fresh_vec = [0.42] * 1024
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed = AsyncMock(return_value=fresh_vec)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_embedder",
|
||||
return_value=mock_embedder,
|
||||
),
|
||||
):
|
||||
mock_repo.count_in_cluster = AsyncMock(return_value=MAX_SKILLS_IN_PROMPT + 5)
|
||||
mock_repo.find_topk_relevant_in_cluster = AsyncMock(return_value=topk_skills)
|
||||
mock_repo.find_in_cluster = AsyncMock()
|
||||
|
||||
got = await _select_existing_skills(
|
||||
agent_id="a", cluster_id="cl_x", target=target
|
||||
)
|
||||
|
||||
assert got == topk_skills
|
||||
mock_embedder.embed.assert_awaited_once_with("how to summarise docs")
|
||||
call_kwargs = mock_repo.find_topk_relevant_in_cluster.await_args.kwargs
|
||||
assert call_kwargs["query_vector"] == fresh_vec
|
||||
|
||||
|
||||
async def test_select_existing_skills_falls_back_to_scalar_when_embed_fails() -> None:
|
||||
"""``total > K`` + no vector + embedder fails → scalar fetch capped at K."""
|
||||
target = _lance_case("ac_001", vector=[], task_intent="how to summarise docs")
|
||||
scalar_skills = [_lance_skill(name=f"s{i}") for i in range(MAX_SKILLS_IN_PROMPT)]
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed = AsyncMock(side_effect=EmbeddingError("provider down"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_embedder",
|
||||
return_value=mock_embedder,
|
||||
),
|
||||
):
|
||||
mock_repo.count_in_cluster = AsyncMock(return_value=MAX_SKILLS_IN_PROMPT + 5)
|
||||
mock_repo.find_in_cluster = AsyncMock(return_value=scalar_skills)
|
||||
mock_repo.find_topk_relevant_in_cluster = AsyncMock()
|
||||
|
||||
got = await _select_existing_skills(
|
||||
agent_id="a", cluster_id="cl_x", target=target
|
||||
)
|
||||
|
||||
assert got == scalar_skills
|
||||
mock_repo.find_topk_relevant_in_cluster.assert_not_awaited()
|
||||
mock_repo.find_in_cluster.assert_awaited_once_with(
|
||||
owner_id="a", cluster_id="cl_x", limit=MAX_SKILLS_IN_PROMPT
|
||||
)
|
||||
|
||||
|
||||
# ── _resolve_query_vector layered fallback ───────────────────────────────
|
||||
|
||||
|
||||
async def test_resolve_query_vector_prefers_persisted_vector() -> None:
|
||||
"""When ``target.vector`` is set, reuse it; never call the embedder."""
|
||||
target = _lance_case("ac_001", vector=[0.3] * 1024)
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_embedder"
|
||||
) as mock_get_embedder:
|
||||
got = await _resolve_query_vector(target)
|
||||
assert got == [0.3] * 1024
|
||||
mock_get_embedder.assert_not_called()
|
||||
|
||||
|
||||
async def test_resolve_query_vector_returns_empty_when_no_text_either() -> None:
|
||||
"""No persisted vector + no task_intent → ``[]`` (no policy here)."""
|
||||
target = _lance_case("ac_001", vector=[], task_intent="")
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_embedder"
|
||||
) as mock_get_embedder:
|
||||
got = await _resolve_query_vector(target)
|
||||
assert got == []
|
||||
mock_get_embedder.assert_not_called()
|
||||
|
||||
|
||||
async def test_resolve_query_vector_swallows_embedder_not_configured() -> None:
|
||||
"""Missing embedder config is a deployment issue, not a strategy fault."""
|
||||
target = _lance_case("ac_001", vector=[], task_intent="hello")
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed = AsyncMock(
|
||||
side_effect=EmbeddingNotConfiguredError("no api key")
|
||||
)
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_embedder",
|
||||
return_value=mock_embedder,
|
||||
):
|
||||
got = await _resolve_query_vector(target)
|
||||
assert got == []
|
||||
|
||||
|
||||
# ── _select_supporting_cases ranking + cap ───────────────────────────────
|
||||
|
||||
|
||||
async def test_select_supporting_cases_ranks_by_quality_then_timestamp() -> None:
|
||||
"""Hydrated cases sort ``(quality_score desc, timestamp desc)``."""
|
||||
skills = [
|
||||
_lance_skill(name="s1", source_case_ids=["ac_a", "ac_b", "ac_c"]),
|
||||
]
|
||||
case_a = _lance_case(
|
||||
"ac_a",
|
||||
quality_score=0.4,
|
||||
timestamp=_dt.datetime(2026, 5, 1, tzinfo=_dt.UTC),
|
||||
)
|
||||
case_b = _lance_case(
|
||||
"ac_b",
|
||||
quality_score=0.9,
|
||||
timestamp=_dt.datetime(2026, 5, 1, tzinfo=_dt.UTC),
|
||||
)
|
||||
case_c = _lance_case(
|
||||
"ac_c",
|
||||
quality_score=0.9,
|
||||
timestamp=_dt.datetime(2026, 5, 10, tzinfo=_dt.UTC),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo:
|
||||
# Order intentionally scrambled to prove the strategy sorts.
|
||||
mock_case_repo.find_by_owner_entries = AsyncMock(
|
||||
return_value=[case_a, case_b, case_c]
|
||||
)
|
||||
|
||||
got = await _select_supporting_cases(
|
||||
skills,
|
||||
agent_id="a",
|
||||
exclude_entry_id="ac_target",
|
||||
app_id="default",
|
||||
project_id="default",
|
||||
)
|
||||
|
||||
assert [c.entry_id for c in got] == ["ac_c", "ac_b", "ac_a"]
|
||||
|
||||
|
||||
async def test_select_supporting_cases_caps_at_max_supporting() -> None:
|
||||
"""Hydrated set is truncated to ``MAX_SUPPORTING_CASES``."""
|
||||
ids = [f"ac_{i:03d}" for i in range(MAX_SUPPORTING_CASES + 3)]
|
||||
skills = [_lance_skill(name="s1", source_case_ids=ids)]
|
||||
hydrated = [
|
||||
_lance_case(eid, quality_score=0.5 + 0.01 * i) for i, eid in enumerate(ids)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo:
|
||||
mock_case_repo.find_by_owner_entries = AsyncMock(return_value=hydrated)
|
||||
got = await _select_supporting_cases(
|
||||
skills,
|
||||
agent_id="a",
|
||||
exclude_entry_id="ac_target",
|
||||
app_id="default",
|
||||
project_id="default",
|
||||
)
|
||||
|
||||
assert len(got) == MAX_SUPPORTING_CASES
|
||||
|
||||
|
||||
async def test_select_supporting_cases_skips_repo_when_no_lineage_ids() -> None:
|
||||
"""No usable source ids → ``[]`` without a repo round trip."""
|
||||
skills = [_lance_skill(name="s1", source_case_ids=[])]
|
||||
with patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo:
|
||||
mock_case_repo.find_by_owner_entries = AsyncMock()
|
||||
got = await _select_supporting_cases(
|
||||
skills,
|
||||
agent_id="a",
|
||||
exclude_entry_id="ac_target",
|
||||
app_id="default",
|
||||
project_id="default",
|
||||
)
|
||||
assert got == []
|
||||
mock_case_repo.find_by_owner_entries.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _collect_supporting_entry_ids dedup + exclude ────────────────────────
|
||||
|
||||
|
||||
def test_collect_supporting_entry_ids_dedups_and_excludes_target() -> None:
|
||||
"""Source ids fold across skills; duplicates and the target id drop out."""
|
||||
skill_a = MagicMock()
|
||||
skill_a.source_case_ids = ["ac_a", "ac_b", "ac_target"]
|
||||
skill_b = MagicMock()
|
||||
skill_b.source_case_ids = ["ac_b", "ac_c"] # ac_b duplicates skill_a's lineage
|
||||
skill_empty = MagicMock()
|
||||
skill_empty.source_case_ids = []
|
||||
|
||||
got = _collect_supporting_entry_ids(
|
||||
[skill_a, skill_b, skill_empty], exclude="ac_target"
|
||||
)
|
||||
assert got == ["ac_a", "ac_b", "ac_c"]
|
||||
|
||||
|
||||
def test_collect_supporting_entry_ids_handles_empty_input() -> None:
|
||||
"""No skills → no supporting cases."""
|
||||
assert _collect_supporting_entry_ids([], exclude="ac_anything") == []
|
||||
|
||||
|
||||
# ── partition lock (agent_id-level serialisation) ────────────────────────
|
||||
|
||||
|
||||
async def _run_serialisation_probe(
|
||||
agent_id_run_a: str, agent_id_run_b: str
|
||||
) -> list[str]:
|
||||
"""Drive two extract_agent_skill runs and record their critical-section order.
|
||||
|
||||
Mocks every I/O seam so the only async work inside the locked region
|
||||
is a tiny ``asyncio.sleep`` masquerading as the LLM call. The returned
|
||||
log is the strict enter/leave sequence both runs go through.
|
||||
"""
|
||||
log: list[str] = []
|
||||
|
||||
async def mock_aextract(case, **_kwargs):
|
||||
log.append(f"enter:{case.id}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{case.id}")
|
||||
return []
|
||||
|
||||
target_a = _lance_case("ac_run_a", vector=[0.1] * 1024)
|
||||
target_b = _lance_case("ac_run_b", vector=[0.1] * 1024)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_case_repo"
|
||||
) as mock_case_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.agent_skill_repo"
|
||||
) as mock_skill_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_skill.AgentSkillExtractor"
|
||||
) as mock_extractor_cls,
|
||||
patch("everos.memory.strategies.extract_agent_skill.AgentSkillWriter"),
|
||||
):
|
||||
mock_cluster_repo.get_with_members = AsyncMock(
|
||||
return_value=_algo_cluster(members=["ac_run_a", "ac_run_b"])
|
||||
)
|
||||
mock_case_repo.find_by_owner_entry = AsyncMock(
|
||||
side_effect=lambda owner, entry, **_kw: (
|
||||
target_a if entry == "ac_run_a" else target_b
|
||||
)
|
||||
)
|
||||
mock_case_repo.find_by_owner_entries = AsyncMock(return_value=[])
|
||||
mock_skill_repo.count_in_cluster = AsyncMock(return_value=0)
|
||||
mock_skill_repo.find_in_cluster = AsyncMock(return_value=[])
|
||||
mock_extractor_cls.return_value.aextract = mock_aextract
|
||||
await asyncio.gather(
|
||||
extract_agent_skill(
|
||||
_event(agent_id=agent_id_run_a, case_entry_id="ac_run_a"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
extract_agent_skill(
|
||||
_event(agent_id=agent_id_run_b, case_entry_id="ac_run_b"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
)
|
||||
return log
|
||||
|
||||
|
||||
async def test_partition_lock_serialises_runs_on_same_agent() -> None:
|
||||
"""Two runs sharing ``agent_id`` must not overlap critical sections."""
|
||||
log = await _run_serialisation_probe("agent_42", "agent_42")
|
||||
assert log in (
|
||||
["enter:ac_run_a", "leave:ac_run_a", "enter:ac_run_b", "leave:ac_run_b"],
|
||||
["enter:ac_run_b", "leave:ac_run_b", "enter:ac_run_a", "leave:ac_run_a"],
|
||||
)
|
||||
|
||||
|
||||
async def test_partition_lock_lets_different_agents_run_in_parallel() -> None:
|
||||
"""Runs on distinct ``agent_id`` must overlap (no false serialisation)."""
|
||||
log = await _run_serialisation_probe("agent_42", "agent_43")
|
||||
assert log.index("enter:ac_run_a") < log.index("leave:ac_run_b")
|
||||
assert log.index("enter:ac_run_b") < log.index("leave:ac_run_a")
|
||||
@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import structlog.testing
|
||||
from everalgo.types import AtomicFact, ChatMessage, MemCell
|
||||
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import UserPipelineStarted
|
||||
from everos.memory.strategies.extract_atomic_facts import extract_atomic_facts
|
||||
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_atomic_facts")
|
||||
|
||||
|
||||
def _two_user_memcell() -> MemCell:
|
||||
return MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hi from alice",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="user",
|
||||
content="hi from bob",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="u_bob",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m3",
|
||||
role="assistant",
|
||||
content="hello both",
|
||||
timestamp=1_700_000_002_000,
|
||||
sender_id="agent",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_002_000,
|
||||
)
|
||||
|
||||
|
||||
def _fact(owner_id: str | None, text: str) -> AtomicFact:
|
||||
return AtomicFact(owner_id=owner_id, content=text, timestamp=1_700_000_000_000)
|
||||
|
||||
|
||||
def _event() -> UserPipelineStarted:
|
||||
return UserPipelineStarted(
|
||||
memcell_id="mc_a", session_id="s1", memcell=_two_user_memcell()
|
||||
)
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = extract_atomic_facts._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "extract_atomic_facts"
|
||||
assert UserPipelineStarted in meta.trigger.on
|
||||
assert meta.emits == frozenset()
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
async def test_extracts_once_and_fans_out_per_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""One LLM call per memcell; same fact list re-written under each sender.
|
||||
|
||||
The algo prompt is subject-agnostic (only ``INPUT_TEXT`` + ``TIME``
|
||||
placeholders), so re-running it per sender would burn LLM tokens
|
||||
and let non-determinism drift the per-sender md files apart. The
|
||||
strategy calls ``aextract`` once with ``sender_id=None`` and
|
||||
broadcasts the resulting list — every user sender gets its own md
|
||||
entries pointing at the same fact bodies.
|
||||
|
||||
Per-owner batching: the strategy collects each sender's full fact
|
||||
list and issues one :meth:`append_entries` per owner (not N single
|
||||
appends), so the call shape is one batch call per sender.
|
||||
"""
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
generic_facts = [
|
||||
_fact(None, "alice mentioned a weekend trip to tokyo"),
|
||||
_fact(None, "bob said he needs hiking gear"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=generic_facts)
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
|
||||
await extract_atomic_facts(_event(), FakeStrategyContext())
|
||||
|
||||
# Exactly one LLM call, parameterised with sender_id=None.
|
||||
assert mock_cls.return_value.aextract.await_count == 1
|
||||
call = mock_cls.return_value.aextract.call_args
|
||||
assert call.kwargs["sender_id"] is None
|
||||
|
||||
# 2 senders → 2 batch calls; each batch carries this sender's 2 facts
|
||||
# (same generic body re-used).
|
||||
assert mock_wcls.return_value.append_entries.call_count == 2
|
||||
batch_calls = mock_wcls.return_value.append_entries.call_args_list
|
||||
batched_owners = sorted(c.args[0] for c in batch_calls)
|
||||
assert batched_owners == ["u_alice", "u_bob"]
|
||||
# Flatten items across batches: (owner, fact_text) pairs.
|
||||
flat = sorted(
|
||||
(c.args[0], sections["Fact"])
|
||||
for c in batch_calls
|
||||
for inline, sections in c.args[1]
|
||||
)
|
||||
assert flat == [
|
||||
("u_alice", "alice mentioned a weekend trip to tokyo"),
|
||||
("u_alice", "bob said he needs hiking gear"),
|
||||
("u_bob", "alice mentioned a weekend trip to tokyo"),
|
||||
("u_bob", "bob said he needs hiking gear"),
|
||||
]
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "atomic_facts_extracted"]
|
||||
assert matching, "expected atomic_facts_extracted log line"
|
||||
record = matching[0]
|
||||
assert record["count"] == 4
|
||||
assert sorted(record["owner_ids"]) == ["u_alice", "u_bob"]
|
||||
|
||||
|
||||
async def test_writes_md_for_each_fact(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
facts = [
|
||||
_fact("u_alice", "alice likes hiking"),
|
||||
_fact("u_alice", "alice lives in tokyo"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactWriter"
|
||||
) as mock_wcls,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=facts)
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hi",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
)
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
)
|
||||
await extract_atomic_facts(event, FakeStrategyContext())
|
||||
|
||||
# Single sender (u_alice) → one batch call with 2 items.
|
||||
assert mock_wcls.return_value.append_entries.call_count == 1
|
||||
batch_call = mock_wcls.return_value.append_entries.call_args
|
||||
assert batch_call.args[0] == "u_alice"
|
||||
items = batch_call.args[1]
|
||||
assert len(items) == 2
|
||||
for (inline, sections), fact in zip(items, facts, strict=True):
|
||||
assert inline["owner_id"] == "u_alice"
|
||||
assert inline["session_id"] == "s1"
|
||||
assert inline["parent_type"] == "memcell"
|
||||
assert inline["parent_id"] == "mc_a"
|
||||
assert "sender_ids" not in inline
|
||||
assert sections == {"Fact": fact.content}
|
||||
|
||||
|
||||
async def test_skips_when_memcell_has_no_messages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_b",
|
||||
session_id="s1",
|
||||
memcell=MemCell(items=[], timestamp=1_700_000_000_000),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[])
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
ctx = FakeStrategyContext()
|
||||
await extract_atomic_facts(event, ctx)
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "atomic_facts_extracted"]
|
||||
assert matching, "log line should still fire (count=0)"
|
||||
assert matching[0]["count"] == 0
|
||||
mock_wcls.return_value.append_entries.assert_not_called()
|
||||
231
tests/unit/test_memory/test_strategies/test_extract_foresight.py
Normal file
231
tests/unit/test_memory/test_strategies/test_extract_foresight.py
Normal file
@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import structlog.testing
|
||||
from everalgo.types import ChatMessage, Foresight, MemCell
|
||||
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import UserPipelineStarted
|
||||
from everos.memory.strategies.extract_foresight import extract_foresight
|
||||
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_foresight")
|
||||
|
||||
|
||||
def _two_user_memcell() -> MemCell:
|
||||
return MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="alice plans a trip",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="user",
|
||||
content="bob will buy tickets",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="u_bob",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m3",
|
||||
role="assistant",
|
||||
content="sounds good",
|
||||
timestamp=1_700_000_002_000,
|
||||
sender_id="agent",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_002_000,
|
||||
)
|
||||
|
||||
|
||||
def _foresight(owner_id: str, text: str) -> Foresight:
|
||||
return Foresight(
|
||||
owner_id=owner_id,
|
||||
foresight=text,
|
||||
evidence="...",
|
||||
timestamp=1_700_000_000_000,
|
||||
)
|
||||
|
||||
|
||||
def _event() -> UserPipelineStarted:
|
||||
return UserPipelineStarted(
|
||||
memcell_id="mc_a", session_id="s1", memcell=_two_user_memcell()
|
||||
)
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = extract_foresight._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "extract_foresight"
|
||||
assert UserPipelineStarted in meta.trigger.on
|
||||
assert meta.emits == frozenset()
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
async def test_extracts_per_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Per-sender extraction (like Episode, unlike AtomicFact's fan-out)."""
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
# sender_ids in the strategy are sorted: alice first, bob second.
|
||||
mock_cls.return_value.aextract = AsyncMock(
|
||||
side_effect=[
|
||||
[_foresight("u_alice", "trip to tokyo")],
|
||||
[_foresight("u_bob", "buy plane tickets")],
|
||||
]
|
||||
)
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
|
||||
await extract_foresight(_event(), FakeStrategyContext())
|
||||
|
||||
# Per-sender semantics: one LLM call per user sender.
|
||||
assert mock_cls.return_value.aextract.await_count == 2
|
||||
sender_id_calls = [
|
||||
call.kwargs.get("sender_id")
|
||||
for call in mock_cls.return_value.aextract.call_args_list
|
||||
]
|
||||
assert sender_id_calls == ["u_alice", "u_bob"]
|
||||
|
||||
# Per-owner batching: one batch call per owner; here each owner has 1
|
||||
# foresight, so two batches each carrying 1 item.
|
||||
assert mock_wcls.return_value.append_entries.call_count == 2
|
||||
batched_owners = sorted(
|
||||
c.args[0] for c in mock_wcls.return_value.append_entries.call_args_list
|
||||
)
|
||||
assert batched_owners == ["u_alice", "u_bob"]
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "foresights_extracted"]
|
||||
assert matching, "expected foresights_extracted log line"
|
||||
record = matching[0]
|
||||
assert record["count"] == 2
|
||||
assert sorted(record["owner_ids"]) == ["u_alice", "u_bob"]
|
||||
|
||||
|
||||
async def test_writes_md_for_each_foresight(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
foresights = [
|
||||
Foresight(
|
||||
owner_id="u_alice",
|
||||
foresight="trip to tokyo",
|
||||
evidence="said so",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
Foresight(
|
||||
owner_id="u_alice",
|
||||
foresight="buy tickets",
|
||||
evidence="confirmed",
|
||||
timestamp=1_700_000_000_000,
|
||||
start_time="2023-11-15",
|
||||
duration_days=7,
|
||||
),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightWriter"
|
||||
) as mock_wcls,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=foresights)
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="planning a trip",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
)
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
)
|
||||
await extract_foresight(event, FakeStrategyContext())
|
||||
|
||||
# Single sender (u_alice) → one batch call with both foresights.
|
||||
assert mock_wcls.return_value.append_entries.call_count == 1
|
||||
batch_call = mock_wcls.return_value.append_entries.call_args
|
||||
assert batch_call.args[0] == "u_alice"
|
||||
items = batch_call.args[1]
|
||||
assert len(items) == 2
|
||||
|
||||
# First foresight: no optional time fields
|
||||
inline0, sections0 = items[0]
|
||||
assert inline0["owner_id"] == "u_alice"
|
||||
assert inline0["session_id"] == "s1"
|
||||
assert inline0["parent_type"] == "memcell"
|
||||
assert inline0["parent_id"] == "mc_a"
|
||||
assert "sender_ids" not in inline0
|
||||
assert "start_time" not in inline0
|
||||
assert "duration_days" not in inline0
|
||||
assert sections0 == {"Foresight": "trip to tokyo", "Evidence": "said so"}
|
||||
|
||||
# Second foresight: has start_time + duration_days
|
||||
inline1, sections1 = items[1]
|
||||
assert inline1["start_time"] == "2023-11-15"
|
||||
assert inline1["duration_days"] == 7
|
||||
assert sections1 == {"Foresight": "buy tickets", "Evidence": "confirmed"}
|
||||
|
||||
|
||||
async def test_skips_when_memcell_has_no_messages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
event = UserPipelineStarted(
|
||||
memcell_id="mc_b",
|
||||
session_id="s1",
|
||||
memcell=MemCell(items=[], timestamp=1_700_000_000_000),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightExtractor"
|
||||
) as mock_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightWriter"
|
||||
) as mock_wcls,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_cls.return_value.aextract = AsyncMock(return_value=[])
|
||||
mock_wcls.return_value.append_entries = AsyncMock(return_value=[])
|
||||
ctx = FakeStrategyContext()
|
||||
await extract_foresight(event, ctx)
|
||||
|
||||
matching = [e for e in captured if e.get("event") == "foresights_extracted"]
|
||||
assert matching, "log line should still fire (count=0)"
|
||||
assert matching[0]["count"] == 0
|
||||
mock_wcls.return_value.append_entries.assert_not_called()
|
||||
@ -0,0 +1,387 @@
|
||||
"""Tests for :func:`extract_user_profile`.
|
||||
|
||||
Heavy mocking — the strategy threads through ``cluster_repo`` (sqlite),
|
||||
``memcell_repo`` (sqlite, payload deserialise), ``ProfileReader`` /
|
||||
``ProfileWriter`` (md), and ``ProfileExtractor`` (algo). We mock all
|
||||
seams so the test exercises the orchestration only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from everalgo.clustering import Cluster as AlgoCluster
|
||||
from everalgo.types import ChatMessage, MemCell
|
||||
from everalgo.types import Profile as AlgoProfile
|
||||
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.infra.persistence.markdown import UserProfileFrontmatter
|
||||
from everos.memory.events import ProfileClusterUpdated
|
||||
from everos.memory.strategies._partition_locks import _reset_for_tests
|
||||
from everos.memory.strategies.extract_user_profile import extract_user_profile
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_partition_locks() -> None:
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def _event(
|
||||
*,
|
||||
owner_id: str = "u_alice",
|
||||
memcell_id: str = "mc_aaaaaaaaaaa1",
|
||||
cluster_id: str = "cl_user00000001",
|
||||
) -> ProfileClusterUpdated:
|
||||
return ProfileClusterUpdated(
|
||||
memcell_id=memcell_id,
|
||||
cluster_id=cluster_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
|
||||
def _algo_cluster(*, cluster_id: str, members: list[str], last_ts: int) -> AlgoCluster:
|
||||
return AlgoCluster(
|
||||
id=cluster_id,
|
||||
centroid=np.zeros(1024, dtype=np.float32),
|
||||
count=len(members),
|
||||
last_ts=last_ts,
|
||||
preview=[],
|
||||
members=members,
|
||||
)
|
||||
|
||||
|
||||
def _memcell_row(memcell_id: str, *, sender_id: str, ts_ms: int) -> MagicMock:
|
||||
"""Stand-in for a sqlite Memcell row — only ``payload_json`` is read."""
|
||||
cell = MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id=f"{memcell_id}_m1",
|
||||
role="user",
|
||||
content=f"hi from {sender_id}",
|
||||
timestamp=ts_ms,
|
||||
sender_id=sender_id,
|
||||
),
|
||||
],
|
||||
timestamp=ts_ms,
|
||||
)
|
||||
row = MagicMock()
|
||||
row.memcell_id = memcell_id
|
||||
row.payload_json = cell.model_dump_json()
|
||||
return row
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = extract_user_profile._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "extract_user_profile"
|
||||
assert ProfileClusterUpdated in meta.trigger.on
|
||||
assert meta.emits == frozenset()
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_mode_writes_profile_when_no_existing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""No prior profile → ProfileExtractor invoked without ``old_profile``."""
|
||||
cluster = _algo_cluster(
|
||||
cluster_id="cl_user00000001",
|
||||
members=["mc_aaaaaaaaaaa1"],
|
||||
last_ts=1_700_000_001_000,
|
||||
)
|
||||
rows = [
|
||||
_memcell_row("mc_aaaaaaaaaaa1", sender_id="u_alice", ts_ms=1_700_000_001_000)
|
||||
]
|
||||
new_profile = AlgoProfile.model_validate(
|
||||
{
|
||||
"owner_id": "u_alice",
|
||||
"summary": "Alice is a hiker.",
|
||||
"timestamp": 1_700_000_001_000,
|
||||
"explicit_info": ["lives in tokyo"],
|
||||
"implicit_traits": ["adventurous"],
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.memcell_repo"
|
||||
) as mock_memcell_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileExtractor"
|
||||
) as mock_extractor_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileReader"
|
||||
) as mock_reader_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileWriter"
|
||||
) as mock_writer_cls,
|
||||
):
|
||||
mock_cluster_repo.list_for_owner = AsyncMock(return_value=[cluster])
|
||||
mock_memcell_repo.find_by_ids = AsyncMock(return_value=rows)
|
||||
mock_reader_cls.return_value.read = AsyncMock(return_value=None)
|
||||
mock_writer_cls.return_value.write = AsyncMock(return_value=None)
|
||||
mock_extractor_cls.return_value.aextract = AsyncMock(return_value=new_profile)
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_user_profile")
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
monkeypatch.setattr(mod, "_reader", None, raising=False)
|
||||
|
||||
await extract_user_profile(_event(), FakeStrategyContext())
|
||||
|
||||
# INIT mode — old_profile is None.
|
||||
extractor_call = mock_extractor_cls.return_value.aextract.call_args
|
||||
assert extractor_call.kwargs["old_profile"] is None
|
||||
assert extractor_call.kwargs["sender_id"] == "u_alice"
|
||||
assert [mc.timestamp for mc in extractor_call.args[0]] == [1_700_000_001_000]
|
||||
|
||||
# Writer received the freshly built frontmatter.
|
||||
write_call = mock_writer_cls.return_value.write.call_args
|
||||
assert write_call.args[0] == "u_alice"
|
||||
fm = write_call.kwargs["frontmatter"]
|
||||
assert fm.user_id == "u_alice"
|
||||
assert fm.summary == "Alice is a hiker."
|
||||
assert fm.profile_timestamp_ms == 1_700_000_001_000
|
||||
assert fm.explicit_info == ["lives in tokyo"]
|
||||
assert fm.implicit_traits == ["adventurous"]
|
||||
assert write_call.kwargs["body"] == "Alice is a hiker."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_mode_rehydrates_old_profile(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Existing profile → algo Profile rehydrated and passed as old_profile."""
|
||||
cluster = _algo_cluster(
|
||||
cluster_id="cl_user00000001",
|
||||
members=["mc_aaaaaaaaaaa1"],
|
||||
last_ts=1_700_000_002_000,
|
||||
)
|
||||
rows = [
|
||||
_memcell_row("mc_aaaaaaaaaaa1", sender_id="u_alice", ts_ms=1_700_000_002_000)
|
||||
]
|
||||
existing_fm = UserProfileFrontmatter(
|
||||
id="profile_u_alice",
|
||||
user_id="u_alice",
|
||||
summary="prior summary",
|
||||
explicit_info=["prior fact"],
|
||||
implicit_traits=["prior trait"],
|
||||
profile_timestamp_ms=1_700_000_000_000,
|
||||
)
|
||||
new_profile = AlgoProfile.model_validate(
|
||||
{
|
||||
"owner_id": "u_alice",
|
||||
"summary": "updated summary",
|
||||
"timestamp": 1_700_000_002_000,
|
||||
"explicit_info": ["prior fact", "new fact"],
|
||||
"implicit_traits": ["prior trait"],
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.memcell_repo"
|
||||
) as mock_memcell_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileExtractor"
|
||||
) as mock_extractor_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileReader"
|
||||
) as mock_reader_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileWriter"
|
||||
) as mock_writer_cls,
|
||||
):
|
||||
mock_cluster_repo.list_for_owner = AsyncMock(return_value=[cluster])
|
||||
mock_memcell_repo.find_by_ids = AsyncMock(return_value=rows)
|
||||
mock_reader_cls.return_value.read = AsyncMock(
|
||||
return_value=(existing_fm, "prior summary")
|
||||
)
|
||||
mock_writer_cls.return_value.write = AsyncMock(return_value=None)
|
||||
mock_extractor_cls.return_value.aextract = AsyncMock(return_value=new_profile)
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_user_profile")
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
monkeypatch.setattr(mod, "_reader", None, raising=False)
|
||||
|
||||
await extract_user_profile(_event(), FakeStrategyContext())
|
||||
|
||||
# UPDATE mode — old_profile is the rehydrated algo type carrying prior fields.
|
||||
extractor_call = mock_extractor_cls.return_value.aextract.call_args
|
||||
old = extractor_call.kwargs["old_profile"]
|
||||
assert isinstance(old, AlgoProfile)
|
||||
assert old.summary == "prior summary"
|
||||
assert old.timestamp == 1_700_000_000_000
|
||||
assert old.model_dump()["explicit_info"] == ["prior fact"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_no_members(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""An empty target cluster set (no fresh clusters) → no extractor call."""
|
||||
# Existing profile timestamp newer than every cluster's last_ts → no
|
||||
# target_cluster matches `last_ts > last_profile_ts`, but the current
|
||||
# cluster_id should still force inclusion. Set the current cluster id
|
||||
# to a non-existent value to drop everything.
|
||||
stale_cluster = _algo_cluster(
|
||||
cluster_id="cl_other000001",
|
||||
members=["mc_other00000"],
|
||||
last_ts=1_600_000_000_000,
|
||||
)
|
||||
existing_fm = UserProfileFrontmatter(
|
||||
id="profile_u_alice",
|
||||
user_id="u_alice",
|
||||
summary="prior",
|
||||
explicit_info=[],
|
||||
implicit_traits=[],
|
||||
profile_timestamp_ms=1_900_000_000_000,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.memcell_repo"
|
||||
) as mock_memcell_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileExtractor"
|
||||
) as mock_extractor_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileReader"
|
||||
) as mock_reader_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileWriter"
|
||||
) as mock_writer_cls,
|
||||
):
|
||||
mock_cluster_repo.list_for_owner = AsyncMock(return_value=[stale_cluster])
|
||||
mock_memcell_repo.find_by_ids = AsyncMock(return_value=[])
|
||||
mock_reader_cls.return_value.read = AsyncMock(
|
||||
return_value=(existing_fm, "prior")
|
||||
)
|
||||
mock_writer_cls.return_value.write = AsyncMock(return_value=None)
|
||||
mock_extractor_cls.return_value.aextract = AsyncMock()
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_user_profile")
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
monkeypatch.setattr(mod, "_reader", None, raising=False)
|
||||
|
||||
await extract_user_profile(
|
||||
_event(cluster_id="cl_unknown00000"), FakeStrategyContext()
|
||||
)
|
||||
|
||||
mock_extractor_cls.return_value.aextract.assert_not_called()
|
||||
mock_writer_cls.return_value.write.assert_not_called()
|
||||
|
||||
|
||||
# ── partition lock (owner_id-level serialisation) ────────────────────────
|
||||
|
||||
|
||||
async def _run_serialisation_probe(
|
||||
owner_a: str, owner_b: str, monkeypatch: pytest.MonkeyPatch
|
||||
) -> list[str]:
|
||||
"""Drive two extract_user_profile runs and record entry/exit order."""
|
||||
log: list[str] = []
|
||||
|
||||
async def mock_aextract(_memcells, *, sender_id, **_kwargs):
|
||||
log.append(f"enter:{sender_id}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{sender_id}")
|
||||
return AlgoProfile(
|
||||
owner_id=sender_id,
|
||||
summary="summary",
|
||||
timestamp=1_700_000_000_000,
|
||||
explicit_info=[],
|
||||
implicit_traits=[],
|
||||
)
|
||||
|
||||
cluster_a = _algo_cluster(
|
||||
cluster_id="cl_a", members=["mc_a"], last_ts=1_700_000_000_000
|
||||
)
|
||||
cluster_b = _algo_cluster(
|
||||
cluster_id="cl_b", members=["mc_b"], last_ts=1_700_000_000_000
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.cluster_repo"
|
||||
) as mock_cluster_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.memcell_repo"
|
||||
) as mock_memcell_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileReader"
|
||||
) as mock_reader_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileWriter"
|
||||
) as mock_writer_cls,
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_user_profile.ProfileExtractor"
|
||||
) as mock_extractor_cls,
|
||||
):
|
||||
mock_cluster_repo.list_for_owner = AsyncMock(
|
||||
side_effect=lambda owner, _kind, **_kw: (
|
||||
[cluster_a] if owner == owner_a else [cluster_b]
|
||||
)
|
||||
)
|
||||
mock_memcell_repo.find_by_ids = AsyncMock(
|
||||
side_effect=lambda ids: [
|
||||
_memcell_row(ids[0], sender_id="sender", ts_ms=1_700_000_000_000)
|
||||
]
|
||||
)
|
||||
mock_reader_cls.return_value.read = AsyncMock(return_value=[])
|
||||
mock_writer_cls.return_value.write = AsyncMock(return_value=None)
|
||||
mock_extractor_cls.return_value.aextract = mock_aextract
|
||||
|
||||
mod = importlib.import_module("everos.memory.strategies.extract_user_profile")
|
||||
monkeypatch.setattr(mod, "_reader", None, raising=False)
|
||||
monkeypatch.setattr(mod, "_writer", None, raising=False)
|
||||
|
||||
await asyncio.gather(
|
||||
extract_user_profile(
|
||||
_event(owner_id=owner_a, cluster_id="cl_a"), FakeStrategyContext()
|
||||
),
|
||||
extract_user_profile(
|
||||
_event(owner_id=owner_b, cluster_id="cl_b"), FakeStrategyContext()
|
||||
),
|
||||
)
|
||||
return log
|
||||
|
||||
|
||||
async def test_partition_lock_serialises_runs_on_same_owner(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Two runs sharing ``owner_id`` must not overlap critical sections."""
|
||||
log = await _run_serialisation_probe("u_alice", "u_alice", monkeypatch)
|
||||
assert log in (
|
||||
["enter:u_alice", "leave:u_alice", "enter:u_alice", "leave:u_alice"],
|
||||
)
|
||||
# Same-owner runs always log "u_alice" twice — verify strict ordering
|
||||
# by tagging entry/leave pairs are adjacent (no interleave possible).
|
||||
assert log[0].startswith("enter:") and log[1].startswith("leave:")
|
||||
assert log[2].startswith("enter:") and log[3].startswith("leave:")
|
||||
|
||||
|
||||
async def test_partition_lock_lets_different_owners_run_in_parallel(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Runs on distinct ``owner_id`` must overlap (no false serialisation)."""
|
||||
log = await _run_serialisation_probe("u_alice", "u_bob", monkeypatch)
|
||||
assert log.index("enter:u_alice") < log.index("leave:u_bob")
|
||||
assert log.index("enter:u_bob") < log.index("leave:u_alice")
|
||||
126
tests/unit/test_memory/test_strategies/test_partition_locks.py
Normal file
126
tests/unit/test_memory/test_strategies/test_partition_locks.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""Tests for :mod:`everos.memory.strategies._partition_locks`.
|
||||
|
||||
The helper is the foundation under every strategy that performs a
|
||||
read → modify → write on shared state; its own behaviour (lock reuse,
|
||||
strategy isolation, FIFO serialisation, parallel keys) is exercised
|
||||
here, in isolation from any business strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.strategies._partition_locks import (
|
||||
_reset_for_tests,
|
||||
get_partition_lock,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_locks() -> None:
|
||||
"""Each test gets a clean registry — no inherited holders / waiters."""
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def test_same_strategy_same_key_returns_identical_lock() -> None:
|
||||
"""Repeat lookups must reuse the lock (otherwise serialisation breaks)."""
|
||||
a = get_partition_lock("strategy_x", "k1")
|
||||
b = get_partition_lock("strategy_x", "k1")
|
||||
assert a is b
|
||||
|
||||
|
||||
def test_same_strategy_different_keys_return_distinct_locks() -> None:
|
||||
"""Different partition keys must not block each other."""
|
||||
assert get_partition_lock("strategy_x", "k1") is not get_partition_lock(
|
||||
"strategy_x", "k2"
|
||||
)
|
||||
|
||||
|
||||
def test_different_strategies_share_no_locks_for_identical_key() -> None:
|
||||
"""Strategy namespaces are independent — same key string is two locks."""
|
||||
assert get_partition_lock("strategy_x", "k1") is not get_partition_lock(
|
||||
"strategy_y", "k1"
|
||||
)
|
||||
|
||||
|
||||
def test_reset_for_tests_drops_every_lock() -> None:
|
||||
"""After reset the registry is empty; the next lookup returns a fresh lock."""
|
||||
before = get_partition_lock("strategy_x", "k1")
|
||||
_reset_for_tests()
|
||||
after = get_partition_lock("strategy_x", "k1")
|
||||
assert before is not after
|
||||
|
||||
|
||||
async def test_same_key_serialises_concurrent_acquirers() -> None:
|
||||
"""Two tasks contending the same key must not overlap critical sections."""
|
||||
log: list[str] = []
|
||||
|
||||
async def worker(tag: str) -> None:
|
||||
async with get_partition_lock("strategy_x", "k1"):
|
||||
log.append(f"enter:{tag}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{tag}")
|
||||
|
||||
await asyncio.gather(worker("a"), worker("b"))
|
||||
|
||||
# The two critical sections must run one after the other (either order
|
||||
# is fine — asyncio scheduling decides who acquires first).
|
||||
assert log in (
|
||||
["enter:a", "leave:a", "enter:b", "leave:b"],
|
||||
["enter:b", "leave:b", "enter:a", "leave:a"],
|
||||
)
|
||||
|
||||
|
||||
async def test_different_keys_run_in_parallel() -> None:
|
||||
"""Two tasks on distinct keys must overlap (no false serialisation)."""
|
||||
log: list[str] = []
|
||||
|
||||
async def worker(key: str, tag: str) -> None:
|
||||
async with get_partition_lock("strategy_x", key):
|
||||
log.append(f"enter:{tag}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{tag}")
|
||||
|
||||
await asyncio.gather(worker("k1", "a"), worker("k2", "b"))
|
||||
|
||||
# Both must enter before either leaves — proves no cross-key blocking.
|
||||
assert log.index("enter:a") < log.index("leave:b")
|
||||
assert log.index("enter:b") < log.index("leave:a")
|
||||
|
||||
|
||||
async def test_concurrent_acquirers_fifo_fairness() -> None:
|
||||
"""asyncio.Lock is FIFO — queued waiters acquire in arrival order."""
|
||||
log: list[str] = []
|
||||
holder_in = asyncio.Event()
|
||||
holder_release = asyncio.Event()
|
||||
|
||||
async def holder() -> None:
|
||||
async with get_partition_lock("strategy_x", "k1"):
|
||||
holder_in.set()
|
||||
await holder_release.wait()
|
||||
log.append("leave:holder")
|
||||
|
||||
async def waiter(tag: str, arrived: asyncio.Event) -> None:
|
||||
arrived.set()
|
||||
async with get_partition_lock("strategy_x", "k1"):
|
||||
log.append(f"enter:{tag}")
|
||||
|
||||
arrived_a = asyncio.Event()
|
||||
arrived_b = asyncio.Event()
|
||||
task_holder = asyncio.create_task(holder())
|
||||
await holder_in.wait() # holder owns the lock
|
||||
|
||||
# Enqueue A first, then B — Lock's deque preserves this order.
|
||||
task_a = asyncio.create_task(waiter("a", arrived_a))
|
||||
await arrived_a.wait()
|
||||
await asyncio.sleep(0) # let A actually park on the lock
|
||||
task_b = asyncio.create_task(waiter("b", arrived_b))
|
||||
await arrived_b.wait()
|
||||
await asyncio.sleep(0) # let B park on the lock
|
||||
|
||||
holder_release.set()
|
||||
await asyncio.gather(task_holder, task_a, task_b)
|
||||
|
||||
assert log == ["leave:holder", "enter:a", "enter:b"]
|
||||
56
tests/unit/test_memory/test_strategies/test_registration.py
Normal file
56
tests/unit/test_memory/test_strategies/test_registration.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""Test strategy package exports and OME engine registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from everos.memory.strategies import (
|
||||
extract_agent_case,
|
||||
extract_agent_skill,
|
||||
extract_atomic_facts,
|
||||
extract_foresight,
|
||||
extract_user_profile,
|
||||
trigger_profile_clustering,
|
||||
trigger_skill_clustering,
|
||||
)
|
||||
|
||||
|
||||
def test_strategies_are_re_exported_from_package() -> None:
|
||||
for fn, name in [
|
||||
(extract_atomic_facts, "extract_atomic_facts"),
|
||||
(extract_foresight, "extract_foresight"),
|
||||
(extract_agent_case, "extract_agent_case"),
|
||||
(trigger_skill_clustering, "trigger_skill_clustering"),
|
||||
(extract_agent_skill, "extract_agent_skill"),
|
||||
(trigger_profile_clustering, "trigger_profile_clustering"),
|
||||
(extract_user_profile, "extract_user_profile"),
|
||||
]:
|
||||
assert fn._ome_strategy_meta.name == name # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_get_engine_registers_all_strategies(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from everos.core.persistence import MemoryRoot
|
||||
|
||||
svc = importlib.import_module("everos.service.memorize")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(svc, "_ome_engine", None, raising=False)
|
||||
|
||||
engine = svc._get_engine()
|
||||
names = {m.name for m in engine._registry.all()} # noqa: SLF001 — test introspection
|
||||
assert names == {
|
||||
"extract_atomic_facts",
|
||||
"extract_foresight",
|
||||
"extract_agent_case",
|
||||
"trigger_skill_clustering",
|
||||
"extract_agent_skill",
|
||||
"trigger_profile_clustering",
|
||||
"extract_user_profile",
|
||||
}
|
||||
@ -0,0 +1,202 @@
|
||||
"""Real md round-trip tests: strategy runs → writer writes → reader finds file."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from everalgo.types import AgentCase, AtomicFact, ChatMessage, Foresight, MemCell
|
||||
|
||||
from everos.core.persistence import MemoryRoot
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.infra.persistence.markdown import (
|
||||
AgentCaseReader,
|
||||
AtomicFactReader,
|
||||
ForesightReader,
|
||||
)
|
||||
from everos.memory.events import AgentPipelineStarted, UserPipelineStarted
|
||||
from everos.memory.strategies.extract_agent_case import extract_agent_case
|
||||
from everos.memory.strategies.extract_atomic_facts import extract_atomic_facts
|
||||
from everos.memory.strategies.extract_foresight import extract_foresight
|
||||
|
||||
|
||||
def _event_for(owner: str) -> UserPipelineStarted:
|
||||
return UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hi",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id=owner,
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _agent_event() -> AgentPipelineStarted:
|
||||
return AgentPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="please summarise",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="assistant",
|
||||
content="here's the summary",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="agent_42",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_001_000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_atomic_facts_round_trip(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import importlib
|
||||
|
||||
af_mod = importlib.import_module("everos.memory.strategies.extract_atomic_facts")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(af_mod, "_writer", None, raising=False)
|
||||
|
||||
facts = [
|
||||
AtomicFact(
|
||||
owner_id="u_alice",
|
||||
content="alice likes hiking",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
AtomicFact(
|
||||
owner_id="u_alice",
|
||||
content="alice lives in tokyo",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=facts)
|
||||
await extract_atomic_facts(_event_for("u_alice"), FakeStrategyContext())
|
||||
|
||||
reader = AtomicFactReader(root=MemoryRoot(root=tmp_path))
|
||||
path = reader.path_for("u_alice")
|
||||
assert path.is_file(), f"expected md at {path}"
|
||||
content = path.read_text(encoding="utf-8")
|
||||
assert "alice likes hiking" in content
|
||||
assert "alice lives in tokyo" in content
|
||||
|
||||
|
||||
async def test_foresights_round_trip(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import importlib
|
||||
|
||||
fs_mod = importlib.import_module("everos.memory.strategies.extract_foresight")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(fs_mod, "_writer", None, raising=False)
|
||||
|
||||
foresights = [
|
||||
Foresight(
|
||||
owner_id="u_alice",
|
||||
foresight="plans trip to tokyo",
|
||||
evidence="said so",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=foresights)
|
||||
await extract_foresight(_event_for("u_alice"), FakeStrategyContext())
|
||||
|
||||
reader = ForesightReader(root=MemoryRoot(root=tmp_path))
|
||||
path = reader.path_for("u_alice")
|
||||
assert path.is_file(), f"expected md at {path}"
|
||||
content = path.read_text(encoding="utf-8")
|
||||
assert "plans trip to tokyo" in content
|
||||
assert "said so" in content
|
||||
|
||||
|
||||
async def test_agent_case_round_trip(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import importlib
|
||||
|
||||
ac_mod = importlib.import_module("everos.memory.strategies.extract_agent_case")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(ac_mod, "_writer", None, raising=False)
|
||||
|
||||
cases = [
|
||||
AgentCase(
|
||||
id=uuid.uuid4().hex,
|
||||
timestamp=1_700_000_001_000,
|
||||
task_intent="summarise the doc",
|
||||
approach="read then condense",
|
||||
quality_score=0.82,
|
||||
key_insight="batch-then-summarise",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=cases)
|
||||
await extract_agent_case(_agent_event(), FakeStrategyContext())
|
||||
|
||||
reader = AgentCaseReader(root=MemoryRoot(root=tmp_path))
|
||||
path = reader.path_for("agent_42")
|
||||
assert path.is_file(), f"expected md at {path}"
|
||||
content = path.read_text(encoding="utf-8")
|
||||
assert "summarise the doc" in content
|
||||
assert "read then condense" in content
|
||||
assert "batch-then-summarise" in content
|
||||
# quality_score must land in inline (cascade requires it via require_float).
|
||||
assert "quality_score" in content
|
||||
@ -0,0 +1,284 @@
|
||||
"""Contract: strategy-written md must round-trip through cascade handler.
|
||||
|
||||
Guards against silent-breakage class: strategy writes section keys
|
||||
(e.g. ``{"fact": ...}``) that the cascade handler reads under a different
|
||||
case (e.g. ``sections.get("Fact")``). Without this contract, the worker
|
||||
still upserts a LanceDB row but with empty ``fact`` / ``foresight``
|
||||
text, empty BM25 tokens, and a vector for the empty string — search
|
||||
fails silently. Earlier unit tests stop at the strategy boundary (mock
|
||||
the writer) or at the writer boundary (skip the strategy); neither
|
||||
catches a key-name drift.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from everalgo.types import AgentCase, AtomicFact, ChatMessage, Foresight, MemCell
|
||||
|
||||
from everos.component.embedding import EmbeddingProvider
|
||||
from everos.component.tokenizer import Tokenizer
|
||||
from everos.core.persistence import MarkdownReader, MemoryRoot
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.cascade.handlers import (
|
||||
AgentCaseHandler,
|
||||
AtomicFactHandler,
|
||||
ForesightHandler,
|
||||
HandlerDeps,
|
||||
)
|
||||
from everos.memory.cascade.handlers._daily_log_base import ParsedEntry
|
||||
from everos.memory.events import AgentPipelineStarted, UserPipelineStarted
|
||||
from everos.memory.strategies.extract_agent_case import extract_agent_case
|
||||
from everos.memory.strategies.extract_atomic_facts import extract_atomic_facts
|
||||
from everos.memory.strategies.extract_foresight import extract_foresight
|
||||
|
||||
|
||||
class _StubTokenizer(Tokenizer):
|
||||
def tokenize(self, text): # type: ignore[no-untyped-def]
|
||||
return [tok for tok in text.split() if tok]
|
||||
|
||||
def tokenize_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [self.tokenize(t) for t in texts]
|
||||
|
||||
|
||||
class _StubEmbedder(EmbeddingProvider):
|
||||
dim = 1024
|
||||
|
||||
async def embed(self, text): # type: ignore[no-untyped-def]
|
||||
return [0.0] * self.dim
|
||||
|
||||
async def embed_batch(self, texts): # type: ignore[no-untyped-def]
|
||||
return [await self.embed(t) for t in texts]
|
||||
|
||||
|
||||
def _event(owner_id: str) -> UserPipelineStarted:
|
||||
return UserPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="hi",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id=owner_id,
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _build_row_from_md(
|
||||
handler: AtomicFactHandler | ForesightHandler | AgentCaseHandler,
|
||||
md_root: Path,
|
||||
md_glob: str,
|
||||
*,
|
||||
owner_id: str = "u_alice",
|
||||
owner_type: str = "user",
|
||||
):
|
||||
md_files: list[anyio.Path] = []
|
||||
async for p in anyio.Path(md_root).glob(md_glob):
|
||||
md_files.append(p)
|
||||
assert len(md_files) == 1, f"expected exactly one md, got: {md_files}"
|
||||
md_abs = Path(md_files[0])
|
||||
rel = str(md_abs.relative_to(md_root))
|
||||
parsed = await MarkdownReader.read(md_abs)
|
||||
assert parsed.entries, "writer should have produced at least one entry"
|
||||
entry = parsed.entries[0]
|
||||
structured = entry.as_structured()
|
||||
pe = ParsedEntry(
|
||||
entry_id=entry.id,
|
||||
structured=structured,
|
||||
content_sha256=handler._content_sha256(structured), # noqa: SLF001
|
||||
)
|
||||
return await handler._build_row( # noqa: SLF001
|
||||
owner_id=owner_id,
|
||||
owner_type=owner_type,
|
||||
md_path=rel,
|
||||
entry=pe,
|
||||
)
|
||||
|
||||
|
||||
async def test_atomic_fact_strategy_md_feeds_handler_with_content(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Strategy → md → AtomicFactHandler must carry the fact text intact."""
|
||||
af_mod = importlib.import_module("everos.memory.strategies.extract_atomic_facts")
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(af_mod, "_writer", None, raising=False)
|
||||
|
||||
facts = [
|
||||
AtomicFact(
|
||||
owner_id="u_alice",
|
||||
content="alice likes hiking",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
]
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_atomic_facts.AtomicFactExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=facts)
|
||||
await extract_atomic_facts(_event("u_alice"), FakeStrategyContext())
|
||||
|
||||
handler = AtomicFactHandler(
|
||||
HandlerDeps(
|
||||
memory_root=MemoryRoot(root=tmp_path),
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
row = await _build_row_from_md(
|
||||
handler, tmp_path, "*/*/users/u_alice/.atomic_facts/atomic_fact-*.md"
|
||||
)
|
||||
# Regression guard: section key drift would land here as fact="".
|
||||
assert row.fact == "alice likes hiking"
|
||||
assert row.fact_tokens == "alice likes hiking"
|
||||
assert len(row.vector) == 1024
|
||||
|
||||
|
||||
async def test_foresight_strategy_md_feeds_handler_with_content(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Strategy → md → ForesightHandler must carry foresight + evidence text."""
|
||||
fs_mod = importlib.import_module("everos.memory.strategies.extract_foresight")
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(fs_mod, "_writer", None, raising=False)
|
||||
|
||||
foresights = [
|
||||
Foresight(
|
||||
owner_id="u_alice",
|
||||
foresight="plans trip to tokyo",
|
||||
evidence="said so explicitly",
|
||||
timestamp=1_700_000_000_000,
|
||||
),
|
||||
]
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_foresight.ForesightExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=foresights)
|
||||
await extract_foresight(_event("u_alice"), FakeStrategyContext())
|
||||
|
||||
handler = ForesightHandler(
|
||||
HandlerDeps(
|
||||
memory_root=MemoryRoot(root=tmp_path),
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
row = await _build_row_from_md(
|
||||
handler, tmp_path, "*/*/users/u_alice/.foresights/foresight-*.md"
|
||||
)
|
||||
# Regression guard: section key drift would land here as foresight="".
|
||||
assert row.foresight == "plans trip to tokyo"
|
||||
assert row.foresight_tokens == "plans trip to tokyo"
|
||||
assert row.evidence == "said so explicitly"
|
||||
assert row.evidence_tokens == "said so explicitly"
|
||||
assert len(row.vector) == 1024
|
||||
|
||||
|
||||
def _agent_event() -> AgentPipelineStarted:
|
||||
return AgentPipelineStarted(
|
||||
memcell_id="mc_a",
|
||||
session_id="s1",
|
||||
memcell=MemCell(
|
||||
items=[
|
||||
ChatMessage(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="please summarise",
|
||||
timestamp=1_700_000_000_000,
|
||||
sender_id="u_alice",
|
||||
),
|
||||
ChatMessage(
|
||||
id="m2",
|
||||
role="assistant",
|
||||
content="here's the summary",
|
||||
timestamp=1_700_000_001_000,
|
||||
sender_id="agent_42",
|
||||
),
|
||||
],
|
||||
timestamp=1_700_000_001_000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_agent_case_strategy_md_feeds_handler_with_content(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Strategy → md → AgentCaseHandler carries task_intent, approach, score."""
|
||||
ac_mod = importlib.import_module("everos.memory.strategies.extract_agent_case")
|
||||
monkeypatch.setattr(
|
||||
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
||||
)
|
||||
monkeypatch.setattr(ac_mod, "_writer", None, raising=False)
|
||||
|
||||
cases = [
|
||||
AgentCase(
|
||||
id=uuid.uuid4().hex,
|
||||
timestamp=1_700_000_001_000,
|
||||
task_intent="summarise the doc",
|
||||
approach="read + condense",
|
||||
quality_score=0.85,
|
||||
key_insight="batch-then-summarise",
|
||||
)
|
||||
]
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.extract_agent_case.AgentCaseExtractor"
|
||||
) as mock_ext,
|
||||
):
|
||||
mock_ext.return_value.aextract = AsyncMock(return_value=cases)
|
||||
await extract_agent_case(_agent_event(), FakeStrategyContext())
|
||||
|
||||
handler = AgentCaseHandler(
|
||||
HandlerDeps(
|
||||
memory_root=MemoryRoot(root=tmp_path),
|
||||
embedder=_StubEmbedder(),
|
||||
tokenizer=_StubTokenizer(),
|
||||
)
|
||||
)
|
||||
row = await _build_row_from_md(
|
||||
handler,
|
||||
tmp_path,
|
||||
"*/*/agents/agent_42/.cases/agent_case-*.md",
|
||||
owner_id="agent_42",
|
||||
owner_type="agent",
|
||||
)
|
||||
# Regression guard: section-key drift or missing quality_score inline
|
||||
# would surface as empty strings / require_float failure.
|
||||
assert row.task_intent == "summarise the doc"
|
||||
assert row.task_intent_tokens == "summarise the doc"
|
||||
assert row.approach == "read + condense"
|
||||
assert row.approach_tokens == "read + condense"
|
||||
assert row.key_insight == "batch-then-summarise"
|
||||
assert row.quality_score == 0.85
|
||||
assert row.owner_id == "agent_42"
|
||||
assert row.owner_type == "agent"
|
||||
assert len(row.vector) == 1024
|
||||
@ -0,0 +1,235 @@
|
||||
"""Tests for :func:`trigger_profile_clustering`.
|
||||
|
||||
Mirrors the skill-side test layout: mock embedder + cluster_repo +
|
||||
cluster_by_geometry, drive the strategy via :class:`FakeStrategyContext`,
|
||||
verify a single :class:`ProfileClusterUpdated` event is emitted.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import structlog.testing
|
||||
from everalgo.clustering import Cluster as AlgoCluster
|
||||
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import EpisodeExtracted, ProfileClusterUpdated
|
||||
from everos.memory.strategies._partition_locks import _reset_for_tests
|
||||
from everos.memory.strategies.trigger_profile_clustering import (
|
||||
trigger_profile_clustering,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_partition_locks() -> None:
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def _event(
|
||||
*,
|
||||
owner_id: str = "u_alice",
|
||||
memcell_id: str = "mc_aaaaaaaaaaa1",
|
||||
episode_text: str = "alice likes hiking",
|
||||
episode_timestamp_ms: int = 1_700_000_001_000,
|
||||
) -> EpisodeExtracted:
|
||||
return EpisodeExtracted(
|
||||
memcell_id=memcell_id,
|
||||
episode_entry_id="ep_20260517_0001",
|
||||
episode_text=episode_text,
|
||||
episode_timestamp_ms=episode_timestamp_ms,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = trigger_profile_clustering._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "trigger_profile_clustering"
|
||||
assert EpisodeExtracted in meta.trigger.on
|
||||
assert meta.emits == frozenset({ProfileClusterUpdated})
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_new_cluster_when_no_existing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Empty existing → cluster_by_geometry returns None → new cluster persisted."""
|
||||
embedder = MagicMock()
|
||||
embedder.embed = AsyncMock(return_value=[0.1] * 1024)
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.get_embedder",
|
||||
return_value=embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_by_geometry",
|
||||
new=MagicMock(return_value=None),
|
||||
) as mock_cluster,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.mint_cluster_id",
|
||||
return_value="cl_newuser00001",
|
||||
),
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[])
|
||||
mock_repo.upsert_with_members = AsyncMock(return_value=None)
|
||||
|
||||
await trigger_profile_clustering(_event(), ctx)
|
||||
|
||||
args, _ = mock_cluster.call_args
|
||||
new_cluster, existing = args
|
||||
assert isinstance(new_cluster, AlgoCluster)
|
||||
assert new_cluster.id == "cl_newuser00001"
|
||||
assert new_cluster.count == 1
|
||||
assert new_cluster.last_ts == 1_700_000_001_000
|
||||
assert new_cluster.members == ["mc_aaaaaaaaaaa1"]
|
||||
assert new_cluster.preview == ["alice likes hiking"]
|
||||
assert existing == []
|
||||
|
||||
upsert_args = mock_repo.upsert_with_members.call_args
|
||||
persisted = upsert_args.args[0]
|
||||
assert persisted.id == "cl_newuser00001"
|
||||
assert upsert_args.kwargs == {
|
||||
"owner_id": "u_alice",
|
||||
"owner_type": "user",
|
||||
"kind": "user_memory",
|
||||
"member_type": "memcell",
|
||||
"app_id": "default",
|
||||
"project_id": "default",
|
||||
}
|
||||
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, ProfileClusterUpdated)]
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].memcell_id == "mc_aaaaaaaaaaa1"
|
||||
assert emitted[0].cluster_id == "cl_newuser00001"
|
||||
assert emitted[0].owner_id == "u_alice"
|
||||
|
||||
matching = [r for r in captured if r.get("event") == "profile_cluster_updated"]
|
||||
assert matching, "expected profile_cluster_updated log line"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merges_into_existing_cluster_when_algo_matches() -> None:
|
||||
"""algo returns merged Cluster → persisted under the existing id."""
|
||||
embedder = MagicMock()
|
||||
embedder.embed = AsyncMock(return_value=[0.2] * 1024)
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
existing_cluster = AlgoCluster(
|
||||
id="cl_existing0001",
|
||||
centroid=np.array([0.15] * 1024, dtype=np.float32),
|
||||
count=1,
|
||||
last_ts=1_700_000_000_000,
|
||||
preview=["earlier episode"],
|
||||
members=["mc_zzzzzzzzzzz0"],
|
||||
)
|
||||
merged_cluster = AlgoCluster(
|
||||
id="cl_existing0001",
|
||||
centroid=np.array([0.17] * 1024, dtype=np.float32),
|
||||
count=2,
|
||||
last_ts=1_700_000_001_000,
|
||||
preview=["earlier episode", "alice likes hiking"],
|
||||
members=["mc_zzzzzzzzzzz0", "mc_aaaaaaaaaaa1"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.get_embedder",
|
||||
return_value=embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_by_geometry",
|
||||
new=MagicMock(return_value=merged_cluster),
|
||||
),
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[existing_cluster])
|
||||
mock_repo.upsert_with_members = AsyncMock(return_value=None)
|
||||
|
||||
await trigger_profile_clustering(_event(), ctx)
|
||||
|
||||
persisted = mock_repo.upsert_with_members.call_args.args[0]
|
||||
assert persisted.id == "cl_existing0001"
|
||||
assert persisted.count == 2
|
||||
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, ProfileClusterUpdated)]
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].cluster_id == "cl_existing0001"
|
||||
|
||||
|
||||
# ── partition lock (owner_id-level serialisation) ────────────────────────
|
||||
|
||||
|
||||
async def _run_serialisation_probe(owner_a: str, owner_b: str) -> list[str]:
|
||||
"""Drive two trigger_profile_clustering runs and record entry/exit order."""
|
||||
log: list[str] = []
|
||||
|
||||
def mock_cluster_by_geometry(_new_cluster, _existing):
|
||||
# Sync, matching the real algo signature (must not be awaited).
|
||||
return None
|
||||
|
||||
async def mock_upsert(cluster, **_kwargs):
|
||||
# Delay inside the partition-lock critical section so two concurrent
|
||||
# runs on the same owner are observably serialised. cluster_by_geometry
|
||||
# is synchronous now, so the await point moves here.
|
||||
mid = cluster.members[0]
|
||||
log.append(f"enter:{mid}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{mid}")
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed = AsyncMock(return_value=np.zeros(1024, dtype=np.float32))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.get_embedder",
|
||||
return_value=mock_embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_profile_clustering.cluster_by_geometry",
|
||||
new=mock_cluster_by_geometry,
|
||||
),
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[])
|
||||
mock_repo.upsert_with_members = mock_upsert
|
||||
|
||||
await asyncio.gather(
|
||||
trigger_profile_clustering(
|
||||
_event(owner_id=owner_a, memcell_id="mc_run_a"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
trigger_profile_clustering(
|
||||
_event(owner_id=owner_b, memcell_id="mc_run_b"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
)
|
||||
return log
|
||||
|
||||
|
||||
async def test_partition_lock_serialises_runs_on_same_owner() -> None:
|
||||
"""Two runs sharing ``owner_id`` must not overlap critical sections."""
|
||||
log = await _run_serialisation_probe("u_alice", "u_alice")
|
||||
assert log in (
|
||||
["enter:mc_run_a", "leave:mc_run_a", "enter:mc_run_b", "leave:mc_run_b"],
|
||||
["enter:mc_run_b", "leave:mc_run_b", "enter:mc_run_a", "leave:mc_run_a"],
|
||||
)
|
||||
|
||||
|
||||
async def test_partition_lock_lets_different_owners_run_in_parallel() -> None:
|
||||
"""Runs on distinct ``owner_id`` must overlap (no false serialisation)."""
|
||||
log = await _run_serialisation_probe("u_alice", "u_bob")
|
||||
assert log.index("enter:mc_run_a") < log.index("leave:mc_run_b")
|
||||
assert log.index("enter:mc_run_b") < log.index("leave:mc_run_a")
|
||||
@ -0,0 +1,277 @@
|
||||
"""Tests for :func:`trigger_skill_clustering`.
|
||||
|
||||
Mock surface: ``cluster_by_llm``, ``get_embedder``, ``get_llm_client``,
|
||||
``cluster_repo`` — strategy is wired to use them as module-level imports
|
||||
so each ``patch`` swaps the symbol in the strategy module's namespace.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import structlog.testing
|
||||
from everalgo.clustering import Cluster as AlgoCluster
|
||||
|
||||
from everos.infra.ome.testing import FakeStrategyContext
|
||||
from everos.memory.events import AgentCaseExtracted, SkillClusterUpdated
|
||||
from everos.memory.strategies._partition_locks import _reset_for_tests
|
||||
from everos.memory.strategies.trigger_skill_clustering import (
|
||||
trigger_skill_clustering,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_partition_locks() -> None:
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def _event(
|
||||
*,
|
||||
quality_score: float = 0.8,
|
||||
case_entry_id: str = "ac_20260517_0001",
|
||||
agent_id: str = "agent_42",
|
||||
task_intent: str = "summarise the doc",
|
||||
case_timestamp_ms: int = 1_700_000_001_000,
|
||||
) -> AgentCaseExtracted:
|
||||
return AgentCaseExtracted(
|
||||
memcell_id="mc_a",
|
||||
case_entry_id=case_entry_id,
|
||||
task_intent=task_intent,
|
||||
quality_score=quality_score,
|
||||
case_timestamp_ms=case_timestamp_ms,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
|
||||
async def test_strategy_meta_is_attached() -> None:
|
||||
meta = trigger_skill_clustering._ome_strategy_meta # type: ignore[attr-defined]
|
||||
assert meta.name == "trigger_skill_clustering"
|
||||
assert AgentCaseExtracted in meta.trigger.on
|
||||
assert meta.emits == frozenset({SkillClusterUpdated})
|
||||
assert meta.max_retries == 2
|
||||
|
||||
|
||||
async def test_skips_when_quality_score_below_threshold() -> None:
|
||||
"""quality_score < 0.2 → log + early return; no embedding, no LLM, no repo call."""
|
||||
ctx = FakeStrategyContext()
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_embedder"
|
||||
) as mock_emb,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_by_llm"
|
||||
) as mock_cluster,
|
||||
structlog.testing.capture_logs() as captured,
|
||||
):
|
||||
await trigger_skill_clustering(_event(quality_score=0.1), ctx)
|
||||
|
||||
mock_emb.assert_not_called()
|
||||
mock_repo.list_for_owner.assert_not_called()
|
||||
mock_cluster.assert_not_called()
|
||||
assert ctx.emitted == []
|
||||
matching = [
|
||||
e for e in captured if e.get("event") == "skill_clustering_skipped_low_quality"
|
||||
]
|
||||
assert matching, "expected low-quality skip log line"
|
||||
|
||||
|
||||
async def test_creates_new_cluster_when_no_existing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Empty existing list → cluster_by_llm returns None → new cluster persisted."""
|
||||
embedder = MagicMock()
|
||||
embedder.embed = AsyncMock(return_value=[0.1] * 1024)
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_embedder",
|
||||
return_value=embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_by_llm",
|
||||
new=AsyncMock(return_value=None),
|
||||
) as mock_cluster,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.mint_cluster_id",
|
||||
return_value="cl_newxxxx0001",
|
||||
),
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[])
|
||||
mock_repo.upsert_with_members = AsyncMock(return_value=None)
|
||||
|
||||
await trigger_skill_clustering(_event(), ctx)
|
||||
|
||||
# cluster_by_llm called with the size-1 new cluster + empty existing.
|
||||
args, kwargs = mock_cluster.call_args
|
||||
new_cluster, existing = args
|
||||
assert isinstance(new_cluster, AlgoCluster)
|
||||
assert new_cluster.id == "cl_newxxxx0001"
|
||||
assert new_cluster.count == 1
|
||||
assert new_cluster.last_ts == 1_700_000_001_000
|
||||
assert new_cluster.members == ["ac_20260517_0001"]
|
||||
assert new_cluster.preview == ["summarise the doc"]
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(new_cluster.centroid), np.array([0.1] * 1024, dtype=np.float32)
|
||||
)
|
||||
assert existing == []
|
||||
|
||||
# upsert called with the new cluster (since merge returned None).
|
||||
upsert_args = mock_repo.upsert_with_members.call_args
|
||||
persisted = upsert_args.args[0]
|
||||
assert persisted.id == "cl_newxxxx0001"
|
||||
assert upsert_args.kwargs == {
|
||||
"owner_id": "agent_42",
|
||||
"owner_type": "agent",
|
||||
"kind": "agent_case",
|
||||
"member_type": "case",
|
||||
"app_id": "default",
|
||||
"project_id": "default",
|
||||
}
|
||||
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, SkillClusterUpdated)]
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].cluster_id == "cl_newxxxx0001"
|
||||
assert emitted[0].case_entry_id == "ac_20260517_0001"
|
||||
assert emitted[0].agent_id == "agent_42"
|
||||
|
||||
|
||||
async def test_merges_into_existing_cluster_when_algo_matches() -> None:
|
||||
"""algo returns a merged Cluster → persisted with the existing id."""
|
||||
embedder = MagicMock()
|
||||
embedder.embed = AsyncMock(return_value=[0.2] * 1024)
|
||||
ctx = FakeStrategyContext()
|
||||
|
||||
existing_cluster = AlgoCluster(
|
||||
id="cl_existing0001",
|
||||
centroid=np.array([0.15] * 1024, dtype=np.float32),
|
||||
count=2,
|
||||
last_ts=1_700_000_000_000,
|
||||
preview=["earlier intent"],
|
||||
members=["ac_20260517_0000"],
|
||||
)
|
||||
# Simulate evercore _merge: id passes through from existing, members appended.
|
||||
merged_cluster = AlgoCluster(
|
||||
id="cl_existing0001",
|
||||
centroid=np.array([0.17] * 1024, dtype=np.float32),
|
||||
count=3,
|
||||
last_ts=1_700_000_001_000,
|
||||
preview=["earlier intent", "summarise the doc"],
|
||||
members=["ac_20260517_0000", "ac_20260517_0001"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_embedder",
|
||||
return_value=embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_by_llm",
|
||||
new=AsyncMock(return_value=merged_cluster),
|
||||
),
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[existing_cluster])
|
||||
mock_repo.upsert_with_members = AsyncMock(return_value=None)
|
||||
|
||||
await trigger_skill_clustering(_event(), ctx)
|
||||
|
||||
upsert_args = mock_repo.upsert_with_members.call_args
|
||||
persisted = upsert_args.args[0]
|
||||
assert persisted.id == "cl_existing0001"
|
||||
assert persisted.members == ["ac_20260517_0000", "ac_20260517_0001"]
|
||||
assert persisted.count == 3
|
||||
|
||||
emitted = [e for e in ctx.emitted if isinstance(e, SkillClusterUpdated)]
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].cluster_id == "cl_existing0001"
|
||||
|
||||
|
||||
# ── partition lock (agent_id-level serialisation) ────────────────────────
|
||||
|
||||
|
||||
async def _run_serialisation_probe(agent_a: str, agent_b: str) -> list[str]:
|
||||
"""Drive two trigger_skill_clustering runs and record entry/exit order.
|
||||
|
||||
The clustering LLM call is the only awaited work inside the locked
|
||||
region — replacing it with a tiny ``asyncio.sleep`` keeps the test
|
||||
fast while still proving the lock either does or does not interleave
|
||||
the two critical sections.
|
||||
"""
|
||||
log: list[str] = []
|
||||
|
||||
async def mock_cluster_by_llm(new_cluster, _existing, **_kwargs):
|
||||
log.append(f"enter:{new_cluster.members[0]}")
|
||||
await asyncio.sleep(0.01)
|
||||
log.append(f"leave:{new_cluster.members[0]}")
|
||||
return None # no merge → caller persists the size-1 cluster
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed = AsyncMock(return_value=np.zeros(1024, dtype=np.float32))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_embedder",
|
||||
return_value=mock_embedder,
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.get_llm_client",
|
||||
return_value=object(),
|
||||
),
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_repo"
|
||||
) as mock_repo,
|
||||
patch(
|
||||
"everos.memory.strategies.trigger_skill_clustering.cluster_by_llm",
|
||||
new=mock_cluster_by_llm,
|
||||
),
|
||||
):
|
||||
mock_repo.list_for_owner = AsyncMock(return_value=[])
|
||||
mock_repo.upsert_with_members = AsyncMock(return_value=None)
|
||||
|
||||
await asyncio.gather(
|
||||
trigger_skill_clustering(
|
||||
_event(agent_id=agent_a, case_entry_id="ac_run_a"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
trigger_skill_clustering(
|
||||
_event(agent_id=agent_b, case_entry_id="ac_run_b"),
|
||||
FakeStrategyContext(),
|
||||
),
|
||||
)
|
||||
return log
|
||||
|
||||
|
||||
async def test_partition_lock_serialises_runs_on_same_agent() -> None:
|
||||
"""Two runs sharing ``agent_id`` must not overlap critical sections."""
|
||||
log = await _run_serialisation_probe("agent_42", "agent_42")
|
||||
assert log in (
|
||||
["enter:ac_run_a", "leave:ac_run_a", "enter:ac_run_b", "leave:ac_run_b"],
|
||||
["enter:ac_run_b", "leave:ac_run_b", "enter:ac_run_a", "leave:ac_run_a"],
|
||||
)
|
||||
|
||||
|
||||
async def test_partition_lock_lets_different_agents_run_in_parallel() -> None:
|
||||
"""Runs on distinct ``agent_id`` must overlap (no false serialisation)."""
|
||||
log = await _run_serialisation_probe("agent_42", "agent_43")
|
||||
assert log.index("enter:ac_run_a") < log.index("leave:ac_run_b")
|
||||
assert log.index("enter:ac_run_b") < log.index("leave:ac_run_a")
|
||||
Reference in New Issue
Block a user