"""Unified agent loop used by all Beaver agents.""" from __future__ import annotations import asyncio from dataclasses import dataclass, field from typing import Any from uuid import uuid4 from beaver.engine.context import ContextBuildInput, SessionContext from beaver.engine.providers import ProviderBundle, make_provider_bundle from beaver.tools import ToolContext from .loader import EngineLoader, EngineLoadResult @dataclass(slots=True) class AgentProfile: """Runtime profile for a Beaver agent instance.""" name: str = "default" system_prompt: str = "" default_model: str = "gpt-4.1-mini" max_tokens: int = 4096 temperature: float = 0.2 max_tool_iterations: int = 8 @dataclass(slots=True) class AgentRunResult: """一次 direct run 的最小结果结构。""" session_id: str run_id: str output_text: str finish_reason: str tool_iterations: int provider_name: str | None = None model: str | None = None usage: dict[str, Any] = field(default_factory=dict) @dataclass(slots=True) class _DirectRunRequest: """运行循环中的单个 direct task。""" task: str kwargs: dict[str, Any] future: asyncio.Future[AgentRunResult] class AgentLoop: """Single execution kernel shared by root agents and delegated agents.""" def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None: self.profile = profile or AgentProfile() self.loader = loader or EngineLoader() self.loaded: EngineLoadResult | None = None self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None self._running = False self._stop_requested = False def boot(self) -> EngineLoadResult: """Load shared runtime capabilities once for this agent instance.""" if self.loaded is None: self.loaded = self.loader.load() return self.loaded @property def is_running(self) -> bool: return self._running async def run(self) -> None: """启动最小运行循环,顺序消费提交进来的 direct tasks。 第一版故意保持克制: 1. 只做单消费者串行消费 2. 真正执行仍复用 `process_direct()` 3. 不引入 bus / worker / priority / retry """ if self._running: raise RuntimeError("AgentLoop.run() is already active") self.boot() self._run_queue = asyncio.Queue() self._running = True self._stop_requested = False try: while True: item = await self._run_queue.get() if item is None: if self._stop_requested: break continue if item.future.cancelled(): continue try: result = await self._process_direct_impl(item.task, **item.kwargs) except asyncio.CancelledError: if not item.future.done(): item.future.cancel() raise except Exception as exc: # pragma: no cover - defensive queue path if not item.future.done(): item.future.set_exception(exc) else: if not item.future.done(): item.future.set_result(result) finally: if self._run_queue is not None: while True: try: pending = self._run_queue.get_nowait() except asyncio.QueueEmpty: break if isinstance(pending, _DirectRunRequest) and not pending.future.done(): pending.future.set_exception( RuntimeError("AgentLoop.run() stopped before processing the queued task") ) self._running = False self._stop_requested = False self._run_queue = None async def stop(self) -> None: """停止运行循环。 第一版语义: - 不再接收新任务 - 当前已经取出的任务允许收尾 - 不自动 close runtime """ if not self._running or self._run_queue is None: return self._stop_requested = True await self._run_queue.put(None) async def submit_direct( self, task: str, **kwargs: Any, ) -> AgentRunResult: """向运行中的 loop 提交一个 direct task,并等待结果。""" if not self._running or self._run_queue is None: raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop") if self._stop_requested: raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()") future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future() await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future)) return await future def close(self) -> None: """关闭当前 loop 持有的 runtime。 第 6 阶段先把生命周期最小骨架立住: - `boot()` 负责建立 runtime - `close()` 负责释放由 runtime 持有的资源 - 之后再在此基础上扩 `run()/stop()/shutdown hooks` """ if self._running: raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first") if self.loaded is None: return try: self.loaded.close() finally: self.loaded = None async def process_direct( self, task: str, *, session_id: str | None = None, source: str = "direct", user_id: str | None = None, title: str | None = None, execution_context: str | None = None, model: str | None = None, provider_name: str | None = None, api_key: str | None = None, api_base: str | None = None, extra_headers: dict[str, str] | None = None, routing: Any = None, fallback_target: dict[str, Any] | None = None, auxiliary_target: dict[str, Any] | None = None, embedding_target: dict[str, Any] | None = None, embedding_model: str | None = None, max_tokens: int | None = None, temperature: float | None = None, max_tool_iterations: int | None = None, provider_bundle: ProviderBundle | None = None, ) -> AgentRunResult: """跑通最小 direct run 主链。 当前主链刻意保持克制,只解决这些事情: 1. 确保 session 存在 2. 用 frozen memory + history 组 prompt 3. 调 provider 4. 若有 tool calls,则进入最小 tool loop 5. 把 user/assistant/tool 消息和 usage 写回 session """ if self._running: raise RuntimeError( "AgentLoop.process_direct() is disabled while run() is active; " "submit tasks via submit_direct() instead." ) return await self._process_direct_impl( task, session_id=session_id, source=source, user_id=user_id, title=title, execution_context=execution_context, model=model, provider_name=provider_name, api_key=api_key, api_base=api_base, extra_headers=extra_headers, routing=routing, fallback_target=fallback_target, auxiliary_target=auxiliary_target, embedding_target=embedding_target, embedding_model=embedding_model, max_tokens=max_tokens, temperature=temperature, max_tool_iterations=max_tool_iterations, provider_bundle=provider_bundle, ) async def _process_direct_impl( self, task: str, *, session_id: str | None = None, source: str = "direct", user_id: str | None = None, title: str | None = None, execution_context: str | None = None, model: str | None = None, provider_name: str | None = None, api_key: str | None = None, api_base: str | None = None, extra_headers: dict[str, str] | None = None, routing: Any = None, fallback_target: dict[str, Any] | None = None, auxiliary_target: dict[str, Any] | None = None, embedding_target: dict[str, Any] | None = None, embedding_model: str | None = None, max_tokens: int | None = None, temperature: float | None = None, max_tool_iterations: int | None = None, provider_bundle: ProviderBundle | None = None, ) -> AgentRunResult: """真正执行一轮 direct run 的内部实现。 规则: - 外部直接调用时走 `process_direct()` - 运行循环内部消费时走 `_process_direct_impl()` - 这样才能保证 run 模式下外部不能绕过队列直接执行 """ loaded = self.boot() session_manager = self._require_loaded("session_manager") memory_service = self._require_loaded("memory_service") context_builder = self._require_loaded("context_builder") tool_registry = self._require_loaded("tool_registry") tool_assembler = self._require_loaded("tool_assembler") tool_executor = self._require_loaded("tool_executor") skills_loader = self._require_loaded("skills_loader") skill_assembler = self._require_loaded("skill_assembler") config = loaded.config configured_provider = config.resolve_provider_target(model=model, provider_name=provider_name) resolved_session_id = session_id or uuid4().hex resolved_run_id = uuid4().hex resolved_model = configured_provider.get("model") or self.profile.default_model resolved_provider_name = configured_provider.get("provider_name") or provider_name resolved_api_key = api_key or configured_provider.get("api_key") resolved_api_base = api_base or configured_provider.get("api_base") resolved_extra_headers = extra_headers or configured_provider.get("extra_headers") resolved_request_timeout_seconds = configured_provider.get("request_timeout_seconds") resolved_embedding_model = embedding_model or config.default_embedding_model resolved_embedding_target = embedding_target or config.resolve_embedding_target() resolved_max_tokens = max_tokens or self.profile.max_tokens resolved_temperature = self.profile.temperature if temperature is None else temperature resolved_max_tool_iterations = ( self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations ) # 每次新运行开始前都通过 MemoryService 刷新 live state。 # 这样 memory policy 会收口在 service,而不是散在 loop 里。 memory_service.reload_for_new_run() session_manager.ensure_session( resolved_session_id, source=source, model=resolved_model, title=title, user_id=user_id, ) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="system", event_type="run_started", event_payload={ "source": source, "model": resolved_model, "agent_name": self.profile.name, }, content=task, context_visible=False, source=source, title=title, model=resolved_model, user_id=user_id, ) user_message_recorded = False iterations = 0 final_usage: dict[str, Any] = {} final_provider_name: str | None = resolved_provider_name final_model: str | None = resolved_model try: bundle = provider_bundle or make_provider_bundle( model=resolved_model, provider_name=resolved_provider_name, api_key=resolved_api_key, api_base=resolved_api_base, request_timeout_seconds=resolved_request_timeout_seconds, extra_headers=resolved_extra_headers, routing=routing, fallback_target=fallback_target, auxiliary_target=auxiliary_target, embedding_target=resolved_embedding_target, embedding_model=resolved_embedding_model, ) skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider skill_selector_model = ( bundle.auxiliary_runtime.model if bundle.auxiliary_runtime is not None else bundle.main_runtime.model ) assembled_skills = await skill_assembler.assemble( task_description=task, provider=skill_selector_provider, model=skill_selector_model, embedding_runtime=bundle.embedding_runtime, ) skill_activation_messages = context_builder.build_skill_activation_messages( assembled_skills.activated_skills ) if skill_activation_messages: session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="system", event_type="skill_activation_snapshotted", event_payload={ "activation_messages": skill_activation_messages, }, content="\n\n".join(message["content"] for message in skill_activation_messages) or None, context_visible=False, source=source, title=title, model=resolved_model, user_id=user_id, ) selected_tool_specs = await tool_assembler.assemble( task_description=task, registry=tool_registry, skills_loader=skills_loader, activated_skills=assembled_skills.activated_skills, embedding_runtime=bundle.embedding_runtime, top_k=10, ) tool_schemas = tool_registry.export_selected_provider_schemas(selected_tool_specs) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="system", event_type="tool_selection_snapshotted", event_payload={ "tools": [spec.to_mcp_descriptor() for spec in selected_tool_specs], "tool_names": [spec.name for spec in selected_tool_specs], }, content=", ".join(spec.name for spec in selected_tool_specs) or None, context_visible=False, source=source, title=title, model=resolved_model, user_id=user_id, ) build_input = ContextBuildInput( base_system_prompt=self.profile.system_prompt, history=session_manager.get_history(resolved_session_id), current_user_input=task, memory_snapshot=memory_service.get_snapshot(), activated_skills=assembled_skills.activated_skills, session_context=SessionContext( session_id=resolved_session_id, source=source, model=resolved_model, user_id=user_id, ), execution_context=execution_context, ) context_result = context_builder.build_messages(build_input) session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="system", event_type="system_prompt_snapshotted", event_payload={ "source": source, "model": resolved_model, "system_prompt_length": len(context_result.system_prompt), }, content=context_result.system_prompt, context_visible=False, source=source, title=title, model=resolved_model, user_id=user_id, ) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="user", event_type="user_message_added", content=task, source=source, title=title, model=resolved_model, user_id=user_id, ) user_message_recorded = True provider = bundle.main_provider messages = list(context_result.messages) tool_context = ToolContext( workspace=str(loaded.workspace), session_id=resolved_session_id, user_id=user_id, services={ "session_manager": session_manager, "memory_service": memory_service, "memory_store": memory_service.get_store(), "tool_registry": tool_registry, }, metadata={ "source": source, "agent_name": self.profile.name, }, ) final_text = "" final_finish_reason = "stop" final_provider_name = bundle.main_runtime.provider_name final_model = bundle.main_runtime.model while True: response = await provider.chat( messages=messages, tools=tool_schemas, model=final_model, max_tokens=resolved_max_tokens, temperature=resolved_temperature, ) final_provider_name = response.provider_name or final_provider_name final_model = response.model or final_model final_usage = self._merge_usage(final_usage, response.usage or {}) self._record_usage(session_manager, resolved_session_id, response.usage or {}) assistant_tool_calls = self._serialize_tool_calls(response.tool_calls) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="assistant", event_type="assistant_message_added", content=response.content, tool_calls=assistant_tool_calls or None, finish_reason=response.finish_reason, reasoning=response.reasoning_content, source=source, title=title, model=final_model, user_id=user_id, ) context_builder.add_assistant_message( messages, content=response.content, tool_calls=assistant_tool_calls or None, reasoning_content=response.reasoning_content, ) if not response.has_tool_calls: final_text = response.content or "" final_finish_reason = response.finish_reason or "stop" break if iterations >= resolved_max_tool_iterations: final_text = response.content or "Tool loop stopped after reaching the configured iteration limit." final_finish_reason = "max_tool_iterations" session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="assistant", event_type="assistant_message_added", content=final_text, finish_reason=final_finish_reason, source=source, title=title, model=final_model, user_id=user_id, ) context_builder.add_assistant_message( messages, content=final_text, ) break iterations += 1 for tool_call in response.tool_calls: result = await tool_executor.execute_tool_call(tool_call, context=tool_context) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="tool", event_type="tool_result_recorded", event_payload={ "success": result.success, "error": result.error, }, content=result.content, tool_name=result.tool_name, tool_call_id=tool_call.id, source=source, title=title, model=final_model, user_id=user_id, ) context_builder.add_tool_result( messages, tool_call_id=tool_call.id, tool_name=result.tool_name, result=result.content, ) session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="system", event_type="run_completed", event_payload={ "finish_reason": final_finish_reason, "tool_iterations": iterations, }, content=final_text, finish_reason=final_finish_reason, context_visible=False, source=source, title=title, model=final_model, user_id=user_id, ) return AgentRunResult( session_id=resolved_session_id, run_id=resolved_run_id, output_text=final_text, finish_reason=final_finish_reason, tool_iterations=iterations, provider_name=final_provider_name, model=final_model, usage=final_usage, ) except Exception as exc: if not user_message_recorded: session_manager.append_message( resolved_session_id, run_id=resolved_run_id, role="user", event_type="user_message_added", content=task, source=source, title=title, model=resolved_model, user_id=user_id, ) return self._build_error_result( session_manager=session_manager, session_id=resolved_session_id, run_id=resolved_run_id, source=source, title=title, user_id=user_id, model=final_model or resolved_model, message=f"Run failed before completion: {exc}", tool_iterations=iterations, provider_name=final_provider_name, usage=final_usage, ) def _require_loaded(self, field_name: str) -> Any: loaded = self.boot() value = getattr(loaded, field_name) if value is None: raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}") return value @staticmethod def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]: payload: list[dict[str, Any]] = [] for tool_call in tool_calls: payload.append( { "id": tool_call.id, "type": "function", "function": { "name": tool_call.name, "arguments": tool_call.arguments, }, } ) return payload @staticmethod def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None: """把 provider usage 映射到 session usage 字段。 这里先做最常见字段的最小映射: - prompt_tokens -> input_tokens - completion_tokens -> output_tokens 后面如果 provider 层补了更细的 cache/reasoning/cost,再往这里扩。 """ if not usage: return session_manager.update_usage( session_id, input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0), output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0), reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0), ) @staticmethod def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]: """把多轮 provider usage 合并成一次 run 的累计 usage。""" merged = dict(total) for key, value in delta.items(): if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)): merged[key] = merged.get(key, 0) + value else: merged[key] = value return merged @staticmethod def _build_error_result( *, session_manager: Any, session_id: str, run_id: str, source: str, title: str | None, user_id: str | None, model: str | None, message: str, tool_iterations: int, provider_name: str | None, usage: dict[str, Any], ) -> AgentRunResult: """把主链中的未处理异常收口成可追踪的 assistant error turn。""" session_manager.append_message( session_id, run_id=run_id, role="assistant", event_type="assistant_message_added", content=message, finish_reason="error", source=source, title=title, model=model, user_id=user_id, ) session_manager.append_message( session_id, run_id=run_id, role="system", event_type="run_failed", event_payload={ "tool_iterations": tool_iterations, "provider_name": provider_name, }, content=message, finish_reason="error", context_visible=False, source=source, title=title, model=model, user_id=user_id, ) return AgentRunResult( session_id=session_id, run_id=run_id, output_text=message, finish_reason="error", tool_iterations=tool_iterations, provider_name=provider_name, model=model, usage=usage, )