feat(tasks): add skill-templated task graph execution
This commit is contained in:
@ -8,7 +8,8 @@ import pytest
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.coordinator import AgentDescriptor, DelegationEnvelope, ExecutionGraph, ExecutionNode
|
||||
from beaver.coordinator import AgentDescriptor, DelegationEnvelope, ExecutionGraph, ExecutionNode, NodeRunResult
|
||||
from beaver.coordinator.execution.scheduler import TeamGraphScheduler
|
||||
from beaver.coordinator.local import LocalAgentRunner
|
||||
from beaver.engine import AgentLoop, EngineLoader
|
||||
from beaver.engine.context import SkillContext
|
||||
@ -90,6 +91,15 @@ class PerRunSnapshotMemoryService(MemoryService):
|
||||
return MemorySnapshot(memory_block="# Memory\n\nshared-snapshot", user_block=None)
|
||||
|
||||
|
||||
class CapturingRunner:
|
||||
def __init__(self) -> None:
|
||||
self.envelopes: list[DelegationEnvelope] = []
|
||||
|
||||
async def run(self, envelope: DelegationEnvelope, **kwargs) -> NodeRunResult:
|
||||
self.envelopes.append(envelope)
|
||||
return NodeRunResult(node_id=envelope.node_id or "node", success=True, output_text="done")
|
||||
|
||||
|
||||
def _bundle(provider: RecordingProvider) -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
@ -161,10 +171,72 @@ def test_local_agent_runner_uses_shared_loop_and_records_parent_task(tmp_path: P
|
||||
child_session = loaded.session_manager.get_session(result.session_id) # type: ignore[union-attr,arg-type]
|
||||
|
||||
assert result.success is True
|
||||
assert result.completion_status == "succeeded"
|
||||
assert result.evidence_gaps == []
|
||||
assert run_record.task_id == "task-parent"
|
||||
assert child_session["parent_session_id"] == "session-root"
|
||||
|
||||
|
||||
def test_node_without_required_tool_result_is_partial(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("collected narrative")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
task="collect",
|
||||
node_id="collect",
|
||||
required_evidence=["tool_result"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is False
|
||||
assert result.completion_status == "partial"
|
||||
assert result.evidence_gaps == ["missing required evidence: tool_result"]
|
||||
|
||||
|
||||
def test_node_with_required_nonempty_output_succeeds(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("verified output")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="verify"),
|
||||
task="verify",
|
||||
node_id="verify",
|
||||
required_evidence=["output"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is True
|
||||
assert result.completion_status == "succeeded"
|
||||
assert result.evidence_gaps == []
|
||||
|
||||
|
||||
def test_unknown_evidence_requirement_makes_node_partial(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("output")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="verify"),
|
||||
task="verify",
|
||||
node_id="verify",
|
||||
required_evidence=["unknown_type"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is False
|
||||
assert result.completion_status == "partial"
|
||||
assert result.evidence_gaps == ["unsupported evidence requirement: unknown_type"]
|
||||
|
||||
|
||||
def test_team_node_preserves_evidence_when_finish_reason_is_not_stop(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("partial evidence", finish_reason="max_tool_iterations")])
|
||||
@ -277,6 +349,108 @@ def test_team_sequence_passes_prior_outputs(tmp_path: Path) -> None:
|
||||
assert "Dependency first output:\nfirst output" in providers["second"].calls[0][0]["content"]
|
||||
|
||||
|
||||
def test_partial_node_allows_downstream_by_default(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
"collect": RecordingProvider([_response("partial source notes")]),
|
||||
"extract": RecordingProvider([_response("extracted metrics")]),
|
||||
}
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
required_evidence=["tool_result"],
|
||||
),
|
||||
ExecutionNode("extract", "extract", AgentDescriptor(name="extract")),
|
||||
],
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
TeamService(loop).run_team(
|
||||
graph,
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.node_results[0].completion_status == "partial"
|
||||
assert result.node_results[1].completion_status == "succeeded"
|
||||
assert "Dependency collect output:\npartial source notes" in providers["extract"].calls[0][0]["content"]
|
||||
|
||||
|
||||
def test_partial_node_blocks_downstream_when_configured(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
"collect": RecordingProvider([_response("partial source notes")]),
|
||||
"extract": RecordingProvider([_response("must not run")]),
|
||||
}
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
required_evidence=["tool_result"],
|
||||
block_downstream_on_partial=True,
|
||||
),
|
||||
ExecutionNode("extract", "extract", AgentDescriptor(name="extract")),
|
||||
],
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
TeamService(loop).run_team(
|
||||
graph,
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.node_results[0].completion_status == "partial"
|
||||
assert result.node_results[1].completion_status == "blocked"
|
||||
assert providers["extract"].calls == []
|
||||
|
||||
|
||||
def test_scheduler_copies_task_two_contract_fields_to_envelope() -> None:
|
||||
runner = CapturingRunner()
|
||||
node = ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
input_contract={"query": "str"},
|
||||
output_contract={"sources": "list"},
|
||||
required_evidence=["tool_result"],
|
||||
evidence_contract={"entities": ["MGM"]},
|
||||
validation_rules=["official_sources_only"],
|
||||
required_for_completion=False,
|
||||
block_downstream_on_partial=True,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
TeamGraphScheduler(runner).run( # type: ignore[arg-type]
|
||||
ExecutionGraph(strategy="sequence", nodes=[node]),
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
)
|
||||
)
|
||||
|
||||
envelope = runner.envelopes[0]
|
||||
assert envelope.input_contract == {"query": "str"}
|
||||
assert envelope.output_contract == {"sources": "list"}
|
||||
assert envelope.required_evidence == ["tool_result"]
|
||||
assert envelope.evidence_contract == {"entities": ["MGM"]}
|
||||
assert envelope.validation_rules == ["official_sources_only"]
|
||||
assert envelope.required_for_completion is False
|
||||
assert envelope.block_downstream_on_partial is True
|
||||
assert envelope.max_tool_iterations == 2
|
||||
|
||||
|
||||
def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
@ -428,9 +602,12 @@ def test_team_dag_blocks_dependents_after_failure(tmp_path: Path) -> None:
|
||||
)
|
||||
)
|
||||
publish = [item for item in result.node_results if item.node_id == "publish"][0]
|
||||
validate = [item for item in result.node_results if item.node_id == "validate"][0]
|
||||
|
||||
assert result.success is False
|
||||
assert validate.completion_status == "failed"
|
||||
assert publish.finish_reason == "blocked"
|
||||
assert publish.completion_status == "blocked"
|
||||
assert publish.run_id is None
|
||||
assert publish.error == "Blocked by failed dependency: validate"
|
||||
assert "failed" not in result.summary.split("Failed nodes:")[0]
|
||||
@ -471,8 +648,10 @@ def test_dag_node_factory_error_blocks_dependents(tmp_path: Path) -> None:
|
||||
|
||||
assert result.success is False
|
||||
assert validate.finish_reason == "error"
|
||||
assert validate.completion_status == "failed"
|
||||
assert validate.error == "validator unavailable"
|
||||
assert publish.finish_reason == "blocked"
|
||||
assert publish.completion_status == "blocked"
|
||||
assert publish.error == "Blocked by failed dependency: validate"
|
||||
|
||||
|
||||
@ -550,6 +729,76 @@ def test_graph_structure_errors_still_raise(tmp_path: Path) -> None:
|
||||
asyncio.run(TeamService(loop).run_team(cyclic, parent_task_id=None, parent_session_id="session-root"))
|
||||
|
||||
|
||||
def test_execution_node_contract_defaults_preserve_legacy_scope_behavior() -> None:
|
||||
node = ExecutionNode("collect", "Collect sources", AgentDescriptor(name="collect"))
|
||||
|
||||
assert node.input_contract == {}
|
||||
assert node.output_contract == {}
|
||||
assert node.allowed_tool_names is None
|
||||
assert node.required_evidence == []
|
||||
assert node.evidence_contract == {}
|
||||
assert node.validation_rules == []
|
||||
assert node.required_for_completion is True
|
||||
assert node.block_downstream_on_partial is False
|
||||
assert node.max_tool_iterations is None
|
||||
|
||||
|
||||
def test_execution_node_keeps_explicit_empty_tool_scope_distinct_from_unspecified_scope() -> None:
|
||||
unrestricted = ExecutionNode("unrestricted", "Collect", AgentDescriptor(name="unrestricted"))
|
||||
tool_free = ExecutionNode(
|
||||
"tool_free",
|
||||
"Synthesize",
|
||||
AgentDescriptor(name="tool_free"),
|
||||
allowed_tool_names=[],
|
||||
)
|
||||
|
||||
assert unrestricted.allowed_tool_names is None
|
||||
assert tool_free.allowed_tool_names == []
|
||||
|
||||
|
||||
def test_delegation_envelope_and_node_result_preserve_new_contract_metadata() -> None:
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
parent_run_id="run-root",
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
task="Collect sources",
|
||||
allowed_tool_names=["web_search"],
|
||||
required_evidence=["url"],
|
||||
evidence_contract={"entities": ["MGM", "Galaxy"]},
|
||||
validation_rules=["official_sources_only"],
|
||||
required_for_completion=True,
|
||||
block_downstream_on_partial=True,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
result = NodeRunResult(
|
||||
node_id="collect",
|
||||
success=False,
|
||||
output_text="MGM source only",
|
||||
completion_status="partial",
|
||||
evidence_gaps=["missing required evidence: Galaxy official source"],
|
||||
)
|
||||
|
||||
assert envelope.allowed_tool_names == ["web_search"]
|
||||
assert envelope.evidence_contract == {"entities": ["MGM", "Galaxy"]}
|
||||
assert result.to_dict()["completion_status"] == "partial"
|
||||
assert result.to_dict()["evidence_gaps"] == ["missing required evidence: Galaxy official source"]
|
||||
|
||||
|
||||
def test_graph_rejects_depth_above_configured_limit() -> None:
|
||||
graph = ExecutionGraph(
|
||||
strategy="dag",
|
||||
nodes=[
|
||||
ExecutionNode("a", "A", AgentDescriptor(name="a")),
|
||||
ExecutionNode("b", "B", AgentDescriptor(name="b"), depends_on=["a"]),
|
||||
ExecutionNode("c", "C", AgentDescriptor(name="c"), depends_on=["b"]),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="max depth"):
|
||||
graph.validate(max_depth=2)
|
||||
|
||||
|
||||
def test_team_run_does_not_create_independent_team_task(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
loaded = loop.boot()
|
||||
|
||||
Reference in New Issue
Block a user