feat(team): run parallel nodes with isolated loops

This commit is contained in:
2026-05-22 11:39:17 +08:00
parent c53e221117
commit 4022db8887
4 changed files with 91 additions and 6 deletions

View File

@ -45,6 +45,18 @@ class RecordingProvider(LLMProvider):
return "stub-model"
class BlockingProvider(RecordingProvider):
def __init__(self, content: str, started: asyncio.Event, release: asyncio.Event) -> None:
super().__init__([_response(content)])
self.started = started
self.release = release
async def chat(self, *args, **kwargs) -> LLMResponse:
self.started.set()
await self.release.wait()
return await super().chat(*args, **kwargs)
class StubSkillAssembler:
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
self.activated_skills = list(activated_skills or [])
@ -298,6 +310,57 @@ def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
def test_team_parallel_starts_nodes_concurrently_with_isolated_loops(tmp_path: Path) -> None:
loop = _loop(tmp_path)
first_started = asyncio.Event()
second_started = asyncio.Event()
release = asyncio.Event()
providers = {
"one": BlockingProvider("one", first_started, release),
"two": BlockingProvider("two", second_started, release),
}
graph = ExecutionGraph(
strategy="parallel",
nodes=[
ExecutionNode("one", "task one", AgentDescriptor(name="one")),
ExecutionNode("two", "task two", AgentDescriptor(name="two")),
],
)
async def run_case():
loop_task = asyncio.create_task(loop.run())
await asyncio.sleep(0)
task = asyncio.create_task(
TeamService(loop).run_team(
graph,
parent_task_id=None,
parent_session_id="session-root",
parent_run_id="run-root",
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
)
)
try:
await asyncio.wait_for(first_started.wait(), timeout=1)
await asyncio.wait_for(second_started.wait(), timeout=1)
release.set()
return await task
finally:
release.set()
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await loop.stop()
await loop_task
result = asyncio.run(run_case())
assert result.success is True
assert [item.node_id for item in result.node_results] == ["one", "two"]
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
loop = _loop(tmp_path)
loaded = loop.boot()