Files
beaver_project/app-instance/backend/tests/unit/test_plugin_skill_learning.py

240 lines
9.5 KiB
Python

from __future__ import annotations
import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
from beaver.engine.providers.base import LLMProvider, LLMResponse
from beaver.engine.providers.factory import ProviderBundle
from beaver.foundation.utils.file_lock import WorkspaceWriteLock
from beaver.memory.runs import RunMemoryStore
from beaver.memory.skills import SkillLearningCandidate, SkillLearningStore
from beaver.plugins.discovery import discover_plugins
from beaver.plugins.skills import PluginManager
from beaver.plugins.state import PluginStateStore
from beaver.plugins.tree_merge import merge_supporting_file_trees
from beaver.skills.drafts import DraftService
from beaver.skills.learning import EvidenceSelector, SkillDraftSynthesizer, SkillLearningService
from beaver.skills.learning.safety import SkillDraftSafetyChecker
from beaver.skills.publisher import SkillPublisher
from beaver.skills.specs import SkillDraft, SkillReviewState, SkillSpecStore
class CountingProvider(LLMProvider):
def __init__(self, content: str = "{}") -> None:
super().__init__()
self.content = content
self.calls: list[dict] = []
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
thinking_enabled: bool | None = None,
) -> LLMResponse:
self.calls.append({"messages": messages, "model": model})
return LLMResponse(content=self.content)
def get_default_model(self) -> str:
return "stub"
def _bundle(provider: CountingProvider) -> ProviderBundle:
runtime = SimpleNamespace(model="stub", provider_name="stub")
return ProviderBundle(main_runtime=runtime, main_provider=provider) # type: ignore[arg-type]
def _write_plugin(root: Path, *, version: str = "1.0.0", body: str = "# Comic\n\nV1.\n", template: str = "v1") -> Path:
plugin_root = root / "baoyu-comic"
skill_root = plugin_root / "skills" / "baoyu-comic"
skill_root.mkdir(parents=True, exist_ok=True)
(skill_root / "SKILL.md").write_text(
"---\nname: baoyu-comic\ndescription: Comic workflow\ntools: []\n---\n\n" + body,
encoding="utf-8",
)
(skill_root / "templates").mkdir(exist_ok=True)
(skill_root / "templates" / "panel.txt").write_text(template, encoding="utf-8")
(plugin_root / "beaver.plugin.json").write_text(
json.dumps(
{
"schema_version": 1,
"id": "baoyu-comic",
"name": "Baoyu Comic",
"version": version,
"skills": [{"name": "baoyu-comic", "path": "skills/baoyu-comic"}],
}
),
encoding="utf-8",
)
return plugin_root
def _rewrite_plugin(plugin_root: Path, *, version: str, body: str, template: str) -> None:
manifest_path = plugin_root / "beaver.plugin.json"
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
manifest["version"] = version
manifest_path.write_text(json.dumps(manifest), encoding="utf-8")
skill_root = plugin_root / "skills" / "baoyu-comic"
(skill_root / "SKILL.md").write_text(
"---\nname: baoyu-comic\ndescription: Comic workflow\ntools: []\n---\n\n" + body,
encoding="utf-8",
)
(skill_root / "templates" / "panel.txt").write_text(template, encoding="utf-8")
def _manager(workspace: Path) -> tuple[PluginManager, SkillSpecStore, SkillLearningStore]:
discovery = discover_plugins(workspace, search_paths=[])
skill_store = SkillSpecStore(workspace)
learning_store = SkillLearningStore(workspace / "memory" / "skills")
manager = PluginManager(
workspace=workspace,
manifests=discovery.manifests,
discovery_errors=discovery.errors,
state_store=PluginStateStore(workspace),
skill_store=skill_store,
learning_store=learning_store,
publisher=SkillPublisher(skill_store),
safety_checker=SkillDraftSafetyChecker(),
write_lock=WorkspaceWriteLock(workspace),
)
return manager, skill_store, learning_store
def test_skill_draft_from_legacy_payload_has_empty_provenance() -> None:
draft = SkillDraft.from_dict(
{
"draft_id": "draft-1",
"skill_name": "debug",
"proposed_content": "# Debug\n",
"created_at": "now",
"created_by": "tester",
}
)
assert draft.provenance == {}
def test_fast_forward_plugin_update_synthesis_uses_exact_upstream_without_llm(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
plugin_root = _write_plugin(workspace / "plugins")
manager, skill_store, learning_store = _manager(workspace)
manager.enable("baoyu-comic")
_rewrite_plugin(plugin_root, version="1.1.0", body="# Comic\n\nV2.\n", template="v2")
_manager(workspace)[0].sync_enabled()
candidate = learning_store.list_learning_candidates()[0]
provider = CountingProvider()
service = SkillLearningService(
run_store=RunMemoryStore(workspace / "memory" / "runs"),
learning_store=learning_store,
draft_service=DraftService(skill_store),
evidence_selector=EvidenceSelector(RunMemoryStore(workspace / "memory" / "runs")),
)
draft = asyncio.run(service.synthesize_draft(candidate.candidate_id, _bundle(provider)))
upstream = skill_store.read_upstream_snapshot(
"baoyu-comic",
"baoyu-comic",
candidate.evidence["new_upstream_tree_hash"],
)
assert upstream is not None
assert draft.proposal_kind == "plugin_skill_update"
assert draft.proposed_content == "# Comic\n\nV2."
assert draft.base_version == "v0001"
assert draft.provenance["merge_mode"] == "fast_forward"
assert draft.provenance["new_upstream_tree_hash"] == upstream.snapshot.skill_tree_hash
assert provider.calls == []
def test_publish_plugin_update_materializes_referenced_supporting_files(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
plugin_root = _write_plugin(workspace / "plugins", template="v1")
manager, skill_store, learning_store = _manager(workspace)
manager.enable("baoyu-comic")
_rewrite_plugin(plugin_root, version="1.1.0", body="# Comic\n\nV2.\n", template="v2")
_manager(workspace)[0].sync_enabled()
candidate = learning_store.list_learning_candidates()[0]
service = SkillLearningService(
run_store=RunMemoryStore(workspace / "memory" / "runs"),
learning_store=learning_store,
draft_service=DraftService(skill_store),
evidence_selector=EvidenceSelector(RunMemoryStore(workspace / "memory" / "runs")),
)
draft = asyncio.run(service.synthesize_draft(candidate.candidate_id, _bundle(CountingProvider())))
draft.status = SkillReviewState.APPROVED.value
skill_store.write_draft(draft)
version = SkillPublisher(skill_store).publish("baoyu-comic", draft.draft_id, publisher="tester")
assert version.version == "v0002"
assert (workspace / "skills" / "baoyu-comic" / "versions" / "v0002" / "templates" / "panel.txt").read_text(
encoding="utf-8"
) == "v2"
def test_supporting_file_merge_adopts_upstream_when_local_is_unchanged() -> None:
plan = merge_supporting_file_trees(
base={"a.txt": {"content_hash": "A", "executable": False}},
local={"a.txt": {"content_hash": "A", "executable": False}},
upstream={"a.txt": {"content_hash": "U", "executable": False}},
)
assert plan.files["a.txt"].source == "upstream"
assert plan.conflicts == []
def test_supporting_file_merge_blocks_divergent_edits() -> None:
plan = merge_supporting_file_trees(
base={"a.txt": {"content_hash": "A", "executable": False}},
local={"a.txt": {"content_hash": "L", "executable": False}},
upstream={"a.txt": {"content_hash": "U", "executable": False}},
)
assert plan.conflicts[0].path == "a.txt"
def test_three_way_synthesizer_prompt_labels_all_inputs() -> None:
provider = CountingProvider(
json.dumps(
{
"frontmatter": {"name": "baoyu-comic", "description": "Comic workflow", "tools": []},
"content": "# Baoyu Comic\n\nMerged.",
"change_reason": "Adopt upstream while preserving local review.",
"preserved_local_sections": ["Review"],
"adopted_upstream_sections": ["Panel Layout"],
"resolved_conflicts": ["Output ordering"],
"dropped_sections": [],
}
)
)
async def run_case() -> dict:
return await SkillDraftSynthesizer().synthesize_plugin_update(
SkillLearningCandidate(
candidate_id="candidate",
kind="plugin_skill_update",
source_run_ids=[],
source_session_ids=[],
related_skill_names=["baoyu-comic"],
reason="merge",
),
EvidenceSelector(RunMemoryStore(Path("/tmp/unused-runs"))).build_evidence_packet([], []),
provider,
"stub",
old_upstream={"content": "# Old\n"},
current_local={"content": "# Local\n"},
new_upstream={"content": "# New\n"},
)
payload = asyncio.run(run_case())
prompt = provider.calls[0]["messages"][1]["content"]
assert "OLD UPSTREAM" in prompt
assert "CURRENT LOCAL" in prompt
assert "NEW UPSTREAM" in prompt
assert payload["preserved_local_sections"] == ["Review"]
assert payload["adopted_upstream_sections"] == ["Panel Layout"]