feat(team): run parallel nodes with isolated loops
This commit is contained in:
@ -45,6 +45,18 @@ class RecordingProvider(LLMProvider):
|
||||
return "stub-model"
|
||||
|
||||
|
||||
class BlockingProvider(RecordingProvider):
|
||||
def __init__(self, content: str, started: asyncio.Event, release: asyncio.Event) -> None:
|
||||
super().__init__([_response(content)])
|
||||
self.started = started
|
||||
self.release = release
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.started.set()
|
||||
await self.release.wait()
|
||||
return await super().chat(*args, **kwargs)
|
||||
|
||||
|
||||
class StubSkillAssembler:
|
||||
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
|
||||
self.activated_skills = list(activated_skills or [])
|
||||
@ -298,6 +310,57 @@ def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
|
||||
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
|
||||
|
||||
|
||||
def test_team_parallel_starts_nodes_concurrently_with_isolated_loops(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
first_started = asyncio.Event()
|
||||
second_started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
providers = {
|
||||
"one": BlockingProvider("one", first_started, release),
|
||||
"two": BlockingProvider("two", second_started, release),
|
||||
}
|
||||
graph = ExecutionGraph(
|
||||
strategy="parallel",
|
||||
nodes=[
|
||||
ExecutionNode("one", "task one", AgentDescriptor(name="one")),
|
||||
ExecutionNode("two", "task two", AgentDescriptor(name="two")),
|
||||
],
|
||||
)
|
||||
|
||||
async def run_case():
|
||||
loop_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0)
|
||||
task = asyncio.create_task(
|
||||
TeamService(loop).run_team(
|
||||
graph,
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id="run-root",
|
||||
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||
)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(first_started.wait(), timeout=1)
|
||||
await asyncio.wait_for(second_started.wait(), timeout=1)
|
||||
release.set()
|
||||
return await task
|
||||
finally:
|
||||
release.set()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await loop.stop()
|
||||
await loop_task
|
||||
|
||||
result = asyncio.run(run_case())
|
||||
|
||||
assert result.success is True
|
||||
assert [item.node_id for item in result.node_results] == ["one", "two"]
|
||||
|
||||
|
||||
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
loaded = loop.boot()
|
||||
|
||||
Reference in New Issue
Block a user