Files
beaver_project/app-instance/backend/tests/unit/test_task_skill_resolver.py

401 lines
14 KiB
Python

from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
from beaver.coordinator import AgentDescriptor, ExecutionGraph, ExecutionNode
from beaver.engine.context import SkillContext
from beaver.engine.providers.base import LLMProvider, LLMResponse
from beaver.engine.providers.factory import ProviderBundle
from beaver.skills.drafts import DraftService
from beaver.skills.learning import EphemeralGuidanceSynthesizer
from beaver.skills.publisher import SkillPublisher
from beaver.skills.reviews import ReviewService
from beaver.skills.specs import SkillSpecStore
from beaver.skills import SkillsLoader
from beaver.tasks import TaskRecord, TaskSkillResolver
class RecordingProvider(LLMProvider):
def __init__(self, responses: list[str]) -> None:
super().__init__()
self.responses = list(responses)
self.calls: list[list[dict]] = []
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
self.calls.append(messages)
content = self.responses.pop(0) if self.responses else "[]"
return LLMResponse(content=content, finish_reason="stop", provider_name="stub", model="stub-model")
def get_default_model(self) -> str:
return "stub-model"
def _bundle(provider: RecordingProvider) -> ProviderBundle:
return ProviderBundle(
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
main_provider=provider,
)
def _task() -> TaskRecord:
return TaskRecord(
task_id="task-1",
session_id="session-1",
description="review api compatibility",
goal="review api compatibility",
constraints=[],
priority=0,
status="open",
creator="test",
created_at="now",
updated_at="now",
)
def _publish_skill(workspace: Path, *, skill_name: str) -> None:
store = SkillSpecStore(workspace)
draft = DraftService(store).create_new_skill_draft(
skill_name=skill_name,
proposed_content=f"# {skill_name}\n\nCheck schema compatibility and breaking changes.",
proposed_frontmatter={"description": f"{skill_name} capability", "tools": []},
created_by="tester",
reason="test",
)
ReviewService(store).approve(skill_name, draft.draft_id, reviewer="tester")
SkillPublisher(store).publish(skill_name, draft.draft_id, publisher="tester")
def test_task_skill_resolver_pins_matching_published_skill(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="api-contract-review")
provider = RecordingProvider(['["api-contract-review"]'])
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
)
graph = ExecutionGraph(
strategy="sequence",
nodes=[
ExecutionNode(
"api_review",
"review API compatibility",
AgentDescriptor(
name="api_review",
metadata={
"skill_query": "API contract compatibility review",
"required_capabilities": ["schema compatibility"],
},
),
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="review api",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
assert resolved.nodes[0].agent.name == "api_review"
assert resolved.nodes[0].agent.role == ""
assert resolved.nodes[0].inherited_pinned_skills == ["api-contract-review"]
assert resolved.nodes[0].inherited_pinned_skill_contexts == []
assert reports[0].selected_skill_names == ["api-contract-review"]
assert reports[0].ephemeral_used is False
def test_task_skill_resolver_generates_ephemeral_guidance_when_missing(tmp_path: Path) -> None:
provider = RecordingProvider(
[
"""
{
"guidance_name": "api-compatibility-review",
"description": "Review API compatibility",
"content": "# API Compatibility Review\\n\\nCheck schema compatibility.",
"tags": ["api", "review"]
}
"""
]
)
store = SkillSpecStore(tmp_path)
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(store),
missing_skill_synthesizer=EphemeralGuidanceSynthesizer(),
)
graph = ExecutionGraph(
strategy="sequence",
nodes=[
ExecutionNode(
"api_review",
"review API compatibility",
AgentDescriptor(
name="api_review",
metadata={
"skill_query": "API compatibility review",
"required_capabilities": ["schema compatibility"],
},
),
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="review api",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
drafts = store.list_drafts("api-compatibility-review")
assert drafts == []
assert store.list_published_skill_names() == []
assert resolved.nodes[0].inherited_pinned_skills == []
assert len(resolved.nodes[0].inherited_pinned_skill_contexts) == 1
context: SkillContext = resolved.nodes[0].inherited_pinned_skill_contexts[0]
assert context.name == "ephemeral:api-compatibility-review"
assert context.version.startswith("ephemeral:eg_")
assert context.activation_reason == "ephemeral_guidance"
assert reports[0].ephemeral_guidance_id is not None
assert reports[0].ephemeral_guidance_name == "api-compatibility-review"
assert reports[0].ephemeral_used is True
def test_task_skill_resolver_keeps_summary_nodes_skillless(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="multi-search-engine")
provider = RecordingProvider(['["multi-search-engine"]'])
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
)
graph = ExecutionGraph(
strategy="dag",
nodes=[
ExecutionNode(
"summarize",
"Compile a clear, concise summary from dependency outputs for the user.",
AgentDescriptor(
name="summarize",
metadata={
"skill_query": "Summarization",
"required_capabilities": ["text generation"],
},
),
depends_on=["verify_result"],
inherited_pinned_skills=["multi-search-engine"],
inherited_pinned_skill_contexts=[
SkillContext(name="ephemeral:search-guidance", content="Search again.")
],
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="summarize result",
attempt_index=2,
provider_bundle=_bundle(provider),
)
)
assert resolved.nodes[0].inherited_pinned_skills == []
assert resolved.nodes[0].inherited_pinned_skill_contexts == []
assert resolved.nodes[0].agent.metadata["selected_skill_names"] == []
assert reports[0].selected_skill_names == []
assert reports[0].ephemeral_used is False
assert reports[0].reason == "summary node uses dependency outputs directly"
assert provider.calls == []
def test_resolver_exact_binds_use_skill_before_dynamic_lookup(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="official-source-research")
provider = RecordingProvider(['["wrong-dynamic-skill"]'])
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
)
graph = ExecutionGraph(
strategy="sequence",
nodes=[
ExecutionNode(
"collect",
"Collect official sources",
AgentDescriptor(
name="collect",
metadata={
"use_skill": "official-source-research",
"skill_query": "generic web research",
},
),
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="collect sources",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
node = resolved.nodes[0]
assert node.inherited_pinned_skills == ["official-source-research"]
assert [context.name for context in node.inherited_pinned_skill_contexts] == ["official-source-research"]
assert node.agent.metadata["exact_binding_used"] is True
assert reports[0].selected_skill_names == ["official-source-research"]
assert reports[0].exact_binding_used is True
assert reports[0].warnings == []
assert provider.calls == []
def test_resolver_falls_back_to_skill_query_when_use_skill_missing(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="financial-metric-extraction")
provider = RecordingProvider(['["financial-metric-extraction"]'])
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
)
graph = ExecutionGraph(
strategy="sequence",
nodes=[
ExecutionNode(
"extract",
"Extract metrics",
AgentDescriptor(
name="extract",
metadata={
"use_skill": "missing-exact-skill",
"skill_query": "financial metric extraction",
},
),
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="extract financial metrics",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
assert resolved.nodes[0].inherited_pinned_skills == ["financial-metric-extraction"]
assert reports[0].exact_binding_used is False
assert reports[0].selected_skill_names == ["financial-metric-extraction"]
assert reports[0].warnings == ["use_skill unresolved: missing-exact-skill"]
assert "financial metric extraction" in provider.calls[0][1]["content"]
def test_resolver_falls_back_to_ephemeral_when_exact_and_query_miss(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="unrelated-skill")
provider = RecordingProvider(
[
"[]",
"""
{
"guidance_name": "financial-extraction-guidance",
"description": "Extract financial metrics",
"content": "# Financial Extraction\\n\\nExtract the requested metrics.",
"tags": ["finance"]
}
""",
]
)
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
missing_skill_synthesizer=EphemeralGuidanceSynthesizer(),
)
graph = ExecutionGraph(
strategy="sequence",
nodes=[
ExecutionNode(
"extract",
"Extract metrics",
AgentDescriptor(
name="extract",
metadata={
"use_skill": "missing-exact-skill",
"skill_query": "financial metric extraction",
},
),
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="extract financial metrics",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
assert resolved.nodes[0].inherited_pinned_skills == []
assert resolved.nodes[0].inherited_pinned_skill_contexts[0].name == "ephemeral:financial-extraction-guidance"
assert reports[0].ephemeral_used is True
assert reports[0].warnings == ["use_skill unresolved: missing-exact-skill"]
def test_explicit_use_skill_is_preserved_for_summary_without_nested_expansion(tmp_path: Path) -> None:
_publish_skill(tmp_path, skill_name="summary-formatting")
provider = RecordingProvider([])
resolver = TaskSkillResolver(
skills_loader=SkillsLoader(tmp_path),
draft_service=DraftService(SkillSpecStore(tmp_path)),
)
graph = ExecutionGraph(
strategy="dag",
nodes=[
ExecutionNode(
"summarize",
"Compile a summary from dependency outputs",
AgentDescriptor(
name="summarize",
metadata={"use_skill": "summary-formatting", "skill_query": "Summarization"},
),
depends_on=["collect"],
)
],
)
resolved, reports = asyncio.run(
resolver.resolve_graph(
graph,
task=_task(),
user_message="summarize",
attempt_index=1,
provider_bundle=_bundle(provider),
)
)
assert len(resolved.nodes) == 1
assert resolved.nodes[0].inherited_pinned_skills == ["summary-formatting"]
assert reports[0].exact_binding_used is True
assert provider.calls == []