feat(engine): allow replay tool executor injection
This commit is contained in:
@ -247,6 +247,7 @@ class AgentLoop:
|
|||||||
attempt_index: int | None = None,
|
attempt_index: int | None = None,
|
||||||
pinned_skill_names: list[str] | None = None,
|
pinned_skill_names: list[str] | None = None,
|
||||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||||
|
tool_executor_override: Any = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
intent_agent_decision: dict[str, Any] | None = None,
|
intent_agent_decision: dict[str, Any] | None = None,
|
||||||
channel_identity: ChannelIdentity | None = None,
|
channel_identity: ChannelIdentity | None = None,
|
||||||
@ -297,6 +298,7 @@ class AgentLoop:
|
|||||||
attempt_index=attempt_index,
|
attempt_index=attempt_index,
|
||||||
pinned_skill_names=pinned_skill_names,
|
pinned_skill_names=pinned_skill_names,
|
||||||
pinned_skill_contexts=pinned_skill_contexts,
|
pinned_skill_contexts=pinned_skill_contexts,
|
||||||
|
tool_executor_override=tool_executor_override,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
intent_agent_decision=intent_agent_decision,
|
intent_agent_decision=intent_agent_decision,
|
||||||
channel_identity=channel_identity,
|
channel_identity=channel_identity,
|
||||||
@ -335,6 +337,7 @@ class AgentLoop:
|
|||||||
attempt_index: int | None = None,
|
attempt_index: int | None = None,
|
||||||
pinned_skill_names: list[str] | None = None,
|
pinned_skill_names: list[str] | None = None,
|
||||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||||
|
tool_executor_override: Any = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
intent_agent_decision: dict[str, Any] | None = None,
|
intent_agent_decision: dict[str, Any] | None = None,
|
||||||
channel_identity: ChannelIdentity | None = None,
|
channel_identity: ChannelIdentity | None = None,
|
||||||
@ -354,6 +357,7 @@ class AgentLoop:
|
|||||||
tool_registry = self._require_loaded("tool_registry")
|
tool_registry = self._require_loaded("tool_registry")
|
||||||
tool_assembler = self._require_loaded("tool_assembler")
|
tool_assembler = self._require_loaded("tool_assembler")
|
||||||
tool_executor = self._require_loaded("tool_executor")
|
tool_executor = self._require_loaded("tool_executor")
|
||||||
|
effective_tool_executor = tool_executor_override or tool_executor
|
||||||
skills_loader = self._require_loaded("skills_loader")
|
skills_loader = self._require_loaded("skills_loader")
|
||||||
skill_assembler = self._require_loaded("skill_assembler")
|
skill_assembler = self._require_loaded("skill_assembler")
|
||||||
skill_learning_service = self._require_loaded("skill_learning_service")
|
skill_learning_service = self._require_loaded("skill_learning_service")
|
||||||
@ -789,7 +793,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
|
result = await effective_tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||||
session_manager.append_message(
|
session_manager.append_message(
|
||||||
resolved_session_id,
|
resolved_session_id,
|
||||||
run_id=resolved_run_id,
|
run_id=resolved_run_id,
|
||||||
|
|||||||
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from beaver.engine.loader import EngineLoader
|
||||||
|
from beaver.engine.loop import AgentLoop
|
||||||
|
from beaver.engine.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from beaver.engine.providers.factory import ProviderBundle
|
||||||
|
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallingProvider(LLMProvider):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
thinking_enabled: bool | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call-1",
|
||||||
|
name="read_file",
|
||||||
|
arguments={"path": "README.md"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return LLMResponse(content="done")
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "stub"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
|
||||||
|
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||||
|
loaded = loop.boot()
|
||||||
|
provider = ToolCallingProvider()
|
||||||
|
runtime = SimpleNamespace(model="stub", provider_name="stub")
|
||||||
|
replay_executor = ReplayToolExecutor(
|
||||||
|
loaded.tool_executor,
|
||||||
|
registry=loaded.tool_registry,
|
||||||
|
policy=ReplayToolPolicy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await loop.process_direct(
|
||||||
|
"Read the README.",
|
||||||
|
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
|
||||||
|
include_skill_assembly=False,
|
||||||
|
pinned_skill_names=[],
|
||||||
|
tool_executor_override=replay_executor,
|
||||||
|
max_tool_iterations=2,
|
||||||
|
source="skill_replay_eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.output_text == "done"
|
||||||
|
assert replay_executor.traces
|
||||||
|
assert replay_executor.traces[0]["tool_name"] == "read_file"
|
||||||
Reference in New Issue
Block a user