Files
beaver_project/app-instance/backend/beaver/services/team_service.py

91 lines
3.8 KiB
Python

"""Application service for coordinated team runs."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
from beaver.coordinator import ExecutionGraph, ExecutionNode, LocalAgentRunner, TeamGraphScheduler, TeamRunResult
from beaver.engine import AgentLoop
from beaver.engine.providers import ProviderBundle
if TYPE_CHECKING:
from beaver.engine.context import SkillContext
class TeamService:
"""Internal service for Beaver-native multi-agent execution."""
def __init__(self, loop: AgentLoop, *, max_parallel_team_nodes: int = 3) -> None:
self.loop = loop
self.runner = LocalAgentRunner(loop)
self.scheduler = TeamGraphScheduler(self.runner, max_parallel_team_nodes=max_parallel_team_nodes)
async def run_team(
self,
graph: ExecutionGraph,
*,
parent_task_id: str | None,
parent_session_id: str,
parent_run_id: str | None = None,
provider_bundle: ProviderBundle | None = None,
provider_bundle_factory: Callable[[ExecutionNode], ProviderBundle | None] | None = None,
inherited_pinned_skills: list[str] | None = None,
inherited_pinned_skill_contexts: list["SkillContext"] | None = None,
allow_candidate_generation: bool = False,
) -> TeamRunResult:
"""Run a team graph inside the parent task context."""
self._validate_parent_task(parent_task_id, parent_session_id)
result = await self.scheduler.run(
graph,
parent_task_id=parent_task_id,
parent_session_id=parent_session_id,
parent_run_id=parent_run_id,
provider_bundle=provider_bundle,
provider_bundle_factory=provider_bundle_factory,
inherited_pinned_skills=inherited_pinned_skills,
inherited_pinned_skill_contexts=inherited_pinned_skill_contexts,
allow_candidate_generation=allow_candidate_generation,
)
self._attach_runs_to_parent_task(result)
return result
def run(self, task: str) -> str:
"""Compatibility shim for old callers that only expected a string."""
return f"team service requires run_team() for coordinated execution: {task}"
def _validate_parent_task(self, parent_task_id: str | None, parent_session_id: str) -> None:
if not parent_task_id:
return
loaded = self.loop.boot()
task_service = getattr(loaded, "task_service", None)
if task_service is None:
raise RuntimeError("TeamService requires task_service when parent_task_id is provided")
task = task_service.get_task(parent_task_id)
if task is None:
raise ValueError(f"Unknown parent_task_id: {parent_task_id}")
if task.session_id != parent_session_id:
raise ValueError(
f"parent_task_id {parent_task_id!r} belongs to session {task.session_id!r}, "
f"not {parent_session_id!r}"
)
def _attach_runs_to_parent_task(self, result: TeamRunResult) -> None:
if not result.task_id or not result.run_ids:
return
loaded = self.loop.boot()
task_service = getattr(loaded, "task_service", None)
if task_service is None or task_service.get_task(result.task_id) is None:
return
run_store = getattr(loaded, "run_memory_store", None)
for run_id in result.run_ids:
skill_names: list[str] = []
if run_store is not None:
for record in run_store.list_runs():
if record.run_id == run_id:
skill_names = [receipt.skill_name for receipt in record.activated_skills]
break
task_service.append_run(result.task_id, run_id, skill_names=skill_names)