feat(tasks): add skill-templated task graph execution

This commit is contained in:
2026-06-23 10:22:58 +08:00
parent 6843d89b2c
commit 53b13e8eac
53 changed files with 4773 additions and 756 deletions

View File

@ -84,11 +84,21 @@ class TeamGraphScheduler:
**kwargs,
) -> list[NodeRunResult]:
results: list[NodeRunResult] = []
nodes_by_id = {node.node_id: node for node in nodes}
for node in nodes:
if any(not item.success for item in results):
results.append(self._blocked(node, results))
blocking = [
item
for item in results
if self._blocks_downstream(item, nodes_by_id[item.node_id])
]
if blocking:
results.append(self._blocked(node, blocking))
continue
dependency_outputs = {item.node_id: item.output_text for item in results if item.success}
dependency_outputs = {
item.node_id: item.output_text
for item in results
if item.completion_status in {"succeeded", "partial"}
}
results.append(await self._run_node(node, dependency_outputs=dependency_outputs, **kwargs))
return results
@ -116,6 +126,7 @@ class TeamGraphScheduler:
**kwargs,
) -> list[NodeRunResult]:
pending = {node.node_id: node for node in nodes}
nodes_by_id = {node.node_id: node for node in nodes}
completed: dict[str, NodeRunResult] = {}
ordered: list[NodeRunResult] = []
@ -123,18 +134,28 @@ class TeamGraphScheduler:
blocked_ids = {
node_id
for node_id, node in pending.items()
if any(dep in completed and not completed[dep].success for dep in node.depends_on)
if any(
dep in completed
and self._blocks_downstream(completed[dep], nodes_by_id[dep])
for dep in node.depends_on
)
}
for node_id in sorted(blocked_ids):
node = pending.pop(node_id)
result = self._blocked(node, list(completed.values()))
completed[node_id] = result
ordered.append(result)
if blocked_ids:
continue
ready = [
node
for node in pending.values()
if all(dep in completed and completed[dep].success for dep in node.depends_on)
if all(
dep in completed
and not self._blocks_downstream(completed[dep], nodes_by_id[dep])
for dep in node.depends_on
)
]
if not ready:
if pending:
@ -196,6 +217,17 @@ class TeamGraphScheduler:
expected_output=node.expected_output,
node_id=node.node_id,
dependency_outputs=dict(dependency_outputs),
input_contract=dict(node.input_contract),
output_contract=dict(node.output_contract),
allowed_tool_names=(
None if node.allowed_tool_names is None else list(node.allowed_tool_names)
),
required_evidence=list(node.required_evidence),
evidence_contract=dict(node.evidence_contract),
validation_rules=list(node.validation_rules),
required_for_completion=node.required_for_completion,
block_downstream_on_partial=node.block_downstream_on_partial,
max_tool_iterations=node.max_tool_iterations,
)
node_provider_bundle = provider_bundle_factory(node) if provider_bundle_factory is not None else provider_bundle
return await self.runner.run(
@ -213,8 +245,17 @@ class TeamGraphScheduler:
output_text="",
finish_reason="error",
error=str(exc),
completion_status="failed",
)
@staticmethod
def _blocks_downstream(result: NodeRunResult, node: ExecutionNode) -> bool:
if result.completion_status in {"failed", "blocked"}:
return True
if result.completion_status == "partial":
return node.block_downstream_on_partial
return not result.success
@staticmethod
def _merge_pinned(parent: list[str], local: list[str]) -> list[str]:
result: list[str] = []
@ -245,6 +286,7 @@ class TeamGraphScheduler:
output_text="",
finish_reason="blocked",
error=f"Blocked by failed dependency: {detail}",
completion_status="blocked",
)
@staticmethod

View File

