feat(engine): allow replay tool executor injection

This commit is contained in:
2026-06-08 13:32:14 +08:00
parent eb69bb168a
commit 70014c0f70
2 changed files with 76 additions and 1 deletions

View File

@ -247,6 +247,7 @@ class AgentLoop:
attempt_index: int | None = None,
pinned_skill_names: list[str] | None = None,
pinned_skill_contexts: list[SkillContext] | None = None,
tool_executor_override: Any = None,
allow_candidate_generation: bool = False,
intent_agent_decision: dict[str, Any] | None = None,
channel_identity: ChannelIdentity | None = None,
@ -297,6 +298,7 @@ class AgentLoop:
attempt_index=attempt_index,
pinned_skill_names=pinned_skill_names,
pinned_skill_contexts=pinned_skill_contexts,
tool_executor_override=tool_executor_override,
allow_candidate_generation=allow_candidate_generation,
intent_agent_decision=intent_agent_decision,
channel_identity=channel_identity,
@ -335,6 +337,7 @@ class AgentLoop:
attempt_index: int | None = None,
pinned_skill_names: list[str] | None = None,
pinned_skill_contexts: list[SkillContext] | None = None,
tool_executor_override: Any = None,
allow_candidate_generation: bool = False,
intent_agent_decision: dict[str, Any] | None = None,
channel_identity: ChannelIdentity | None = None,
@ -354,6 +357,7 @@ class AgentLoop:
tool_registry = self._require_loaded("tool_registry")
tool_assembler = self._require_loaded("tool_assembler")
tool_executor = self._require_loaded("tool_executor")
effective_tool_executor = tool_executor_override or tool_executor
skills_loader = self._require_loaded("skills_loader")
skill_assembler = self._require_loaded("skill_assembler")
skill_learning_service = self._require_loaded("skill_learning_service")
@ -789,7 +793,7 @@ class AgentLoop:
iterations += 1
for tool_call in response.tool_calls:
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
result = await effective_tool_executor.execute_tool_call(tool_call, context=tool_context)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
import pytest
from beaver.engine.loader import EngineLoader
from beaver.engine.loop import AgentLoop
from beaver.engine.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from beaver.engine.providers.factory import ProviderBundle
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy
class ToolCallingProvider(LLMProvider):
def __init__(self) -> None:
super().__init__()
self.calls = 0
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int | None = None,
temperature: float = 0.7,
thinking_enabled: bool | None = None,
) -> LLMResponse:
self.calls += 1
if self.calls == 1:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="call-1",
name="read_file",
arguments={"path": "README.md"},
)
],
)
return LLMResponse(content="done")
def get_default_model(self) -> str:
return "stub"
@pytest.mark.asyncio
async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
loaded = loop.boot()
provider = ToolCallingProvider()
runtime = SimpleNamespace(model="stub", provider_name="stub")
replay_executor = ReplayToolExecutor(
loaded.tool_executor,
registry=loaded.tool_registry,
policy=ReplayToolPolicy(),
)
result = await loop.process_direct(
"Read the README.",
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
include_skill_assembly=False,
pinned_skill_names=[],
tool_executor_override=replay_executor,
max_tool_iterations=2,
source="skill_replay_eval",
)
assert result.output_text == "done"
assert replay_executor.traces
assert replay_executor.traces[0]["tool_name"] == "read_file"