Files
beaver_project/app-instance/backend/beaver/tasks/planner.py

613 lines
24 KiB
Python

"""Internal Task execution planner for single-agent vs team execution."""
from __future__ import annotations
import asyncio
import json
import os
from dataclasses import dataclass, field
from typing import Any, Literal
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
from beaver.engine.context import SkillContext
from beaver.engine.providers import ProviderBundle
from beaver.tools.registry import ToolRegistry
from .models import TaskRecord
from .skill_resolver import SkillResolutionReport, TaskSkillResolver
TaskExecutionMode = Literal["single", "team"]
# Temporary name-based denylist until high-risk tool approval is implemented.
# Keep this policy centralized so planner behavior cannot drift by call site.
HIGH_RISK_PLANNER_TOOL_NAMES = frozenset(
{
"delete_file",
"execute_command",
"external_send",
"send_email",
"terminal",
"write_file",
}
)
def _agent_team_enabled() -> bool:
return os.getenv("BEAVER_AGENT_TEAM_ENABLED", "1").strip().lower() not in {"0", "false", "no", "off"}
@dataclass(slots=True)
class TaskExecutionPlan:
mode: TaskExecutionMode
reason: str = ""
graph: ExecutionGraph | None = None
final_synthesis_instruction: str = ""
fallback_error: str | None = None
skill_resolution_report: list[SkillResolutionReport] = field(default_factory=list)
planner_adaptation: dict[str, Any] = field(default_factory=dict)
@property
def is_team(self) -> bool:
return self.mode == "team" and self.graph is not None
@classmethod
def single(
cls,
reason: str,
*,
fallback_error: str | None = None,
planner_adaptation: dict[str, Any] | None = None,
) -> "TaskExecutionPlan":
return cls(
mode="single",
reason=reason,
fallback_error=fallback_error,
planner_adaptation=dict(planner_adaptation or {}),
)
def to_event_payload(self) -> dict[str, Any]:
strategy = self.graph.strategy if self.graph is not None else None
nodes = self.graph.nodes if self.graph is not None else []
return {
"plan_mode": self.mode,
"reason": self.reason,
"strategy": strategy,
"node_ids": [node.node_id for node in nodes],
"skill_queries": [
str(node.agent.metadata.get("skill_query") or "")
for node in nodes
],
"selected_skill_names": [
name
for node in nodes
for name in node.inherited_pinned_skills
],
"ephemeral_guidance_ids": [
item.ephemeral_guidance_id
for item in self.skill_resolution_report
if item.ephemeral_guidance_id
],
"skill_resolution_report": [item.to_dict() for item in self.skill_resolution_report],
"planner_adaptation": dict(self.planner_adaptation),
"fallback_error": self.fallback_error,
}
class TaskExecutionPlanner:
"""Plan whether a Task attempt should run through a team first."""
_MAX_NODES = 6
_MAX_DEPTH = 4
_SUPPORTED_STRATEGIES = {"sequence", "parallel", "dag"}
_ALLOWED_NODE_FIELDS = {
"node_id",
"task",
"use_skill",
"skill_query",
"depends_on",
"input_contract",
"output_contract",
"requested_tools",
"required_evidence",
"evidence_contract",
"validation_rules",
"required_for_completion",
"block_downstream_on_partial",
"max_tool_iterations",
"constraints",
}
def __init__(
self,
*,
task_skill_resolver: TaskSkillResolver | None = None,
tool_registry: ToolRegistry | None = None,
) -> None:
self.task_skill_resolver = task_skill_resolver
self.tool_registry = tool_registry
async def plan(
self,
*,
task: TaskRecord,
user_message: str,
attempt_index: int,
provider_bundle: ProviderBundle | None = None,
timeout_seconds: float = 30.0,
skill_summaries: list[str] | None = None,
tool_hints: list[str] | None = None,
activated_skills: list[SkillContext] | None = None,
) -> TaskExecutionPlan:
if not _agent_team_enabled():
return TaskExecutionPlan.single("planner_disabled_by_environment")
if not self._needs_team_planning(task=task, user_message=user_message):
return TaskExecutionPlan.single("planner_skipped_simple_task")
provider = None
model = None
if provider_bundle is not None:
provider = provider_bundle.auxiliary_provider or provider_bundle.main_provider
runtime = provider_bundle.auxiliary_runtime or provider_bundle.main_runtime
model = getattr(runtime, "model", None)
if provider is None:
return TaskExecutionPlan.single("planner_provider_unavailable")
selected_template, base_adaptation = self._select_team_template(activated_skills or [])
try:
response = await asyncio.wait_for(
provider.chat(
messages=[
{
"role": "system",
"content": (
"You choose whether an internal Beaver Task attempt should run as a single "
"main-agent pass or use a small sub-agent team first. Return only compact JSON."
),
},
{
"role": "user",
"content": self._prompt(
task=task,
user_message=user_message,
attempt_index=attempt_index,
skill_summaries=skill_summaries or [],
tool_hints=tool_hints or [],
activated_skills=activated_skills or [],
selected_template=selected_template,
),
},
],
tools=None,
model=model,
max_tokens=4096,
temperature=0.0,
),
timeout=timeout_seconds,
)
try:
plan = self._from_json_or_raise(response.content or "")
except Exception as first_error:
repair_response = await asyncio.wait_for(
provider.chat(
messages=[
{
"role": "system",
"content": "Repair invalid Beaver task planner JSON. Return only one compact JSON object.",
},
{
"role": "user",
"content": (
"Repair the invalid planner JSON using the task-only schema from the original "
f"request. Validation error: {first_error}\nInvalid output:\n{response.content or ''}"
),
},
],
tools=None,
model=model,
max_tokens=4096,
temperature=0.0,
),
timeout=timeout_seconds,
)
try:
plan = self._from_json_or_raise(repair_response.content or "")
except Exception as repair_error:
return TaskExecutionPlan.single(
"planner_fallback_single",
fallback_error=f"initial validation: {first_error}; repair validation: {repair_error}",
planner_adaptation=base_adaptation,
)
self._merge_adaptation(plan, base_adaptation)
return await self._resolve_plan(
plan,
task=task,
user_message=user_message,
attempt_index=attempt_index,
provider_bundle=provider_bundle,
)
except Exception as exc:
detail = str(exc)
error = f"{type(exc).__name__}: {detail}" if detail else type(exc).__name__
return TaskExecutionPlan.single("planner_failed", fallback_error=error)
async def _resolve_plan(
self,
plan: TaskExecutionPlan,
*,
task: TaskRecord,
user_message: str,
attempt_index: int,
provider_bundle: ProviderBundle | None,
) -> TaskExecutionPlan:
if not plan.is_team or self.task_skill_resolver is None:
return plan
if provider_bundle is None:
return TaskExecutionPlan.single("planner_fallback_single", fallback_error="task_skill_resolver_provider_unavailable")
try:
assert plan.graph is not None
graph, reports = await self.task_skill_resolver.resolve_graph(
plan.graph,
task=task,
user_message=user_message,
attempt_index=attempt_index,
provider_bundle=provider_bundle,
)
graph.validate()
plan.graph = graph
plan.skill_resolution_report = reports
self._merge_skill_resolution_adaptation(plan, reports)
return plan
except Exception as exc:
return TaskExecutionPlan.single("planner_fallback_single", fallback_error=f"task_skill_resolver_failed: {exc}")
@staticmethod
def _needs_team_planning(*, task: TaskRecord, user_message: str) -> bool:
text = " ".join(
part
for part in (
task.goal,
task.description,
user_message,
)
if part
).lower()
if not text.strip():
return False
complex_markers = (
"agent team",
"sub-agent",
"multi-agent",
"parallel",
"dag",
"workflow",
"review",
"research",
"compare",
"comparison",
"architecture",
"refactor",
"multi-file",
"end-to-end",
"并行",
"团队",
"多智能体",
"子代理",
"工作流",
"评审",
"审查",
"调研",
"研究",
"对比",
"架构",
"重构",
"多文件",
"端到端",
)
return any(marker in text for marker in complex_markers)
def from_json(self, text: str) -> TaskExecutionPlan:
try:
return self._from_json_or_raise(text)
except Exception as exc:
return TaskExecutionPlan.single("planner_fallback_single", fallback_error=str(exc))
def _from_json_or_raise(self, text: str) -> TaskExecutionPlan:
payload = self._parse_json_object(text)
mode = str(payload.get("mode") or "single").strip().lower()
reason = str(payload.get("reason") or "")
adaptation = self._adaptation_from_payload(payload)
if mode != "team":
return TaskExecutionPlan.single(
reason or "planner_selected_single",
planner_adaptation=adaptation,
)
graph = self._graph_from_payload(payload, adaptation=adaptation)
graph.validate(max_depth=self._MAX_DEPTH)
return TaskExecutionPlan(
mode="team",
reason=reason or "planner_selected_team",
graph=graph,
final_synthesis_instruction=str(payload.get("final_synthesis_instruction") or ""),
planner_adaptation=adaptation,
)
def _graph_from_payload(
self,
payload: dict[str, Any],
*,
adaptation: dict[str, Any],
) -> ExecutionGraph:
strategy = str(payload.get("strategy") or "sequence").strip().lower()
if strategy not in self._SUPPORTED_STRATEGIES:
raise ValueError(f"Unsupported team strategy: {strategy}")
raw_nodes = payload.get("nodes")
if not isinstance(raw_nodes, list) or not raw_nodes:
raise ValueError("Team plan requires at least one node")
if len(raw_nodes) > self._MAX_NODES:
raise ValueError(f"Team plan exceeds max node count {self._MAX_NODES}")
nodes: list[ExecutionNode] = []
for index, item in enumerate(raw_nodes, start=1):
if not isinstance(item, dict):
raise ValueError("Each team node must be an object")
unsupported = sorted(set(item) - self._ALLOWED_NODE_FIELDS)
if unsupported:
raise ValueError(f"Unsupported team node field(s): {', '.join(unsupported)}")
node_id = str(item.get("node_id") or f"node_{index}").strip()
task = str(item.get("task") or "").strip()
if not node_id or not task:
raise ValueError("Each team node requires node_id and task")
allowed_tool_names = self._resolve_requested_tools(
item.get("requested_tools"),
warnings=adaptation["warnings"],
)
use_skill = _optional_str(item.get("use_skill"))
skill_query = _optional_str(item.get("skill_query")) or task
if use_skill is not None or "skill_query" in item:
adaptation.setdefault("node_skill_bindings", []).append(
{
"node_id": node_id,
"use_skill": use_skill,
"skill_query": skill_query,
}
)
nodes.append(
ExecutionNode(
node_id=node_id,
task=task,
agent=AgentDescriptor(
name=node_id,
role="",
system_prompt="",
metadata={
"use_skill": use_skill,
"skill_query": skill_query,
"required_capabilities": [],
"requested_tags": [],
"sub_agent_kind": "generic_skill_worker",
},
),
depends_on=[str(dep) for dep in item.get("depends_on") or []],
constraints=[str(value) for value in item.get("constraints") or []],
input_contract=_dict_value(item.get("input_contract")),
output_contract=_dict_value(item.get("output_contract")),
allowed_tool_names=allowed_tool_names,
required_evidence=_string_list(item.get("required_evidence")),
evidence_contract=_dict_value(item.get("evidence_contract")),
validation_rules=_string_list(item.get("validation_rules")),
required_for_completion=bool(item.get("required_for_completion", True)),
block_downstream_on_partial=bool(item.get("block_downstream_on_partial", False)),
max_tool_iterations=_optional_int(item.get("max_tool_iterations")),
)
)
return ExecutionGraph(strategy=strategy, nodes=nodes) # type: ignore[arg-type]
def _resolve_requested_tools(self, value: Any, *, warnings: list[str]) -> list[str] | None:
if value is None:
return None
result: list[str] = []
for name in _string_list(value):
if name.lower() in HIGH_RISK_PLANNER_TOOL_NAMES:
_append_unique(warnings, f"requires_high_risk_review: {name}")
continue
if self.tool_registry is None or self.tool_registry.get(name) is None:
_append_unique(warnings, f"unknown tool removed: {name}")
continue
result.append(name)
return result
@staticmethod
def _adaptation_from_payload(payload: dict[str, Any]) -> dict[str, Any]:
raw = payload.get("adaptation")
adaptation = dict(raw) if isinstance(raw, dict) else {}
adaptation["warnings"] = _string_list(adaptation.get("warnings"))
return adaptation
@staticmethod
def _select_team_template(
activated_skills: list[SkillContext],
) -> tuple[SkillContext | None, dict[str, Any]]:
candidates = [
skill
for skill in activated_skills
if isinstance(skill.team_template, dict) and isinstance(skill.team_template.get("nodes"), list)
]
selected = candidates[0] if candidates else None
warnings: list[str] = []
for skill in activated_skills:
for warning in skill.team_template_warnings:
_append_unique(warnings, f"{skill.name}: {warning}")
return selected, {
"template_used": False,
"selected_template": selected.name if selected else None,
"selection_reason": (
"first activated skill with a valid team template"
if selected
else "no activated skill has a valid team template"
),
"ignored_templates": [skill.name for skill in candidates[1:]],
"warnings": warnings,
}
@staticmethod
def _merge_adaptation(plan: TaskExecutionPlan, base: dict[str, Any]) -> None:
payload = dict(plan.planner_adaptation)
warnings: list[str] = []
for warning in [*base.get("warnings", []), *payload.get("warnings", [])]:
_append_unique(warnings, str(warning))
merged = {
"template_used": bool(payload.get("template_used", False)),
"selected_template": base.get("selected_template"),
"selection_reason": base.get("selection_reason"),
"ignored_templates": list(base.get("ignored_templates", [])),
"warnings": warnings,
}
if isinstance(payload.get("node_skill_bindings"), list):
merged["node_skill_bindings"] = [dict(item) for item in payload["node_skill_bindings"] if isinstance(item, dict)]
plan.planner_adaptation = merged
@staticmethod
def _merge_skill_resolution_adaptation(
plan: TaskExecutionPlan,
reports: list[SkillResolutionReport],
) -> None:
warnings = plan.planner_adaptation.setdefault("warnings", [])
bindings = plan.planner_adaptation.get("node_skill_bindings")
binding_by_node = {
str(item.get("node_id")): item
for item in bindings or []
if isinstance(item, dict)
}
for report in reports:
for warning in report.warnings:
_append_unique(warnings, warning)
binding = binding_by_node.get(report.node_id)
if binding is not None and report.requested_skill_name and not report.exact_binding_used:
binding["fallback_reason"] = f"use_skill unresolved; {report.reason}"
@staticmethod
def _prompt(
*,
task: TaskRecord,
user_message: str,
attempt_index: int,
skill_summaries: list[str] | None = None,
tool_hints: list[str] | None = None,
activated_skills: list[SkillContext] | None = None,
selected_template: SkillContext | None = None,
) -> str:
history_note = ""
if task.feedback:
history_note = "\nRelevant task history:\n" + json.dumps(task.feedback[-5:], ensure_ascii=False)
skill_note = ""
if skill_summaries:
skill_note = "\nActivated skill summaries:\n" + "\n".join(f"- {item}" for item in skill_summaries)
guidance_note = ""
if activated_skills:
guidance_note = "\nActivated Skill guidance:\n" + "\n".join(
f"[{skill.name}]\n{skill.content}" for skill in activated_skills
)
template_note = ""
if selected_template is not None:
template_note = "\nPrimary Skill team template:\n" + json.dumps(
{
"skill_name": selected_template.name,
"skill_version": selected_template.version,
"template": selected_template.team_template,
},
ensure_ascii=False,
indent=2,
)
tool_note = ""
if tool_hints:
tool_note = "\nActivated skill tool hints:\n" + "\n".join(f"- {item}" for item in tool_hints)
return (
"Decide execution mode for this internal Task attempt.\n"
"Use mode=team only when independent research, review, implementation slices, or staged checks "
"would materially improve the result. Otherwise use mode=single.\n\n"
"JSON schema:\n"
"{\n"
' "mode": "single" | "team",\n'
' "reason": "short reason",\n'
' "strategy": "sequence" | "parallel" | "dag",\n'
' "nodes": [{"node_id": "collect", "task": "...", "use_skill": "optional exact skill", '
'"skill_query": "optional dynamic skill query", "depends_on": [], '
'"input_contract": {}, "output_contract": {}, "requested_tools": [], '
'"required_evidence": [], "evidence_contract": {}, "validation_rules": [], '
'"required_for_completion": true, "block_downstream_on_partial": false, '
'"max_tool_iterations": 3, "constraints": []}],\n'
' "adaptation": {"template_used": true, "warnings": []},\n'
' "final_synthesis_instruction": "how the main agent should synthesize team output"\n'
"}\n\n"
"Node definitions are task-only. Never output agent or role fields. Use at most one primary "
"Skill template; treat all other activated Skills as guidance.\n\n"
f"Task goal:\n{task.goal}\n\n"
f"Current user request:\n{user_message}\n\n"
f"Attempt index: {attempt_index}\n"
f"{skill_note}"
f"{guidance_note}"
f"{template_note}"
f"{tool_note}"
f"{history_note}"
)
@staticmethod
def _parse_json_object(text: str) -> dict[str, Any]:
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.strip("`")
if cleaned.lower().startswith("json"):
cleaned = cleaned[4:].strip()
start = cleaned.find("{")
end = cleaned.rfind("}")
if start >= 0 and end >= start:
cleaned = cleaned[start : end + 1]
payload = json.loads(cleaned)
if not isinstance(payload, dict):
raise ValueError("planner response must be a JSON object")
return payload
def _optional_str(value: Any) -> str | None:
if value in (None, ""):
return None
text = str(value).strip()
return text or None
def _optional_int(value: Any) -> int | None:
if value in (None, ""):
return None
if isinstance(value, bool):
raise ValueError("max_tool_iterations must be an integer")
result = int(value)
if result < 0:
raise ValueError("max_tool_iterations must be non-negative")
return result
def _dict_value(value: Any) -> dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
def _append_unique(values: list[str], value: str) -> None:
if value and value not in values:
values.append(value)
def _string_list(value: Any) -> list[str]:
if not isinstance(value, list):
if isinstance(value, str):
value = [item.strip() for item in value.split(",")]
else:
return []
result: list[str] = []
for item in value:
text = str(item).strip()
if text and text not in result:
result.append(text)
return result