"""Internal Task execution planner for single-agent vs team execution.""" from __future__ import annotations import asyncio import json from dataclasses import dataclass, field from typing import Any, Literal from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode from beaver.engine.providers import ProviderBundle from .models import TaskRecord, ValidationResult from .skill_resolver import SkillResolutionReport, TaskSkillResolver TaskExecutionMode = Literal["single", "team"] @dataclass(slots=True) class TaskExecutionPlan: mode: TaskExecutionMode reason: str = "" graph: ExecutionGraph | None = None final_synthesis_instruction: str = "" fallback_error: str | None = None skill_resolution_report: list[SkillResolutionReport] = field(default_factory=list) @property def is_team(self) -> bool: return self.mode == "team" and self.graph is not None @classmethod def single(cls, reason: str, *, fallback_error: str | None = None) -> "TaskExecutionPlan": return cls(mode="single", reason=reason, fallback_error=fallback_error) def to_event_payload(self) -> dict[str, Any]: strategy = self.graph.strategy if self.graph is not None else None nodes = self.graph.nodes if self.graph is not None else [] return { "plan_mode": self.mode, "reason": self.reason, "strategy": strategy, "node_ids": [node.node_id for node in nodes], "skill_queries": [ str(node.agent.metadata.get("skill_query") or "") for node in nodes ], "selected_skill_names": [ name for node in nodes for name in node.inherited_pinned_skills ], "ephemeral_guidance_ids": [ item.ephemeral_guidance_id for item in self.skill_resolution_report if item.ephemeral_guidance_id ], "skill_resolution_report": [item.to_dict() for item in self.skill_resolution_report], "fallback_error": self.fallback_error, } class TaskExecutionPlanner: """Plan whether a Task attempt should run through a team first.""" _MAX_NODES = 6 _SUPPORTED_STRATEGIES = {"sequence", "parallel", "dag"} def __init__(self, *, task_skill_resolver: TaskSkillResolver | None = None) -> None: self.task_skill_resolver = task_skill_resolver async def plan( self, *, task: TaskRecord, user_message: str, attempt_index: int, latest_validation: ValidationResult | None = None, provider_bundle: ProviderBundle | None = None, timeout_seconds: float = 30.0, ) -> TaskExecutionPlan: provider = None model = None if provider_bundle is not None: provider = provider_bundle.auxiliary_provider or provider_bundle.main_provider runtime = provider_bundle.auxiliary_runtime or provider_bundle.main_runtime model = getattr(runtime, "model", None) if provider is None: return TaskExecutionPlan.single("planner_provider_unavailable") try: response = await asyncio.wait_for( provider.chat( messages=[ { "role": "system", "content": ( "You choose whether an internal Beaver Task attempt should run as a single " "main-agent pass or use a small sub-agent team first. Return only compact JSON." ), }, { "role": "user", "content": self._prompt( task=task, user_message=user_message, attempt_index=attempt_index, latest_validation=latest_validation, ), }, ], tools=None, model=model, max_tokens=4096, temperature=0.0, ), timeout=timeout_seconds, ) plan = self.from_json(response.content or "") return await self._resolve_plan( plan, task=task, user_message=user_message, attempt_index=attempt_index, provider_bundle=provider_bundle, ) except Exception as exc: detail = str(exc) error = f"{type(exc).__name__}: {detail}" if detail else type(exc).__name__ return TaskExecutionPlan.single("planner_failed", fallback_error=error) async def _resolve_plan( self, plan: TaskExecutionPlan, *, task: TaskRecord, user_message: str, attempt_index: int, provider_bundle: ProviderBundle | None, ) -> TaskExecutionPlan: if not plan.is_team or self.task_skill_resolver is None: return plan if provider_bundle is None: return TaskExecutionPlan.single("planner_fallback_single", fallback_error="task_skill_resolver_provider_unavailable") try: assert plan.graph is not None graph, reports = await self.task_skill_resolver.resolve_graph( plan.graph, task=task, user_message=user_message, attempt_index=attempt_index, provider_bundle=provider_bundle, ) graph.validate() plan.graph = graph plan.skill_resolution_report = reports return plan except Exception as exc: return TaskExecutionPlan.single("planner_fallback_single", fallback_error=f"task_skill_resolver_failed: {exc}") def from_json(self, text: str) -> TaskExecutionPlan: try: payload = self._parse_json_object(text) mode = str(payload.get("mode") or "single").strip().lower() reason = str(payload.get("reason") or "") if mode != "team": return TaskExecutionPlan.single(reason or "planner_selected_single") graph = self._graph_from_payload(payload) graph.validate() return TaskExecutionPlan( mode="team", reason=reason or "planner_selected_team", graph=graph, final_synthesis_instruction=str(payload.get("final_synthesis_instruction") or ""), ) except Exception as exc: return TaskExecutionPlan.single("planner_fallback_single", fallback_error=str(exc)) def _graph_from_payload(self, payload: dict[str, Any]) -> ExecutionGraph: strategy = str(payload.get("strategy") or "sequence").strip().lower() if strategy not in self._SUPPORTED_STRATEGIES: raise ValueError(f"Unsupported team strategy: {strategy}") raw_nodes = payload.get("nodes") if not isinstance(raw_nodes, list) or not raw_nodes: raise ValueError("Team plan requires at least one node") if len(raw_nodes) > self._MAX_NODES: raise ValueError(f"Team plan exceeds max node count {self._MAX_NODES}") nodes: list[ExecutionNode] = [] for index, item in enumerate(raw_nodes, start=1): if not isinstance(item, dict): raise ValueError("Each team node must be an object") agent_payload = item.get("agent") if isinstance(item.get("agent"), dict) else {} skill_query = str(item.get("skill_query") or agent_payload.get("skill_query") or item.get("task") or "").strip() requested_capabilities = _string_list( item.get("required_capabilities") or item.get("capabilities") or agent_payload.get("capabilities") ) requested_tags = _string_list(item.get("tags") or agent_payload.get("tags")) node_id = str(item.get("node_id") or item.get("id") or agent_payload.get("name") or f"node_{index}").strip() task = str(item.get("task") or "").strip() if not node_id or not task: raise ValueError("Each team node requires node_id/id and task") nodes.append( ExecutionNode( node_id=node_id, task=task, agent=AgentDescriptor( name=node_id, role="", system_prompt="", metadata={ "skill_query": skill_query, "required_capabilities": requested_capabilities, "requested_tags": requested_tags, "sub_agent_kind": "generic_skill_worker", }, ), depends_on=[str(dep) for dep in item.get("depends_on") or []], inherited_pinned_skills=[str(name) for name in item.get("pinned_skills") or []], constraints=[str(value) for value in item.get("constraints") or []], expected_output=str(item.get("expected_output") or "") or None, ) ) return ExecutionGraph(strategy=strategy, nodes=nodes) # type: ignore[arg-type] @staticmethod def _prompt( *, task: TaskRecord, user_message: str, attempt_index: int, latest_validation: ValidationResult | None, ) -> str: validation_note = "" if latest_validation is not None: validation_note = ( "\nPrevious validation issues:\n" + json.dumps(latest_validation.to_dict(), ensure_ascii=False) ) return ( "Decide execution mode for this internal Task attempt.\n" "Use mode=team only when independent research, review, implementation slices, or staged checks " "would materially improve the result. Otherwise use mode=single.\n\n" "JSON schema:\n" "{\n" ' "mode": "single" | "team",\n' ' "reason": "short reason",\n' ' "strategy": "sequence" | "parallel" | "dag",\n' ' "nodes": [{"node_id": "api_review", "task": "...", "skill_query": "API contract review", ' '"required_capabilities": ["schema compatibility"], "depends_on": []}],\n' ' "final_synthesis_instruction": "how the main agent should synthesize team output"\n' "}\n\n" f"Task goal:\n{task.goal}\n\n" f"Current user request:\n{user_message}\n\n" f"Attempt index: {attempt_index}\n" f"{validation_note}" ) @staticmethod def _parse_json_object(text: str) -> dict[str, Any]: cleaned = text.strip() if cleaned.startswith("```"): cleaned = cleaned.strip("`") if cleaned.lower().startswith("json"): cleaned = cleaned[4:].strip() start = cleaned.find("{") end = cleaned.rfind("}") if start >= 0 and end >= start: cleaned = cleaned[start : end + 1] payload = json.loads(cleaned) if not isinstance(payload, dict): raise ValueError("planner response must be a JSON object") return payload def _optional_str(value: Any) -> str | None: if value in (None, ""): return None text = str(value).strip() return text or None def _string_list(value: Any) -> list[str]: if not isinstance(value, list): if isinstance(value, str): value = [item.strip() for item in value.split(",")] else: return [] result: list[str] = [] for item in value: text = str(item).strip() if text and text not in result: result.append(text) return result