md-first memory extraction framework for AI agents. Markdown is the single source of truth; SQLite holds state and LanceDB provides the rebuildable vector + BM25 + scalar index. The codebase follows a single-direction DDD layering (entrypoints -> service -> memory -> infra, with component / core / config cross-cutting) enforced by import-linter. Engineering surface: - Coding conventions in .claude/rules/ (path-scoped) and workflows in .claude/skills/ (/commit, /new-branch, /pr). - GitHub Actions CI runs make lint + test + integration; pre-commit mirrors the gates locally (ruff, hygiene hooks, gitlint commit-msg). - Commit messages follow Conventional Commits, enforced by gitlint. - make lint also enforces datetime two-zone discipline and OpenAPI drift.
301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""Concurrent /add on one session must not lose messages (regression).
|
|
|
|
White-box integration test for the per-session lock added in
|
|
``everos.service._session_lock``.
|
|
|
|
Bug class
|
|
---------
|
|
Without the lock, two concurrent ``memorize()`` calls on the same
|
|
``session_id`` race on ``unprocessed_buffer``:
|
|
|
|
1. Both read the same pre-existing buffer rows.
|
|
2. Each boundary call sees only its own newly-arrived messages plus
|
|
the shared pre-existing buffer (neither sees the other's messages).
|
|
3. Both call ``_replace_buffer(session_id, tail)`` — the later write
|
|
silently overwrites the earlier write's tail; the earlier task's
|
|
tail messages are lost forever.
|
|
|
|
Invariant under test
|
|
--------------------
|
|
After N concurrent ``memorize()`` calls on one session, every input
|
|
message_id is **either** in some memcell's ``message_ids_json`` **or**
|
|
in the surviving ``unprocessed_buffer`` rows. Nothing silently vanishes.
|
|
|
|
This is a white-box integration test (not e2e): it bypasses HTTP, calls
|
|
``memorize()`` directly, but inspects sqlite tables to assert internal
|
|
state. Uses ``FakeLLMClient`` to avoid real LLM latency and to control
|
|
boundary decisions deterministically.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import importlib
|
|
import json
|
|
from collections.abc import AsyncIterator, Callable
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from everalgo.llm.types import ChatMessage as LLMChatMessage
|
|
from everalgo.llm.types import ChatResponse
|
|
from everalgo.testing.fake_llm import FakeLLMClient
|
|
from sqlalchemy import text
|
|
from sqlmodel import SQLModel
|
|
|
|
from everos.core.persistence import MemoryRoot
|
|
from everos.service.memorize import memorize
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fake LLM that splits each call into one memcell + 0-tail (force extract)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _boundary_response(boundaries: list[int]) -> str:
|
|
return json.dumps(
|
|
{"reasoning": "test", "boundaries": boundaries, "should_wait": False}
|
|
)
|
|
|
|
|
|
def _episode_response(title: str = "T", content: str = "B") -> str:
|
|
return json.dumps({"title": title, "content": content})
|
|
|
|
|
|
def _make_extract_all_llm() -> FakeLLMClient:
|
|
"""Boundary returns single boundary at end → entire merged → 1 cell, tail=[]."""
|
|
|
|
def handler(messages: list[LLMChatMessage], **_: Any) -> ChatResponse:
|
|
prompt = messages[0].content
|
|
if "boundaries" in prompt.lower() or "memcell" in prompt.lower():
|
|
# Always cut: the boundary indices are relative to merged input;
|
|
# an empty list means "no cut, hold". A single [N] means "cut
|
|
# after index N", i.e. everything before goes into one cell.
|
|
# We use a sentinel large index to force boundary to take all.
|
|
return ChatResponse(content=_boundary_response([999]), model="fake")
|
|
return ChatResponse(content=_episode_response(), model="fake")
|
|
|
|
return FakeLLMClient(handler=handler)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture — mirrors test_memorize_integration's pattern but without OME / strategies
|
|
# (the lock bug lives at the boundary stage; downstream strategies are
|
|
# irrelevant to this race).
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def memorize_env_locked(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> AsyncIterator[Callable[..., AsyncMock]]:
|
|
monkeypatch.setattr(
|
|
MemoryRoot, "default", classmethod(lambda cls: MemoryRoot(root=tmp_path))
|
|
)
|
|
(tmp_path / ".index" / "sqlite").mkdir(parents=True, exist_ok=True)
|
|
|
|
svc = importlib.import_module("everos.service.memorize")
|
|
af_mod = importlib.import_module("everos.memory.strategies.extract_atomic_facts")
|
|
fs_mod = importlib.import_module("everos.memory.strategies.extract_foresight")
|
|
client_mod = importlib.import_module("everos.component.llm.client")
|
|
lock_mod = importlib.import_module("everos.service._session_lock")
|
|
|
|
# Reset memorize singletons + session lock registry.
|
|
for attr in (
|
|
"_episode_writer",
|
|
"_prompt_loader",
|
|
"_user_pipeline",
|
|
"_agent_pipeline",
|
|
"_ome_engine",
|
|
):
|
|
monkeypatch.setattr(svc, attr, None, raising=False)
|
|
monkeypatch.setattr(client_mod, "_llm_client", None, raising=False)
|
|
monkeypatch.setattr(af_mod, "_writer", None, raising=False)
|
|
monkeypatch.setattr(fs_mod, "_writer", None, raising=False)
|
|
lock_mod._reset_for_tests()
|
|
|
|
started: dict[str, Any] = {"engine": None}
|
|
|
|
async def _setup(*, fake_llm: FakeLLMClient) -> None:
|
|
monkeypatch.setenv("EVEROS_MEMORIZE__MODE", "chat")
|
|
monkeypatch.setenv("EVEROS_LLM__API_KEY", "fake-key")
|
|
monkeypatch.setenv("EVEROS_LLM__BASE_URL", "https://fake.example.com")
|
|
from everos.config import load_settings
|
|
|
|
load_settings.cache_clear()
|
|
|
|
monkeypatch.setattr(client_mod, "_llm_client", fake_llm)
|
|
|
|
from everos.infra.persistence.sqlite import get_engine
|
|
|
|
db_engine = get_engine()
|
|
async with db_engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
# Silence OME strategy extractors (we only care about the boundary +
|
|
# memcell + buffer cycle; downstream strategies are a separate story).
|
|
mock_af = AsyncMock(return_value=[])
|
|
mock_fs = AsyncMock(return_value=[])
|
|
monkeypatch.setattr(
|
|
af_mod,
|
|
"AtomicFactExtractor",
|
|
lambda *a, **k: type("M", (), {"aextract": mock_af})(),
|
|
)
|
|
monkeypatch.setattr(
|
|
fs_mod,
|
|
"ForesightExtractor",
|
|
lambda *a, **k: type("M", (), {"aextract": mock_fs})(),
|
|
)
|
|
|
|
engine = svc._get_engine()
|
|
await engine.start()
|
|
started["engine"] = engine
|
|
|
|
yield _setup
|
|
|
|
if started["engine"] is not None:
|
|
await started["engine"].stop()
|
|
from everos.infra.persistence.sqlite import dispose_engine
|
|
|
|
await dispose_engine()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _msg(idx: int, sender: str, ts: int) -> dict[str, Any]:
|
|
return {
|
|
"sender_id": sender,
|
|
"role": "user",
|
|
"timestamp": ts,
|
|
"content": f"msg-{idx} from {sender}",
|
|
}
|
|
|
|
|
|
async def _collect_buffer_message_ids(session_id: str) -> set[str]:
|
|
from everos.infra.persistence.sqlite import get_engine
|
|
|
|
eng = get_engine()
|
|
async with eng.connect() as conn:
|
|
result = await conn.execute(
|
|
text("SELECT message_id FROM unprocessed_buffer WHERE session_id = :s"),
|
|
{"s": session_id},
|
|
)
|
|
return {row[0] for row in result.fetchall()}
|
|
|
|
|
|
async def _collect_memcell_message_ids(session_id: str) -> set[str]:
|
|
from everos.infra.persistence.sqlite import get_engine
|
|
|
|
eng = get_engine()
|
|
async with eng.connect() as conn:
|
|
result = await conn.execute(
|
|
text("SELECT message_ids_json FROM memcell WHERE session_id = :s"),
|
|
{"s": session_id},
|
|
)
|
|
out: set[str] = set()
|
|
for (raw,) in result.fetchall():
|
|
out.update(json.loads(raw))
|
|
return out
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_concurrent_adds_same_session_no_message_loss(
|
|
memorize_env_locked: Callable[..., AsyncMock],
|
|
) -> None:
|
|
"""Two concurrent /add on one session: every input message must end up
|
|
either in a memcell's message_ids OR in the surviving buffer."""
|
|
await memorize_env_locked(fake_llm=_make_extract_all_llm())
|
|
|
|
session_id = "s_concurrent"
|
|
|
|
batch_a = [_msg(i, "alice", 1_700_000_000_000 + i * 1000) for i in range(4)]
|
|
batch_b = [_msg(i + 100, "bob", 1_700_000_100_000 + i * 1000) for i in range(4)]
|
|
|
|
# Fire both concurrently against the same session.
|
|
await asyncio.gather(
|
|
memorize({"session_id": session_id, "messages": batch_a}),
|
|
memorize({"session_id": session_id, "messages": batch_b}),
|
|
)
|
|
|
|
buffered = await _collect_buffer_message_ids(session_id)
|
|
in_cells = await _collect_memcell_message_ids(session_id)
|
|
covered = buffered | in_cells
|
|
|
|
# The id format is ``m_<session>_<ts_ms>_<idx>`` — we can derive
|
|
# exactly what the 8 inputs should hash to without depending on the
|
|
# internal id_gen import. Easier: assert the *count* covered == 8.
|
|
assert len(covered) == 8, (
|
|
f"expected 8 distinct message ids covered, got {len(covered)}: "
|
|
f"buffer={len(buffered)}, memcell={len(in_cells)}"
|
|
)
|
|
|
|
# Sanity: no message appears in both buffer and memcell at once
|
|
# (consumed = removed from buffer).
|
|
overlap = buffered & in_cells
|
|
assert not overlap, f"messages in both buffer and memcell: {overlap}"
|
|
|
|
|
|
async def test_concurrent_adds_serial_when_locked(
|
|
memorize_env_locked: Callable[..., AsyncMock],
|
|
) -> None:
|
|
"""Same as above but explicitly stress with 4 concurrent batches."""
|
|
await memorize_env_locked(fake_llm=_make_extract_all_llm())
|
|
|
|
session_id = "s_stress"
|
|
|
|
n_batches = 4
|
|
batch_size = 3
|
|
batches = [
|
|
[
|
|
_msg(b * 10 + i, f"u{b}", 1_700_000_000_000 + (b * 10 + i) * 1000)
|
|
for i in range(batch_size)
|
|
]
|
|
for b in range(n_batches)
|
|
]
|
|
|
|
await asyncio.gather(
|
|
*(memorize({"session_id": session_id, "messages": batch}) for batch in batches)
|
|
)
|
|
|
|
buffered = await _collect_buffer_message_ids(session_id)
|
|
in_cells = await _collect_memcell_message_ids(session_id)
|
|
covered = buffered | in_cells
|
|
|
|
expected = n_batches * batch_size
|
|
assert len(covered) == expected, (
|
|
f"expected {expected} message ids covered, got {len(covered)}: "
|
|
f"buffer={len(buffered)}, memcell={len(in_cells)}"
|
|
)
|
|
assert not (buffered & in_cells)
|
|
|
|
|
|
async def test_different_sessions_run_in_parallel(
|
|
memorize_env_locked: Callable[..., AsyncMock],
|
|
) -> None:
|
|
"""Cross-session calls share no lock — must not serialise."""
|
|
await memorize_env_locked(fake_llm=_make_extract_all_llm())
|
|
|
|
def _msgs(sid: str) -> list[dict[str, Any]]:
|
|
return [_msg(i, sid, 1_700_000_000_000 + i * 1000) for i in range(3)]
|
|
|
|
await asyncio.gather(
|
|
memorize({"session_id": "s_a", "messages": _msgs("s_a")}),
|
|
memorize({"session_id": "s_b", "messages": _msgs("s_b")}),
|
|
memorize({"session_id": "s_c", "messages": _msgs("s_c")}),
|
|
)
|
|
|
|
for sid in ("s_a", "s_b", "s_c"):
|
|
buffered = await _collect_buffer_message_ids(sid)
|
|
in_cells = await _collect_memcell_message_ids(sid)
|
|
covered = buffered | in_cells
|
|
assert len(covered) == 3, f"session {sid}: got {len(covered)}, want 3"
|