346 lines
12 KiB
Python
346 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
|
||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||
from beaver.tasks import MainAgentRouter, TaskRecord
|
||
|
||
|
||
class RouterProvider(LLMProvider):
|
||
def __init__(self, response: str | Exception) -> None:
|
||
super().__init__()
|
||
self.response = response
|
||
self.calls: list[dict] = []
|
||
|
||
async def chat(
|
||
self,
|
||
messages: list[dict],
|
||
tools: list[dict] | None = None,
|
||
model: str | None = None,
|
||
max_tokens: int = 4096,
|
||
temperature: float = 0.7,
|
||
thinking_enabled: bool | None = None,
|
||
) -> LLMResponse:
|
||
self.calls.append(
|
||
{
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
"model": model,
|
||
"thinking_enabled": thinking_enabled,
|
||
}
|
||
)
|
||
if isinstance(self.response, Exception):
|
||
raise self.response
|
||
return LLMResponse(content=self.response, finish_reason="stop", provider_name="stub", model="stub-model")
|
||
|
||
def get_default_model(self) -> str:
|
||
return "stub-model"
|
||
|
||
|
||
class SequenceRouterProvider(LLMProvider):
|
||
def __init__(self, responses: list[str | Exception]) -> None:
|
||
super().__init__()
|
||
self.responses = list(responses)
|
||
self.calls: list[dict] = []
|
||
|
||
async def chat(
|
||
self,
|
||
messages: list[dict],
|
||
tools: list[dict] | None = None,
|
||
model: str | None = None,
|
||
max_tokens: int = 4096,
|
||
temperature: float = 0.7,
|
||
thinking_enabled: bool | None = None,
|
||
) -> LLMResponse:
|
||
self.calls.append(
|
||
{
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
"model": model,
|
||
"thinking_enabled": thinking_enabled,
|
||
}
|
||
)
|
||
response = self.responses.pop(0)
|
||
if isinstance(response, Exception):
|
||
raise response
|
||
return LLMResponse(content=response, finish_reason="stop", provider_name="stub", model="stub-model")
|
||
|
||
def get_default_model(self) -> str:
|
||
return "stub-model"
|
||
|
||
|
||
def _task() -> TaskRecord:
|
||
return TaskRecord(
|
||
task_id="task-1",
|
||
session_id="web:task",
|
||
description="实现任务连续性",
|
||
goal="实现任务连续性",
|
||
constraints=[],
|
||
priority=0,
|
||
status="awaiting_acceptance",
|
||
creator="test",
|
||
created_at="now",
|
||
updated_at="now",
|
||
metadata={"short_title": "任务连续性"},
|
||
)
|
||
|
||
|
||
def _weather_task() -> TaskRecord:
|
||
task = _task()
|
||
task.description = "珠海天气怎样"
|
||
task.goal = "珠海天气怎样"
|
||
task.metadata["short_title"] = "查询珠海天气"
|
||
return task
|
||
|
||
|
||
def test_router_continues_active_task_from_llm_decision() -> None:
|
||
provider = RouterProvider('{"action":"continue_task","reason":"related","short_title":"任务连续性"}')
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"再把输入框标识也补上",
|
||
active_task=_task(),
|
||
provider=provider,
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.starts_new_task is False
|
||
assert decision.short_title == "任务连续性"
|
||
assert provider.calls[0]["max_tokens"] == 256
|
||
|
||
|
||
def test_router_keeps_same_session_but_starts_new_task_for_standalone_weather_repeat() -> None:
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"珠海天气怎么样",
|
||
active_task=_weather_task(),
|
||
provider=RouterProvider('{"action":"continue_task","reason":"neutral follow-up","short_title":"查询珠海天气"}'),
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "create_task"
|
||
assert decision.starts_new_task is True
|
||
assert "fresh standalone task request" in decision.reason
|
||
|
||
|
||
def test_router_allows_explicit_followup_to_continue_active_weather_task() -> None:
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"顺便查一下深圳",
|
||
active_task=_weather_task(),
|
||
provider=RouterProvider('{"action":"continue_task","reason":"related follow-up","short_title":"查询珠海天气"}'),
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "continue_task"
|
||
assert decision.starts_new_task is False
|
||
|
||
|
||
def test_router_marks_revision_from_llm_decision() -> None:
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"再详细一点,并加上表格",
|
||
active_task=_task(),
|
||
provider=RouterProvider('{"action":"revise_task","reason":"user requested changes","short_title":"任务连续性"}'),
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.starts_new_task is False
|
||
assert decision.action == "revise_task"
|
||
|
||
|
||
def test_router_receives_thinking_mode() -> None:
|
||
provider = RouterProvider('{"action":"simple_chat","reason":"simple"}')
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"请判断一下这个概念是否合理",
|
||
provider=provider,
|
||
thinking_enabled=False,
|
||
)
|
||
)
|
||
|
||
assert not decision.is_task
|
||
assert provider.calls[0]["thinking_enabled"] is False
|
||
|
||
|
||
def test_router_fast_paths_obvious_simple_chat_without_provider_call() -> None:
|
||
provider = RouterProvider('{"action":"new_task","reason":"should not be used"}')
|
||
|
||
decision = asyncio.run(MainAgentRouter().classify("你好", provider=provider))
|
||
punctuated = asyncio.run(MainAgentRouter().classify("你好!", provider=provider))
|
||
translation = asyncio.run(MainAgentRouter().classify("翻译这句话:hello world", provider=provider))
|
||
|
||
assert not decision.is_task
|
||
assert decision.action == "simple_chat"
|
||
assert decision.reason == "obvious_simple_chat"
|
||
assert not punctuated.is_task
|
||
assert punctuated.action == "simple_chat"
|
||
assert not translation.is_task
|
||
assert translation.action == "simple_chat"
|
||
assert provider.calls == []
|
||
|
||
|
||
def test_router_sends_broad_explanations_to_intent_llm() -> None:
|
||
provider = RouterProvider('{"action":"simple_chat","reason":"intent decided concept explanation"}')
|
||
|
||
explanation = asyncio.run(MainAgentRouter().classify("解释一下什么是 MCP", provider=provider))
|
||
definition = asyncio.run(MainAgentRouter().classify("什么是 context engineering", provider=provider))
|
||
|
||
assert not explanation.is_task
|
||
assert explanation.reason == "intent decided concept explanation"
|
||
assert not definition.is_task
|
||
assert definition.reason == "intent decided concept explanation"
|
||
assert len(provider.calls) == 2
|
||
|
||
|
||
def test_router_fast_paths_obvious_task_without_provider_call() -> None:
|
||
provider = RouterProvider('{"action":"simple_chat","reason":"should not be used"}')
|
||
|
||
decision = asyncio.run(MainAgentRouter().classify("帮我查一下今天深圳天气", provider=provider))
|
||
current_event = asyncio.run(
|
||
MainAgentRouter().classify("解释一下今天法国队在世界杯的表现为什么那么好", provider=provider)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "create_task"
|
||
assert decision.reason == "obvious_task"
|
||
assert current_event.is_task
|
||
assert current_event.action == "create_task"
|
||
assert provider.calls == []
|
||
|
||
|
||
def test_router_does_not_simple_fast_path_current_event_explanations() -> None:
|
||
provider = RouterProvider('{"action":"simple_chat","reason":"llm fallback"}')
|
||
|
||
decision = asyncio.run(MainAgentRouter().classify("解释一下昨晚法国队在世界杯的表现为什么那么好", provider=provider))
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "create_task"
|
||
assert decision.reason == "obvious_task"
|
||
assert provider.calls == []
|
||
|
||
|
||
def test_router_keeps_active_task_followups_in_llm_path() -> None:
|
||
provider = RouterProvider('{"action":"revise_task","reason":"needs revision","short_title":"任务连续性"}')
|
||
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"这个也加上",
|
||
active_task=_task(),
|
||
provider=provider,
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "revise_task"
|
||
assert len(provider.calls) == 1
|
||
|
||
|
||
def test_router_injects_intent_skill_guidance() -> None:
|
||
provider = RouterProvider('{"action":"new_task","reason":"needs weather tool","short_title":"珠海天气"}')
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"帮我判断这个需求要不要进入任务模式",
|
||
provider=provider,
|
||
intent_skill="Weather and current external data must be routed to new_task.",
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.starts_new_task is True
|
||
assert decision.action == "create_task"
|
||
prompt = provider.calls[0]["messages"][1]["content"]
|
||
assert "Intent Agent skill guidance" in prompt
|
||
assert "Weather and current external data" in prompt
|
||
|
||
|
||
def test_router_prompt_treats_unrelated_lightweight_conversation_as_new_topic() -> None:
|
||
provider = RouterProvider('{"action":"simple_chat","reason":"unrelated lightweight conversation"}')
|
||
|
||
asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"吃饭没",
|
||
active_task=_task(),
|
||
provider=provider,
|
||
)
|
||
)
|
||
|
||
prompt = provider.calls[0]["messages"][1]["content"]
|
||
assert "unrelated lightweight conversation" in prompt
|
||
assert "must not be classified as revise_task merely because the active Task is awaiting acceptance" in prompt
|
||
assert "A Session is the durable conversation/device/group context" in prompt
|
||
assert "Repeating '珠海天气怎么样' later is a new Task" in prompt
|
||
|
||
|
||
def test_router_closes_active_task_from_llm_decision() -> None:
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"这个任务结束了",
|
||
active_task=_task(),
|
||
provider=RouterProvider('{"action":"close_task","reason":"user said done"}'),
|
||
)
|
||
)
|
||
|
||
assert not decision.is_task
|
||
assert decision.closes_task is True
|
||
|
||
|
||
def test_router_fallback_keeps_active_task_but_not_new_task() -> None:
|
||
active = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"继续",
|
||
active_task=_task(),
|
||
provider=RouterProvider(RuntimeError("provider down")),
|
||
)
|
||
)
|
||
inactive = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"implement something",
|
||
active_task=None,
|
||
provider=RouterProvider(RuntimeError("provider down")),
|
||
)
|
||
)
|
||
|
||
assert active.is_task
|
||
assert not inactive.is_task
|
||
|
||
|
||
def test_router_retries_once_after_provider_failure() -> None:
|
||
provider = SequenceRouterProvider(
|
||
[
|
||
TimeoutError(),
|
||
'{"action":"new_task","reason":"needs search","short_title":"中美会面"}',
|
||
]
|
||
)
|
||
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"帮我判断这次中美会面分析需求要不要进入任务模式",
|
||
provider=provider,
|
||
)
|
||
)
|
||
|
||
assert decision.is_task
|
||
assert decision.action == "create_task"
|
||
assert len(provider.calls) == 2
|
||
|
||
|
||
def test_router_fallback_after_two_provider_failures() -> None:
|
||
provider = SequenceRouterProvider([TimeoutError(), RuntimeError("provider down")])
|
||
|
||
decision = asyncio.run(
|
||
MainAgentRouter().classify(
|
||
"帮我判断这次中美会面分析需求要不要进入任务模式",
|
||
provider=provider,
|
||
)
|
||
)
|
||
|
||
assert not decision.is_task
|
||
assert decision.reason == "router_failed: provider down"
|
||
assert len(provider.calls) == 2
|