feat(tasks): add skill-templated task graph execution
This commit is contained in:
@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine import AgentLoop, AgentRunResult, EngineLoader
|
||||
from beaver.engine import loop as loop_module
|
||||
|
||||
|
||||
def _run_result(run_id: str, output_text: str) -> AgentRunResult:
|
||||
@ -45,3 +47,37 @@ def test_running_loop_handles_reentrant_submit_direct(tmp_path) -> None:
|
||||
assert calls == ["outer", "inner"]
|
||||
|
||||
asyncio.run(run_case())
|
||||
|
||||
|
||||
def test_web_search_loop_guard_stops_after_repeated_low_quality_results() -> None:
|
||||
guard = loop_module._WebSearchLoopGuard()
|
||||
low_quality = json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": "weather beijing",
|
||||
"quality": "low",
|
||||
"results": [{"title": "Example", "url": "https://example.com", "snippet": ""}],
|
||||
}
|
||||
)
|
||||
|
||||
assert guard.observe_result("web_search", low_quality) is None
|
||||
assert guard.observe_result("web_search", low_quality) is None
|
||||
|
||||
guidance = guard.observe_result("web_search", low_quality)
|
||||
|
||||
assert guidance is not None
|
||||
assert guidance["finish_reason"] == "web_search_low_quality_budget"
|
||||
assert "weather beijing" in guidance["message"]
|
||||
|
||||
|
||||
def test_web_search_loop_guard_resets_after_useful_result() -> None:
|
||||
guard = loop_module._WebSearchLoopGuard()
|
||||
low_quality = json.dumps({"success": True, "query": "weather", "quality": "low", "results": []})
|
||||
useful = json.dumps({"success": True, "query": "weather", "quality": "high", "results": []})
|
||||
|
||||
assert guard.observe_result("web_search", low_quality) is None
|
||||
assert guard.observe_result("web_search", useful) is None
|
||||
assert guard.observe_result("web_search", low_quality) is None
|
||||
assert guard.observe_result("web_search", low_quality) is None
|
||||
|
||||
assert guard.observe_result("web_search", low_quality) is not None
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -44,6 +46,49 @@ class ToolCallingProvider(LLMProvider):
|
||||
return "stub"
|
||||
|
||||
|
||||
class ParallelToolProvider(LLMProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.calls = 0
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float = 0.7,
|
||||
thinking_enabled: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(id="call-1", name="read_file", arguments={"path": "README.md"}),
|
||||
ToolCallRequest(id="call-2", name="search_files", arguments={"query": "Beaver"}),
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="done")
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "stub"
|
||||
|
||||
|
||||
class ConcurrentReadOnlyExecutor:
|
||||
def __init__(self) -> None:
|
||||
self.started: list[str] = []
|
||||
self._both_started = asyncio.Event()
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCallRequest | dict[str, Any], *, context=None):
|
||||
name = getattr(tool_call, "name", "")
|
||||
self.started.append(name)
|
||||
if len(self.started) >= 2:
|
||||
self._both_started.set()
|
||||
await asyncio.wait_for(self._both_started.wait(), timeout=0.2)
|
||||
return SimpleNamespace(success=True, error=None, content=f"{name} result", tool_name=name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
@ -69,3 +114,63 @@ async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
|
||||
assert result.output_text == "done"
|
||||
assert replay_executor.traces
|
||||
assert replay_executor.traces[0]["tool_name"] == "read_file"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_runs_read_only_tool_calls_concurrently(tmp_path: Path) -> None:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
provider = ParallelToolProvider()
|
||||
executor = ConcurrentReadOnlyExecutor()
|
||||
runtime = SimpleNamespace(model="stub", provider_name="stub")
|
||||
|
||||
result = await loop.process_direct(
|
||||
"Read and search the workspace.",
|
||||
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
|
||||
include_skill_assembly=False,
|
||||
pinned_skill_names=[],
|
||||
tool_executor_override=executor,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
|
||||
assert result.output_text == "done"
|
||||
assert executor.started == ["read_file", "search_files"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_records_latency_breakdown(tmp_path: Path) -> None:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
provider = ParallelToolProvider()
|
||||
executor = ConcurrentReadOnlyExecutor()
|
||||
runtime = SimpleNamespace(model="stub", provider_name="stub")
|
||||
|
||||
result = await loop.process_direct(
|
||||
"Read and search the workspace.",
|
||||
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
|
||||
include_skill_assembly=False,
|
||||
pinned_skill_names=[],
|
||||
tool_executor_override=executor,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
|
||||
latency = result.usage["latency_ms"]
|
||||
expected_keys = {
|
||||
"router_ms",
|
||||
"mcp_ms",
|
||||
"skill_assembly_ms",
|
||||
"tool_assembly_ms",
|
||||
"context_build_ms",
|
||||
"llm_ms",
|
||||
"tool_ms",
|
||||
"session_write_ms",
|
||||
"total_ms",
|
||||
}
|
||||
assert expected_keys.issubset(latency)
|
||||
assert all(isinstance(latency[key], (int, float)) and latency[key] >= 0 for key in expected_keys)
|
||||
assert latency["llm_ms"] > 0
|
||||
assert latency["tool_ms"] > 0
|
||||
assert latency["total_ms"] >= latency["llm_ms"]
|
||||
|
||||
loaded = loop.boot()
|
||||
events = loaded.session_manager.get_run_event_records(result.session_id, result.run_id)
|
||||
completed = next(event for event in events if event.event_type == "run_completed")
|
||||
assert completed.event_payload["latency_ms"] == latency
|
||||
|
||||
67
app-instance/backend/tests/unit/test_agent_team_toggle.py
Normal file
67
app-instance/backend/tests/unit/test_agent_team_toggle.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.tasks import TaskExecutionPlanner, TaskRecord
|
||||
|
||||
|
||||
class _TeamPlannerProvider(LLMProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.calls = 0
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
self.calls += 1
|
||||
return LLMResponse(
|
||||
content='{"mode":"team","reason":"parallel research","strategy":"parallel","nodes":[{"node_id":"research","task":"research","agent":{"name":"researcher"}}]}',
|
||||
finish_reason="stop",
|
||||
provider_name="stub",
|
||||
model="stub-model",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "stub-model"
|
||||
|
||||
|
||||
def test_agent_team_can_be_disabled_by_environment(monkeypatch) -> None:
|
||||
monkeypatch.setenv("BEAVER_AGENT_TEAM_ENABLED", "0")
|
||||
provider = _TeamPlannerProvider()
|
||||
task = TaskRecord(
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
description="research and compare options",
|
||||
goal="research and compare options",
|
||||
constraints=[],
|
||||
priority=0,
|
||||
status="open",
|
||||
creator="test",
|
||||
created_at="now",
|
||||
updated_at="now",
|
||||
)
|
||||
bundle = ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
main_provider=provider,
|
||||
)
|
||||
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner().plan(
|
||||
task=task,
|
||||
user_message="research and compare options",
|
||||
attempt_index=1,
|
||||
provider_bundle=bundle,
|
||||
)
|
||||
)
|
||||
|
||||
assert plan.mode == "single"
|
||||
assert plan.reason == "planner_disabled_by_environment"
|
||||
assert provider.calls == 0
|
||||
@ -8,7 +8,8 @@ import pytest
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.coordinator import AgentDescriptor, DelegationEnvelope, ExecutionGraph, ExecutionNode
|
||||
from beaver.coordinator import AgentDescriptor, DelegationEnvelope, ExecutionGraph, ExecutionNode, NodeRunResult
|
||||
from beaver.coordinator.execution.scheduler import TeamGraphScheduler
|
||||
from beaver.coordinator.local import LocalAgentRunner
|
||||
from beaver.engine import AgentLoop, EngineLoader
|
||||
from beaver.engine.context import SkillContext
|
||||
@ -90,6 +91,15 @@ class PerRunSnapshotMemoryService(MemoryService):
|
||||
return MemorySnapshot(memory_block="# Memory\n\nshared-snapshot", user_block=None)
|
||||
|
||||
|
||||
class CapturingRunner:
|
||||
def __init__(self) -> None:
|
||||
self.envelopes: list[DelegationEnvelope] = []
|
||||
|
||||
async def run(self, envelope: DelegationEnvelope, **kwargs) -> NodeRunResult:
|
||||
self.envelopes.append(envelope)
|
||||
return NodeRunResult(node_id=envelope.node_id or "node", success=True, output_text="done")
|
||||
|
||||
|
||||
def _bundle(provider: RecordingProvider) -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
@ -161,10 +171,72 @@ def test_local_agent_runner_uses_shared_loop_and_records_parent_task(tmp_path: P
|
||||
child_session = loaded.session_manager.get_session(result.session_id) # type: ignore[union-attr,arg-type]
|
||||
|
||||
assert result.success is True
|
||||
assert result.completion_status == "succeeded"
|
||||
assert result.evidence_gaps == []
|
||||
assert run_record.task_id == "task-parent"
|
||||
assert child_session["parent_session_id"] == "session-root"
|
||||
|
||||
|
||||
def test_node_without_required_tool_result_is_partial(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("collected narrative")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
task="collect",
|
||||
node_id="collect",
|
||||
required_evidence=["tool_result"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is False
|
||||
assert result.completion_status == "partial"
|
||||
assert result.evidence_gaps == ["missing required evidence: tool_result"]
|
||||
|
||||
|
||||
def test_node_with_required_nonempty_output_succeeds(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("verified output")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="verify"),
|
||||
task="verify",
|
||||
node_id="verify",
|
||||
required_evidence=["output"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is True
|
||||
assert result.completion_status == "succeeded"
|
||||
assert result.evidence_gaps == []
|
||||
|
||||
|
||||
def test_unknown_evidence_requirement_makes_node_partial(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("output")])
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
parent_run_id=None,
|
||||
agent=AgentDescriptor(name="verify"),
|
||||
task="verify",
|
||||
node_id="verify",
|
||||
required_evidence=["unknown_type"],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is False
|
||||
assert result.completion_status == "partial"
|
||||
assert result.evidence_gaps == ["unsupported evidence requirement: unknown_type"]
|
||||
|
||||
|
||||
def test_team_node_preserves_evidence_when_finish_reason_is_not_stop(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider([_response("partial evidence", finish_reason="max_tool_iterations")])
|
||||
@ -277,6 +349,108 @@ def test_team_sequence_passes_prior_outputs(tmp_path: Path) -> None:
|
||||
assert "Dependency first output:\nfirst output" in providers["second"].calls[0][0]["content"]
|
||||
|
||||
|
||||
def test_partial_node_allows_downstream_by_default(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
"collect": RecordingProvider([_response("partial source notes")]),
|
||||
"extract": RecordingProvider([_response("extracted metrics")]),
|
||||
}
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
required_evidence=["tool_result"],
|
||||
),
|
||||
ExecutionNode("extract", "extract", AgentDescriptor(name="extract")),
|
||||
],
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
TeamService(loop).run_team(
|
||||
graph,
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.node_results[0].completion_status == "partial"
|
||||
assert result.node_results[1].completion_status == "succeeded"
|
||||
assert "Dependency collect output:\npartial source notes" in providers["extract"].calls[0][0]["content"]
|
||||
|
||||
|
||||
def test_partial_node_blocks_downstream_when_configured(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
"collect": RecordingProvider([_response("partial source notes")]),
|
||||
"extract": RecordingProvider([_response("must not run")]),
|
||||
}
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
required_evidence=["tool_result"],
|
||||
block_downstream_on_partial=True,
|
||||
),
|
||||
ExecutionNode("extract", "extract", AgentDescriptor(name="extract")),
|
||||
],
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
TeamService(loop).run_team(
|
||||
graph,
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.node_results[0].completion_status == "partial"
|
||||
assert result.node_results[1].completion_status == "blocked"
|
||||
assert providers["extract"].calls == []
|
||||
|
||||
|
||||
def test_scheduler_copies_task_two_contract_fields_to_envelope() -> None:
|
||||
runner = CapturingRunner()
|
||||
node = ExecutionNode(
|
||||
"collect",
|
||||
"collect",
|
||||
AgentDescriptor(name="collect"),
|
||||
input_contract={"query": "str"},
|
||||
output_contract={"sources": "list"},
|
||||
required_evidence=["tool_result"],
|
||||
evidence_contract={"entities": ["MGM"]},
|
||||
validation_rules=["official_sources_only"],
|
||||
required_for_completion=False,
|
||||
block_downstream_on_partial=True,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
TeamGraphScheduler(runner).run( # type: ignore[arg-type]
|
||||
ExecutionGraph(strategy="sequence", nodes=[node]),
|
||||
parent_task_id=None,
|
||||
parent_session_id="session-root",
|
||||
)
|
||||
)
|
||||
|
||||
envelope = runner.envelopes[0]
|
||||
assert envelope.input_contract == {"query": "str"}
|
||||
assert envelope.output_contract == {"sources": "list"}
|
||||
assert envelope.required_evidence == ["tool_result"]
|
||||
assert envelope.evidence_contract == {"entities": ["MGM"]}
|
||||
assert envelope.validation_rules == ["official_sources_only"]
|
||||
assert envelope.required_for_completion is False
|
||||
assert envelope.block_downstream_on_partial is True
|
||||
assert envelope.max_tool_iterations == 2
|
||||
|
||||
|
||||
def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
providers = {
|
||||
@ -428,9 +602,12 @@ def test_team_dag_blocks_dependents_after_failure(tmp_path: Path) -> None:
|
||||
)
|
||||
)
|
||||
publish = [item for item in result.node_results if item.node_id == "publish"][0]
|
||||
validate = [item for item in result.node_results if item.node_id == "validate"][0]
|
||||
|
||||
assert result.success is False
|
||||
assert validate.completion_status == "failed"
|
||||
assert publish.finish_reason == "blocked"
|
||||
assert publish.completion_status == "blocked"
|
||||
assert publish.run_id is None
|
||||
assert publish.error == "Blocked by failed dependency: validate"
|
||||
assert "failed" not in result.summary.split("Failed nodes:")[0]
|
||||
@ -471,8 +648,10 @@ def test_dag_node_factory_error_blocks_dependents(tmp_path: Path) -> None:
|
||||
|
||||
assert result.success is False
|
||||
assert validate.finish_reason == "error"
|
||||
assert validate.completion_status == "failed"
|
||||
assert validate.error == "validator unavailable"
|
||||
assert publish.finish_reason == "blocked"
|
||||
assert publish.completion_status == "blocked"
|
||||
assert publish.error == "Blocked by failed dependency: validate"
|
||||
|
||||
|
||||
@ -550,6 +729,76 @@ def test_graph_structure_errors_still_raise(tmp_path: Path) -> None:
|
||||
asyncio.run(TeamService(loop).run_team(cyclic, parent_task_id=None, parent_session_id="session-root"))
|
||||
|
||||
|
||||
def test_execution_node_contract_defaults_preserve_legacy_scope_behavior() -> None:
|
||||
node = ExecutionNode("collect", "Collect sources", AgentDescriptor(name="collect"))
|
||||
|
||||
assert node.input_contract == {}
|
||||
assert node.output_contract == {}
|
||||
assert node.allowed_tool_names is None
|
||||
assert node.required_evidence == []
|
||||
assert node.evidence_contract == {}
|
||||
assert node.validation_rules == []
|
||||
assert node.required_for_completion is True
|
||||
assert node.block_downstream_on_partial is False
|
||||
assert node.max_tool_iterations is None
|
||||
|
||||
|
||||
def test_execution_node_keeps_explicit_empty_tool_scope_distinct_from_unspecified_scope() -> None:
|
||||
unrestricted = ExecutionNode("unrestricted", "Collect", AgentDescriptor(name="unrestricted"))
|
||||
tool_free = ExecutionNode(
|
||||
"tool_free",
|
||||
"Synthesize",
|
||||
AgentDescriptor(name="tool_free"),
|
||||
allowed_tool_names=[],
|
||||
)
|
||||
|
||||
assert unrestricted.allowed_tool_names is None
|
||||
assert tool_free.allowed_tool_names == []
|
||||
|
||||
|
||||
def test_delegation_envelope_and_node_result_preserve_new_contract_metadata() -> None:
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
parent_run_id="run-root",
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
task="Collect sources",
|
||||
allowed_tool_names=["web_search"],
|
||||
required_evidence=["url"],
|
||||
evidence_contract={"entities": ["MGM", "Galaxy"]},
|
||||
validation_rules=["official_sources_only"],
|
||||
required_for_completion=True,
|
||||
block_downstream_on_partial=True,
|
||||
max_tool_iterations=2,
|
||||
)
|
||||
result = NodeRunResult(
|
||||
node_id="collect",
|
||||
success=False,
|
||||
output_text="MGM source only",
|
||||
completion_status="partial",
|
||||
evidence_gaps=["missing required evidence: Galaxy official source"],
|
||||
)
|
||||
|
||||
assert envelope.allowed_tool_names == ["web_search"]
|
||||
assert envelope.evidence_contract == {"entities": ["MGM", "Galaxy"]}
|
||||
assert result.to_dict()["completion_status"] == "partial"
|
||||
assert result.to_dict()["evidence_gaps"] == ["missing required evidence: Galaxy official source"]
|
||||
|
||||
|
||||
def test_graph_rejects_depth_above_configured_limit() -> None:
|
||||
graph = ExecutionGraph(
|
||||
strategy="dag",
|
||||
nodes=[
|
||||
ExecutionNode("a", "A", AgentDescriptor(name="a")),
|
||||
ExecutionNode("b", "B", AgentDescriptor(name="b"), depends_on=["a"]),
|
||||
ExecutionNode("c", "C", AgentDescriptor(name="c"), depends_on=["b"]),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="max depth"):
|
||||
graph.validate(max_depth=2)
|
||||
|
||||
|
||||
def test_team_run_does_not_create_independent_team_task(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
loaded = loop.boot()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@ -74,10 +75,77 @@ def test_debug_chat_logs_group_events_by_run(tmp_path: Path) -> None:
|
||||
assert run["intent_agent_choice"] == "create_task"
|
||||
assert run["user_input"] == "hello"
|
||||
assert [event["event_type"] for event in run["events"]] == [
|
||||
"run_started",
|
||||
"intent_agent_decision_snapshotted",
|
||||
"llm_request_snapshotted",
|
||||
"user_message_added",
|
||||
"assistant_message_added",
|
||||
"user_message_added",
|
||||
"llm_request_snapshotted",
|
||||
"intent_agent_decision_snapshotted",
|
||||
"run_started",
|
||||
]
|
||||
assert run["events"][2]["event_payload"]["messages"][0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_debug_chat_logs_are_reverse_chronological_and_include_latency(tmp_path: Path) -> None:
|
||||
service = AgentService(workspace=tmp_path)
|
||||
loaded = service.create_loop().boot()
|
||||
manager = loaded.session_manager
|
||||
session_id = "web:debug-order"
|
||||
manager.ensure_session(session_id, source="web", title="Debug order")
|
||||
|
||||
manager.append_message(
|
||||
session_id,
|
||||
run_id="run-old",
|
||||
role="system",
|
||||
event_type="run_started",
|
||||
content="old",
|
||||
context_visible=False,
|
||||
)
|
||||
manager.append_message(
|
||||
session_id,
|
||||
run_id="run-old",
|
||||
role="system",
|
||||
event_type="run_completed",
|
||||
event_payload={"latency_ms": {"total_ms": 10.0, "llm_ms": 7.0}},
|
||||
finish_reason="stop",
|
||||
context_visible=False,
|
||||
)
|
||||
sleep(0.01)
|
||||
manager.append_message(
|
||||
session_id,
|
||||
run_id="run-new",
|
||||
role="system",
|
||||
event_type="run_started",
|
||||
content="new",
|
||||
context_visible=False,
|
||||
)
|
||||
manager.append_message(
|
||||
session_id,
|
||||
run_id="run-new",
|
||||
role="system",
|
||||
event_type="run_completed",
|
||||
event_payload={
|
||||
"latency_ms": {
|
||||
"router_ms": 1.0,
|
||||
"mcp_ms": 2.0,
|
||||
"skill_assembly_ms": 3.0,
|
||||
"tool_assembly_ms": 4.0,
|
||||
"context_build_ms": 5.0,
|
||||
"llm_ms": 6.0,
|
||||
"tool_ms": 7.0,
|
||||
"session_write_ms": 8.0,
|
||||
"total_ms": 36.0,
|
||||
}
|
||||
},
|
||||
finish_reason="stop",
|
||||
context_visible=False,
|
||||
)
|
||||
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/debug/chat-logs")
|
||||
|
||||
assert response.status_code == 200
|
||||
runs = response.json()["sessions"][0]["runs"]
|
||||
assert [run["run_id"] for run in runs] == ["run-new", "run-old"]
|
||||
assert [event["event_type"] for event in runs[0]["events"]] == ["run_completed", "run_started"]
|
||||
assert runs[0]["latency_ms"]["total_ms"] == 36.0
|
||||
assert runs[0]["latency_ms"]["router_ms"] == 1.0
|
||||
|
||||
@ -158,7 +158,7 @@ def test_router_receives_thinking_mode() -> None:
|
||||
provider = RouterProvider('{"action":"simple_chat","reason":"simple"}')
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"你好",
|
||||
"请判断一下这个概念是否合理",
|
||||
provider=provider,
|
||||
thinking_enabled=False,
|
||||
)
|
||||
@ -168,11 +168,84 @@ def test_router_receives_thinking_mode() -> None:
|
||||
assert provider.calls[0]["thinking_enabled"] is False
|
||||
|
||||
|
||||
def test_router_fast_paths_obvious_simple_chat_without_provider_call() -> None:
|
||||
provider = RouterProvider('{"action":"new_task","reason":"should not be used"}')
|
||||
|
||||
decision = asyncio.run(MainAgentRouter().classify("你好", provider=provider))
|
||||
punctuated = asyncio.run(MainAgentRouter().classify("你好!", provider=provider))
|
||||
translation = asyncio.run(MainAgentRouter().classify("翻译这句话:hello world", provider=provider))
|
||||
|
||||
assert not decision.is_task
|
||||
assert decision.action == "simple_chat"
|
||||
assert decision.reason == "obvious_simple_chat"
|
||||
assert not punctuated.is_task
|
||||
assert punctuated.action == "simple_chat"
|
||||
assert not translation.is_task
|
||||
assert translation.action == "simple_chat"
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_router_sends_broad_explanations_to_intent_llm() -> None:
|
||||
provider = RouterProvider('{"action":"simple_chat","reason":"intent decided concept explanation"}')
|
||||
|
||||
explanation = asyncio.run(MainAgentRouter().classify("解释一下什么是 MCP", provider=provider))
|
||||
definition = asyncio.run(MainAgentRouter().classify("什么是 context engineering", provider=provider))
|
||||
|
||||
assert not explanation.is_task
|
||||
assert explanation.reason == "intent decided concept explanation"
|
||||
assert not definition.is_task
|
||||
assert definition.reason == "intent decided concept explanation"
|
||||
assert len(provider.calls) == 2
|
||||
|
||||
|
||||
def test_router_fast_paths_obvious_task_without_provider_call() -> None:
|
||||
provider = RouterProvider('{"action":"simple_chat","reason":"should not be used"}')
|
||||
|
||||
decision = asyncio.run(MainAgentRouter().classify("帮我查一下今天深圳天气", provider=provider))
|
||||
current_event = asyncio.run(
|
||||
MainAgentRouter().classify("解释一下今天法国队在世界杯的表现为什么那么好", provider=provider)
|
||||
)
|
||||
|
||||
assert decision.is_task
|
||||
assert decision.action == "create_task"
|
||||
assert decision.reason == "obvious_task"
|
||||
assert current_event.is_task
|
||||
assert current_event.action == "create_task"
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_router_does_not_simple_fast_path_current_event_explanations() -> None:
|
||||
provider = RouterProvider('{"action":"simple_chat","reason":"llm fallback"}')
|
||||
|
||||
decision = asyncio.run(MainAgentRouter().classify("解释一下昨晚法国队在世界杯的表现为什么那么好", provider=provider))
|
||||
|
||||
assert decision.is_task
|
||||
assert decision.action == "create_task"
|
||||
assert decision.reason == "obvious_task"
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_router_keeps_active_task_followups_in_llm_path() -> None:
|
||||
provider = RouterProvider('{"action":"revise_task","reason":"needs revision","short_title":"任务连续性"}')
|
||||
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"这个也加上",
|
||||
active_task=_task(),
|
||||
provider=provider,
|
||||
)
|
||||
)
|
||||
|
||||
assert decision.is_task
|
||||
assert decision.action == "revise_task"
|
||||
assert len(provider.calls) == 1
|
||||
|
||||
|
||||
def test_router_injects_intent_skill_guidance() -> None:
|
||||
provider = RouterProvider('{"action":"new_task","reason":"needs weather tool","short_title":"珠海天气"}')
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"帮我查一下今天珠海天气",
|
||||
"帮我判断这个需求要不要进入任务模式",
|
||||
provider=provider,
|
||||
intent_skill="Weather and current external data must be routed to new_task.",
|
||||
)
|
||||
@ -247,7 +320,7 @@ def test_router_retries_once_after_provider_failure() -> None:
|
||||
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"帮我看看昨天的中美会面都谈了什么?",
|
||||
"帮我判断这次中美会面分析需求要不要进入任务模式",
|
||||
provider=provider,
|
||||
)
|
||||
)
|
||||
@ -262,7 +335,7 @@ def test_router_fallback_after_two_provider_failures() -> None:
|
||||
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"帮我看看昨天的中美会面都谈了什么?",
|
||||
"帮我判断这次中美会面分析需求要不要进入任务模式",
|
||||
provider=provider,
|
||||
)
|
||||
)
|
||||
|
||||
@ -103,7 +103,7 @@ def test_skill_selection_receives_thinking_mode() -> None:
|
||||
assert provider.thinking_enabled is False
|
||||
|
||||
|
||||
def test_skill_assembler_loads_detail_directly_for_small_candidate_sets() -> None:
|
||||
def test_skill_assembler_directly_activates_single_clear_candidate_without_llm() -> None:
|
||||
provider = SequencedProvider(['["docker-debug"]'])
|
||||
assembler = SkillAssembler(loader=LoaderWithFullSkill(), retriever=StaticRetriever())
|
||||
|
||||
@ -117,10 +117,8 @@ def test_skill_assembler_loads_detail_directly_for_small_candidate_sets() -> Non
|
||||
|
||||
assert [skill.name for skill in result.activated_skills] == ["docker-debug"]
|
||||
assert result.activated_skills[0].tool_hints == ["search_files"]
|
||||
assert [item["stage"] for item in result.llm_interactions] == ["final"]
|
||||
assert len(provider.messages) == 1
|
||||
first_user_prompt = provider.messages[0][1]["content"]
|
||||
assert "Use this skill when doing Docker log triage" in first_user_prompt
|
||||
assert result.llm_interactions == []
|
||||
assert provider.messages == []
|
||||
|
||||
|
||||
def test_skill_assembler_shortlists_before_loading_detail_for_large_candidate_sets() -> None:
|
||||
|
||||
@ -395,6 +395,52 @@ def test_replay_main_score_uses_validator_not_tool_success(tmp_path: Path) -> No
|
||||
assert report.synthetic_score_avg is not None
|
||||
|
||||
|
||||
def test_replay_real_case_without_validator_uses_same_output_scoring_for_both_arms(tmp_path: Path) -> None:
|
||||
pipeline = _pipeline(tmp_path, task_score=0.8)
|
||||
pipeline.learning_store.update_learning_candidate(
|
||||
"candidate-1",
|
||||
evidence={
|
||||
"eval_cases": [
|
||||
{
|
||||
"run_id": "real-no-validator",
|
||||
"task_id": "real-no-validator",
|
||||
"session_id": "eval",
|
||||
"task_text": "Summarize the release checklist.",
|
||||
"accepted_score": 0.8,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
draft = pipeline.draft_service.create_new_skill_draft(
|
||||
skill_name="release-checklist",
|
||||
proposed_content="# Release\n\nRun tests.",
|
||||
proposed_frontmatter={"description": "release", "tools": []},
|
||||
created_by="test",
|
||||
reason="test",
|
||||
)
|
||||
pipeline.learning_store.update_learning_candidate("candidate-1", draft_skill_name=draft.skill_name, draft_id=draft.draft_id)
|
||||
|
||||
report = asyncio.run(
|
||||
pipeline.evaluate_draft(
|
||||
"candidate-1",
|
||||
draft.skill_name,
|
||||
draft.draft_id,
|
||||
provider_bundle=_bundle(),
|
||||
replay_runner=FakeReplayRunner(
|
||||
baseline_answer="Release checklist summarized.",
|
||||
candidate_answer="Release checklist summarized.",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case = next(item for item in report.case_reports if item["run_id"] == "real-no-validator")
|
||||
legacy_case = next(item for item in report.cases if item["run_id"] == "real-no-validator")
|
||||
assert case["baseline_score"] == 0.7
|
||||
assert case["candidate_score"] == 0.7
|
||||
assert case["delta"] == 0.0
|
||||
assert legacy_case["delta"] == 0.0
|
||||
|
||||
|
||||
def test_synthetic_cases_without_validator_are_not_replay_scored(tmp_path: Path) -> None:
|
||||
pipeline = _pipeline(tmp_path)
|
||||
pipeline.learning_store.update_learning_candidate(
|
||||
|
||||
65
app-instance/backend/tests/unit/test_skill_team_template.py
Normal file
65
app-instance/backend/tests/unit/test_skill_team_template.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from beaver.skills.assembler.task_assembler import SkillAssembler
|
||||
from beaver.skills.catalog.loader import SkillsLoader
|
||||
from beaver.skills.catalog.utils import extract_skill_team_template
|
||||
|
||||
|
||||
def test_extract_team_template_returns_none_when_block_is_absent() -> None:
|
||||
result = extract_skill_team_template("# Ordinary Skill")
|
||||
|
||||
assert result.template is None
|
||||
assert result.warnings == []
|
||||
|
||||
|
||||
def test_extract_team_template_parses_valid_json_block() -> None:
|
||||
result = extract_skill_team_template(
|
||||
"```beaver-team-template\n"
|
||||
'{"version": 1, "nodes": [{"node_id": "collect", "task": "Collect"}]}\n'
|
||||
"```"
|
||||
)
|
||||
|
||||
assert result.template == {
|
||||
"version": 1,
|
||||
"nodes": [{"node_id": "collect", "task": "Collect"}],
|
||||
}
|
||||
assert result.warnings == []
|
||||
|
||||
|
||||
def test_invalid_template_is_warning_not_skill_load_failure() -> None:
|
||||
result = extract_skill_team_template("```beaver-team-template\nnot-json\n```")
|
||||
|
||||
assert result.template is None
|
||||
assert result.warnings == ["team template JSON is invalid"]
|
||||
|
||||
|
||||
def test_loader_and_assembler_propagate_team_template_to_skill_context(tmp_path) -> None:
|
||||
skill_dir = tmp_path / "plugin-skills" / "financial-comparison"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"description: Compare financial disclosures.\n"
|
||||
"---\n\n"
|
||||
"# Financial Comparison\n\n"
|
||||
"```beaver-team-template\n"
|
||||
'{"version": 1, "nodes": [{"node_id": "collect", "task": "Collect official sources"}]}\n'
|
||||
"```\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
loader = SkillsLoader(
|
||||
tmp_path,
|
||||
builtin_skills_dir=tmp_path / "no-builtins",
|
||||
extra_dirs=[tmp_path / "plugin-skills"],
|
||||
)
|
||||
|
||||
record = loader.get_skill_record("financial-comparison")
|
||||
context = SkillAssembler(loader)._activate_skill_contexts(["financial-comparison"])[0]
|
||||
|
||||
assert record is not None
|
||||
assert record.team_template == {
|
||||
"version": 1,
|
||||
"nodes": [{"node_id": "collect", "task": "Collect official sources"}],
|
||||
}
|
||||
assert record.team_template_warnings == []
|
||||
assert context.team_template == record.team_template
|
||||
assert context.team_template_warnings == []
|
||||
@ -3,7 +3,65 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
from beaver.engine.session.manager import SessionManager
|
||||
from beaver.tasks.evidence import EvidenceBuilder, RunEvidence, TaskEvidencePacket, ToolEvidence, render_task_evidence
|
||||
from beaver.tasks.evidence import (
|
||||
EvidenceBuilder,
|
||||
RunEvidence,
|
||||
TaskEvidencePacket,
|
||||
ToolEvidence,
|
||||
evaluate_node_evidence,
|
||||
render_task_evidence,
|
||||
)
|
||||
|
||||
|
||||
def _run_evidence(*, tool_results: list[ToolEvidence] | None = None) -> RunEvidence:
|
||||
return RunEvidence(
|
||||
run_id="run-1",
|
||||
session_id="session-1",
|
||||
output_text="",
|
||||
finish_reason="stop",
|
||||
tool_results=list(tool_results or []),
|
||||
)
|
||||
|
||||
|
||||
def test_evaluate_node_evidence_requires_successful_tool_result() -> None:
|
||||
evidence = _run_evidence(
|
||||
tool_results=[
|
||||
ToolEvidence(
|
||||
tool_name="web_fetch",
|
||||
tool_call_id="call-1",
|
||||
content="failed",
|
||||
event_payload={"success": False},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert evaluate_node_evidence(evidence, ["tool_result"], "done") == [
|
||||
"missing required evidence: tool_result"
|
||||
]
|
||||
|
||||
|
||||
def test_evaluate_node_evidence_accepts_url_in_successful_tool_content() -> None:
|
||||
evidence = _run_evidence(
|
||||
tool_results=[
|
||||
ToolEvidence(
|
||||
tool_name="web_fetch",
|
||||
tool_call_id="call-1",
|
||||
content="Source: https://example.test/report",
|
||||
event_payload={"success": True},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert evaluate_node_evidence(evidence, ["tool_result", "url"], "done") == []
|
||||
|
||||
|
||||
def test_evaluate_node_evidence_checks_output_and_unknown_requirements() -> None:
|
||||
evidence = _run_evidence()
|
||||
|
||||
assert evaluate_node_evidence(evidence, ["output", "unknown_type"], " ") == [
|
||||
"missing required evidence: output",
|
||||
"unsupported evidence requirement: unknown_type",
|
||||
]
|
||||
|
||||
|
||||
def test_evidence_builder_preserves_full_tool_result(tmp_path: Path) -> None:
|
||||
|
||||
@ -3,15 +3,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from beaver.engine.context import SkillContext
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.tasks import TaskExecutionPlanner, TaskRecord
|
||||
from beaver.tasks import SkillResolutionReport, TaskExecutionPlanner, TaskRecord
|
||||
from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec
|
||||
from beaver.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class PlannerProvider(LLMProvider):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__()
|
||||
self.response = response
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@ -21,6 +25,15 @@ class PlannerProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
)
|
||||
return LLMResponse(content=self.response, finish_reason="stop", provider_name="stub", model="stub-model")
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
@ -43,6 +56,28 @@ class HangingPlannerProvider(LLMProvider):
|
||||
return "stub-model"
|
||||
|
||||
|
||||
class SequencedPlannerProvider(PlannerProvider):
|
||||
def __init__(self, responses: list[str]) -> None:
|
||||
super().__init__(responses[0])
|
||||
self.responses = list(responses)
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.response = self.responses.pop(0)
|
||||
return await super().chat(*args, **kwargs)
|
||||
|
||||
|
||||
class StubTool(BaseTool):
|
||||
def __init__(self, name: str) -> None:
|
||||
self._spec = ToolSpec(name=name, description=name, input_schema={"type": "object"})
|
||||
|
||||
@property
|
||||
def spec(self) -> ToolSpec:
|
||||
return self._spec
|
||||
|
||||
async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult:
|
||||
raise AssertionError("Planner tests do not execute tools")
|
||||
|
||||
|
||||
def _task() -> TaskRecord:
|
||||
return TaskRecord(
|
||||
task_id="task-1",
|
||||
@ -59,12 +94,26 @@ def _task() -> TaskRecord:
|
||||
|
||||
|
||||
def _bundle(response: str) -> ProviderBundle:
|
||||
provider = PlannerProvider(response)
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
main_provider=PlannerProvider(response),
|
||||
main_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def _bundle_with_provider(provider: LLMProvider) -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
main_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def _registry() -> ToolRegistry:
|
||||
registry = ToolRegistry()
|
||||
registry.register_many([StubTool("web_search"), StubTool("web_fetch"), StubTool("terminal")])
|
||||
return registry
|
||||
|
||||
|
||||
def _hanging_bundle() -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
@ -87,26 +136,55 @@ def test_planner_selects_single_mode() -> None:
|
||||
assert plan.reason == "main agent is enough"
|
||||
|
||||
|
||||
def test_planner_skips_llm_for_simple_task() -> None:
|
||||
provider = PlannerProvider('{"mode":"team","reason":"should not be used"}')
|
||||
bundle = ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
main_provider=provider,
|
||||
)
|
||||
task = _task()
|
||||
task.description = "查询深圳天气"
|
||||
task.goal = "查询深圳天气"
|
||||
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner().plan(
|
||||
task=task,
|
||||
user_message="帮我查一下今天深圳天气",
|
||||
attempt_index=1,
|
||||
provider_bundle=bundle,
|
||||
)
|
||||
)
|
||||
|
||||
assert plan.mode == "single"
|
||||
assert plan.graph is None
|
||||
assert plan.reason == "planner_skipped_simple_task"
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_planner_builds_team_graph() -> None:
|
||||
bundle = _bundle(
|
||||
"""
|
||||
{
|
||||
"mode": "team",
|
||||
"reason": "needs parallel review",
|
||||
"strategy": "dag",
|
||||
"nodes": [
|
||||
{"node_id": "research", "task": "research options"},
|
||||
{"node_id": "review", "task": "review result", "depends_on": ["research"]}
|
||||
],
|
||||
"final_synthesis_instruction": "merge the findings"
|
||||
}
|
||||
"""
|
||||
)
|
||||
provider = bundle.main_provider
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner().plan(
|
||||
task=_task(),
|
||||
user_message="implement workflow",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle(
|
||||
"""
|
||||
{
|
||||
"mode": "team",
|
||||
"reason": "needs parallel review",
|
||||
"strategy": "dag",
|
||||
"nodes": [
|
||||
{"node_id": "research", "task": "research options", "agent": {"name": "researcher"}},
|
||||
{"node_id": "review", "task": "review result", "agent": {"name": "reviewer"}, "depends_on": ["research"]}
|
||||
],
|
||||
"final_synthesis_instruction": "merge the findings"
|
||||
}
|
||||
"""
|
||||
),
|
||||
provider_bundle=bundle,
|
||||
skill_summaries=["docker-debug: Use docker logs before editing config."],
|
||||
tool_hints=["terminal", "search_files"],
|
||||
)
|
||||
)
|
||||
|
||||
@ -116,6 +194,12 @@ def test_planner_builds_team_graph() -> None:
|
||||
assert [node.node_id for node in plan.graph.nodes] == ["research", "review"]
|
||||
assert plan.graph.nodes[1].depends_on == ["research"]
|
||||
assert plan.final_synthesis_instruction == "merge the findings"
|
||||
assert isinstance(provider, PlannerProvider)
|
||||
prompt = provider.calls[0]["messages"][1]["content"]
|
||||
assert "Activated skill summaries" in prompt
|
||||
assert "docker-debug: Use docker logs before editing config." in prompt
|
||||
assert "terminal" in prompt
|
||||
assert "search_files" in prompt
|
||||
|
||||
|
||||
def test_planner_timeout_falls_back_to_single() -> None:
|
||||
@ -134,7 +218,7 @@ def test_planner_timeout_falls_back_to_single() -> None:
|
||||
assert "TimeoutError" in (plan.fallback_error or "")
|
||||
|
||||
|
||||
def test_planner_team_nodes_can_target_skills_without_agent_roles() -> None:
|
||||
def test_planner_team_nodes_use_task_as_internal_skill_query() -> None:
|
||||
plan = TaskExecutionPlanner().from_json(
|
||||
"""
|
||||
{
|
||||
@ -144,9 +228,7 @@ def test_planner_team_nodes_can_target_skills_without_agent_roles() -> None:
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "api_review",
|
||||
"task": "review API compatibility",
|
||||
"skill_query": "API contract compatibility review",
|
||||
"required_capabilities": ["schema compatibility"]
|
||||
"task": "review API compatibility"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -158,8 +240,77 @@ def test_planner_team_nodes_can_target_skills_without_agent_roles() -> None:
|
||||
node = plan.graph.nodes[0]
|
||||
assert node.agent.name == "api_review"
|
||||
assert node.agent.role == ""
|
||||
assert node.agent.metadata["skill_query"] == "API contract compatibility review"
|
||||
assert node.agent.metadata["required_capabilities"] == ["schema compatibility"]
|
||||
assert node.agent.metadata["skill_query"] == "review API compatibility"
|
||||
assert node.agent.metadata["required_capabilities"] == []
|
||||
|
||||
|
||||
def test_planner_accepts_use_skill_and_skill_query() -> None:
|
||||
plan = TaskExecutionPlanner().from_json(
|
||||
"""
|
||||
{
|
||||
"mode": "team",
|
||||
"strategy": "sequence",
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "collect",
|
||||
"task": "Collect official sources",
|
||||
"use_skill": "official-source-research",
|
||||
"skill_query": "official source verification"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
node = plan.graph.nodes[0]
|
||||
assert node.agent.metadata["use_skill"] == "official-source-research"
|
||||
assert node.agent.metadata["skill_query"] == "official source verification"
|
||||
assert node.inherited_pinned_skills == []
|
||||
assert node.allowed_tool_names is None
|
||||
assert plan.planner_adaptation["node_skill_bindings"] == [
|
||||
{
|
||||
"node_id": "collect",
|
||||
"use_skill": "official-source-research",
|
||||
"skill_query": "official source verification",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_planner_defaults_skill_query_to_node_task_when_absent() -> None:
|
||||
plan = TaskExecutionPlanner().from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"extract","task":"Extract financial metrics","use_skill":"financial-extraction"}]}'
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
assert plan.graph.nodes[0].agent.metadata["skill_query"] == "Extract financial metrics"
|
||||
|
||||
|
||||
def test_planner_adaptation_records_unresolved_use_skill_fallback() -> None:
|
||||
planner = TaskExecutionPlanner()
|
||||
plan = planner.from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"extract","task":"Extract metrics","use_skill":"missing-skill",'
|
||||
'"skill_query":"financial extraction"}]}'
|
||||
)
|
||||
report = SkillResolutionReport(
|
||||
node_id="extract",
|
||||
skill_query="financial extraction",
|
||||
requested_skill_name="missing-skill",
|
||||
exact_binding_used=False,
|
||||
warnings=["use_skill unresolved: missing-skill"],
|
||||
reason="matched published skill",
|
||||
)
|
||||
|
||||
planner._merge_skill_resolution_adaptation(plan, [report])
|
||||
|
||||
assert plan.planner_adaptation["warnings"] == ["use_skill unresolved: missing-skill"]
|
||||
assert plan.planner_adaptation["node_skill_bindings"][0]["fallback_reason"] == (
|
||||
"use_skill unresolved; matched published skill"
|
||||
)
|
||||
|
||||
|
||||
def test_planner_invalid_outputs_fallback_to_single() -> None:
|
||||
@ -193,3 +344,216 @@ def test_planner_invalid_outputs_fallback_to_single() -> None:
|
||||
assert unknown_strategy.mode == "single"
|
||||
assert too_many_nodes.mode == "single"
|
||||
assert cyclic.mode == "single"
|
||||
|
||||
|
||||
def test_template_plan_builds_generic_worker_and_preserves_v1_contract_fields() -> None:
|
||||
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
||||
"""
|
||||
{
|
||||
"mode": "team",
|
||||
"strategy": "dag",
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "collect",
|
||||
"task": "Collect official sources",
|
||||
"requested_tools": ["web_search"],
|
||||
"evidence_contract": {"entities": ["MGM", "Galaxy"]},
|
||||
"block_downstream_on_partial": true
|
||||
}
|
||||
],
|
||||
"adaptation": {"template_used": true}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
node = plan.graph.nodes[0]
|
||||
assert node.agent.name == "collect"
|
||||
assert node.agent.role == ""
|
||||
assert node.agent.metadata["sub_agent_kind"] == "generic_skill_worker"
|
||||
assert node.allowed_tool_names == ["web_search"]
|
||||
assert node.evidence_contract == {"entities": ["MGM", "Galaxy"]}
|
||||
assert node.block_downstream_on_partial is True
|
||||
assert plan.planner_adaptation["template_used"] is True
|
||||
|
||||
|
||||
def test_unknown_tool_is_removed_and_warned() -> None:
|
||||
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"collect","task":"Collect","requested_tools":["web_search","not_real"]}]}'
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
assert plan.graph.nodes[0].allowed_tool_names == ["web_search"]
|
||||
assert "unknown tool removed: not_real" in plan.planner_adaptation["warnings"]
|
||||
|
||||
|
||||
def test_high_risk_tool_is_removed_without_failing_low_risk_plan() -> None:
|
||||
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"collect","task":"Collect","requested_tools":["web_search","terminal"]}]}'
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
assert plan.graph.nodes[0].allowed_tool_names == ["web_search"]
|
||||
assert "requires_high_risk_review: terminal" in plan.planner_adaptation["warnings"]
|
||||
|
||||
|
||||
def test_planner_rejects_agent_and_role_node_fields() -> None:
|
||||
planner = TaskExecutionPlanner(tool_registry=_registry())
|
||||
|
||||
agent_plan = planner.from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"collect","task":"Collect","agent":{"name":"researcher"}}]}'
|
||||
)
|
||||
role_plan = planner.from_json(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"collect","task":"Collect","role":"researcher"}]}'
|
||||
)
|
||||
|
||||
assert agent_plan.mode == "single"
|
||||
assert "agent" in (agent_plan.fallback_error or "")
|
||||
assert role_plan.mode == "single"
|
||||
assert "role" in (role_plan.fallback_error or "")
|
||||
|
||||
|
||||
def test_planner_records_primary_template_selection_and_ignored_templates() -> None:
|
||||
primary = SkillContext(
|
||||
name="financial-comparison",
|
||||
version="v1",
|
||||
content="Compare official financial disclosures.",
|
||||
team_template={"version": 1, "nodes": [{"node_id": "collect", "task": "Collect"}]},
|
||||
)
|
||||
secondary = SkillContext(
|
||||
name="chart-reporting",
|
||||
version="v2",
|
||||
content="Render chart-ready Markdown.",
|
||||
team_template={"version": 1, "nodes": [{"node_id": "report", "task": "Report"}]},
|
||||
)
|
||||
provider = PlannerProvider(
|
||||
'{"mode":"team","strategy":"sequence","nodes":['
|
||||
'{"node_id":"collect","task":"Collect official sources"}],'
|
||||
'"adaptation":{"template_used":true}}'
|
||||
)
|
||||
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
||||
task=_task(),
|
||||
user_message="compare financial workflow",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle_with_provider(provider),
|
||||
activated_skills=[primary, secondary],
|
||||
)
|
||||
)
|
||||
|
||||
assert plan.planner_adaptation == {
|
||||
"template_used": True,
|
||||
"selected_template": "financial-comparison",
|
||||
"selection_reason": "first activated skill with a valid team template",
|
||||
"ignored_templates": ["chart-reporting"],
|
||||
"warnings": [],
|
||||
}
|
||||
prompt = provider.calls[0]["messages"][1]["content"]
|
||||
assert '"skill_name": "financial-comparison"' in prompt
|
||||
assert "Compare official financial disclosures." in prompt
|
||||
assert "Render chart-ready Markdown." in prompt
|
||||
|
||||
|
||||
def test_malformed_planner_output_repairs_once_without_tools() -> None:
|
||||
provider = SequencedPlannerProvider(
|
||||
[
|
||||
"not json",
|
||||
'{"mode":"team","strategy":"sequence","nodes":[{"node_id":"collect","task":"Collect"}]}',
|
||||
]
|
||||
)
|
||||
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
||||
task=_task(),
|
||||
user_message="implement workflow",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle_with_provider(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert len(provider.calls) == 2
|
||||
assert provider.calls[1]["tools"] is None
|
||||
assert "Repair the invalid planner JSON" in provider.calls[1]["messages"][1]["content"]
|
||||
|
||||
|
||||
def test_failed_planner_repair_falls_back_to_single() -> None:
|
||||
provider = SequencedPlannerProvider(["not json", "still not json"])
|
||||
|
||||
plan = asyncio.run(
|
||||
TaskExecutionPlanner(tool_registry=_registry()).plan(
|
||||
task=_task(),
|
||||
user_message="implement workflow",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle_with_provider(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert plan.mode == "single"
|
||||
assert plan.reason == "planner_fallback_single"
|
||||
assert len(provider.calls) == 2
|
||||
|
||||
|
||||
def test_finance_template_adapts_to_task_oriented_read_only_graph() -> None:
|
||||
plan = TaskExecutionPlanner(tool_registry=_registry()).from_json(
|
||||
"""
|
||||
{
|
||||
"mode": "team",
|
||||
"strategy": "dag",
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "collect_official_sources",
|
||||
"task": "Collect MGM and Galaxy official financial disclosures",
|
||||
"requested_tools": ["web_search", "web_fetch"],
|
||||
"required_evidence": ["tool_result", "url"]
|
||||
},
|
||||
{
|
||||
"node_id": "extract_financial_metrics",
|
||||
"task": "Extract comparable financial metrics from collected sources",
|
||||
"depends_on": ["collect_official_sources"],
|
||||
"requested_tools": ["web_fetch"],
|
||||
"required_evidence": ["output"]
|
||||
},
|
||||
{
|
||||
"node_id": "validate_metrics",
|
||||
"task": "Validate metric units, periods, and source consistency",
|
||||
"depends_on": ["extract_financial_metrics"],
|
||||
"required_evidence": ["output"]
|
||||
},
|
||||
{
|
||||
"node_id": "generate_chart_report",
|
||||
"task": "Generate a Markdown comparison table and chart-ready data without claiming an image or file artifact",
|
||||
"depends_on": ["validate_metrics"],
|
||||
"requested_tools": [],
|
||||
"required_evidence": ["output"]
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
assert plan.is_team
|
||||
assert plan.graph is not None
|
||||
assert [node.node_id for node in plan.graph.nodes] == [
|
||||
"collect_official_sources",
|
||||
"extract_financial_metrics",
|
||||
"validate_metrics",
|
||||
"generate_chart_report",
|
||||
]
|
||||
assert all(node.agent.role == "" for node in plan.graph.nodes)
|
||||
assert not {"researcher", "writer", "reviewer", "analyst"}.intersection(
|
||||
node.node_id for node in plan.graph.nodes
|
||||
)
|
||||
assert plan.graph.nodes[0].allowed_tool_names == ["web_search", "web_fetch"]
|
||||
assert plan.graph.nodes[-1].allowed_tool_names == []
|
||||
report_task = plan.graph.nodes[-1].task.lower()
|
||||
assert "markdown" in report_task
|
||||
assert "without claiming an image or file artifact" in report_task
|
||||
|
||||
@ -4,10 +4,12 @@ import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from beaver.engine import EngineLoader
|
||||
from beaver.engine import AgentRunResult, EngineLoader
|
||||
from beaver.engine.context import SkillContext
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.services.agent_service import AgentService
|
||||
from beaver.skills.assembler import SkillAssemblyResult
|
||||
from beaver.tasks import TaskExecutionPlan, TaskService
|
||||
|
||||
|
||||
@ -39,6 +41,44 @@ class StubTaskExecutionPlanner:
|
||||
return TaskExecutionPlan.single("test-single")
|
||||
|
||||
|
||||
class RecordingTaskExecutionPlanner:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def plan(self, **kwargs) -> TaskExecutionPlan:
|
||||
self.calls.append(dict(kwargs))
|
||||
return TaskExecutionPlan.single("test-single")
|
||||
|
||||
|
||||
class RecordingSkillAssembler:
|
||||
def __init__(self, skills: list[SkillContext]) -> None:
|
||||
self.skills = list(skills)
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def assemble(self, **kwargs) -> SkillAssemblyResult:
|
||||
self.calls.append(dict(kwargs))
|
||||
return SkillAssemblyResult(activated_skills=list(self.skills))
|
||||
|
||||
|
||||
class RecordingTaskAttemptOrchestrator:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def run(self, **kwargs) -> AgentRunResult:
|
||||
self.calls.append(dict(kwargs))
|
||||
task = kwargs["task"]
|
||||
task.task_id = "task-from-orchestrator"
|
||||
return AgentRunResult(
|
||||
session_id=kwargs["kwargs"]["session_id"],
|
||||
run_id="run-from-orchestrator",
|
||||
output_text="orchestrated",
|
||||
finish_reason="stop",
|
||||
tool_iterations=0,
|
||||
task_id=task.task_id,
|
||||
task_status=task.status,
|
||||
)
|
||||
|
||||
|
||||
class FakeLearningCandidate:
|
||||
def to_dict(self) -> dict:
|
||||
return {"candidate_id": "candidate-1", "kind": "new_skill", "status": "open"}
|
||||
@ -101,6 +141,91 @@ def test_task_run_records_evidence_and_waits_for_acceptance(tmp_path: Path) -> N
|
||||
assert "validated" not in event_types
|
||||
|
||||
|
||||
def test_agent_service_records_router_latency(tmp_path: Path) -> None:
|
||||
service = AgentService(
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
task_execution_planner=StubTaskExecutionPlanner(),
|
||||
)
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.process_direct(
|
||||
"draft release notes",
|
||||
session_id="web:latency",
|
||||
provider_bundle=_bundle("Done"),
|
||||
)
|
||||
)
|
||||
|
||||
latency = result.usage["latency_ms"]
|
||||
assert latency["router_ms"] > 0
|
||||
|
||||
|
||||
def test_task_mode_preselects_skills_for_planner_and_reuses_them_in_main_run(tmp_path: Path) -> None:
|
||||
skill = SkillContext(
|
||||
name="docker-debug",
|
||||
content="Use docker logs before editing config.",
|
||||
version="v1",
|
||||
content_hash="hash-v1",
|
||||
activation_reason="llm_selected",
|
||||
tool_hints=["terminal"],
|
||||
)
|
||||
skill_assembler = RecordingSkillAssembler([skill])
|
||||
planner = RecordingTaskExecutionPlanner()
|
||||
service = AgentService(
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
skill_assembler=skill_assembler,
|
||||
task_execution_planner=planner,
|
||||
)
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.process_direct(
|
||||
"debug this workflow",
|
||||
session_id="web:skill-aware-task",
|
||||
provider_bundle=_bundle("Done"),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.task_id
|
||||
assert len(skill_assembler.calls) == 1
|
||||
assert planner.calls
|
||||
assert planner.calls[0]["skill_summaries"] == ["docker-debug: Use docker logs before editing config."]
|
||||
assert planner.calls[0]["tool_hints"] == ["terminal"]
|
||||
|
||||
task_service = service.create_loop().boot().task_service
|
||||
assert task_service is not None
|
||||
task = task_service.get_task(result.task_id)
|
||||
assert task is not None
|
||||
assert task.skill_names == ["docker-debug"]
|
||||
|
||||
|
||||
def test_task_mode_delegates_attempt_execution_to_orchestrator(tmp_path: Path) -> None:
|
||||
orchestrator = RecordingTaskAttemptOrchestrator()
|
||||
service = AgentService(
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
task_execution_planner=StubTaskExecutionPlanner(),
|
||||
)
|
||||
)
|
||||
service._build_task_attempt_orchestrator = lambda loaded: orchestrator # type: ignore[attr-defined]
|
||||
|
||||
result = asyncio.run(
|
||||
service.process_direct(
|
||||
"draft release notes",
|
||||
session_id="web:orchestrator",
|
||||
provider_bundle=_bundle("main runner should not be used"),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.output_text == "orchestrated"
|
||||
assert result.run_id == "run-from-orchestrator"
|
||||
assert len(orchestrator.calls) == 1
|
||||
assert orchestrator.calls[0]["message"] == "draft release notes"
|
||||
assert orchestrator.calls[0]["task"].description == "draft release notes"
|
||||
|
||||
|
||||
def test_task_mode_injects_prompt_locale_output_language(tmp_path: Path) -> None:
|
||||
service = AgentService(
|
||||
loader=EngineLoader(
|
||||
|
||||
@ -222,3 +222,179 @@ def test_task_skill_resolver_keeps_summary_nodes_skillless(tmp_path: Path) -> No
|
||||
assert reports[0].ephemeral_used is False
|
||||
assert reports[0].reason == "summary node uses dependency outputs directly"
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_resolver_exact_binds_use_skill_before_dynamic_lookup(tmp_path: Path) -> None:
|
||||
_publish_skill(tmp_path, skill_name="official-source-research")
|
||||
provider = RecordingProvider(['["wrong-dynamic-skill"]'])
|
||||
resolver = TaskSkillResolver(
|
||||
skills_loader=SkillsLoader(tmp_path),
|
||||
draft_service=DraftService(SkillSpecStore(tmp_path)),
|
||||
)
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"collect",
|
||||
"Collect official sources",
|
||||
AgentDescriptor(
|
||||
name="collect",
|
||||
metadata={
|
||||
"use_skill": "official-source-research",
|
||||
"skill_query": "generic web research",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
resolved, reports = asyncio.run(
|
||||
resolver.resolve_graph(
|
||||
graph,
|
||||
task=_task(),
|
||||
user_message="collect sources",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
node = resolved.nodes[0]
|
||||
assert node.inherited_pinned_skills == ["official-source-research"]
|
||||
assert [context.name for context in node.inherited_pinned_skill_contexts] == ["official-source-research"]
|
||||
assert node.agent.metadata["exact_binding_used"] is True
|
||||
assert reports[0].selected_skill_names == ["official-source-research"]
|
||||
assert reports[0].exact_binding_used is True
|
||||
assert reports[0].warnings == []
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_resolver_falls_back_to_skill_query_when_use_skill_missing(tmp_path: Path) -> None:
|
||||
_publish_skill(tmp_path, skill_name="financial-metric-extraction")
|
||||
provider = RecordingProvider(['["financial-metric-extraction"]'])
|
||||
resolver = TaskSkillResolver(
|
||||
skills_loader=SkillsLoader(tmp_path),
|
||||
draft_service=DraftService(SkillSpecStore(tmp_path)),
|
||||
)
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"extract",
|
||||
"Extract metrics",
|
||||
AgentDescriptor(
|
||||
name="extract",
|
||||
metadata={
|
||||
"use_skill": "missing-exact-skill",
|
||||
"skill_query": "financial metric extraction",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
resolved, reports = asyncio.run(
|
||||
resolver.resolve_graph(
|
||||
graph,
|
||||
task=_task(),
|
||||
user_message="extract financial metrics",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert resolved.nodes[0].inherited_pinned_skills == ["financial-metric-extraction"]
|
||||
assert reports[0].exact_binding_used is False
|
||||
assert reports[0].selected_skill_names == ["financial-metric-extraction"]
|
||||
assert reports[0].warnings == ["use_skill unresolved: missing-exact-skill"]
|
||||
assert "financial metric extraction" in provider.calls[0][1]["content"]
|
||||
|
||||
|
||||
def test_resolver_falls_back_to_ephemeral_when_exact_and_query_miss(tmp_path: Path) -> None:
|
||||
_publish_skill(tmp_path, skill_name="unrelated-skill")
|
||||
provider = RecordingProvider(
|
||||
[
|
||||
"[]",
|
||||
"""
|
||||
{
|
||||
"guidance_name": "financial-extraction-guidance",
|
||||
"description": "Extract financial metrics",
|
||||
"content": "# Financial Extraction\\n\\nExtract the requested metrics.",
|
||||
"tags": ["finance"]
|
||||
}
|
||||
""",
|
||||
]
|
||||
)
|
||||
resolver = TaskSkillResolver(
|
||||
skills_loader=SkillsLoader(tmp_path),
|
||||
draft_service=DraftService(SkillSpecStore(tmp_path)),
|
||||
missing_skill_synthesizer=EphemeralGuidanceSynthesizer(),
|
||||
)
|
||||
graph = ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"extract",
|
||||
"Extract metrics",
|
||||
AgentDescriptor(
|
||||
name="extract",
|
||||
metadata={
|
||||
"use_skill": "missing-exact-skill",
|
||||
"skill_query": "financial metric extraction",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
resolved, reports = asyncio.run(
|
||||
resolver.resolve_graph(
|
||||
graph,
|
||||
task=_task(),
|
||||
user_message="extract financial metrics",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert resolved.nodes[0].inherited_pinned_skills == []
|
||||
assert resolved.nodes[0].inherited_pinned_skill_contexts[0].name == "ephemeral:financial-extraction-guidance"
|
||||
assert reports[0].ephemeral_used is True
|
||||
assert reports[0].warnings == ["use_skill unresolved: missing-exact-skill"]
|
||||
|
||||
|
||||
def test_explicit_use_skill_is_preserved_for_summary_without_nested_expansion(tmp_path: Path) -> None:
|
||||
_publish_skill(tmp_path, skill_name="summary-formatting")
|
||||
provider = RecordingProvider([])
|
||||
resolver = TaskSkillResolver(
|
||||
skills_loader=SkillsLoader(tmp_path),
|
||||
draft_service=DraftService(SkillSpecStore(tmp_path)),
|
||||
)
|
||||
graph = ExecutionGraph(
|
||||
strategy="dag",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
"summarize",
|
||||
"Compile a summary from dependency outputs",
|
||||
AgentDescriptor(
|
||||
name="summarize",
|
||||
metadata={"use_skill": "summary-formatting", "skill_query": "Summarization"},
|
||||
),
|
||||
depends_on=["collect"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
resolved, reports = asyncio.run(
|
||||
resolver.resolve_graph(
|
||||
graph,
|
||||
task=_task(),
|
||||
user_message="summarize",
|
||||
attempt_index=1,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert len(resolved.nodes) == 1
|
||||
assert resolved.nodes[0].inherited_pinned_skills == ["summary-formatting"]
|
||||
assert reports[0].exact_binding_used is True
|
||||
assert provider.calls == []
|
||||
|
||||
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from beaver.coordinator import AgentDescriptor, ExecutionGraph, ExecutionNode, NodeRunResult, TeamRunResult
|
||||
from beaver.engine import AgentRunResult
|
||||
from beaver.tasks import TaskExecutionPlan, TaskRecord
|
||||
from beaver.tasks.attempt_orchestrator import TaskAttemptOrchestrator
|
||||
|
||||
|
||||
def _plan(*, optional_second: bool = False) -> TaskExecutionPlan:
|
||||
return TaskExecutionPlan(
|
||||
mode="team",
|
||||
reason="test team",
|
||||
graph=ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode("collect", "Collect", AgentDescriptor(name="collect")),
|
||||
ExecutionNode(
|
||||
"report",
|
||||
"Report",
|
||||
AgentDescriptor(name="report"),
|
||||
required_for_completion=not optional_second,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _team_result(*results: NodeRunResult) -> TeamRunResult:
|
||||
return TeamRunResult(
|
||||
success=all(result.success for result in results),
|
||||
summary="team summary",
|
||||
node_results=list(results),
|
||||
)
|
||||
|
||||
|
||||
def _result(node_id: str, status: str, *, gaps: list[str] | None = None) -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
node_id=node_id,
|
||||
success=status == "succeeded",
|
||||
output_text=f"{node_id} output",
|
||||
finish_reason="blocked" if status == "blocked" else "stop",
|
||||
error=None if status == "succeeded" else f"{status} node",
|
||||
completion_status=status,
|
||||
evidence_gaps=list(gaps or []),
|
||||
)
|
||||
|
||||
|
||||
def test_required_partial_node_marks_synthesis_incomplete() -> None:
|
||||
context, prefix, metadata = TaskAttemptOrchestrator._team_synthesis_outcome(
|
||||
_plan(),
|
||||
_team_result(
|
||||
_result("collect", "partial", gaps=["missing required evidence: url"]),
|
||||
_result("report", "succeeded"),
|
||||
),
|
||||
)
|
||||
|
||||
assert metadata["task_outcome"] == "incomplete"
|
||||
assert metadata["incomplete_node_ids"] == ["collect"]
|
||||
assert metadata["evidence_gaps"] == {"collect": ["missing required evidence: url"]}
|
||||
assert "Task outcome: incomplete" in context
|
||||
assert "missing required evidence: url" in context
|
||||
assert prefix.startswith("任务未完成:")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status", ["failed", "blocked"])
|
||||
def test_required_failed_or_blocked_node_marks_synthesis_incomplete(status: str) -> None:
|
||||
_, prefix, metadata = TaskAttemptOrchestrator._team_synthesis_outcome(
|
||||
_plan(),
|
||||
_team_result(_result("collect", status), _result("report", "succeeded")),
|
||||
)
|
||||
|
||||
assert metadata["task_outcome"] == "incomplete"
|
||||
assert metadata["incomplete_node_ids"] == ["collect"]
|
||||
assert metadata["node_statuses"]["collect"] == status
|
||||
assert prefix
|
||||
|
||||
|
||||
def test_optional_failed_node_does_not_force_incomplete() -> None:
|
||||
context, prefix, metadata = TaskAttemptOrchestrator._team_synthesis_outcome(
|
||||
_plan(optional_second=True),
|
||||
_team_result(_result("collect", "succeeded"), _result("report", "failed")),
|
||||
)
|
||||
|
||||
assert metadata["task_outcome"] == "complete"
|
||||
assert metadata["incomplete_node_ids"] == []
|
||||
assert "Task outcome: complete" in context
|
||||
assert prefix == ""
|
||||
|
||||
|
||||
def test_all_required_nodes_succeeded_is_complete() -> None:
|
||||
_, prefix, metadata = TaskAttemptOrchestrator._team_synthesis_outcome(
|
||||
_plan(),
|
||||
_team_result(_result("collect", "succeeded"), _result("report", "succeeded")),
|
||||
)
|
||||
|
||||
assert metadata["task_outcome"] == "complete"
|
||||
assert prefix == ""
|
||||
|
||||
|
||||
def test_single_plan_outcome_does_not_add_prefix() -> None:
|
||||
context, prefix, metadata = TaskAttemptOrchestrator._team_synthesis_outcome(
|
||||
TaskExecutionPlan.single("single"),
|
||||
None,
|
||||
)
|
||||
|
||||
assert metadata["task_outcome"] == "single"
|
||||
assert "Task outcome: single" in context
|
||||
assert prefix == ""
|
||||
|
||||
|
||||
class FakeTaskService:
|
||||
def start_run(self, task_id: str, **_: Any) -> None:
|
||||
return None
|
||||
|
||||
def append_run(self, task_id: str, run_id: str, **_: Any) -> TaskRecord:
|
||||
return self.task
|
||||
|
||||
|
||||
class FakeSessionManager:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[dict[str, Any]] = []
|
||||
|
||||
def append_message(self, session_id: str, **kwargs: Any) -> None:
|
||||
self.events.append({"session_id": session_id, **kwargs})
|
||||
|
||||
def update_latest_assistant_event_payload(self, *args: Any, **kwargs: Any) -> None:
|
||||
return None
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
class FixedPlanner:
|
||||
def __init__(self, plan: TaskExecutionPlan) -> None:
|
||||
self.fixed_plan = plan
|
||||
|
||||
async def plan(self, **_: Any) -> TaskExecutionPlan:
|
||||
return self.fixed_plan
|
||||
|
||||
|
||||
def _task() -> TaskRecord:
|
||||
return TaskRecord(
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
description="finance comparison",
|
||||
goal="finance comparison",
|
||||
constraints=[],
|
||||
priority=0,
|
||||
status="open",
|
||||
creator="test",
|
||||
created_at="now",
|
||||
updated_at="now",
|
||||
)
|
||||
|
||||
|
||||
def test_incomplete_team_still_runs_tool_free_synthesis_and_prefixes_output() -> None:
|
||||
plan = _plan()
|
||||
team_result = _team_result(
|
||||
_result("collect", "partial", gaps=["missing required evidence: url"]),
|
||||
_result("report", "succeeded"),
|
||||
)
|
||||
task = _task()
|
||||
task_service = FakeTaskService()
|
||||
task_service.task = task
|
||||
session_manager = FakeSessionManager()
|
||||
loaded = SimpleNamespace(
|
||||
task_service=task_service,
|
||||
task_execution_planner=FixedPlanner(plan),
|
||||
session_manager=session_manager,
|
||||
run_memory_store=None,
|
||||
)
|
||||
orchestrator = TaskAttemptOrchestrator(
|
||||
loaded=loaded,
|
||||
create_loop=lambda: None,
|
||||
make_provider_bundle_for_task=lambda *_: None,
|
||||
)
|
||||
|
||||
async def fake_run_team(*args: Any, **kwargs: Any) -> tuple[TeamRunResult, None]:
|
||||
return team_result, None
|
||||
|
||||
runner_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def runner(message: str, **kwargs: Any) -> AgentRunResult:
|
||||
runner_calls.append(kwargs)
|
||||
return AgentRunResult(
|
||||
session_id="session-1",
|
||||
run_id="main-run",
|
||||
output_text="Available financial comparison.",
|
||||
finish_reason="stop",
|
||||
tool_iterations=0,
|
||||
)
|
||||
|
||||
orchestrator._run_team_for_task = fake_run_team # type: ignore[method-assign]
|
||||
result = asyncio.run(
|
||||
orchestrator.run(
|
||||
message="compare finance",
|
||||
runner=runner,
|
||||
kwargs={
|
||||
"session_id": "session-1",
|
||||
"provider_bundle": SimpleNamespace(),
|
||||
"include_skill_assembly": False,
|
||||
},
|
||||
task=task,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(runner_calls) == 1
|
||||
assert runner_calls[0]["include_tools"] is False
|
||||
assert runner_calls[0]["max_tool_iterations"] == 0
|
||||
assert "Task outcome: incomplete" in runner_calls[0]["execution_context"]
|
||||
assert result.output_text.startswith("任务未完成:")
|
||||
synthesis_event = [event for event in session_manager.events if event.get("event_type") == "task_synthesis_completed"][0]
|
||||
assert synthesis_event["event_payload"]["task_outcome"] == "incomplete"
|
||||
assert synthesis_event["event_payload"]["incomplete_node_ids"] == ["collect"]
|
||||
assert synthesis_event["event_payload"]["node_statuses"] == {
|
||||
"collect": "partial",
|
||||
"report": "succeeded",
|
||||
}
|
||||
assert synthesis_event["event_payload"]["evidence_gaps"] == {
|
||||
"collect": ["missing required evidence: url"]
|
||||
}
|
||||
|
||||
|
||||
def test_incomplete_notice_is_not_prefixed_twice() -> None:
|
||||
text = "任务未完成:缺少官方来源。"
|
||||
|
||||
assert TaskAttemptOrchestrator._apply_incomplete_prefix(text, "任务未完成:部分步骤缺少证据。\n\n") == text
|
||||
231
app-instance/backend/tests/unit/test_team_node_tool_policy.py
Normal file
231
app-instance/backend/tests/unit/test_team_node_tool_policy.py
Normal file
@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from beaver.coordinator import AgentDescriptor, DelegationEnvelope, ExecutionGraph, ExecutionNode, NodeRunResult
|
||||
from beaver.coordinator.execution.scheduler import TeamGraphScheduler
|
||||
from beaver.coordinator.local import LocalAgentRunner
|
||||
from beaver.engine import AgentLoop, EngineLoader
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.tools import BaseTool, ToolContext, ToolExecutor, ToolRegistry, ToolResult, ToolSpec
|
||||
|
||||
|
||||
class RecordingProvider(LLMProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float = 0.7,
|
||||
thinking_enabled: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
self.calls.append({"messages": messages, "tools": tools})
|
||||
return LLMResponse(content="done", finish_reason="stop", provider_name="stub", model="stub")
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "stub"
|
||||
|
||||
|
||||
class StaticToolAssembler:
|
||||
def __init__(self, specs: list[ToolSpec]) -> None:
|
||||
self.specs = specs
|
||||
|
||||
async def assemble(self, **_: Any) -> list[ToolSpec]:
|
||||
return list(self.specs)
|
||||
|
||||
|
||||
class StubTool(BaseTool):
|
||||
def __init__(self, name: str) -> None:
|
||||
self._spec = ToolSpec(name=name, description=name, input_schema={"type": "object"})
|
||||
self.calls = 0
|
||||
|
||||
@property
|
||||
def spec(self) -> ToolSpec:
|
||||
return self._spec
|
||||
|
||||
async def invoke(self, arguments: dict[str, Any], context: ToolContext) -> ToolResult:
|
||||
self.calls += 1
|
||||
return ToolResult(True, "called", self.spec.name)
|
||||
|
||||
|
||||
class CapturingRunner:
|
||||
def __init__(self) -> None:
|
||||
self.envelopes: list[DelegationEnvelope] = []
|
||||
|
||||
async def run(self, envelope: DelegationEnvelope, **_: Any) -> NodeRunResult:
|
||||
self.envelopes.append(envelope)
|
||||
return NodeRunResult(
|
||||
node_id=envelope.node_id or envelope.agent.name,
|
||||
success=True,
|
||||
output_text="done",
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
def _bundle(provider: LLMProvider) -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub", provider_name="stub"),
|
||||
main_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def _loop(tmp_path: Path) -> AgentLoop:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
loaded = loop.boot()
|
||||
specs = [loaded.tool_registry.get(name).spec for name in ("read_file", "web_search")]
|
||||
loaded.tool_assembler = StaticToolAssembler(specs) # type: ignore[assignment]
|
||||
return loop
|
||||
|
||||
|
||||
def _tool_names(tools: list[dict] | None) -> list[str]:
|
||||
return [str(tool["function"]["name"]) for tool in tools or []]
|
||||
|
||||
|
||||
def _graph(allowed_tool_names: list[str] | None) -> ExecutionGraph:
|
||||
return ExecutionGraph(
|
||||
strategy="sequence",
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
node_id="collect",
|
||||
task="collect",
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
allowed_tool_names=allowed_tool_names,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_none_tool_scope_preserves_legacy_selection(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
|
||||
asyncio.run(
|
||||
loop.process_direct(
|
||||
"collect",
|
||||
allowed_tool_names=None,
|
||||
include_skill_assembly=False,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert _tool_names(provider.calls[0]["tools"]) == ["read_file", "web_search"]
|
||||
|
||||
|
||||
def test_empty_tool_scope_exposes_no_tools(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
|
||||
asyncio.run(
|
||||
loop.process_direct(
|
||||
"collect",
|
||||
allowed_tool_names=[],
|
||||
include_skill_assembly=False,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert _tool_names(provider.calls[0]["tools"]) == []
|
||||
|
||||
|
||||
def test_named_tool_scope_exposes_only_allowed_schema(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
|
||||
asyncio.run(
|
||||
loop.process_direct(
|
||||
"collect",
|
||||
allowed_tool_names=["web_search"],
|
||||
include_skill_assembly=False,
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert _tool_names(provider.calls[0]["tools"]) == ["web_search"]
|
||||
|
||||
|
||||
def test_executor_rejects_registered_tool_outside_node_allowlist() -> None:
|
||||
registry = ToolRegistry()
|
||||
write_file = StubTool("write_file")
|
||||
registry.register(write_file)
|
||||
executor = ToolExecutor(registry)
|
||||
context = ToolContext(metadata={"allowed_tool_names": ["web_search"]})
|
||||
|
||||
result = asyncio.run(executor.execute("write_file", {"path": "x"}, context=context))
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "tool_not_allowed"
|
||||
assert write_file.calls == 0
|
||||
|
||||
|
||||
def test_local_agent_runner_passes_node_tool_scope(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
envelope = DelegationEnvelope(
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
parent_run_id="run-root",
|
||||
agent=AgentDescriptor(name="collect"),
|
||||
task="collect",
|
||||
node_id="collect",
|
||||
allowed_tool_names=[],
|
||||
)
|
||||
|
||||
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||
|
||||
assert result.success is True
|
||||
assert _tool_names(provider.calls[0]["tools"]) == []
|
||||
|
||||
|
||||
def test_scheduler_copies_named_node_tool_scope_to_envelope() -> None:
|
||||
runner = CapturingRunner()
|
||||
|
||||
asyncio.run(
|
||||
TeamGraphScheduler(runner).run( # type: ignore[arg-type]
|
||||
_graph(["web_search"]),
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
)
|
||||
)
|
||||
|
||||
assert runner.envelopes[0].allowed_tool_names == ["web_search"]
|
||||
|
||||
|
||||
def test_empty_tool_scope_reaches_provider_through_real_team_path(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
|
||||
asyncio.run(
|
||||
TeamGraphScheduler(LocalAgentRunner(loop)).run(
|
||||
_graph([]),
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert _tool_names(provider.calls[0]["tools"]) == []
|
||||
|
||||
|
||||
def test_none_tool_scope_preserves_tools_through_real_team_path(tmp_path: Path) -> None:
|
||||
loop = _loop(tmp_path)
|
||||
provider = RecordingProvider()
|
||||
|
||||
asyncio.run(
|
||||
TeamGraphScheduler(LocalAgentRunner(loop)).run(
|
||||
_graph(None),
|
||||
parent_task_id="task-parent",
|
||||
parent_session_id="session-root",
|
||||
provider_bundle=_bundle(provider),
|
||||
)
|
||||
)
|
||||
|
||||
assert _tool_names(provider.calls[0]["tools"]) == ["read_file", "web_search"]
|
||||
@ -11,6 +11,7 @@ from beaver.services.user_files import (
|
||||
UserFileNotFoundError,
|
||||
UserFilePathError,
|
||||
UserFileSizeError,
|
||||
UserFileStorageError,
|
||||
UserFileService,
|
||||
normalize_user_path,
|
||||
)
|
||||
@ -151,3 +152,68 @@ def test_minio_storage_rejects_paths_that_escape_namespace() -> None:
|
||||
|
||||
with pytest.raises(UserFilePathError):
|
||||
storage._user_path("users/bob/uploads/secret.txt")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minio_storage_translates_s3_errors_to_user_file_errors() -> None:
|
||||
from minio.error import S3Error
|
||||
|
||||
class FakeMinioClient:
|
||||
def list_objects(self, *args, **kwargs):
|
||||
raise S3Error(
|
||||
None,
|
||||
"SignatureDoesNotMatch",
|
||||
"The request signature we calculated does not match",
|
||||
"/beaver-user-files",
|
||||
"request-id",
|
||||
"host-id",
|
||||
bucket_name="beaver-user-files",
|
||||
)
|
||||
|
||||
storage = object.__new__(MinIOUserFileStorage)
|
||||
storage.config = MinIOStorageConfig(
|
||||
endpoint="minio.local:9000",
|
||||
access_key="alice-access",
|
||||
secret_key="alice-secret",
|
||||
bucket="beaver-user-files",
|
||||
namespace="users/alice",
|
||||
)
|
||||
storage.client = FakeMinioClient()
|
||||
|
||||
with pytest.raises(UserFileStorageError) as exc_info:
|
||||
await storage.list_dir("uploads")
|
||||
|
||||
assert "SignatureDoesNotMatch" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minio_storage_does_not_report_auth_errors_as_missing_files() -> None:
|
||||
from minio.error import S3Error
|
||||
|
||||
class FakeMinioClient:
|
||||
def stat_object(self, *args, **kwargs):
|
||||
raise S3Error(
|
||||
None,
|
||||
"SignatureDoesNotMatch",
|
||||
"The request signature we calculated does not match",
|
||||
"/beaver-user-files/uploads/input.txt",
|
||||
"request-id",
|
||||
"host-id",
|
||||
bucket_name="beaver-user-files",
|
||||
object_name="users/alice/uploads/input.txt",
|
||||
)
|
||||
|
||||
storage = object.__new__(MinIOUserFileStorage)
|
||||
storage.config = MinIOStorageConfig(
|
||||
endpoint="minio.local:9000",
|
||||
access_key="alice-access",
|
||||
secret_key="alice-secret",
|
||||
bucket="beaver-user-files",
|
||||
namespace="users/alice",
|
||||
)
|
||||
storage.client = FakeMinioClient()
|
||||
|
||||
with pytest.raises(UserFileStorageError) as exc_info:
|
||||
await storage.read_file("uploads/input.txt")
|
||||
|
||||
assert "SignatureDoesNotMatch" in str(exc_info.value)
|
||||
|
||||
@ -7,7 +7,7 @@ from fastapi.testclient import TestClient
|
||||
from beaver.interfaces.web.app import create_app
|
||||
from beaver.services.agent_service import AgentService
|
||||
from beaver.services.user_file_resolver import UserFileStorageResolver
|
||||
from beaver.services.user_files import LocalUserFileStorage, UserFileService
|
||||
from beaver.services.user_files import LocalUserFileStorage, UserFileService, UserFileStorageError
|
||||
|
||||
|
||||
def _auth_headers(app, username: str = "alice") -> dict[str, str]:
|
||||
@ -191,6 +191,26 @@ def test_user_files_api_authenticated_request_resolves_identity(tmp_path: Path,
|
||||
assert seen[0].storage_namespace == "users/alice"
|
||||
|
||||
|
||||
def test_user_files_api_reports_storage_errors_as_unavailable(tmp_path: Path, monkeypatch) -> None:
|
||||
service = AgentService(workspace=tmp_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
class BrokenStorage:
|
||||
async def list_dir(self, path: str):
|
||||
raise UserFileStorageError("User file storage list directory failed: SignatureDoesNotMatch")
|
||||
|
||||
async def fake_service(self):
|
||||
return UserFileService(BrokenStorage())
|
||||
|
||||
monkeypatch.setattr(UserFileStorageResolver, "service", fake_service)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/user-files/browse", params={"path": "uploads"}, headers=_auth_headers(app))
|
||||
|
||||
assert response.status_code == 503
|
||||
assert "SignatureDoesNotMatch" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_user_files_api_streams_upload_and_enforces_configured_limit(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setenv("BEAVER_USER_FILES_MAX_UPLOAD_BYTES", "5")
|
||||
service = AgentService(workspace=tmp_path)
|
||||
|
||||
@ -2,23 +2,43 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
|
||||
from beaver.tools.builtins import web
|
||||
|
||||
|
||||
def _disable_ddgs(monkeypatch) -> None:
|
||||
def _raise_unavailable(query: str, limit: int) -> list[dict[str, str]]:
|
||||
raise ModuleNotFoundError("ddgs disabled for fallback test")
|
||||
|
||||
monkeypatch.setattr(web, "_search_ddgs", _raise_unavailable)
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
headers = {"content-type": "text/html"}
|
||||
status_code = 200
|
||||
fetch_html = """
|
||||
<html>
|
||||
<head><title>Investor Reports</title></head>
|
||||
<body>
|
||||
<a href="/reports/2025-annual.pdf">2025 Annual Report</a>
|
||||
<a href="https://example.com/investor">Investor Centre</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
def __init__(self, url: str = "https://example.com") -> None:
|
||||
self.url = url
|
||||
if "duckduckgo.com" in url:
|
||||
self.text = '<a class="result__a" href="https://duck.example.com">Duck Example</a>'
|
||||
else:
|
||||
elif "bing.com" in url:
|
||||
self.text = (
|
||||
'<li class="b_algo"><h2><a href="https://example.com">Example</a></h2>'
|
||||
"<p>Example result</p></li>"
|
||||
)
|
||||
else:
|
||||
self.text = self.fetch_html
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
@ -48,6 +68,7 @@ class _FakeAsyncClient:
|
||||
|
||||
def test_web_tools_use_environment_proxy_settings(monkeypatch) -> None:
|
||||
_FakeAsyncClient.calls = []
|
||||
_disable_ddgs(monkeypatch)
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
async def _run() -> None:
|
||||
@ -73,10 +94,39 @@ def test_web_fetch_uses_short_connect_timeout(monkeypatch) -> None:
|
||||
assert timeout.read == 12
|
||||
|
||||
|
||||
def test_web_fetch_returns_page_title_and_links(monkeypatch) -> None:
|
||||
_FakeAsyncClient.calls = []
|
||||
_FakeAsyncClient.urls = []
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
raw = asyncio.run(web.WebFetchTool().execute(url="https://example.com/investor"))
|
||||
|
||||
payload = json.loads(raw)
|
||||
assert payload["success"] is True
|
||||
assert payload["title"] == "Investor Reports"
|
||||
assert payload["links"] == [
|
||||
{
|
||||
"text": "2025 Annual Report",
|
||||
"url": "https://example.com/reports/2025-annual.pdf",
|
||||
},
|
||||
{
|
||||
"text": "Investor Centre",
|
||||
"url": "https://example.com/investor",
|
||||
},
|
||||
]
|
||||
assert payload["pdf_links"] == [
|
||||
{
|
||||
"text": "2025 Annual Report",
|
||||
"url": "https://example.com/reports/2025-annual.pdf",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_web_search_uses_reachable_bing_endpoint_first(monkeypatch) -> None:
|
||||
_FakeAsyncClient.calls = []
|
||||
_FakeAsyncClient.urls = []
|
||||
_FakeAsyncClient.fail_bing = False
|
||||
_disable_ddgs(monkeypatch)
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
raw = asyncio.run(web.WebSearchTool().execute(query="weather beijing"))
|
||||
@ -95,10 +145,60 @@ def test_web_search_uses_reachable_bing_endpoint_first(monkeypatch) -> None:
|
||||
assert timeout.read == 8
|
||||
|
||||
|
||||
def test_web_search_prefers_ddgs_provider_when_available(monkeypatch) -> None:
|
||||
class _FakeDDGS:
|
||||
def text(self, query: str, max_results: int) -> list[dict[str, str]]:
|
||||
assert query == "weather beijing"
|
||||
assert max_results == 5
|
||||
return [
|
||||
{
|
||||
"title": "Beijing Weather",
|
||||
"href": "https://weather.example.com/beijing",
|
||||
"body": "Current Beijing weather forecast",
|
||||
}
|
||||
]
|
||||
|
||||
fake_module = types.SimpleNamespace(DDGS=_FakeDDGS)
|
||||
monkeypatch.setitem(sys.modules, "ddgs", fake_module)
|
||||
_FakeAsyncClient.calls = []
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
raw = asyncio.run(web.WebSearchTool().execute(query="weather beijing"))
|
||||
|
||||
payload = json.loads(raw)
|
||||
assert payload["success"] is True
|
||||
assert payload["engine"] == "ddgs"
|
||||
assert payload["quality"] == "high"
|
||||
assert payload["results"] == [
|
||||
{
|
||||
"title": "Beijing Weather",
|
||||
"url": "https://weather.example.com/beijing",
|
||||
"snippet": "Current Beijing weather forecast",
|
||||
}
|
||||
]
|
||||
assert _FakeAsyncClient.calls == []
|
||||
|
||||
|
||||
def test_web_search_reports_low_quality_for_irrelevant_results(monkeypatch) -> None:
|
||||
_FakeAsyncClient.calls = []
|
||||
_FakeAsyncClient.urls = []
|
||||
_FakeAsyncClient.fail_bing = False
|
||||
_disable_ddgs(monkeypatch)
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
raw = asyncio.run(web.WebSearchTool().execute(query="weather beijing"))
|
||||
|
||||
payload = json.loads(raw)
|
||||
assert payload["success"] is True
|
||||
assert payload["quality"] == "low"
|
||||
assert payload["low_relevance_reason"] == "results do not overlap enough with query terms"
|
||||
|
||||
|
||||
def test_web_search_falls_back_when_bing_is_unavailable(monkeypatch) -> None:
|
||||
_FakeAsyncClient.calls = []
|
||||
_FakeAsyncClient.urls = []
|
||||
_FakeAsyncClient.fail_bing = True
|
||||
_disable_ddgs(monkeypatch)
|
||||
monkeypatch.setattr(web.httpx, "AsyncClient", _FakeAsyncClient)
|
||||
|
||||
raw = asyncio.run(web.WebSearchTool().execute(query="weather beijing"))
|
||||
|
||||
Reference in New Issue
Block a user