feat(team): run parallel nodes with isolated loops

This commit is contained in:
2026-05-22 11:39:17 +08:00
parent c53e221117
commit 4022db8887
4 changed files with 91 additions and 6 deletions

View File

@ -18,8 +18,9 @@ if TYPE_CHECKING:
class TeamGraphScheduler: class TeamGraphScheduler:
"""Execute sequence, parallel, and DAG team graphs.""" """Execute sequence, parallel, and DAG team graphs."""
def __init__(self, runner: LocalAgentRunner) -> None: def __init__(self, runner: LocalAgentRunner, *, max_parallel_team_nodes: int = 3) -> None:
self.runner = runner self.runner = runner
self.max_parallel_team_nodes = max(1, int(max_parallel_team_nodes))
async def run( async def run(
self, self,
@ -96,7 +97,18 @@ class TeamGraphScheduler:
nodes: list[ExecutionNode], nodes: list[ExecutionNode],
**kwargs, **kwargs,
) -> list[NodeRunResult]: ) -> list[NodeRunResult]:
return list(await asyncio.gather(*(self._run_node(node, dependency_outputs={}, **kwargs) for node in nodes))) semaphore = asyncio.Semaphore(self.max_parallel_team_nodes)
async def run_one(node: ExecutionNode) -> NodeRunResult:
async with semaphore:
return await self._run_node(
node,
dependency_outputs={},
execution_mode="isolated_loop",
**kwargs,
)
return list(await asyncio.gather(*(run_one(node) for node in nodes)))
async def _run_dag( async def _run_dag(
self, self,
@ -164,6 +176,7 @@ class TeamGraphScheduler:
inherited_pinned_skill_contexts: list["SkillContext"], inherited_pinned_skill_contexts: list["SkillContext"],
allow_candidate_generation: bool, allow_candidate_generation: bool,
dependency_outputs: dict[str, str], dependency_outputs: dict[str, str],
execution_mode: str = "shared_loop",
) -> NodeRunResult: ) -> NodeRunResult:
try: try:
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills) pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
@ -189,6 +202,7 @@ class TeamGraphScheduler:
envelope, envelope,
provider_bundle=node_provider_bundle, provider_bundle=node_provider_bundle,
allow_candidate_generation=allow_candidate_generation, allow_candidate_generation=allow_candidate_generation,
execution_mode=execution_mode,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise

View File

@ -23,6 +23,7 @@ class LocalAgentRunner:
*, *,
provider_bundle: ProviderBundle | None = None, provider_bundle: ProviderBundle | None = None,
allow_candidate_generation: bool = False, allow_candidate_generation: bool = False,
execution_mode: str = "shared_loop",
) -> NodeRunResult: ) -> NodeRunResult:
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name): if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
raise ValueError( raise ValueError(
@ -30,7 +31,14 @@ class LocalAgentRunner:
"build a node-specific provider bundle instead." "build a node-specific provider bundle instead."
) )
child_session_id = self._child_session_id(envelope) child_session_id = self._child_session_id(envelope)
runner = self.loop.submit_direct if self.loop.is_running else self.loop.process_direct target_loop = self.loop
if execution_mode == "isolated_loop":
target_loop = AgentLoop(profile=self.loop.profile, loader=self.loop.loader)
runner = (
target_loop.process_direct
if execution_mode == "isolated_loop"
else (self.loop.submit_direct if self.loop.is_running else self.loop.process_direct)
)
result = await runner( result = await runner(
envelope.task, envelope.task,
session_id=child_session_id, session_id=child_session_id,
@ -48,7 +56,7 @@ class LocalAgentRunner:
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts, pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
allow_candidate_generation=allow_candidate_generation, allow_candidate_generation=allow_candidate_generation,
) )
loaded = self.loop.boot() loaded = target_loop.boot()
evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence( evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence(
result.session_id, result.session_id,
result.run_id, result.run_id,

View File

@ -16,10 +16,10 @@ if TYPE_CHECKING:
class TeamService: class TeamService:
"""Internal service for Beaver-native multi-agent execution.""" """Internal service for Beaver-native multi-agent execution."""
def __init__(self, loop: AgentLoop) -> None: def __init__(self, loop: AgentLoop, *, max_parallel_team_nodes: int = 3) -> None:
self.loop = loop self.loop = loop
self.runner = LocalAgentRunner(loop) self.runner = LocalAgentRunner(loop)
self.scheduler = TeamGraphScheduler(self.runner) self.scheduler = TeamGraphScheduler(self.runner, max_parallel_team_nodes=max_parallel_team_nodes)
async def run_team( async def run_team(
self, self,

View File

@ -45,6 +45,18 @@ class RecordingProvider(LLMProvider):
return "stub-model" return "stub-model"
class BlockingProvider(RecordingProvider):
def __init__(self, content: str, started: asyncio.Event, release: asyncio.Event) -> None:
super().__init__([_response(content)])
self.started = started
self.release = release
async def chat(self, *args, **kwargs) -> LLMResponse:
self.started.set()
await self.release.wait()
return await super().chat(*args, **kwargs)
class StubSkillAssembler: class StubSkillAssembler:
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None: def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
self.activated_skills = list(activated_skills or []) self.activated_skills = list(activated_skills or [])
@ -298,6 +310,57 @@ def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
assert [item.output_text for item in result.node_results] == ["one", "two", "three"] assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
def test_team_parallel_starts_nodes_concurrently_with_isolated_loops(tmp_path: Path) -> None:
loop = _loop(tmp_path)
first_started = asyncio.Event()
second_started = asyncio.Event()
release = asyncio.Event()
providers = {
"one": BlockingProvider("one", first_started, release),
"two": BlockingProvider("two", second_started, release),
}
graph = ExecutionGraph(
strategy="parallel",
nodes=[
ExecutionNode("one", "task one", AgentDescriptor(name="one")),
ExecutionNode("two", "task two", AgentDescriptor(name="two")),
],
)
async def run_case():
loop_task = asyncio.create_task(loop.run())
await asyncio.sleep(0)
task = asyncio.create_task(
TeamService(loop).run_team(
graph,
parent_task_id=None,
parent_session_id="session-root",
parent_run_id="run-root",
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
)
)
try:
await asyncio.wait_for(first_started.wait(), timeout=1)
await asyncio.wait_for(second_started.wait(), timeout=1)
release.set()
return await task
finally:
release.set()
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await loop.stop()
await loop_task
result = asyncio.run(run_case())
assert result.success is True
assert [item.node_id for item in result.node_results] == ["one", "two"]
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None: def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
loop = _loop(tmp_path) loop = _loop(tmp_path)
loaded = loop.boot() loaded = loop.boot()