560 lines
18 KiB
Python
560 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from types import SimpleNamespace
|
|
|
|
from beaver.engine.context import SkillContext
|
|
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
|
from beaver.engine.providers.factory import ProviderBundle
|
|
from beaver.tasks import SkillResolutionReport, TaskExecutionPlanner, TaskRecord
|
|
from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec
|
|
from beaver.tools.registry import ToolRegistry
|
|
|
|
|
|
class PlannerProvider(LLMProvider):
|
|
def __init__(self, response: str) -> None:
|
|
super().__init__()
|
|
self.response = response
|
|
self.calls: list[dict] = []
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> LLMResponse:
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
)
|
|
return LLMResponse(content=self.response, finish_reason="stop", provider_name="stub", model="stub-model")
|
|
|
|
def get_default_model(self) -> str:
|
|
return "stub-model"
|
|
|
|
|
|
class HangingPlannerProvider(LLMProvider):
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> LLMResponse:
|
|
await asyncio.sleep(10)
|
|
return LLMResponse(content='{"mode":"team"}', finish_reason="stop", provider_name="stub", model="stub-model")
|
|
|
|
def get_default_model(self) -> str:
|
|
return "stub-model"
|
|
|
|
|
|
class SequencedPlannerProvider(PlannerProvider):
|
|
def __init__(self, responses: list[str]) -> None:
|
|
super().__init__(responses[0])
|
|
self.responses = list(responses)
|
|
|
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
|
self.response = self.responses.pop(0)
|
|
return await super().chat(*args, **kwargs)
|
|
|
|
|
|
class StubTool(BaseTool):
|
|
def __init__(self, name: str) -> None:
|
|
self._spec = ToolSpec(name=name, description=name, input_schema={"type": "object"})
|
|
|
|
@property
|
|
def spec(self) -> ToolSpec:
|
|
return self._spec
|
|
|
|
async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult:
|
|
raise AssertionError("Planner tests do not execute tools")
|
|
|
|
|
|
def _task() -> TaskRecord:
|
|
return TaskRecord(
|
|
task_id="task-1",
|
|
session_id="session-1",
|
|
description="implement workflow",
|
|
goal="implement workflow",
|
|
constraints=[],
|
|
priority=0,
|
|
status="open",
|
|
creator="test",
|
|
created_at="now",
|
|
updated_at="now",
|
|
)
|
|
|
|
|
|
def _bundle(response: str) -> ProviderBundle:
|
|
provider = PlannerProvider(response)
|
|
return ProviderBundle(
|
|
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
|
main_provider=provider,
|
|
)
|
|
|
|
|
|
def _bundle_with_provider(provider: LLMProvider) -> ProviderBundle:
|
|
return ProviderBundle(
|
|
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
|
main_provider=provider,
|
|
)
|
|
|
|
|
|
def _registry() -> ToolRegistry:
|
|
registry = ToolRegistry()
|
|
registry.register_many([StubTool("web_search"), StubTool("web_fetch"), StubTool("terminal")])
|
|
return registry
|
|
|
|
|
|
def _hanging_bundle() -> ProviderBundle:
|
|
return ProviderBundle(
|
|
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
|
main_provider=HangingPlannerProvider(),
|
|
)
|
|
|
|
|
|
def test_planner_selects_single_mode() -> None:
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner().plan(
|
|
task=_task(),
|
|
user_message="implement workflow",
|
|
attempt_index=1,
|
|
provider_bundle=_bundle('{"mode":"single","reason":"main agent is enough"}'),
|
|
)
|
|
)
|
|
|
|
assert plan.mode == "single"
|
|
assert plan.graph is None
|
|
assert plan.reason == "main agent is enough"
|
|
|
|
|
|
def test_planner_skips_llm_for_simple_task() -> None:
|
|
provider = PlannerProvider('{"mode":"team","reason":"should not be used"}')
|
|
bundle = ProviderBundle(
|
|
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
|
main_provider=provider,
|
|
)
|
|
task = _task()
|
|
task.description = "查询深圳天气"
|
|
task.goal = "查询深圳天气"
|
|
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner().plan(
|
|
task=task,
|
|
user_message="帮我查一下今天深圳天气",
|
|
attempt_index=1,
|
|
provider_bundle=bundle,
|
|
)
|
|
)
|
|
|
|
assert plan.mode == "single"
|
|
assert plan.graph is None
|
|
assert plan.reason == "planner_skipped_simple_task"
|
|
assert provider.calls == []
|
|
|
|
|
|
def test_planner_builds_team_graph() -> None:
|
|
bundle = _bundle(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"reason": "needs parallel review",
|
|
"strategy": "dag",
|
|
"nodes": [
|
|
{"node_id": "research", "task": "research options"},
|
|
{"node_id": "review", "task": "review result", "depends_on": ["research"]}
|
|
],
|
|
"final_synthesis_instruction": "merge the findings"
|
|
}
|
|
"""
|
|
)
|
|
provider = bundle.main_provider
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner().plan(
|
|
task=_task(),
|
|
user_message="implement workflow",
|
|
attempt_index=1,
|
|
provider_bundle=bundle,
|
|
skill_summaries=["docker-debug: Use docker logs before editing config."],
|
|
tool_hints=["terminal", "search_files"],
|
|
)
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
assert plan.graph.strategy == "dag"
|
|
assert [node.node_id for node in plan.graph.nodes] == ["research", "review"]
|
|
assert plan.graph.nodes[1].depends_on == ["research"]
|
|
assert plan.final_synthesis_instruction == "merge the findings"
|
|
assert isinstance(provider, PlannerProvider)
|
|
prompt = provider.calls[0]["messages"][1]["content"]
|
|
assert "Activated skill summaries" in prompt
|
|
assert "docker-debug: Use docker logs before editing config." in prompt
|
|
assert "terminal" in prompt
|
|
assert "search_files" in prompt
|
|
|
|
|
|
def test_planner_timeout_falls_back_to_single() -> None:
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner().plan(
|
|
task=_task(),
|
|
user_message="implement workflow",
|
|
attempt_index=1,
|
|
provider_bundle=_hanging_bundle(),
|
|
timeout_seconds=0.01,
|
|
)
|
|
)
|
|
|
|
assert plan.mode == "single"
|
|
assert plan.reason == "planner_failed"
|
|
assert "TimeoutError" in (plan.fallback_error or "")
|
|
|
|
|
|
def test_planner_team_nodes_use_task_as_internal_skill_query() -> None:
|
|
plan = TaskExecutionPlanner().from_json(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"reason": "needs skill-guided review",
|
|
"strategy": "sequence",
|
|
"nodes": [
|
|
{
|
|
"node_id": "api_review",
|
|
"task": "review API compatibility"
|
|
}
|
|
]
|
|
}
|
|
"""
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
node = plan.graph.nodes[0]
|
|
assert node.agent.name == "api_review"
|
|
assert node.agent.role == ""
|
|
assert node.agent.metadata["skill_query"] == "review API compatibility"
|
|
assert node.agent.metadata["required_capabilities"] == []
|
|
|
|
|
|
def test_planner_accepts_use_skill_and_skill_query() -> None:
|
|
plan = TaskExecutionPlanner().from_json(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"strategy": "sequence",
|
|
"nodes": [
|
|
{
|
|
"node_id": "collect",
|
|
"task": "Collect official sources",
|
|
"use_skill": "official-source-research",
|
|
"skill_query": "official source verification"
|
|
}
|
|
]
|
|
}
|
|
"""
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
node = plan.graph.nodes[0]
|
|
assert node.agent.metadata["use_skill"] == "official-source-research"
|
|
assert node.agent.metadata["skill_query"] == "official source verification"
|
|
assert node.inherited_pinned_skills == []
|
|
assert node.allowed_tool_names is None
|
|
assert plan.planner_adaptation["node_skill_bindings"] == [
|
|
{
|
|
"node_id": "collect",
|
|
"use_skill": "official-source-research",
|
|
"skill_query": "official source verification",
|
|
}
|
|
]
|
|
|
|
|
|
def test_planner_defaults_skill_query_to_node_task_when_absent() -> None:
|
|
plan = TaskExecutionPlanner().from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"extract","task":"Extract financial metrics","use_skill":"financial-extraction"}]}'
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
assert plan.graph.nodes[0].agent.metadata["skill_query"] == "Extract financial metrics"
|
|
|
|
|
|
def test_planner_adaptation_records_unresolved_use_skill_fallback() -> None:
|
|
planner = TaskExecutionPlanner()
|
|
plan = planner.from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"extract","task":"Extract metrics","use_skill":"missing-skill",'
|
|
'"skill_query":"financial extraction"}]}'
|
|
)
|
|
report = SkillResolutionReport(
|
|
node_id="extract",
|
|
skill_query="financial extraction",
|
|
requested_skill_name="missing-skill",
|
|
exact_binding_used=False,
|
|
warnings=["use_skill unresolved: missing-skill"],
|
|
reason="matched published skill",
|
|
)
|
|
|
|
planner._merge_skill_resolution_adaptation(plan, [report])
|
|
|
|
assert plan.planner_adaptation["warnings"] == ["use_skill unresolved: missing-skill"]
|
|
assert plan.planner_adaptation["node_skill_bindings"][0]["fallback_reason"] == (
|
|
"use_skill unresolved; matched published skill"
|
|
)
|
|
|
|
|
|
def test_planner_invalid_outputs_fallback_to_single() -> None:
|
|
planner = TaskExecutionPlanner()
|
|
invalid_json = planner.from_json("not json")
|
|
unknown_strategy = planner.from_json(
|
|
'{"mode":"team","strategy":"moa","nodes":[{"node_id":"a","task":"a","agent":{"name":"a"}}]}'
|
|
)
|
|
too_many_nodes = planner.from_json(
|
|
'{"mode":"team","strategy":"parallel","nodes":['
|
|
+ ",".join(
|
|
'{"node_id":"n%s","task":"work","agent":{"name":"n%s"}}' % (index, index)
|
|
for index in range(7)
|
|
)
|
|
+ "]}"
|
|
)
|
|
cyclic = planner.from_json(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"strategy": "dag",
|
|
"nodes": [
|
|
{"node_id": "a", "task": "a", "agent": {"name": "a"}, "depends_on": ["b"]},
|
|
{"node_id": "b", "task": "b", "agent": {"name": "b"}, "depends_on": ["a"]}
|
|
]
|
|
}
|
|
"""
|
|
)
|
|
|
|
assert invalid_json.mode == "single"
|
|
assert unknown_strategy.mode == "single"
|
|
assert too_many_nodes.mode == "single"
|
|
assert cyclic.mode == "single"
|
|
|
|
|
|
def test_template_plan_builds_generic_worker_and_preserves_v1_contract_fields() -> None:
|
|
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"strategy": "dag",
|
|
"nodes": [
|
|
{
|
|
"node_id": "collect",
|
|
"task": "Collect official sources",
|
|
"requested_tools": ["web_search"],
|
|
"evidence_contract": {"entities": ["MGM", "Galaxy"]},
|
|
"block_downstream_on_partial": true
|
|
}
|
|
],
|
|
"adaptation": {"template_used": true}
|
|
}
|
|
"""
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
node = plan.graph.nodes[0]
|
|
assert node.agent.name == "collect"
|
|
assert node.agent.role == ""
|
|
assert node.agent.metadata["sub_agent_kind"] == "generic_skill_worker"
|
|
assert node.allowed_tool_names == ["web_search"]
|
|
assert node.evidence_contract == {"entities": ["MGM", "Galaxy"]}
|
|
assert node.block_downstream_on_partial is True
|
|
assert plan.planner_adaptation["template_used"] is True
|
|
|
|
|
|
def test_unknown_tool_is_removed_and_warned() -> None:
|
|
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"collect","task":"Collect","requested_tools":["web_search","not_real"]}]}'
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
assert plan.graph.nodes[0].allowed_tool_names == ["web_search"]
|
|
assert "unknown tool removed: not_real" in plan.planner_adaptation["warnings"]
|
|
|
|
|
|
def test_high_risk_tool_is_removed_without_failing_low_risk_plan() -> None:
|
|
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"collect","task":"Collect","requested_tools":["web_search","terminal"]}]}'
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
assert plan.graph.nodes[0].allowed_tool_names == ["web_search"]
|
|
assert "requires_high_risk_review: terminal" in plan.planner_adaptation["warnings"]
|
|
|
|
|
|
def test_planner_rejects_agent_and_role_node_fields() -> None:
|
|
planner = TaskExecutionPlanner(tool_registry=_registry())
|
|
|
|
agent_plan = planner.from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"collect","task":"Collect","agent":{"name":"researcher"}}]}'
|
|
)
|
|
role_plan = planner.from_json(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"collect","task":"Collect","role":"researcher"}]}'
|
|
)
|
|
|
|
assert agent_plan.mode == "single"
|
|
assert "agent" in (agent_plan.fallback_error or "")
|
|
assert role_plan.mode == "single"
|
|
assert "role" in (role_plan.fallback_error or "")
|
|
|
|
|
|
def test_planner_records_primary_template_selection_and_ignored_templates() -> None:
|
|
primary = SkillContext(
|
|
name="financial-comparison",
|
|
version="v1",
|
|
content="Compare official financial disclosures.",
|
|
team_template={"version": 1, "nodes": [{"node_id": "collect", "task": "Collect"}]},
|
|
)
|
|
secondary = SkillContext(
|
|
name="chart-reporting",
|
|
version="v2",
|
|
content="Render chart-ready Markdown.",
|
|
team_template={"version": 1, "nodes": [{"node_id": "report", "task": "Report"}]},
|
|
)
|
|
provider = PlannerProvider(
|
|
'{"mode":"team","strategy":"sequence","nodes":['
|
|
'{"node_id":"collect","task":"Collect official sources"}],'
|
|
'"adaptation":{"template_used":true}}'
|
|
)
|
|
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
|
task=_task(),
|
|
user_message="compare financial workflow",
|
|
attempt_index=1,
|
|
provider_bundle=_bundle_with_provider(provider),
|
|
activated_skills=[primary, secondary],
|
|
)
|
|
)
|
|
|
|
assert plan.planner_adaptation == {
|
|
"template_used": True,
|
|
"selected_template": "financial-comparison",
|
|
"selection_reason": "first activated skill with a valid team template",
|
|
"ignored_templates": ["chart-reporting"],
|
|
"warnings": [],
|
|
}
|
|
prompt = provider.calls[0]["messages"][1]["content"]
|
|
assert '"skill_name": "financial-comparison"' in prompt
|
|
assert "Compare official financial disclosures." in prompt
|
|
assert "Render chart-ready Markdown." in prompt
|
|
|
|
|
|
def test_malformed_planner_output_repairs_once_without_tools() -> None:
|
|
provider = SequencedPlannerProvider(
|
|
[
|
|
"not json",
|
|
'{"mode":"team","strategy":"sequence","nodes":[{"node_id":"collect","task":"Collect"}]}',
|
|
]
|
|
)
|
|
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
|
task=_task(),
|
|
user_message="implement workflow",
|
|
attempt_index=1,
|
|
provider_bundle=_bundle_with_provider(provider),
|
|
)
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert len(provider.calls) == 2
|
|
assert provider.calls[1]["tools"] is None
|
|
assert "Repair the invalid planner JSON" in provider.calls[1]["messages"][1]["content"]
|
|
|
|
|
|
def test_failed_planner_repair_falls_back_to_single() -> None:
|
|
provider = SequencedPlannerProvider(["not json", "still not json"])
|
|
|
|
plan = asyncio.run(
|
|
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
|
task=_task(),
|
|
user_message="implement workflow",
|
|
attempt_index=1,
|
|
provider_bundle=_bundle_with_provider(provider),
|
|
)
|
|
)
|
|
|
|
assert plan.mode == "single"
|
|
assert plan.reason == "planner_fallback_single"
|
|
assert len(provider.calls) == 2
|
|
|
|
|
|
def test_finance_template_adapts_to_task_oriented_read_only_graph() -> None:
|
|
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
|
"""
|
|
{
|
|
"mode": "team",
|
|
"strategy": "dag",
|
|
"nodes": [
|
|
{
|
|
"node_id": "collect_official_sources",
|
|
"task": "Collect MGM and Galaxy official financial disclosures",
|
|
"requested_tools": ["web_search", "web_fetch"],
|
|
"required_evidence": ["tool_result", "url"]
|
|
},
|
|
{
|
|
"node_id": "extract_financial_metrics",
|
|
"task": "Extract comparable financial metrics from collected sources",
|
|
"depends_on": ["collect_official_sources"],
|
|
"requested_tools": ["web_fetch"],
|
|
"required_evidence": ["output"]
|
|
},
|
|
{
|
|
"node_id": "validate_metrics",
|
|
"task": "Validate metric units, periods, and source consistency",
|
|
"depends_on": ["extract_financial_metrics"],
|
|
"required_evidence": ["output"]
|
|
},
|
|
{
|
|
"node_id": "generate_chart_report",
|
|
"task": "Generate a Markdown comparison table and chart-ready data without claiming an image or file artifact",
|
|
"depends_on": ["validate_metrics"],
|
|
"requested_tools": [],
|
|
"required_evidence": ["output"]
|
|
}
|
|
]
|
|
}
|
|
"""
|
|
)
|
|
|
|
assert plan.is_team
|
|
assert plan.graph is not None
|
|
assert [node.node_id for node in plan.graph.nodes] == [
|
|
"collect_official_sources",
|
|
"extract_financial_metrics",
|
|
"validate_metrics",
|
|
"generate_chart_report",
|
|
]
|
|
assert all(node.agent.role == "" for node in plan.graph.nodes)
|
|
assert not {"researcher", "writer", "reviewer", "analyst"}.intersection(
|
|
node.node_id for node in plan.graph.nodes
|
|
)
|
|
assert plan.graph.nodes[0].allowed_tool_names == ["web_search", "web_fetch"]
|
|
assert plan.graph.nodes[-1].allowed_tool_names == []
|
|
report_task = plan.graph.nodes[-1].task.lower()
|
|
assert "markdown" in report_task
|
|
assert "without claiming an image or file artifact" in report_task
|