feat(tasks): add skill-templated task graph execution
This commit is contained in:
@ -84,11 +84,21 @@ class TeamGraphScheduler:
|
||||
**kwargs,
|
||||
) -> list[NodeRunResult]:
|
||||
results: list[NodeRunResult] = []
|
||||
nodes_by_id = {node.node_id: node for node in nodes}
|
||||
for node in nodes:
|
||||
if any(not item.success for item in results):
|
||||
results.append(self._blocked(node, results))
|
||||
blocking = [
|
||||
item
|
||||
for item in results
|
||||
if self._blocks_downstream(item, nodes_by_id[item.node_id])
|
||||
]
|
||||
if blocking:
|
||||
results.append(self._blocked(node, blocking))
|
||||
continue
|
||||
dependency_outputs = {item.node_id: item.output_text for item in results if item.success}
|
||||
dependency_outputs = {
|
||||
item.node_id: item.output_text
|
||||
for item in results
|
||||
if item.completion_status in {"succeeded", "partial"}
|
||||
}
|
||||
results.append(await self._run_node(node, dependency_outputs=dependency_outputs, **kwargs))
|
||||
return results
|
||||
|
||||
@ -116,6 +126,7 @@ class TeamGraphScheduler:
|
||||
**kwargs,
|
||||
) -> list[NodeRunResult]:
|
||||
pending = {node.node_id: node for node in nodes}
|
||||
nodes_by_id = {node.node_id: node for node in nodes}
|
||||
completed: dict[str, NodeRunResult] = {}
|
||||
ordered: list[NodeRunResult] = []
|
||||
|
||||
@ -123,18 +134,28 @@ class TeamGraphScheduler:
|
||||
blocked_ids = {
|
||||
node_id
|
||||
for node_id, node in pending.items()
|
||||
if any(dep in completed and not completed[dep].success for dep in node.depends_on)
|
||||
if any(
|
||||
dep in completed
|
||||
and self._blocks_downstream(completed[dep], nodes_by_id[dep])
|
||||
for dep in node.depends_on
|
||||
)
|
||||
}
|
||||
for node_id in sorted(blocked_ids):
|
||||
node = pending.pop(node_id)
|
||||
result = self._blocked(node, list(completed.values()))
|
||||
completed[node_id] = result
|
||||
ordered.append(result)
|
||||
if blocked_ids:
|
||||
continue
|
||||
|
||||
ready = [
|
||||
node
|
||||
for node in pending.values()
|
||||
if all(dep in completed and completed[dep].success for dep in node.depends_on)
|
||||
if all(
|
||||
dep in completed
|
||||
and not self._blocks_downstream(completed[dep], nodes_by_id[dep])
|
||||
for dep in node.depends_on
|
||||
)
|
||||
]
|
||||
if not ready:
|
||||
if pending:
|
||||
@ -196,6 +217,17 @@ class TeamGraphScheduler:
|
||||
expected_output=node.expected_output,
|
||||
node_id=node.node_id,
|
||||
dependency_outputs=dict(dependency_outputs),
|
||||
input_contract=dict(node.input_contract),
|
||||
output_contract=dict(node.output_contract),
|
||||
allowed_tool_names=(
|
||||
None if node.allowed_tool_names is None else list(node.allowed_tool_names)
|
||||
),
|
||||
required_evidence=list(node.required_evidence),
|
||||
evidence_contract=dict(node.evidence_contract),
|
||||
validation_rules=list(node.validation_rules),
|
||||
required_for_completion=node.required_for_completion,
|
||||
block_downstream_on_partial=node.block_downstream_on_partial,
|
||||
max_tool_iterations=node.max_tool_iterations,
|
||||
)
|
||||
node_provider_bundle = provider_bundle_factory(node) if provider_bundle_factory is not None else provider_bundle
|
||||
return await self.runner.run(
|
||||
@ -213,8 +245,17 @@ class TeamGraphScheduler:
|
||||
output_text="",
|
||||
finish_reason="error",
|
||||
error=str(exc),
|
||||
completion_status="failed",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _blocks_downstream(result: NodeRunResult, node: ExecutionNode) -> bool:
|
||||
if result.completion_status in {"failed", "blocked"}:
|
||||
return True
|
||||
if result.completion_status == "partial":
|
||||
return node.block_downstream_on_partial
|
||||
return not result.success
|
||||
|
||||
@staticmethod
|
||||
def _merge_pinned(parent: list[str], local: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
@ -245,6 +286,7 @@ class TeamGraphScheduler:
|
||||
output_text="",
|
||||
finish_reason="blocked",
|
||||
error=f"Blocked by failed dependency: {detail}",
|
||||
completion_status="blocked",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
|
||||
from beaver.engine import AgentLoop
|
||||
from beaver.engine.providers import ProviderBundle
|
||||
from beaver.tasks.evidence import EvidenceBuilder
|
||||
from beaver.tasks.evidence import EvidenceBuilder, evaluate_node_evidence
|
||||
|
||||
from .models import DelegationEnvelope, NodeRunResult
|
||||
|
||||
@ -54,6 +54,8 @@ class LocalAgentRunner:
|
||||
task_mode=bool(envelope.parent_task_id),
|
||||
pinned_skill_names=envelope.inherited_pinned_skills,
|
||||
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
||||
allowed_tool_names=envelope.allowed_tool_names,
|
||||
max_tool_iterations=envelope.max_tool_iterations,
|
||||
allow_candidate_generation=allow_candidate_generation,
|
||||
)
|
||||
loaded = target_loop.boot()
|
||||
@ -63,7 +65,23 @@ class LocalAgentRunner:
|
||||
result.output_text,
|
||||
result.finish_reason,
|
||||
)
|
||||
success = result.finish_reason == "stop"
|
||||
evidence_gaps = evaluate_node_evidence(
|
||||
evidence,
|
||||
envelope.required_evidence,
|
||||
result.output_text,
|
||||
)
|
||||
run_succeeded = result.finish_reason == "stop"
|
||||
if not run_succeeded:
|
||||
completion_status = "failed"
|
||||
elif evidence_gaps:
|
||||
completion_status = "partial"
|
||||
else:
|
||||
completion_status = "succeeded"
|
||||
success = completion_status == "succeeded"
|
||||
if completion_status == "partial":
|
||||
error = "; ".join(evidence_gaps)
|
||||
else:
|
||||
error = None if success else (result.output_text or result.finish_reason)
|
||||
return NodeRunResult(
|
||||
node_id=envelope.node_id or envelope.agent.name,
|
||||
success=success,
|
||||
@ -71,8 +89,10 @@ class LocalAgentRunner:
|
||||
run_id=result.run_id,
|
||||
session_id=result.session_id,
|
||||
finish_reason=result.finish_reason,
|
||||
error=None if success else (result.output_text or result.finish_reason),
|
||||
error=error,
|
||||
evidence=evidence,
|
||||
completion_status=completion_status,
|
||||
evidence_gaps=evidence_gaps,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -51,6 +51,15 @@ class DelegationEnvelope:
|
||||
expected_output: str | None = None
|
||||
node_id: str | None = None
|
||||
dependency_outputs: dict[str, str] = field(default_factory=dict)
|
||||
input_contract: dict[str, Any] = field(default_factory=dict)
|
||||
output_contract: dict[str, Any] = field(default_factory=dict)
|
||||
allowed_tool_names: list[str] | None = None
|
||||
required_evidence: list[str] = field(default_factory=list)
|
||||
evidence_contract: dict[str, Any] = field(default_factory=dict)
|
||||
validation_rules: list[str] = field(default_factory=list)
|
||||
required_for_completion: bool = True
|
||||
block_downstream_on_partial: bool = False
|
||||
max_tool_iterations: int | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -65,6 +74,15 @@ class ExecutionNode:
|
||||
inherited_pinned_skill_contexts: list["SkillContext"] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
expected_output: str | None = None
|
||||
input_contract: dict[str, Any] = field(default_factory=dict)
|
||||
output_contract: dict[str, Any] = field(default_factory=dict)
|
||||
allowed_tool_names: list[str] | None = None
|
||||
required_evidence: list[str] = field(default_factory=list)
|
||||
evidence_contract: dict[str, Any] = field(default_factory=dict)
|
||||
validation_rules: list[str] = field(default_factory=list)
|
||||
required_for_completion: bool = True
|
||||
block_downstream_on_partial: bool = False
|
||||
max_tool_iterations: int | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -74,7 +92,7 @@ class ExecutionGraph:
|
||||
strategy: TeamStrategy
|
||||
nodes: list[ExecutionNode]
|
||||
|
||||
def validate(self) -> None:
|
||||
def validate(self, *, max_depth: int | None = None) -> None:
|
||||
if self.strategy not in {"sequence", "parallel", "dag"}:
|
||||
raise NotImplementedError(f"Team strategy {self.strategy!r} is reserved but not implemented in v1")
|
||||
if not self.nodes:
|
||||
@ -91,19 +109,25 @@ class ExecutionGraph:
|
||||
visited: set[str] = set()
|
||||
deps = {node.node_id: list(node.depends_on) for node in self.nodes}
|
||||
|
||||
def visit(node_id: str) -> None:
|
||||
def visit(node_id: str) -> int:
|
||||
if node_id in visited:
|
||||
return
|
||||
return depths[node_id]
|
||||
if node_id in visiting:
|
||||
raise ValueError(f"ExecutionGraph has cyclic or unresolved dependencies involving {node_id!r}")
|
||||
visiting.add(node_id)
|
||||
depth = 1
|
||||
for dep in deps[node_id]:
|
||||
visit(dep)
|
||||
depth = max(depth, visit(dep) + 1)
|
||||
visiting.remove(node_id)
|
||||
visited.add(node_id)
|
||||
depths[node_id] = depth
|
||||
return depth
|
||||
|
||||
depths: dict[str, int] = {}
|
||||
for node_id in node_ids:
|
||||
visit(node_id)
|
||||
depth = visit(node_id)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
raise ValueError(f"ExecutionGraph exceeds max depth {max_depth}")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -118,6 +142,8 @@ class NodeRunResult:
|
||||
finish_reason: str = "stop"
|
||||
error: str | None = None
|
||||
evidence: "RunEvidence | None" = None
|
||||
completion_status: str = "succeeded"
|
||||
evidence_gaps: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -129,6 +155,8 @@ class NodeRunResult:
|
||||
"finish_reason": self.finish_reason,
|
||||
"error": self.error,
|
||||
"evidence": self.evidence.to_dict() if self.evidence is not None else None,
|
||||
"completion_status": self.completion_status,
|
||||
"evidence_gaps": list(self.evidence_gaps),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user