"""Resolve Task team nodes to pinned skills for generic sub-agents.""" from __future__ import annotations import json from dataclasses import dataclass, field, replace from typing import Any from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode from beaver.engine.providers import ProviderBundle from beaver.skills.assembler.embedding_retriever import SkillEmbeddingRetriever from beaver.skills.catalog.loader import SkillsLoader from beaver.skills.drafts import DraftService from beaver.skills.learning import EphemeralGuidanceSynthesizer from beaver.tasks.models import TaskRecord @dataclass(slots=True) class SkillResolutionReport: node_id: str skill_query: str required_capabilities: list[str] = field(default_factory=list) selected_skill_names: list[str] = field(default_factory=list) ephemeral_guidance_id: str | None = None ephemeral_guidance_name: str | None = None ephemeral_used: bool = False reason: str = "" def to_dict(self) -> dict[str, Any]: return { "node_id": self.node_id, "skill_query": self.skill_query, "required_capabilities": list(self.required_capabilities), "selected_skill_names": list(self.selected_skill_names), "ephemeral_guidance_id": self.ephemeral_guidance_id, "ephemeral_guidance_name": self.ephemeral_guidance_name, "ephemeral_used": self.ephemeral_used, "reason": self.reason, } class TaskSkillResolver: """Pins published skills or one-run guidance onto generic team nodes.""" def __init__( self, *, skills_loader: SkillsLoader, draft_service: DraftService, retriever: SkillEmbeddingRetriever | None = None, missing_skill_synthesizer: EphemeralGuidanceSynthesizer | None = None, ) -> None: self.skills_loader = skills_loader self.draft_service = draft_service self.retriever = retriever or SkillEmbeddingRetriever() self.missing_skill_synthesizer = missing_skill_synthesizer or EphemeralGuidanceSynthesizer() async def resolve_graph( self, graph: ExecutionGraph, *, task: TaskRecord, user_message: str, attempt_index: int, provider_bundle: ProviderBundle, ) -> tuple[ExecutionGraph, list[SkillResolutionReport]]: resolved_nodes: list[ExecutionNode] = [] reports: list[SkillResolutionReport] = [] for node in graph.nodes: resolved, report = await self.resolve_node( node, task=task, user_message=user_message, attempt_index=attempt_index, provider_bundle=provider_bundle, ) resolved_nodes.append(resolved) reports.append(report) return ExecutionGraph(strategy=graph.strategy, nodes=resolved_nodes), reports async def resolve_node( self, node: ExecutionNode, *, task: TaskRecord, user_message: str, attempt_index: int, provider_bundle: ProviderBundle, ) -> tuple[ExecutionNode, SkillResolutionReport]: skill_query = str(node.agent.metadata.get("skill_query") or node.task or node.node_id).strip() required_capabilities = [ str(item).strip() for item in node.agent.metadata.get("required_capabilities", []) if str(item).strip() ] selected = await self._select_published_skills( query="\n".join( part for part in [ skill_query, node.task, " ".join(required_capabilities), task.goal, user_message, ] if part ), provider_bundle=provider_bundle, ) if selected: pinned = _merge_names(node.inherited_pinned_skills, selected) resolved = self._generic_node( node, pinned_skill_names=pinned, metadata={ **node.agent.metadata, "skill_query": skill_query, "required_capabilities": required_capabilities, "selected_skill_names": selected, "ephemeral_skill_names": [], }, ) return resolved, SkillResolutionReport( node_id=node.node_id, skill_query=skill_query, required_capabilities=required_capabilities, selected_skill_names=selected, ephemeral_used=False, reason="matched published skill", ) missing = await self.missing_skill_synthesizer.synthesize( task=task, user_message=user_message, attempt_index=attempt_index, node_id=node.node_id, node_task=node.task, skill_query=skill_query, required_capabilities=required_capabilities, provider_bundle=provider_bundle, ) resolved = self._generic_node( node, pinned_skill_names=list(node.inherited_pinned_skills), pinned_skill_contexts=[*node.inherited_pinned_skill_contexts, missing.skill_context], metadata={ **node.agent.metadata, "skill_query": skill_query, "required_capabilities": required_capabilities, "selected_skill_names": [], "ephemeral_guidance_id": missing.guidance_id, "ephemeral_guidance_name": missing.guidance_name, "ephemeral_skill_names": [missing.skill_context.name], }, ) return resolved, SkillResolutionReport( node_id=node.node_id, skill_query=skill_query, required_capabilities=required_capabilities, ephemeral_guidance_id=missing.guidance_id, ephemeral_guidance_name=missing.guidance_name, ephemeral_used=True, reason="generated ephemeral guidance for missing sub-agent capability", ) async def _select_published_skills(self, *, query: str, provider_bundle: ProviderBundle) -> list[str]: candidates = self.skills_loader.build_selection_candidates() if not candidates: return [] candidates = await self.retriever.retrieve( query=query, candidates=candidates, top_k=8, api_key=provider_bundle.embedding_runtime.api_key if provider_bundle.embedding_runtime is not None else None, api_base=provider_bundle.embedding_runtime.api_base if provider_bundle.embedding_runtime is not None else None, model=provider_bundle.embedding_runtime.model if provider_bundle.embedding_runtime is not None else None, extra_headers=( provider_bundle.embedding_runtime.extra_headers if provider_bundle.embedding_runtime is not None else None ), timeout_seconds=( provider_bundle.embedding_runtime.request_timeout_seconds if provider_bundle.embedding_runtime is not None else None ), fallback_top_k=8, ) if not candidates: return [] provider = provider_bundle.auxiliary_provider or provider_bundle.main_provider runtime = provider_bundle.auxiliary_runtime or provider_bundle.main_runtime model = getattr(runtime, "model", None) candidate_names = {item["name"] for item in candidates} try: response = await provider.chat( messages=[ { "role": "system", "content": ( "Select published Beaver skills for one generic sub-agent node. " "Return only a JSON array of skill names. Do not invent names. " "If none of the candidates directly match the required guidance, return []." ), }, { "role": "user", "content": ( f"Node skill query:\n{query}\n\n" f"Candidate skills:\n{self._render_candidates(candidates)}\n\n" "Return only JSON, for example: [\"skill-a\"] or []" ), }, ], tools=None, model=model, max_tokens=2048, temperature=0, ) parsed = self._parse_names(response.content or "") except Exception: parsed = [] selected: list[str] = [] for name in parsed: if name in candidate_names and name not in selected: selected.append(name) return selected @staticmethod def _generic_node( node: ExecutionNode, *, pinned_skill_names: list[str], metadata: dict[str, Any], pinned_skill_contexts: list[Any] | None = None, ) -> ExecutionNode: return replace( node, agent=AgentDescriptor( name=node.node_id, role="", system_prompt="", metadata={ **metadata, "sub_agent_kind": "generic_skill_worker", }, ), inherited_pinned_skills=pinned_skill_names, inherited_pinned_skill_contexts=list(pinned_skill_contexts or node.inherited_pinned_skill_contexts), ) @staticmethod def _render_candidates(candidates: list[dict[str, str]]) -> str: return "\n".join(f"- {item['name']}: {item['description']}" for item in candidates) @staticmethod def _parse_names(content: str) -> list[str]: cleaned = content.strip() if cleaned.startswith("```"): lines = cleaned.splitlines() if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"): cleaned = "\n".join(lines[1:-1]).strip() if cleaned.lower().startswith("json"): cleaned = cleaned[4:].strip() try: payload = json.loads(cleaned) except json.JSONDecodeError: return [] if isinstance(payload, dict): for key in ("skills", "selected_skills", "selected"): value = payload.get(key) if isinstance(value, list): payload = value break if not isinstance(payload, list): return [] return [str(item).strip() for item in payload if str(item).strip()] def _merge_names(parent: list[str], selected: list[str]) -> list[str]: result: list[str] = [] for name in [*parent, *selected]: if name and name not in result: result.append(name) return result