Files
beaver_project/app-instance/backend/beaver/tasks/skill_resolver.py

418 lines
16 KiB
Python

"""Resolve Task team nodes to pinned skills for generic sub-agents."""
from __future__ import annotations
import json
from dataclasses import dataclass, field, replace
from typing import Any
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
from beaver.engine.context import SkillContext
from beaver.engine.providers import ProviderBundle
from beaver.skills.assembler.embedding_retriever import SkillEmbeddingRetriever
from beaver.skills.catalog.loader import SkillsLoader
from beaver.skills.catalog.utils import strip_frontmatter
from beaver.skills.drafts import DraftService
from beaver.skills.learning import EphemeralGuidanceSynthesizer
from beaver.tasks.models import TaskRecord
@dataclass(slots=True)
class SkillResolutionReport:
node_id: str
skill_query: str
required_capabilities: list[str] = field(default_factory=list)
selected_skill_names: list[str] = field(default_factory=list)
ephemeral_guidance_id: str | None = None
ephemeral_guidance_name: str | None = None
ephemeral_used: bool = False
requested_skill_name: str | None = None
exact_binding_used: bool = False
warnings: list[str] = field(default_factory=list)
reason: str = ""
def to_dict(self) -> dict[str, Any]:
return {
"node_id": self.node_id,
"skill_query": self.skill_query,
"required_capabilities": list(self.required_capabilities),
"selected_skill_names": list(self.selected_skill_names),
"ephemeral_guidance_id": self.ephemeral_guidance_id,
"ephemeral_guidance_name": self.ephemeral_guidance_name,
"ephemeral_used": self.ephemeral_used,
"requested_skill_name": self.requested_skill_name,
"exact_binding_used": self.exact_binding_used,
"warnings": list(self.warnings),
"reason": self.reason,
}
class TaskSkillResolver:
"""Pins published skills or one-run guidance onto generic team nodes."""
def __init__(
self,
*,
skills_loader: SkillsLoader,
draft_service: DraftService,
retriever: SkillEmbeddingRetriever | None = None,
missing_skill_synthesizer: EphemeralGuidanceSynthesizer | None = None,
) -> None:
self.skills_loader = skills_loader
self.draft_service = draft_service
self.retriever = retriever or SkillEmbeddingRetriever()
self.missing_skill_synthesizer = missing_skill_synthesizer or EphemeralGuidanceSynthesizer()
async def resolve_graph(
self,
graph: ExecutionGraph,
*,
task: TaskRecord,
user_message: str,
attempt_index: int,
provider_bundle: ProviderBundle,
) -> tuple[ExecutionGraph, list[SkillResolutionReport]]:
resolved_nodes: list[ExecutionNode] = []
reports: list[SkillResolutionReport] = []
for node in graph.nodes:
resolved, report = await self.resolve_node(
node,
task=task,
user_message=user_message,
attempt_index=attempt_index,
provider_bundle=provider_bundle,
)
resolved_nodes.append(resolved)
reports.append(report)
return ExecutionGraph(strategy=graph.strategy, nodes=resolved_nodes), reports
async def resolve_node(
self,
node: ExecutionNode,
*,
task: TaskRecord,
user_message: str,
attempt_index: int,
provider_bundle: ProviderBundle,
) -> tuple[ExecutionNode, SkillResolutionReport]:
use_skill = str(node.agent.metadata.get("use_skill") or "").strip()
skill_query = str(node.agent.metadata.get("skill_query") or node.task or node.node_id).strip()
warnings: list[str] = []
required_capabilities = [
str(item).strip()
for item in node.agent.metadata.get("required_capabilities", [])
if str(item).strip()
]
if use_skill:
exact_context = self._load_exact_skill_context(use_skill)
if exact_context is not None:
resolved = self._generic_node(
node,
pinned_skill_names=_merge_names(node.inherited_pinned_skills, [use_skill]),
pinned_skill_contexts=_merge_skill_contexts(
node.inherited_pinned_skill_contexts,
[exact_context],
),
metadata={
**node.agent.metadata,
"use_skill": use_skill,
"skill_query": skill_query,
"required_capabilities": required_capabilities,
"selected_skill_names": [use_skill],
"ephemeral_skill_names": [],
"exact_binding_used": True,
},
)
return resolved, SkillResolutionReport(
node_id=node.node_id,
skill_query=skill_query,
required_capabilities=required_capabilities,
selected_skill_names=[use_skill],
requested_skill_name=use_skill,
exact_binding_used=True,
reason="exact use_skill binding",
)
warnings.append(f"use_skill unresolved: {use_skill}")
if self._is_summary_only_node(node, skill_query=skill_query, required_capabilities=required_capabilities):
resolved = self._generic_node(
node,
pinned_skill_names=[],
pinned_skill_contexts=[],
metadata={
**node.agent.metadata,
"skill_query": skill_query,
"required_capabilities": required_capabilities,
"selected_skill_names": [],
"ephemeral_skill_names": [],
"exact_binding_used": False,
"summary_uses_dependency_outputs_only": True,
},
)
return resolved, SkillResolutionReport(
node_id=node.node_id,
skill_query=skill_query,
required_capabilities=required_capabilities,
selected_skill_names=[],
ephemeral_used=False,
requested_skill_name=use_skill or None,
exact_binding_used=False,
warnings=warnings,
reason="summary node uses dependency outputs directly",
)
selected = await self._select_published_skills(
query="\n".join(
part
for part in [
skill_query,
node.task,
" ".join(required_capabilities),
task.goal,
user_message,
]
if part
),
provider_bundle=provider_bundle,
)
if selected:
pinned = _merge_names(node.inherited_pinned_skills, selected)
resolved = self._generic_node(
node,
pinned_skill_names=pinned,
metadata={
**node.agent.metadata,
"skill_query": skill_query,
"required_capabilities": required_capabilities,
"selected_skill_names": selected,
"ephemeral_skill_names": [],
"exact_binding_used": False,
},
)
return resolved, SkillResolutionReport(
node_id=node.node_id,
skill_query=skill_query,
required_capabilities=required_capabilities,
selected_skill_names=selected,
ephemeral_used=False,
requested_skill_name=use_skill or None,
exact_binding_used=False,
warnings=warnings,
reason="matched published skill",
)
missing = await self.missing_skill_synthesizer.synthesize(
task=task,
user_message=user_message,
attempt_index=attempt_index,
node_id=node.node_id,
node_task=node.task,
skill_query=skill_query,
required_capabilities=required_capabilities,
provider_bundle=provider_bundle,
)
resolved = self._generic_node(
node,
pinned_skill_names=list(node.inherited_pinned_skills),
pinned_skill_contexts=[*node.inherited_pinned_skill_contexts, missing.skill_context],
metadata={
**node.agent.metadata,
"skill_query": skill_query,
"required_capabilities": required_capabilities,
"selected_skill_names": [],
"ephemeral_guidance_id": missing.guidance_id,
"ephemeral_guidance_name": missing.guidance_name,
"ephemeral_skill_names": [missing.skill_context.name],
"exact_binding_used": False,
},
)
return resolved, SkillResolutionReport(
node_id=node.node_id,
skill_query=skill_query,
required_capabilities=required_capabilities,
ephemeral_guidance_id=missing.guidance_id,
ephemeral_guidance_name=missing.guidance_name,
ephemeral_used=True,
requested_skill_name=use_skill or None,
exact_binding_used=False,
warnings=warnings,
reason="generated ephemeral guidance for missing sub-agent capability",
)
def _load_exact_skill_context(self, name: str) -> SkillContext | None:
record = self.skills_loader.get_skill_record(name)
raw_content = self.skills_loader.load_published_skill(name)
content = strip_frontmatter(raw_content).strip() if raw_content else ""
if record is None or not content:
return None
return SkillContext(
name=name,
content=content,
version=record.version,
content_hash=record.content_hash or "",
activation_reason="explicit_node_binding",
tool_hints=list(record.tool_hints),
)
async def _select_published_skills(self, *, query: str, provider_bundle: ProviderBundle) -> list[str]:
candidates = self.skills_loader.build_selection_candidates()
if not candidates:
return []
candidates = await self.retriever.retrieve(
query=query,
candidates=candidates,
top_k=8,
api_key=provider_bundle.embedding_runtime.api_key if provider_bundle.embedding_runtime is not None else None,
api_base=provider_bundle.embedding_runtime.api_base if provider_bundle.embedding_runtime is not None else None,
model=provider_bundle.embedding_runtime.model if provider_bundle.embedding_runtime is not None else None,
extra_headers=(
provider_bundle.embedding_runtime.extra_headers
if provider_bundle.embedding_runtime is not None
else None
),
timeout_seconds=(
provider_bundle.embedding_runtime.request_timeout_seconds
if provider_bundle.embedding_runtime is not None
else None
),
fallback_top_k=8,
)
if not candidates:
return []
provider = provider_bundle.auxiliary_provider or provider_bundle.main_provider
runtime = provider_bundle.auxiliary_runtime or provider_bundle.main_runtime
model = getattr(runtime, "model", None)
candidate_names = {item["name"] for item in candidates}
try:
response = await provider.chat(
messages=[
{
"role": "system",
"content": (
"Select published Beaver skills for one generic sub-agent node. "
"Return only a JSON array of skill names. Do not invent names. "
"If none of the candidates directly match the required guidance, return []."
),
},
{
"role": "user",
"content": (
f"Node skill query:\n{query}\n\n"
f"Candidate skills:\n{self._render_candidates(candidates)}\n\n"
"Return only JSON, for example: [\"skill-a\"] or []"
),
},
],
tools=None,
model=model,
max_tokens=2048,
temperature=0,
)
parsed = self._parse_names(response.content or "")
except Exception:
parsed = []
selected: list[str] = []
for name in parsed:
if name in candidate_names and name not in selected:
selected.append(name)
return selected
@staticmethod
def _is_summary_only_node(
node: ExecutionNode,
*,
skill_query: str,
required_capabilities: list[str],
) -> bool:
node_id = node.node_id.strip().lower()
query = skill_query.strip().lower()
capabilities = {item.strip().lower() for item in required_capabilities}
task_text = node.task.strip().lower()
summary_identity = node_id in {"summarize", "summary", "synthesis"} or query in {
"summarization",
"summary",
"synthesis",
"final synthesis",
}
text_only_capabilities = not capabilities or capabilities.issubset(
{"text generation", "summarization", "summary", "synthesis"}
)
dependency_summary_task = (
"summary" in task_text
or "summarize" in task_text
or "synthesis" in task_text
or "compile" in task_text
)
return summary_identity and text_only_capabilities and dependency_summary_task
@staticmethod
def _generic_node(
node: ExecutionNode,
*,
pinned_skill_names: list[str],
metadata: dict[str, Any],
pinned_skill_contexts: list[Any] | None = None,
) -> ExecutionNode:
return replace(
node,
agent=AgentDescriptor(
name=node.node_id,
role="",
system_prompt="",
metadata={
**metadata,
"sub_agent_kind": "generic_skill_worker",
},
),
inherited_pinned_skills=pinned_skill_names,
inherited_pinned_skill_contexts=list(
node.inherited_pinned_skill_contexts if pinned_skill_contexts is None else pinned_skill_contexts
),
)
@staticmethod
def _render_candidates(candidates: list[dict[str, str]]) -> str:
return "\n".join(f"- {item['name']}: {item['description']}" for item in candidates)
@staticmethod
def _parse_names(content: str) -> list[str]:
cleaned = content.strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
cleaned = "\n".join(lines[1:-1]).strip()
if cleaned.lower().startswith("json"):
cleaned = cleaned[4:].strip()
try:
payload = json.loads(cleaned)
except json.JSONDecodeError:
return []
if isinstance(payload, dict):
for key in ("skills", "selected_skills", "selected"):
value = payload.get(key)
if isinstance(value, list):
payload = value
break
if not isinstance(payload, list):
return []
return [str(item).strip() for item in payload if str(item).strip()]
def _merge_names(parent: list[str], selected: list[str]) -> list[str]:
result: list[str] = []
for name in [*parent, *selected]:
if name and name not in result:
result.append(name)
return result
def _merge_skill_contexts(parent: list[SkillContext], selected: list[SkillContext]) -> list[SkillContext]:
result: list[SkillContext] = []
seen: set[str] = set()
for context in [*parent, *selected]:
if context.name in seen:
continue
seen.add(context.name)
result.append(context)
return result