feat(team): run parallel nodes with isolated loops
This commit is contained in:
@ -18,8 +18,9 @@ if TYPE_CHECKING:
|
|||||||
class TeamGraphScheduler:
|
class TeamGraphScheduler:
|
||||||
"""Execute sequence, parallel, and DAG team graphs."""
|
"""Execute sequence, parallel, and DAG team graphs."""
|
||||||
|
|
||||||
def __init__(self, runner: LocalAgentRunner) -> None:
|
def __init__(self, runner: LocalAgentRunner, *, max_parallel_team_nodes: int = 3) -> None:
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.max_parallel_team_nodes = max(1, int(max_parallel_team_nodes))
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@ -96,7 +97,18 @@ class TeamGraphScheduler:
|
|||||||
nodes: list[ExecutionNode],
|
nodes: list[ExecutionNode],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[NodeRunResult]:
|
) -> list[NodeRunResult]:
|
||||||
return list(await asyncio.gather(*(self._run_node(node, dependency_outputs={}, **kwargs) for node in nodes)))
|
semaphore = asyncio.Semaphore(self.max_parallel_team_nodes)
|
||||||
|
|
||||||
|
async def run_one(node: ExecutionNode) -> NodeRunResult:
|
||||||
|
async with semaphore:
|
||||||
|
return await self._run_node(
|
||||||
|
node,
|
||||||
|
dependency_outputs={},
|
||||||
|
execution_mode="isolated_loop",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(await asyncio.gather(*(run_one(node) for node in nodes)))
|
||||||
|
|
||||||
async def _run_dag(
|
async def _run_dag(
|
||||||
self,
|
self,
|
||||||
@ -164,6 +176,7 @@ class TeamGraphScheduler:
|
|||||||
inherited_pinned_skill_contexts: list["SkillContext"],
|
inherited_pinned_skill_contexts: list["SkillContext"],
|
||||||
allow_candidate_generation: bool,
|
allow_candidate_generation: bool,
|
||||||
dependency_outputs: dict[str, str],
|
dependency_outputs: dict[str, str],
|
||||||
|
execution_mode: str = "shared_loop",
|
||||||
) -> NodeRunResult:
|
) -> NodeRunResult:
|
||||||
try:
|
try:
|
||||||
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
|
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
|
||||||
@ -189,6 +202,7 @@ class TeamGraphScheduler:
|
|||||||
envelope,
|
envelope,
|
||||||
provider_bundle=node_provider_bundle,
|
provider_bundle=node_provider_bundle,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
|
execution_mode=execution_mode,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class LocalAgentRunner:
|
|||||||
*,
|
*,
|
||||||
provider_bundle: ProviderBundle | None = None,
|
provider_bundle: ProviderBundle | None = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
|
execution_mode: str = "shared_loop",
|
||||||
) -> NodeRunResult:
|
) -> NodeRunResult:
|
||||||
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
|
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -30,7 +31,14 @@ class LocalAgentRunner:
|
|||||||
"build a node-specific provider bundle instead."
|
"build a node-specific provider bundle instead."
|
||||||
)
|
)
|
||||||
child_session_id = self._child_session_id(envelope)
|
child_session_id = self._child_session_id(envelope)
|
||||||
runner = self.loop.submit_direct if self.loop.is_running else self.loop.process_direct
|
target_loop = self.loop
|
||||||
|
if execution_mode == "isolated_loop":
|
||||||
|
target_loop = AgentLoop(profile=self.loop.profile, loader=self.loop.loader)
|
||||||
|
runner = (
|
||||||
|
target_loop.process_direct
|
||||||
|
if execution_mode == "isolated_loop"
|
||||||
|
else (self.loop.submit_direct if self.loop.is_running else self.loop.process_direct)
|
||||||
|
)
|
||||||
result = await runner(
|
result = await runner(
|
||||||
envelope.task,
|
envelope.task,
|
||||||
session_id=child_session_id,
|
session_id=child_session_id,
|
||||||
@ -48,7 +56,7 @@ class LocalAgentRunner:
|
|||||||
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
)
|
)
|
||||||
loaded = self.loop.boot()
|
loaded = target_loop.boot()
|
||||||
evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence(
|
evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence(
|
||||||
result.session_id,
|
result.session_id,
|
||||||
result.run_id,
|
result.run_id,
|
||||||
|
|||||||
@ -16,10 +16,10 @@ if TYPE_CHECKING:
|
|||||||
class TeamService:
|
class TeamService:
|
||||||
"""Internal service for Beaver-native multi-agent execution."""
|
"""Internal service for Beaver-native multi-agent execution."""
|
||||||
|
|
||||||
def __init__(self, loop: AgentLoop) -> None:
|
def __init__(self, loop: AgentLoop, *, max_parallel_team_nodes: int = 3) -> None:
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.runner = LocalAgentRunner(loop)
|
self.runner = LocalAgentRunner(loop)
|
||||||
self.scheduler = TeamGraphScheduler(self.runner)
|
self.scheduler = TeamGraphScheduler(self.runner, max_parallel_team_nodes=max_parallel_team_nodes)
|
||||||
|
|
||||||
async def run_team(
|
async def run_team(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -45,6 +45,18 @@ class RecordingProvider(LLMProvider):
|
|||||||
return "stub-model"
|
return "stub-model"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingProvider(RecordingProvider):
|
||||||
|
def __init__(self, content: str, started: asyncio.Event, release: asyncio.Event) -> None:
|
||||||
|
super().__init__([_response(content)])
|
||||||
|
self.started = started
|
||||||
|
self.release = release
|
||||||
|
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.started.set()
|
||||||
|
await self.release.wait()
|
||||||
|
return await super().chat(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class StubSkillAssembler:
|
class StubSkillAssembler:
|
||||||
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
|
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
|
||||||
self.activated_skills = list(activated_skills or [])
|
self.activated_skills = list(activated_skills or [])
|
||||||
@ -298,6 +310,57 @@ def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
|
|||||||
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
|
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_team_parallel_starts_nodes_concurrently_with_isolated_loops(tmp_path: Path) -> None:
|
||||||
|
loop = _loop(tmp_path)
|
||||||
|
first_started = asyncio.Event()
|
||||||
|
second_started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
providers = {
|
||||||
|
"one": BlockingProvider("one", first_started, release),
|
||||||
|
"two": BlockingProvider("two", second_started, release),
|
||||||
|
}
|
||||||
|
graph = ExecutionGraph(
|
||||||
|
strategy="parallel",
|
||||||
|
nodes=[
|
||||||
|
ExecutionNode("one", "task one", AgentDescriptor(name="one")),
|
||||||
|
ExecutionNode("two", "task two", AgentDescriptor(name="two")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_case():
|
||||||
|
loop_task = asyncio.create_task(loop.run())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
TeamService(loop).run_team(
|
||||||
|
graph,
|
||||||
|
parent_task_id=None,
|
||||||
|
parent_session_id="session-root",
|
||||||
|
parent_run_id="run-root",
|
||||||
|
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(first_started.wait(), timeout=1)
|
||||||
|
await asyncio.wait_for(second_started.wait(), timeout=1)
|
||||||
|
release.set()
|
||||||
|
return await task
|
||||||
|
finally:
|
||||||
|
release.set()
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await loop.stop()
|
||||||
|
await loop_task
|
||||||
|
|
||||||
|
result = asyncio.run(run_case())
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert [item.node_id for item in result.node_results] == ["one", "two"]
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
|
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
|
||||||
loop = _loop(tmp_path)
|
loop = _loop(tmp_path)
|
||||||
loaded = loop.boot()
|
loaded = loop.boot()
|
||||||
|
|||||||
Reference in New Issue
Block a user