feat(team): run parallel nodes with isolated loops
This commit is contained in:
@ -18,8 +18,9 @@ if TYPE_CHECKING:
|
||||
class TeamGraphScheduler:
|
||||
"""Execute sequence, parallel, and DAG team graphs."""
|
||||
|
||||
def __init__(self, runner: LocalAgentRunner) -> None:
|
||||
def __init__(self, runner: LocalAgentRunner, *, max_parallel_team_nodes: int = 3) -> None:
|
||||
self.runner = runner
|
||||
self.max_parallel_team_nodes = max(1, int(max_parallel_team_nodes))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@ -96,7 +97,18 @@ class TeamGraphScheduler:
|
||||
nodes: list[ExecutionNode],
|
||||
**kwargs,
|
||||
) -> list[NodeRunResult]:
|
||||
return list(await asyncio.gather(*(self._run_node(node, dependency_outputs={}, **kwargs) for node in nodes)))
|
||||
semaphore = asyncio.Semaphore(self.max_parallel_team_nodes)
|
||||
|
||||
async def run_one(node: ExecutionNode) -> NodeRunResult:
|
||||
async with semaphore:
|
||||
return await self._run_node(
|
||||
node,
|
||||
dependency_outputs={},
|
||||
execution_mode="isolated_loop",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return list(await asyncio.gather(*(run_one(node) for node in nodes)))
|
||||
|
||||
async def _run_dag(
|
||||
self,
|
||||
@ -164,6 +176,7 @@ class TeamGraphScheduler:
|
||||
inherited_pinned_skill_contexts: list["SkillContext"],
|
||||
allow_candidate_generation: bool,
|
||||
dependency_outputs: dict[str, str],
|
||||
execution_mode: str = "shared_loop",
|
||||
) -> NodeRunResult:
|
||||
try:
|
||||
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
|
||||
@ -189,6 +202,7 @@ class TeamGraphScheduler:
|
||||
envelope,
|
||||
provider_bundle=node_provider_bundle,
|
||||
allow_candidate_generation=allow_candidate_generation,
|
||||
execution_mode=execution_mode,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
@ -23,6 +23,7 @@ class LocalAgentRunner:
|
||||
*,
|
||||
provider_bundle: ProviderBundle | None = None,
|
||||
allow_candidate_generation: bool = False,
|
||||
execution_mode: str = "shared_loop",
|
||||
) -> NodeRunResult:
|
||||
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
|
||||
raise ValueError(
|
||||
@ -30,7 +31,14 @@ class LocalAgentRunner:
|
||||
"build a node-specific provider bundle instead."
|
||||
)
|
||||
child_session_id = self._child_session_id(envelope)
|
||||
runner = self.loop.submit_direct if self.loop.is_running else self.loop.process_direct
|
||||
target_loop = self.loop
|
||||
if execution_mode == "isolated_loop":
|
||||
target_loop = AgentLoop(profile=self.loop.profile, loader=self.loop.loader)
|
||||
runner = (
|
||||
target_loop.process_direct
|
||||
if execution_mode == "isolated_loop"
|
||||
else (self.loop.submit_direct if self.loop.is_running else self.loop.process_direct)
|
||||
)
|
||||
result = await runner(
|
||||
envelope.task,
|
||||
session_id=child_session_id,
|
||||
@ -48,7 +56,7 @@ class LocalAgentRunner:
|
||||
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
||||
allow_candidate_generation=allow_candidate_generation,
|
||||
)
|
||||
loaded = self.loop.boot()
|
||||
loaded = target_loop.boot()
|
||||
evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence(
|
||||
result.session_id,
|
||||
result.run_id,
|
||||
|
||||
Reference in New Issue
Block a user