Files
beaver_project/app-instance/backend/beaver/engine/loop.py

690 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified agent loop used by all Beaver agents."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4
from beaver.engine.context import ContextBuildInput, SessionContext
from beaver.engine.providers import ProviderBundle, make_provider_bundle
from beaver.tools import ToolContext
from .loader import EngineLoader, EngineLoadResult
@dataclass(slots=True)
class AgentProfile:
"""Runtime profile for a Beaver agent instance."""
name: str = "default"
system_prompt: str = ""
default_model: str = "gpt-4.1-mini"
max_tokens: int = 4096
temperature: float = 0.2
max_tool_iterations: int = 8
@dataclass(slots=True)
class AgentRunResult:
"""一次 direct run 的最小结果结构。"""
session_id: str
run_id: str
output_text: str
finish_reason: str
tool_iterations: int
provider_name: str | None = None
model: str | None = None
usage: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class _DirectRunRequest:
"""运行循环中的单个 direct task。"""
task: str
kwargs: dict[str, Any]
future: asyncio.Future[AgentRunResult]
class AgentLoop:
"""Single execution kernel shared by root agents and delegated agents."""
def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None:
self.profile = profile or AgentProfile()
self.loader = loader or EngineLoader()
self.loaded: EngineLoadResult | None = None
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
self._running = False
self._stop_requested = False
def boot(self) -> EngineLoadResult:
"""Load shared runtime capabilities once for this agent instance."""
if self.loaded is None:
self.loaded = self.loader.load()
return self.loaded
@property
def is_running(self) -> bool:
return self._running
async def run(self) -> None:
"""启动最小运行循环,顺序消费提交进来的 direct tasks。
第一版故意保持克制:
1. 只做单消费者串行消费
2. 真正执行仍复用 `process_direct()`
3. 不引入 bus / worker / priority / retry
"""
if self._running:
raise RuntimeError("AgentLoop.run() is already active")
self.boot()
self._run_queue = asyncio.Queue()
self._running = True
self._stop_requested = False
try:
while True:
item = await self._run_queue.get()
if item is None:
if self._stop_requested:
break
continue
if item.future.cancelled():
continue
try:
result = await self._process_direct_impl(item.task, **item.kwargs)
except asyncio.CancelledError:
if not item.future.done():
item.future.cancel()
raise
except Exception as exc: # pragma: no cover - defensive queue path
if not item.future.done():
item.future.set_exception(exc)
else:
if not item.future.done():
item.future.set_result(result)
finally:
if self._run_queue is not None:
while True:
try:
pending = self._run_queue.get_nowait()
except asyncio.QueueEmpty:
break
if isinstance(pending, _DirectRunRequest) and not pending.future.done():
pending.future.set_exception(
RuntimeError("AgentLoop.run() stopped before processing the queued task")
)
self._running = False
self._stop_requested = False
self._run_queue = None
async def stop(self) -> None:
"""停止运行循环。
第一版语义:
- 不再接收新任务
- 当前已经取出的任务允许收尾
- 不自动 close runtime
"""
if not self._running or self._run_queue is None:
return
self._stop_requested = True
await self._run_queue.put(None)
async def submit_direct(
self,
task: str,
**kwargs: Any,
) -> AgentRunResult:
"""向运行中的 loop 提交一个 direct task并等待结果。"""
if not self._running or self._run_queue is None:
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
if self._stop_requested:
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
return await future
def close(self) -> None:
"""关闭当前 loop 持有的 runtime。
第 6 阶段先把生命周期最小骨架立住:
- `boot()` 负责建立 runtime
- `close()` 负责释放由 runtime 持有的资源
- 之后再在此基础上扩 `run()/stop()/shutdown hooks`
"""
if self._running:
raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first")
if self.loaded is None:
return
try:
self.loaded.close()
finally:
self.loaded = None
async def process_direct(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""跑通最小 direct run 主链。
当前主链刻意保持克制,只解决这些事情:
1. 确保 session 存在
2. 用 frozen memory + history 组 prompt
3. 调 provider
4. 若有 tool calls则进入最小 tool loop
5. 把 user/assistant/tool 消息和 usage 写回 session
"""
if self._running:
raise RuntimeError(
"AgentLoop.process_direct() is disabled while run() is active; "
"submit tasks via submit_direct() instead."
)
return await self._process_direct_impl(
task,
session_id=session_id,
source=source,
user_id=user_id,
title=title,
execution_context=execution_context,
model=model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model,
max_tokens=max_tokens,
temperature=temperature,
max_tool_iterations=max_tool_iterations,
provider_bundle=provider_bundle,
)
async def _process_direct_impl(
self,
task: str,
*,
session_id: str | None = None,
source: str = "direct",
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
routing: Any = None,
fallback_target: dict[str, Any] | None = None,
auxiliary_target: dict[str, Any] | None = None,
embedding_target: dict[str, Any] | None = None,
embedding_model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
max_tool_iterations: int | None = None,
provider_bundle: ProviderBundle | None = None,
) -> AgentRunResult:
"""真正执行一轮 direct run 的内部实现。
规则:
- 外部直接调用时走 `process_direct()`
- 运行循环内部消费时走 `_process_direct_impl()`
- 这样才能保证 run 模式下外部不能绕过队列直接执行
"""
loaded = self.boot()
session_manager = self._require_loaded("session_manager")
memory_service = self._require_loaded("memory_service")
context_builder = self._require_loaded("context_builder")
tool_registry = self._require_loaded("tool_registry")
tool_executor = self._require_loaded("tool_executor")
skill_assembler = self._require_loaded("skill_assembler")
resolved_session_id = session_id or uuid4().hex
resolved_run_id = uuid4().hex
resolved_model = model or self.profile.default_model
resolved_max_tokens = max_tokens or self.profile.max_tokens
resolved_temperature = self.profile.temperature if temperature is None else temperature
resolved_max_tool_iterations = (
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
)
# 每次新运行开始前都通过 MemoryService 刷新 live state。
# 这样 memory policy 会收口在 service而不是散在 loop 里。
memory_service.reload_for_new_run()
session_manager.ensure_session(
resolved_session_id,
source=source,
model=resolved_model,
title=title,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_started",
event_payload={
"source": source,
"model": resolved_model,
"agent_name": self.profile.name,
},
content=task,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = False
iterations = 0
final_usage: dict[str, Any] = {}
final_provider_name: str | None = provider_name
final_model: str | None = resolved_model
try:
bundle = provider_bundle or make_provider_bundle(
model=resolved_model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model or "text-embedding-v4",
)
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
skill_selector_model = (
bundle.auxiliary_runtime.model
if bundle.auxiliary_runtime is not None
else bundle.main_runtime.model
)
assembled_skills = await skill_assembler.assemble(
task_description=task,
provider=skill_selector_provider,
model=skill_selector_model,
embedding_runtime=bundle.embedding_runtime,
)
skill_activation_messages = context_builder.build_skill_activation_messages(
assembled_skills.activated_skills
)
if skill_activation_messages:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="skill_activation_snapshotted",
event_payload={
"activation_messages": skill_activation_messages,
},
content="\n\n".join(message["content"] for message in skill_activation_messages) or None,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
build_input = ContextBuildInput(
base_system_prompt=self.profile.system_prompt,
history=session_manager.get_history(resolved_session_id),
current_user_input=task,
memory_snapshot=memory_service.get_snapshot(),
activated_skills=assembled_skills.activated_skills,
session_context=SessionContext(
session_id=resolved_session_id,
source=source,
model=resolved_model,
user_id=user_id,
),
execution_context=execution_context,
)
context_result = context_builder.build_messages(build_input)
session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="system_prompt_snapshotted",
event_payload={
"source": source,
"model": resolved_model,
"system_prompt_length": len(context_result.system_prompt),
},
content=context_result.system_prompt,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
user_message_recorded = True
provider = bundle.main_provider
messages = list(context_result.messages)
tool_schemas = tool_registry.export_provider_schemas()
tool_context = ToolContext(
workspace=str(loaded.workspace),
session_id=resolved_session_id,
user_id=user_id,
services={
"session_manager": session_manager,
"memory_service": memory_service,
"memory_store": memory_service.get_store(),
"tool_registry": tool_registry,
},
metadata={
"source": source,
"agent_name": self.profile.name,
},
)
final_text = ""
final_finish_reason = "stop"
final_provider_name = bundle.main_runtime.provider_name
final_model = bundle.main_runtime.model
while True:
response = await provider.chat(
messages=messages,
tools=tool_schemas,
model=final_model,
max_tokens=resolved_max_tokens,
temperature=resolved_temperature,
)
final_provider_name = response.provider_name or final_provider_name
final_model = response.model or final_model
final_usage = self._merge_usage(final_usage, response.usage or {})
self._record_usage(session_manager, resolved_session_id, response.usage or {})
assistant_tool_calls = self._serialize_tool_calls(response.tool_calls)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=response.content,
tool_calls=assistant_tool_calls or None,
finish_reason=response.finish_reason,
reasoning=response.reasoning_content,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=response.content,
tool_calls=assistant_tool_calls or None,
reasoning_content=response.reasoning_content,
)
if not response.has_tool_calls:
final_text = response.content or ""
final_finish_reason = response.finish_reason or "stop"
break
if iterations >= resolved_max_tool_iterations:
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
final_finish_reason = "max_tool_iterations"
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="assistant",
event_type="assistant_message_added",
content=final_text,
finish_reason=final_finish_reason,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_assistant_message(
messages,
content=final_text,
)
break
iterations += 1
for tool_call in response.tool_calls:
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="tool",
event_type="tool_result_recorded",
event_payload={
"success": result.success,
"error": result.error,
},
content=result.content,
tool_name=result.tool_name,
tool_call_id=tool_call.id,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
context_builder.add_tool_result(
messages,
tool_call_id=tool_call.id,
tool_name=result.tool_name,
result=result.content,
)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="run_completed",
event_payload={
"finish_reason": final_finish_reason,
"tool_iterations": iterations,
},
content=final_text,
finish_reason=final_finish_reason,
context_visible=False,
source=source,
title=title,
model=final_model,
user_id=user_id,
)
return AgentRunResult(
session_id=resolved_session_id,
run_id=resolved_run_id,
output_text=final_text,
finish_reason=final_finish_reason,
tool_iterations=iterations,
provider_name=final_provider_name,
model=final_model,
usage=final_usage,
)
except Exception as exc:
if not user_message_recorded:
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="user",
event_type="user_message_added",
content=task,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
return self._build_error_result(
session_manager=session_manager,
session_id=resolved_session_id,
run_id=resolved_run_id,
source=source,
title=title,
user_id=user_id,
model=final_model or resolved_model,
message=f"Run failed before completion: {exc}",
tool_iterations=iterations,
provider_name=final_provider_name,
usage=final_usage,
)
def _require_loaded(self, field_name: str) -> Any:
loaded = self.boot()
value = getattr(loaded, field_name)
if value is None:
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
return value
@staticmethod
def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
payload: list[dict[str, Any]] = []
for tool_call in tool_calls:
payload.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
)
return payload
@staticmethod
def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None:
"""把 provider usage 映射到 session usage 字段。
这里先做最常见字段的最小映射:
- prompt_tokens -> input_tokens
- completion_tokens -> output_tokens
后面如果 provider 层补了更细的 cache/reasoning/cost再往这里扩。
"""
if not usage:
return
session_manager.update_usage(
session_id,
input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0),
output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0),
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
)
@staticmethod
def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]:
"""把多轮 provider usage 合并成一次 run 的累计 usage。"""
merged = dict(total)
for key, value in delta.items():
if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)):
merged[key] = merged.get(key, 0) + value
else:
merged[key] = value
return merged
@staticmethod
def _build_error_result(
*,
session_manager: Any,
session_id: str,
run_id: str,
source: str,
title: str | None,
user_id: str | None,
model: str | None,
message: str,
tool_iterations: int,
provider_name: str | None,
usage: dict[str, Any],
) -> AgentRunResult:
"""把主链中的未处理异常收口成可追踪的 assistant error turn。"""
session_manager.append_message(
session_id,
run_id=run_id,
role="assistant",
event_type="assistant_message_added",
content=message,
finish_reason="error",
source=source,
title=title,
model=model,
user_id=user_id,
)
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="run_failed",
event_payload={
"tool_iterations": tool_iterations,
"provider_name": provider_name,
},
content=message,
finish_reason="error",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
return AgentRunResult(
session_id=session_id,
run_id=run_id,
output_text=message,
finish_reason="error",
tool_iterations=tool_iterations,
provider_name=provider_name,
model=model,
usage=usage,
)