@ -6,7 +6,7 @@ from uuid import uuid4
from beaver.engine import AgentLoop
from beaver.engine.providers import ProviderBundle
from beaver.tasks.evidence import EvidenceBuilder
from beaver.tasks.evidence import EvidenceBuilder, evaluate_node_evidence
from .models import DelegationEnvelope, NodeRunResult
@ -54,6 +54,8 @@ class LocalAgentRunner:
task_mode=bool(envelope.parent_task_id),
pinned_skill_names=envelope.inherited_pinned_skills,
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
allowed_tool_names=envelope.allowed_tool_names,
max_tool_iterations=envelope.max_tool_iterations,
allow_candidate_generation=allow_candidate_generation,
)
loaded = target_loop.boot()
@ -63,7 +65,23 @@ class LocalAgentRunner:
result.output_text,
result.finish_reason,
)
success = result.finish_reason == "stop"
evidence_gaps = evaluate_node_evidence(
evidence,
envelope.required_evidence,
result.output_text,
)
run_succeeded = result.finish_reason == "stop"
if not run_succeeded:
completion_status = "failed"
elif evidence_gaps:
completion_status = "partial"
else:
completion_status = "succeeded"
success = completion_status == "succeeded"
if completion_status == "partial":
error = "; ".join(evidence_gaps)
else:
error = None if success else (result.output_text or result.finish_reason)
return NodeRunResult(
node_id=envelope.node_id or envelope.agent.name,
success=success,
@ -71,8 +89,10 @@ class LocalAgentRunner:
run_id=result.run_id,
session_id=result.session_id,
finish_reason=result.finish_reason,
error=None if success else (result.output_text or result.finish_reason),
error=error,
evidence=evidence,
completion_status=completion_status,
evidence_gaps=evidence_gaps,
)
@staticmethod

View File

@ -51,6 +51,15 @@ class DelegationEnvelope:
expected_output: str | None = None
node_id: str | None = None
dependency_outputs: dict[str, str] = field(default_factory=dict)
input_contract: dict[str, Any] = field(default_factory=dict)
output_contract: dict[str, Any] = field(default_factory=dict)
allowed_tool_names: list[str] | None = None
required_evidence: list[str] = field(default_factory=list)
evidence_contract: dict[str, Any] = field(default_factory=dict)
validation_rules: list[str] = field(default_factory=list)
required_for_completion: bool = True
block_downstream_on_partial: bool = False
max_tool_iterations: int | None = None
@dataclass(slots=True)
@ -65,6 +74,15 @@ class ExecutionNode:
inherited_pinned_skill_contexts: list["SkillContext"] = field(default_factory=list)
constraints: list[str] = field(default_factory=list)
expected_output: str | None = None
input_contract: dict[str, Any] = field(default_factory=dict)
output_contract: dict[str, Any] = field(default_factory=dict)
allowed_tool_names: list[str] | None = None
required_evidence: list[str] = field(default_factory=list)
evidence_contract: dict[str, Any] = field(default_factory=dict)
validation_rules: list[str] = field(default_factory=list)
required_for_completion: bool = True
block_downstream_on_partial: bool = False
max_tool_iterations: int | None = None
@dataclass(slots=True)
@ -74,7 +92,7 @@ class ExecutionGraph:
strategy: TeamStrategy
nodes: list[ExecutionNode]
def validate(self) -> None:
def validate(self, *, max_depth: int | None = None) -> None:
if self.strategy not in {"sequence", "parallel", "dag"}:
raise NotImplementedError(f"Team strategy {self.strategy!r} is reserved but not implemented in v1")
if not self.nodes:
@ -91,19 +109,25 @@ class ExecutionGraph:
visited: set[str] = set()
deps = {node.node_id: list(node.depends_on) for node in self.nodes}
def visit(node_id: str) -> None:
def visit(node_id: str) -> int:
if node_id in visited:
return
return depths[node_id]
if node_id in visiting:
raise ValueError(f"ExecutionGraph has cyclic or unresolved dependencies involving {node_id!r}")
visiting.add(node_id)
depth = 1
for dep in deps[node_id]:
visit(dep)
depth = max(depth, visit(dep) + 1)
visiting.remove(node_id)
visited.add(node_id)
depths[node_id] = depth
return depth
depths: dict[str, int] = {}
for node_id in node_ids:
visit(node_id)
depth = visit(node_id)
if max_depth is not None and depth > max_depth:
raise ValueError(f"ExecutionGraph exceeds max depth {max_depth}")
@dataclass(slots=True)
@ -118,6 +142,8 @@ class NodeRunResult:
finish_reason: str = "stop"
error: str | None = None
evidence: "RunEvidence | None" = None
completion_status: str = "succeeded"
evidence_gaps: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
@ -129,6 +155,8 @@ class NodeRunResult:
"finish_reason": self.finish_reason,
"error": self.error,
"evidence": self.evidence.to_dict() if self.evidence is not None else None,
"completion_status": self.completion_status,
"evidence_gaps": list(self.evidence_gaps),
}