feat(app-instance): 集成Beaver后端并更新配置管理

集成新的Beaver后端服务到应用实例中,替换原有的nanobot实现。

主要变更包括:
- 在Dockerfile和环境配置中添加Beaver相关路径和配置变量
- 更新工作目录结构从.nanobot到.beaver
- 实现Beaver引擎加载器,支持配置文件加载和工具组装
- 添加内置工具如ListDirectoryTool、ReadFileTool、SearchFilesTool
- 更新消息处理流程,支持通道适配器和网关模式
- 重构技能系统,支持显式工具提示和嵌入式检索
- 改进错误处理和生命周期管理

此变更使应用实例能够使用统一的Beaver后端进行AI代理运行时管理。
This commit is contained in:
2026-04-27 17:37:40 +08:00
parent 36882a7d7b
commit 5ba5c7e4c1
47 changed files with 2821 additions and 462 deletions

View File

@ -2,17 +2,27 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable
from beaver.engine.context import ContextBuilder
from beaver.engine.session import SessionManager
from beaver.foundation.config import BeaverConfig, load_config
from beaver.memory.curated.store import MemoryStore
from beaver.services.memory_service import MemoryService
from beaver.skills import SkillAssembler, SkillsLoader
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
from beaver.tools import ObjectBackedTool, ToolAssembler, ToolExecutor, ToolRegistry
from beaver.tools.builtins import (
EchoTool,
ListDirectoryTool,
MemoryTool,
ReadFileTool,
SearchFilesTool,
SessionSearchTool,
SkillViewTool,
)
@dataclass(slots=True)
@ -27,6 +37,7 @@ class EngineLoadResult:
"""
workspace: Path
config: BeaverConfig = field(default_factory=BeaverConfig)
tools: list[str] = field(default_factory=list)
skills: list[str] = field(default_factory=list)
memory_stores: list[str] = field(default_factory=list)
@ -35,6 +46,7 @@ class EngineLoadResult:
curated_memory_store: MemoryStore | None = None
memory_service: MemoryService | None = None
tool_registry: ToolRegistry | None = None
tool_assembler: ToolAssembler | None = None
tool_executor: ToolExecutor | None = None
context_builder: ContextBuilder | None = None
skills_loader: SkillsLoader | None = None
@ -89,19 +101,26 @@ class EngineLoader:
self,
*,
workspace: str | Path | None = None,
config_path: str | Path | None = None,
config: BeaverConfig | None = None,
session_manager: SessionManager | None = None,
curated_memory_store: MemoryStore | None = None,
memory_service: MemoryService | None = None,
tool_registry: ToolRegistry | None = None,
tool_assembler: ToolAssembler | None = None,
context_builder: ContextBuilder | None = None,
skills_loader: SkillsLoader | None = None,
skill_assembler: SkillAssembler | None = None,
) -> None:
self.workspace = Path(workspace or Path.cwd())
self.config = config or load_config(workspace=workspace, config_path=config_path)
configured_workspace = self.config.agents_defaults.workspace
env_workspace = os.getenv("BEAVER_WORKSPACE")
self.workspace = Path(workspace or configured_workspace or env_workspace or Path.cwd())
self._session_manager = session_manager
self._curated_memory_store = curated_memory_store
self._memory_service = memory_service
self._tool_registry = tool_registry
self._tool_assembler = tool_assembler
self._context_builder = context_builder
self._skills_loader = skills_loader
self._skill_assembler = skill_assembler
@ -127,15 +146,20 @@ class EngineLoader:
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
ObjectBackedTool(SessionSearchTool(db=session_manager)),
ObjectBackedTool(ListDirectoryTool()),
ObjectBackedTool(ReadFileTool()),
ObjectBackedTool(SearchFilesTool()),
]
)
context_builder = self._context_builder or ContextBuilder()
tool_assembler = self._tool_assembler or ToolAssembler()
tool_executor = ToolExecutor(tool_registry)
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
result = EngineLoadResult(
workspace=workspace,
config=self.config,
tools=[spec.name for spec in tool_registry.list_specs()],
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
memory_stores=["curated"],
@ -144,6 +168,7 @@ class EngineLoader:
curated_memory_store=memory_service.get_store(),
memory_service=memory_service,
tool_registry=tool_registry,
tool_assembler=tool_assembler,
tool_executor=tool_executor,
context_builder=context_builder,
skills_loader=skills_loader,

View File

@ -272,12 +272,24 @@ class AgentLoop:
memory_service = self._require_loaded("memory_service")
context_builder = self._require_loaded("context_builder")
tool_registry = self._require_loaded("tool_registry")
tool_assembler = self._require_loaded("tool_assembler")
tool_executor = self._require_loaded("tool_executor")
skills_loader = self._require_loaded("skills_loader")
skill_assembler = self._require_loaded("skill_assembler")
config = loaded.config
configured_provider = config.resolve_provider_target(model=model, provider_name=provider_name)
resolved_session_id = session_id or uuid4().hex
resolved_run_id = uuid4().hex
resolved_model = model or self.profile.default_model
resolved_model = configured_provider.get("model") or self.profile.default_model
resolved_provider_name = configured_provider.get("provider_name") or provider_name
resolved_api_key = api_key or configured_provider.get("api_key")
resolved_api_base = api_base or configured_provider.get("api_base")
resolved_extra_headers = extra_headers or configured_provider.get("extra_headers")
resolved_request_timeout_seconds = configured_provider.get("request_timeout_seconds")
resolved_embedding_model = embedding_model or config.default_embedding_model
resolved_embedding_target = embedding_target or config.resolve_embedding_target()
resolved_max_tokens = max_tokens or self.profile.max_tokens
resolved_temperature = self.profile.temperature if temperature is None else temperature
resolved_max_tool_iterations = (
@ -316,20 +328,21 @@ class AgentLoop:
user_message_recorded = False
iterations = 0
final_usage: dict[str, Any] = {}
final_provider_name: str | None = provider_name
final_provider_name: str | None = resolved_provider_name
final_model: str | None = resolved_model
try:
bundle = provider_bundle or make_provider_bundle(
model=resolved_model,
provider_name=provider_name,
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
provider_name=resolved_provider_name,
api_key=resolved_api_key,
api_base=resolved_api_base,
request_timeout_seconds=resolved_request_timeout_seconds,
extra_headers=resolved_extra_headers,
routing=routing,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
embedding_model=embedding_model or "text-embedding-v4",
embedding_target=resolved_embedding_target,
embedding_model=resolved_embedding_model,
)
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
skill_selector_model = (
@ -364,6 +377,32 @@ class AgentLoop:
user_id=user_id,
)
selected_tool_specs = await tool_assembler.assemble(
task_description=task,
registry=tool_registry,
skills_loader=skills_loader,
activated_skills=assembled_skills.activated_skills,
embedding_runtime=bundle.embedding_runtime,
top_k=10,
)
tool_schemas = tool_registry.export_selected_provider_schemas(selected_tool_specs)
session_manager.append_message(
resolved_session_id,
run_id=resolved_run_id,
role="system",
event_type="tool_selection_snapshotted",
event_payload={
"tools": [spec.to_mcp_descriptor() for spec in selected_tool_specs],
"tool_names": [spec.name for spec in selected_tool_specs],
},
content=", ".join(spec.name for spec in selected_tool_specs) or None,
context_visible=False,
source=source,
title=title,
model=resolved_model,
user_id=user_id,
)
build_input = ContextBuildInput(
base_system_prompt=self.profile.system_prompt,
history=session_manager.get_history(resolved_session_id),
@ -412,7 +451,6 @@ class AgentLoop:
provider = bundle.main_provider
messages = list(context_result.messages)
tool_schemas = tool_registry.export_provider_schemas()
tool_context = ToolContext(
workspace=str(loaded.workspace),
session_id=resolved_session_id,

View File

@ -8,7 +8,7 @@ import os
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .registry import find_by_model, find_gateway
from .registry import find_by_model, find_by_name, find_gateway
from .runtime import ProviderRoutingConfig
try: # pragma: no cover - optional dependency
@ -58,7 +58,11 @@ class LiteLLMProvider(LLMProvider):
if not api_key:
return {}
spec = self._gateway or find_by_model(model)
spec = self._gateway
if spec is None and self.provider_name:
spec = find_by_name(self.provider_name)
if spec is None:
spec = find_by_model(model)
if spec is None or not spec.env_key:
return {}
overrides: dict[str, str] = {spec.env_key: api_key}
@ -97,6 +101,15 @@ class LiteLLMProvider(LLMProvider):
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if self.provider_name:
spec = find_by_name(self.provider_name)
if spec is not None and not spec.is_gateway and not spec.is_local:
resolved = model
if spec.litellm_prefix and not any(resolved.startswith(prefix) for prefix in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
elif spec.name == "openai" and "/" not in resolved:
resolved = f"openai/{resolved}"
return resolved
spec = find_by_model(model)
if spec and spec.litellm_prefix:
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):

View File

@ -1,2 +1,13 @@
"""Configuration models and loaders."""
from .loader import default_config_path, load_config
from .schema import AgentDefaultsConfig, BeaverConfig, EmbeddingConfig, ProviderConfig
__all__ = [
"AgentDefaultsConfig",
"BeaverConfig",
"EmbeddingConfig",
"ProviderConfig",
"default_config_path",
"load_config",
]

View File

@ -0,0 +1,127 @@
"""Config loader for per-sandbox Beaver runtime settings."""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from .schema import AgentDefaultsConfig, BeaverConfig, EmbeddingConfig, ProviderConfig
def default_config_path(*, workspace: str | Path | None = None) -> Path:
"""Resolve the default config path for a single-user sandbox instance.
Priority:
1. `BEAVER_CONFIG_PATH`
2. `NANOBOT_CONFIG_PATH` for compatibility during migration
3. `BEAVER_HOME/config.json`
4. `NANOBOT_HOME/config.json` for migration compatibility
5. `<workspace>/.beaver/config.json`
6. `./.beaver/config.json`
"""
explicit = os.getenv("BEAVER_CONFIG_PATH") or os.getenv("NANOBOT_CONFIG_PATH")
if explicit:
return Path(explicit).expanduser()
beaver_home = os.getenv("BEAVER_HOME")
if beaver_home:
return Path(beaver_home).expanduser() / "config.json"
nanobot_home = os.getenv("NANOBOT_HOME")
if nanobot_home:
return Path(nanobot_home).expanduser() / "config.json"
root = Path(workspace).expanduser() if workspace is not None else Path.cwd()
return root / ".beaver" / "config.json"
def load_config(
*,
workspace: str | Path | None = None,
config_path: str | Path | None = None,
) -> BeaverConfig:
"""Load backend config; missing config is treated as an empty config."""
path = Path(config_path).expanduser() if config_path is not None else default_config_path(workspace=workspace)
if not path.exists():
return BeaverConfig(config_path=path)
data = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Beaver config must be a JSON object: {path}")
return BeaverConfig(
agents_defaults=_parse_agent_defaults(data),
providers=_parse_providers(data.get("providers")),
embedding=_parse_embedding(data),
config_path=path,
)
def _parse_agent_defaults(data: dict[str, Any]) -> AgentDefaultsConfig:
agents = _as_dict(data.get("agents"))
defaults = _as_dict(agents.get("defaults"))
return AgentDefaultsConfig(
workspace=_string(defaults.get("workspace") or data.get("workspace")),
model=_string(defaults.get("model") or data.get("model")),
provider=_string(defaults.get("provider") or data.get("provider")),
embedding_model=_string(defaults.get("embeddingModel") or defaults.get("embedding_model") or data.get("embeddingModel")),
)
def _parse_providers(raw: Any) -> dict[str, ProviderConfig]:
providers: dict[str, ProviderConfig] = {}
for name, payload in _as_dict(raw).items():
if not isinstance(payload, dict):
continue
providers[str(name)] = ProviderConfig(
api_key=_string(payload.get("apiKey") or payload.get("api_key")),
api_base=_string(payload.get("apiBase") or payload.get("api_base") or payload.get("baseUrl") or payload.get("base_url")),
extra_headers=_string_dict(payload.get("extraHeaders") or payload.get("extra_headers") or payload.get("headers")),
request_timeout_seconds=_float(
payload.get("requestTimeoutSeconds")
or payload.get("request_timeout_seconds")
or payload.get("timeout")
),
)
return providers
def _parse_embedding(data: dict[str, Any]) -> EmbeddingConfig:
raw = _as_dict(data.get("embedding") or data.get("embeddings"))
return EmbeddingConfig(
provider=_string(raw.get("provider") or raw.get("provider_name")),
model=_string(raw.get("model") or data.get("embeddingModel") or data.get("embedding_model")),
api_key=_string(raw.get("apiKey") or raw.get("api_key")),
api_base=_string(raw.get("apiBase") or raw.get("api_base") or raw.get("baseUrl") or raw.get("base_url")),
extra_headers=_string_dict(raw.get("extraHeaders") or raw.get("extra_headers") or raw.get("headers")),
request_timeout_seconds=_float(
raw.get("requestTimeoutSeconds") or raw.get("request_timeout_seconds") or raw.get("timeout")
),
)
def _as_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def _string(value: Any) -> str | None:
if value is None:
return None
value = str(value).strip()
return value or None
def _string_dict(value: Any) -> dict[str, str]:
if not isinstance(value, dict):
return {}
return {str(key): str(item) for key, item in value.items() if item is not None}
def _float(value: Any) -> float | None:
if value in (None, ""):
return None
return float(value)

View File

@ -0,0 +1,136 @@
"""Runtime configuration schema for Beaver sandbox instances."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(slots=True)
class ProviderConfig:
"""One configured LLM provider profile."""
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
@dataclass(slots=True)
class AgentDefaultsConfig:
"""Default agent settings for this sandbox instance."""
workspace: str | None = None
model: str | None = None
provider: str | None = None
embedding_model: str | None = None
@dataclass(slots=True)
class EmbeddingConfig:
"""Optional dedicated embedding model settings."""
provider: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
@dataclass(slots=True)
class BeaverConfig:
"""Config loaded once per backend sandbox instance."""
agents_defaults: AgentDefaultsConfig = field(default_factory=AgentDefaultsConfig)
providers: dict[str, ProviderConfig] = field(default_factory=dict)
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
config_path: Path | None = None
@property
def default_model(self) -> str | None:
return _clean(self.agents_defaults.model)
@property
def default_embedding_model(self) -> str:
return _clean(self.embedding.model) or _clean(self.agents_defaults.embedding_model) or "text-embedding-v4"
def resolve_provider_target(
self,
*,
model: str | None = None,
provider_name: str | None = None,
) -> dict[str, Any]:
"""Resolve model/provider credentials from instance config.
Request-level model/provider overrides are allowed, but credentials are still
read from backend config, not from Web/channel payloads.
"""
resolved_model = _clean(model) or self.default_model
resolved_provider = _clean(provider_name) or self._infer_provider(resolved_model)
provider_cfg = self.providers.get(resolved_provider or "") if resolved_provider else None
payload: dict[str, Any] = {
"model": resolved_model,
"provider_name": resolved_provider,
}
if provider_cfg is not None:
payload.update(
{
"api_key": provider_cfg.api_key,
"api_base": provider_cfg.api_base,
"extra_headers": dict(provider_cfg.extra_headers),
"request_timeout_seconds": provider_cfg.request_timeout_seconds,
}
)
return {key: value for key, value in payload.items() if value not in (None, "", {})}
def resolve_embedding_target(self) -> dict[str, Any] | None:
"""Return an explicit embedding target when configured."""
has_explicit_embedding = any(
[
_clean(self.embedding.provider),
_clean(self.embedding.api_key),
_clean(self.embedding.api_base),
self.embedding.extra_headers,
self.embedding.request_timeout_seconds is not None,
]
)
if not has_explicit_embedding:
return None
provider_cfg = self.providers.get(_clean(self.embedding.provider) or "")
payload: dict[str, Any] = {
"provider": _clean(self.embedding.provider),
"model": self.default_embedding_model,
"api_key": _clean(self.embedding.api_key) or (provider_cfg.api_key if provider_cfg else None),
"api_base": _clean(self.embedding.api_base) or (provider_cfg.api_base if provider_cfg else None),
"extra_headers": self.embedding.extra_headers or (dict(provider_cfg.extra_headers) if provider_cfg else {}),
"request_timeout_seconds": self.embedding.request_timeout_seconds
or (provider_cfg.request_timeout_seconds if provider_cfg else None),
}
return {key: value for key, value in payload.items() if value not in (None, "", {})}
def _infer_provider(self, model: str | None) -> str | None:
configured_provider = _clean(self.agents_defaults.provider)
if configured_provider:
return configured_provider
if model and "/" in model:
prefix = model.split("/", 1)[0]
if prefix in self.providers:
return prefix
if len(self.providers) == 1:
return next(iter(self.providers))
return None
def _clean(value: str | None) -> str | None:
if value is None:
return None
value = str(value).strip()
return value or None

View File

@ -0,0 +1,205 @@
"""Shared embedding-based semantic retrieval utilities."""
from __future__ import annotations
import asyncio
import json
import math
import os
from typing import Any
from urllib import request
class EmbeddingRetriever:
"""Use an OpenAI-compatible embeddings API to rank lightweight candidates."""
def __init__(
self,
*,
api_key_env: str = "OPENAI_API_KEY",
api_base_env: str = "OPENAI_API_BASE",
model: str = "text-embedding-v4",
timeout_seconds: float = 20.0,
) -> None:
self.api_key_env = api_key_env
self.api_base_env = api_base_env
self.model = model
self.timeout_seconds = timeout_seconds
async def retrieve(
self,
*,
query: str,
candidates: list[dict[str, str]],
top_k: int,
api_key: str | None = None,
api_base: str | None = None,
model: str | None = None,
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
fallback_top_k: int | None = None,
) -> list[dict[str, str]]:
"""Return candidates ordered by embedding similarity.
If embedding config is missing or the request fails, return the original
candidate order. This keeps retrieval non-blocking for the main run.
"""
if not candidates or top_k <= 0:
return []
fallback = self._fallback_candidates(candidates, fallback_top_k=fallback_top_k)
resolved_api_key = api_key or os.getenv(self.api_key_env)
resolved_api_base = api_base or os.getenv(self.api_base_env)
if not resolved_api_key or not resolved_api_base:
return fallback
try:
query_embedding = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[query],
model=model or self.model,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
candidate_embeddings = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[self._candidate_text(item) for item in candidates],
model=model or self.model,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
except Exception:
return fallback
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
return fallback
query_vector = query_embedding[0]
scored: list[tuple[float, dict[str, str]]] = []
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
if vector:
scored.append((self._cosine_similarity(query_vector, vector), candidate))
scored.sort(key=lambda item: item[0], reverse=True)
return [item[1] for item in scored[:top_k]]
async def _embed_texts(
self,
*,
api_key: str,
api_base: str,
texts: list[str],
model: str,
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> list[list[float]]:
all_vectors: list[list[float]] = []
endpoint = self._normalize_embeddings_endpoint(api_base)
for start in range(0, len(texts), 10):
batch = texts[start:start + 10]
payload = await self._post_embeddings(
endpoint=endpoint,
api_key=api_key,
model=model,
texts=batch,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
embeddings = payload.get("data") or []
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
return all_vectors
async def _post_embeddings(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> dict[str, Any]:
return await asyncio.to_thread(
self._post_embeddings_sync,
endpoint=endpoint,
api_key=api_key,
model=model,
texts=texts,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
def _post_embeddings_sync(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> dict[str, Any]:
body = json.dumps(
{
"model": model,
"input": texts if len(texts) > 1 else texts[0],
"encoding_format": "float",
}
).encode("utf-8")
req = request.Request(
endpoint,
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
},
method="POST",
)
with request.urlopen(req, timeout=timeout_seconds or self.timeout_seconds) as response:
return json.loads(response.read().decode("utf-8"))
@staticmethod
def _fallback_candidates(
candidates: list[dict[str, str]],
*,
fallback_top_k: int | None,
) -> list[dict[str, str]]:
if fallback_top_k is None:
return list(candidates)
if fallback_top_k <= 0:
return []
return candidates[:fallback_top_k]
@staticmethod
def _candidate_text(candidate: dict[str, str]) -> str:
parts = [
(candidate.get("name") or "").strip(),
(candidate.get("description") or "").strip(),
(candidate.get("input_schema") or "").strip(),
]
return "\n".join(part for part in parts if part)
@staticmethod
def _normalize_embeddings_endpoint(api_base: str) -> str:
base = api_base.rstrip("/")
if base.endswith("/embeddings"):
return base
if base.endswith("/v1"):
return f"{base}/embeddings"
return f"{base}/v1/embeddings"
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right or len(left) != len(right):
return -1.0
dot = sum(a * b for a, b in zip(left, right, strict=False))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0 or right_norm == 0:
return -1.0
return dot / (left_norm * right_norm)

View File

@ -1,2 +1,7 @@
"""Channel interfaces."""
from .base import ChannelAdapter
from .manager import ChannelManager
from .memory import MemoryChannelAdapter
__all__ = ["ChannelAdapter", "ChannelManager", "MemoryChannelAdapter"]

View File

@ -0,0 +1,24 @@
"""Channel adapter contracts for gateway-facing integrations."""
from __future__ import annotations
from typing import Protocol
from beaver.foundation.events import MessageBus, OutboundMessage
class ChannelAdapter(Protocol):
"""Minimal contract every gateway channel must implement."""
name: str
bus: MessageBus
async def start(self) -> None:
"""Prepare the channel before messages are routed."""
async def stop(self) -> None:
"""Stop accepting/routing channel messages."""
async def send(self, message: OutboundMessage) -> None:
"""Deliver an outbound message to the concrete channel."""

View File

@ -0,0 +1,76 @@
"""Channel manager for routing gateway outbound messages."""
from __future__ import annotations
import asyncio
from contextlib import suppress
from beaver.foundation.events import MessageBus, OutboundMessage
from .base import ChannelAdapter
class ChannelManager:
"""Start/stop channel adapters and dispatch outbound messages to them."""
def __init__(self, bus: MessageBus) -> None:
self.bus = bus
self.channels: dict[str, ChannelAdapter] = {}
self.undeliverable: list[OutboundMessage] = []
self.started = False
def register(self, channel: ChannelAdapter) -> None:
if self.started:
raise RuntimeError("Cannot register channels after ChannelManager.start()")
if channel.name in self.channels:
raise ValueError(f"Channel already registered: {channel.name}")
if channel.bus is not self.bus:
raise ValueError("Channel must share the same MessageBus as ChannelManager")
self.channels[channel.name] = channel
async def start(self) -> None:
started: list[ChannelAdapter] = []
try:
for channel in self.channels.values():
await channel.start()
started.append(channel)
except BaseException:
for channel in reversed(started):
with suppress(BaseException):
await channel.stop()
raise
else:
self.started = True
async def stop(self) -> None:
errors: list[BaseException] = []
for channel in reversed(tuple(self.channels.values())):
try:
await channel.stop()
except Exception as exc: # pragma: no cover - defensive cleanup path
errors.append(exc)
self.started = False
if errors:
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
async def dispatch_outbound(self, stop_event: asyncio.Event) -> None:
"""Route bus outbound messages until stopped and the queue is drained."""
while True:
if stop_event.is_set() and self.bus.outbound_size == 0:
break
try:
message = await asyncio.wait_for(self.bus.consume_outbound(), timeout=0.25)
except asyncio.TimeoutError:
continue
channel = self.channels.get(message.channel)
if channel is None:
self.undeliverable.append(message)
continue
try:
await channel.send(message)
except Exception: # pragma: no cover - defensive channel isolation
self.undeliverable.append(message)

View File

@ -0,0 +1,57 @@
"""In-memory channel adapter for tests and local gateway embedding."""
from __future__ import annotations
from typing import Any
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
class MemoryChannelAdapter:
"""A local channel that stores outbound messages in memory."""
def __init__(self, bus: MessageBus, *, name: str = "memory") -> None:
self.name = name
self.bus = bus
self.started = False
self.sent_messages: list[OutboundMessage] = []
async def start(self) -> None:
self.started = True
async def stop(self) -> None:
self.started = False
async def send(self, message: OutboundMessage) -> None:
self.sent_messages.append(message)
async def publish_text(
self,
content: str,
*,
session_id: str | None = None,
user_id: str | None = None,
title: str | None = None,
execution_context: str | None = None,
model: str | None = None,
provider_name: str | None = None,
embedding_model: str | None = None,
metadata: dict[str, Any] | None = None,
) -> InboundMessage:
"""Publish a text message from this channel into the shared bus."""
message = InboundMessage(
channel=self.name,
content=content,
session_id=session_id,
user_id=user_id,
title=title,
execution_context=execution_context,
model=model,
provider_name=provider_name,
embedding_model=embedding_model,
metadata=metadata or {},
)
await self.bus.publish_inbound(message)
return message

View File

@ -35,6 +35,7 @@ app = typer.Typer(help="Beaver backend CLI") if hasattr(typer, "Typer") else typ
def run(
message: str | None = typer.Option(None, "--message", "-m", help="Run one direct Beaver request."),
workspace: str | None = typer.Option(None, "--workspace", help="Workspace root for this run."),
config: str | None = typer.Option(None, "--config", help="Backend config path for this run."),
) -> None:
"""Thin CLI wrapper around AgentService.
@ -44,7 +45,7 @@ def run(
3. 打印结果
"""
service = AgentService(workspace=workspace)
service = AgentService(workspace=workspace, config_path=config)
if not message:
service.create_loop()
typer.echo("Beaver engine booted.")

View File

@ -1,41 +1,39 @@
"""Gateway entrypoint for Beaver.
当前阶段先不扩 bus / channels adapter,只做最小消息桥接:
当前阶段只做最小 gateway 宿主与 channel adapter 桥接:
1. 启动时托管 `AgentService.start()`
2. 常驻消费 `MessageBus.inbound`
3. 调 `service.submit_direct(...)`
3. 调 `service.handle_inbound_message(...)`
4. 将结果写回 `MessageBus.outbound`
5. 退出时走 `AgentService.shutdown()`
5. 如果配置了 channel adapters则由 `ChannelManager` 分发 outbound
6. 退出时走 `AgentService.shutdown()`
"""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from contextlib import suppress
from pathlib import Path
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
from beaver.foundation.events import InboundMessage, MessageBus
from beaver.interfaces.channels import ChannelAdapter, ChannelManager
from beaver.services.agent_service import AgentService
async def _publish_bridge_error(
bus: MessageBus,
inbound: InboundMessage,
async def _cleanup_owned_service(
service: AgentService,
*,
detail: str,
finish_reason: str = "error",
timeout_seconds: float | None,
force: bool,
) -> None:
"""把 bridge 处理失败转换成结构化 outbound 错误消息。"""
"""Best-effort cleanup for service startup failures or cancellations."""
await bus.publish_outbound(
OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=inbound.session_id,
content=detail,
finish_reason=finish_reason,
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
)
)
with suppress(BaseException):
if service.is_running:
await service.shutdown(timeout_seconds=timeout_seconds, force=force)
else:
service.close()
async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
@ -46,11 +44,17 @@ async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
pending = bus.inbound.get_nowait()
except asyncio.QueueEmpty:
break
await _publish_bridge_error(bus, pending, detail=reason, finish_reason="stopped")
await bus.publish_outbound(
AgentService.build_outbound_error(
pending,
detail=reason,
finish_reason="stopped",
)
)
async def _await_bridge_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
"""等待 bridge 退出;超时则取消,避免 shutdown 被桥接层反向卡死。"""
async def _await_task_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
"""等待后台任务退出;超时则取消,避免 shutdown 被反向卡死。"""
try:
await asyncio.wait_for(task, timeout=timeout_seconds)
@ -85,53 +89,28 @@ async def _bridge_inbound_to_runtime(
continue
try:
result = await service.submit_direct(
inbound.content,
session_id=inbound.session_id,
source=f"gateway:{inbound.channel}",
user_id=inbound.user_id,
title=inbound.title,
execution_context=inbound.execution_context,
model=inbound.model,
provider_name=inbound.provider_name,
embedding_model=inbound.embedding_model,
)
outbound = await service.handle_inbound_message(inbound)
except asyncio.CancelledError:
await _publish_bridge_error(
bus,
inbound,
detail="Gateway stopped before completing the inbound message",
finish_reason="cancelled",
)
raise
except Exception as exc: # pragma: no cover - defensive bridge path
await _publish_bridge_error(
bus,
inbound,
detail=str(exc),
)
else:
await bus.publish_outbound(
OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=result.session_id,
run_id=result.run_id,
content=result.output_text,
finish_reason=result.finish_reason,
provider_name=result.provider_name,
model=result.model,
usage=dict(result.usage),
metadata={"inbound_metadata": dict(inbound.metadata)},
AgentService.build_outbound_error(
inbound,
detail="Gateway stopped before completing the inbound message",
finish_reason="cancelled",
)
)
raise
else:
await bus.publish_outbound(outbound)
async def run_gateway(
*,
workspace: str | Path | None = None,
config_path: str | Path | None = None,
service: AgentService | None = None,
bus: MessageBus | None = None,
channels: Sequence[ChannelAdapter] | None = None,
channel_manager: ChannelManager | None = None,
manage_service_lifecycle: bool | None = None,
stop_event: asyncio.Event | None = None,
shutdown_timeout_seconds: float | None = 5.0,
@ -142,19 +121,41 @@ async def run_gateway(
默认 ownership 语义:
- 未传 `service`gateway 自己创建并接管其 lifecycle
- 传入外部 `service`:默认只使用,不自动 start/shutdown
- `channel_manager` 和 `channels` 二选一,避免隐式修改外部 manager
"""
attached_service = service or AgentService(workspace=workspace)
attached_bus = bus or MessageBus()
attached_service = service or AgentService(workspace=workspace, config_path=config_path)
if channel_manager is not None and channels is not None:
raise ValueError("Pass either channel_manager or channels, not both")
if bus is not None:
attached_bus = bus
elif channel_manager is not None:
attached_bus = channel_manager.bus
else:
attached_bus = MessageBus()
attached_channel_manager = channel_manager
if attached_channel_manager is not None and attached_channel_manager.bus is not attached_bus:
raise ValueError("Injected channel_manager must share the gateway MessageBus")
if attached_channel_manager is None and channels is not None:
attached_channel_manager = ChannelManager(attached_bus)
if attached_channel_manager is not None and channels is not None:
for channel in channels:
attached_channel_manager.register(channel)
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
owned_stop_event = stop_event or asyncio.Event()
started = False
channels_started = False
if owns_service:
try:
await attached_service.start()
started = True
except Exception:
attached_service.close()
except BaseException:
await _cleanup_owned_service(
attached_service,
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
raise
if not attached_service.is_running:
@ -163,7 +164,25 @@ async def run_gateway(
"or allow the gateway to manage its lifecycle."
)
if attached_channel_manager is not None:
try:
await attached_channel_manager.start()
channels_started = True
except BaseException:
if owns_service and started:
await _cleanup_owned_service(
attached_service,
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
raise
bridge_task = asyncio.create_task(_bridge_inbound_to_runtime(attached_service, attached_bus, owned_stop_event))
dispatch_task: asyncio.Task[None] | None = None
dispatch_stop_event = asyncio.Event()
if attached_channel_manager is not None:
dispatch_task = asyncio.create_task(attached_channel_manager.dispatch_outbound(dispatch_stop_event))
try:
await owned_stop_event.wait()
finally:
@ -175,9 +194,14 @@ async def run_gateway(
force=shutdown_force,
)
finally:
await _await_bridge_shutdown(bridge_task)
await _await_task_shutdown(bridge_task)
else:
await _await_bridge_shutdown(bridge_task)
await _await_task_shutdown(bridge_task)
if dispatch_task is not None:
dispatch_stop_event.set()
await _await_task_shutdown(dispatch_task)
if attached_channel_manager is not None and channels_started:
await attached_channel_manager.stop()
def main() -> None:

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from pathlib import Path
from types import SimpleNamespace
from typing import Any
@ -56,6 +56,7 @@ async def _app_lifespan(
app: FastAPI,
*,
workspace: str | Path | None,
config_path: str | Path | None,
service: AgentService | None,
manage_service_lifecycle: bool | None,
shutdown_timeout_seconds: float | None,
@ -63,7 +64,7 @@ async def _app_lifespan(
) -> AsyncIterator[None]:
"""把 Web app 接到 AgentService lifecycle 上。"""
attached_service = service or AgentService(workspace=workspace)
attached_service = service or AgentService(workspace=workspace, config_path=config_path)
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
app.state.agent_service = attached_service
started = False
@ -71,8 +72,15 @@ async def _app_lifespan(
try:
await attached_service.start()
started = True
except Exception:
attached_service.close()
except BaseException:
with suppress(BaseException):
if attached_service.is_running:
await attached_service.shutdown(
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
else:
attached_service.close()
raise
try:
yield
@ -87,6 +95,7 @@ async def _app_lifespan(
def create_app(
*,
workspace: str | Path | None = None,
config_path: str | Path | None = None,
service: AgentService | None = None,
manage_service_lifecycle: bool | None = None,
shutdown_timeout_seconds: float | None = 5.0,
@ -106,6 +115,7 @@ def create_app(
lifespan=lambda fastapi_app: _app_lifespan(
fastapi_app,
workspace=workspace,
config_path=config_path,
service=service,
manage_service_lifecycle=manage_service_lifecycle,
shutdown_timeout_seconds=shutdown_timeout_seconds,

View File

@ -1,6 +1,15 @@
"""Application services for Beaver."""
from .agent_service import AgentService
from .memory_service import MemoryService
__all__ = ["AgentService", "MemoryService"]
def __getattr__(name: str):
if name == "AgentService":
from .agent_service import AgentService
return AgentService
if name == "MemoryService":
from .memory_service import MemoryService
return MemoryService
raise AttributeError(name)

View File

@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any
from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
from beaver.foundation.events import InboundMessage, OutboundMessage
class AgentService:
@ -36,11 +37,12 @@ class AgentService:
self,
*,
workspace: str | Path | None = None,
config_path: str | Path | None = None,
profile: AgentProfile | None = None,
loader: EngineLoader | None = None,
) -> None:
self.profile = profile or AgentProfile()
self.loader = loader or EngineLoader(workspace=workspace)
self.loader = loader or EngineLoader(workspace=workspace, config_path=config_path)
self._loop: AgentLoop | None = None
self._run_task: asyncio.Task[None] | None = None
@ -189,6 +191,60 @@ class AgentService:
loop = self.create_loop()
return await loop.submit_direct(message, **kwargs)
async def handle_inbound_message(self, inbound: InboundMessage) -> OutboundMessage:
"""把 bus inbound 映射成标准 runtime 调用,并返回结构化 outbound。"""
try:
result = await self.submit_direct(
inbound.content,
session_id=inbound.session_id,
source=f"gateway:{inbound.channel}",
user_id=inbound.user_id,
title=inbound.title,
execution_context=inbound.execution_context,
model=inbound.model,
provider_name=inbound.provider_name,
embedding_model=inbound.embedding_model,
)
except Exception as exc:
return self.build_outbound_error(inbound, detail=str(exc))
return self.build_outbound_message(inbound, result)
@staticmethod
def build_outbound_message(inbound: InboundMessage, result: AgentRunResult) -> OutboundMessage:
"""把一次 runtime 正常结果转成 bus outbound。"""
return OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=result.session_id,
run_id=result.run_id,
content=result.output_text,
finish_reason=result.finish_reason,
provider_name=result.provider_name,
model=result.model,
usage=dict(result.usage),
metadata={"inbound_metadata": dict(inbound.metadata)},
)
@staticmethod
def build_outbound_error(
inbound: InboundMessage,
*,
detail: str,
finish_reason: str = "error",
) -> OutboundMessage:
"""把 inbound 处理失败转换成结构化 outbound 错误消息。"""
return OutboundMessage(
message_id=inbound.message_id,
channel=inbound.channel,
session_id=inbound.session_id,
content=detail,
finish_reason=finish_reason,
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
)
def run_direct(
self,
message: str,

View File

@ -1,7 +1,12 @@
"""Skill system for Beaver."""
"""Skill system for Beaver.
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
from .catalog import SkillRecord, SkillsLoader
顶层包保持 lazy export避免只导入 catalog/loader 时顺带拉起
SkillAssembler -> provider -> litellm 这条重依赖链。
"""
from __future__ import annotations
from typing import Any
__all__ = [
"SkillAssembler",
@ -10,3 +15,22 @@ __all__ = [
"SkillRecord",
"SkillsLoader",
]
def __getattr__(name: str) -> Any:
if name in {"SkillAssembler", "SkillAssemblyResult", "SkillEmbeddingRetriever"}:
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
return {
"SkillAssembler": SkillAssembler,
"SkillAssemblyResult": SkillAssemblyResult,
"SkillEmbeddingRetriever": SkillEmbeddingRetriever,
}[name]
if name in {"SkillRecord", "SkillsLoader"}:
from .catalog import SkillRecord, SkillsLoader
return {
"SkillRecord": SkillRecord,
"SkillsLoader": SkillsLoader,
}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -1,188 +1,9 @@
"""Embedding-based skill candidate retrieval.
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
阿里云百炼 `text-embedding-v4` 做最小语义召回:
1. 复用当前 provider 的 `api_key/api_base`
2. 先用 embedding 相似度召回一小批候选
3. 再交给上层 LLM selector 做最终技能选择
"""
"""Embedding-based skill candidate retrieval."""
from __future__ import annotations
import asyncio
import math
import os
import json
from urllib import request
from typing import Any
from beaver.foundation.embedding import EmbeddingRetriever
class SkillEmbeddingRetriever:
class SkillEmbeddingRetriever(EmbeddingRetriever):
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
def __init__(
self,
*,
api_key_env: str = "OPENAI_API_KEY",
api_base_env: str = "OPENAI_API_BASE",
model: str = "text-embedding-v4",
timeout_seconds: float = 20.0,
) -> None:
self.api_key_env = api_key_env
self.api_base_env = api_base_env
self.model = model
self.timeout_seconds = timeout_seconds
async def retrieve(
self,
*,
query: str,
candidates: list[dict[str, str]],
top_k: int = 12,
api_key: str | None = None,
api_base: str | None = None,
model: str | None = None,
) -> list[dict[str, str]]:
"""按 embedding 相似度召回 top-k 候选。
如果没有可用的 API Key / base URL或者 embedding 调用失败,
当前阶段先退回到“全部候选交给 LLM selector”。
"""
if not candidates:
return []
resolved_api_key = api_key or os.getenv(self.api_key_env)
resolved_api_base = api_base or os.getenv(self.api_base_env)
if not resolved_api_key or not resolved_api_base:
return candidates
try:
query_embedding = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[query],
model=model or self.model,
)
candidate_texts = [self._candidate_text(item) for item in candidates]
candidate_embeddings = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=candidate_texts,
model=model or self.model,
)
except Exception:
return candidates
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
return candidates
query_vector = query_embedding[0]
scored: list[tuple[float, dict[str, str]]] = []
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
if not vector:
continue
scored.append((self._cosine_similarity(query_vector, vector), candidate))
scored.sort(key=lambda item: item[0], reverse=True)
return [item[1] for item in scored[:top_k]]
async def _embed_texts(
self,
*,
api_key: str,
api_base: str,
texts: list[str],
model: str,
) -> list[list[float]]:
"""调用 OpenAI-compatible embeddings 接口。
当前对齐的是你们实际在用的网关配置:
- `POST {api_base}/embeddings`
- `model=text-embedding-v4`
- `encoding_format=float`
"""
all_vectors: list[list[float]] = []
endpoint = self._normalize_embeddings_endpoint(api_base)
for start in range(0, len(texts), 10):
batch = texts[start:start + 10]
payload = await self._post_embeddings(
endpoint=endpoint,
api_key=api_key,
model=model,
texts=batch,
)
embeddings = payload.get("data") or []
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
return all_vectors
async def _post_embeddings(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
return await asyncio.to_thread(
self._post_embeddings_sync,
endpoint=endpoint,
api_key=api_key,
model=model,
texts=texts,
)
def _post_embeddings_sync(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
body = json.dumps(
{
"model": model,
"input": texts if len(texts) > 1 else texts[0],
"encoding_format": "float",
}
).encode("utf-8")
req = request.Request(
endpoint,
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as response:
return json.loads(response.read().decode("utf-8"))
@staticmethod
def _candidate_text(candidate: dict[str, str]) -> str:
name = (candidate.get("name") or "").strip()
description = (candidate.get("description") or "").strip()
return f"{name}\n{description}".strip()
@staticmethod
def _normalize_embeddings_endpoint(api_base: str) -> str:
base = api_base.rstrip("/")
if base.endswith("/embeddings"):
return base
if base.endswith("/v1"):
return f"{base}/embeddings"
return f"{base}/v1/embeddings"
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right or len(left) != len(right):
return -1.0
dot = sum(a * b for a, b in zip(left, right, strict=False))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0 or right_norm == 0:
return -1.0
return dot / (left_norm * right_norm)

View File

@ -63,6 +63,11 @@ class SkillAssembler:
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
model=embedding_runtime.model if embedding_runtime is not None else None,
extra_headers=embedding_runtime.extra_headers if embedding_runtime is not None else None,
timeout_seconds=(
embedding_runtime.request_timeout_seconds if embedding_runtime is not None else None
),
fallback_top_k=None,
)
if not candidates:
return SkillAssemblyResult()

View File

@ -18,6 +18,7 @@
from __future__ import annotations
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Any
@ -111,6 +112,32 @@ class SkillsLoader:
metadata, _ = parse_frontmatter(content)
return metadata
def get_skill_tool_hints(self, name: str) -> list[str]:
"""读取 skill 显式声明的推荐工具。
第一版只信任显式 metadata不从正文里猜
- `tools: read_file, search_files`
- `tools: ["read_file", "search_files"]`
- YAML-like list:
tools:
- read_file
- search_files
- 兼容 metadata JSON blob 里的 `tools`
"""
frontmatter = self.get_skill_metadata(name) or {}
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
names = [
*self._coerce_tool_names(frontmatter.get("tools")),
*self._coerce_tool_names(meta_blob.get("tools")),
*self._coerce_tool_names(meta_blob.get("required_tools")),
]
result: list[str] = []
for item in names:
if item and item not in result:
result.append(item)
return result
def load_skills_for_context(self, skill_names: list[str]) -> str:
"""加载指定 skills 的正文,并整理成上下文块。"""
@ -253,6 +280,26 @@ class SkillsLoader:
result.append(record.name)
return result
@staticmethod
def _coerce_tool_names(value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, str):
raw = value.strip()
if not raw:
return []
if raw.startswith("["):
try:
parsed = json.loads(raw)
except Exception:
parsed = None
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
return [item.strip() for item in raw.split(",") if item.strip()]
if isinstance(value, (list, tuple, set)):
return [str(item).strip() for item in value if str(item).strip()]
return []
def _find_record(self, name: str) -> SkillRecord | None:
for record in self.list_skills(filter_unavailable=False):
if record.name == name:

View File

@ -20,7 +20,7 @@ import shutil
from typing import Any
def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
def parse_frontmatter(content: str) -> tuple[dict[str, Any], str]:
"""解析 Markdown 文件顶部的极简 frontmatter。
当前先只支持最常见的:
@ -43,12 +43,36 @@ def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
if match is None:
return {}, content
metadata: dict[str, str] = {}
for line in match.group(1).splitlines():
metadata: dict[str, Any] = {}
lines = match.group(1).splitlines()
index = 0
while index < len(lines):
line = lines[index]
if ":" not in line:
index += 1
continue
key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'')
key = key.strip()
value = value.strip()
if not value:
items: list[str] = []
lookahead = index + 1
while lookahead < len(lines):
candidate = lines[lookahead]
stripped = candidate.strip()
if not stripped:
lookahead += 1
continue
if not stripped.startswith("- "):
break
items.append(stripped[2:].strip().strip('"\''))
lookahead += 1
if items:
metadata[key] = items
index = lookahead
continue
metadata[key] = value.strip('"\'')
index += 1
body = content[match.end():].strip()
return metadata, body

View File

@ -1,6 +1,7 @@
"""Tool system for Beaver."""
from .base import BaseTool, ObjectBackedTool, ToolContext, ToolResult, ToolSpec
from .assembler import ToolAssembler
from .registry import ToolRegistry
from .runtime import ToolExecutor
@ -8,6 +9,7 @@ __all__ = [
"BaseTool",
"ObjectBackedTool",
"ToolContext",
"ToolAssembler",
"ToolExecutor",
"ToolRegistry",
"ToolResult",

View File

@ -0,0 +1,5 @@
"""Tool selection for a single Beaver run."""
from .task_assembler import ToolAssembler
__all__ = ["ToolAssembler"]

View File

@ -0,0 +1,106 @@
"""Task-driven tool assembler.
这层和 SkillAssembler 的位置类似:它不执行工具,只决定本轮 run 应该把哪些
tool schema 暴露给模型。
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from beaver.engine.context import SkillContext
from beaver.foundation.embedding import EmbeddingRetriever
from beaver.tools.base import ToolSpec
from beaver.tools.registry import ToolRegistry
if TYPE_CHECKING:
from beaver.engine.providers.runtime import ProviderRuntime
from beaver.skills.catalog.loader import SkillsLoader
class ToolAssembler:
"""Use skill hints and embedding retrieval to select run-scoped tools."""
def __init__(
self,
*,
retriever: EmbeddingRetriever | None = None,
always_tool_names: Sequence[str] | None = None,
) -> None:
self.retriever = retriever or EmbeddingRetriever()
self.always_tool_names = tuple(always_tool_names or ("memory", "session_search", "skill_view"))
async def assemble(
self,
*,
task_description: str,
registry: ToolRegistry,
skills_loader: SkillsLoader | None = None,
activated_skills: Sequence[SkillContext] | None = None,
embedding_runtime: ProviderRuntime | None = None,
top_k: int = 10,
) -> list[ToolSpec]:
"""Return selected tool specs for the current run.
Selection order is intentionally deterministic:
1. always tools from config/spec
2. tools explicitly declared by activated skills
3. embedding top-k tools for the task
"""
selected: list[ToolSpec] = []
selected_names: set[str] = set()
def add_specs(specs: Sequence[ToolSpec]) -> None:
for spec in specs:
if spec.name in selected_names:
continue
selected.append(spec)
selected_names.add(spec.name)
add_specs(registry.list_always_specs())
add_specs(registry.get_specs(self.always_tool_names))
skill_tool_names = self._collect_skill_tool_names(
skills_loader=skills_loader,
activated_skills=activated_skills or (),
)
add_specs(registry.get_specs(skill_tool_names))
candidates = [
spec.to_embedding_candidate()
for spec in registry.list_specs()
if spec.name not in selected_names
]
retrieved = await self.retriever.retrieve(
query=task_description,
candidates=candidates,
top_k=top_k,
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
model=embedding_runtime.model if embedding_runtime is not None else None,
extra_headers=embedding_runtime.extra_headers if embedding_runtime is not None else None,
timeout_seconds=(
embedding_runtime.request_timeout_seconds if embedding_runtime is not None else None
),
fallback_top_k=top_k,
)
add_specs(registry.get_specs([item["name"] for item in retrieved]))
return selected
@staticmethod
def _collect_skill_tool_names(
*,
skills_loader: SkillsLoader | None,
activated_skills: Sequence[SkillContext],
) -> list[str]:
if skills_loader is None or not activated_skills:
return []
result: list[str] = []
for skill in activated_skills:
for name in skills_loader.get_skill_tool_hints(skill.name):
if name not in result:
result.append(name)
return result

View File

@ -29,13 +29,30 @@ class ToolSpec:
"""单个工具对外暴露的描述信息。
这份信息主要服务两个场景:
1. 导出给 provider 的 function schema
2. 在 registry 中做列出、查找、调试
1. 以 MCP-style descriptor 作为统一事实来源
2. 导出给 provider 的 function schema
3. 在 registry 中做列出、查找、调试与 embedding 召回
"""
name: str
description: str
input_schema: dict[str, Any]
toolset: str = "core"
always_available: bool = False
def to_mcp_descriptor(self) -> dict[str, Any]:
"""导出 MCP ListTools 风格的工具描述。
MCP 的基础字段是 `name`、`description`、`inputSchema`。
Beaver 内部额外的 toolset/always_available 不塞进这个对象,
避免未来对接真实 MCP server 时出现格式偏差。
"""
return {
"name": self.name,
"description": self.description,
"inputSchema": self.input_schema,
}
def to_provider_schema(self) -> dict[str, Any]:
"""导出为 OpenAI-compatible function tool schema。"""
@ -49,6 +66,15 @@ class ToolSpec:
},
}
def to_embedding_candidate(self) -> dict[str, str]:
"""导出给语义召回使用的轻量文本候选。"""
return {
"name": self.name,
"description": self.description,
"input_schema": json.dumps(self.input_schema, ensure_ascii=False, sort_keys=True),
}
@dataclass(slots=True)
class ToolContext:
@ -113,6 +139,8 @@ class ObjectBackedTool(BaseTool):
name=str(getattr(backend, "name")),
description=str(getattr(backend, "description", "")),
input_schema=dict(getattr(backend, "parameters", {"type": "object", "properties": {}})),
toolset=str(getattr(backend, "toolset", "core")),
always_available=bool(getattr(backend, "always_available", False)),
)
@property
@ -150,6 +178,8 @@ class ObjectBackedTool(BaseTool):
if "current_session_id" not in arguments and hasattr(self.backend, "current_session_id"):
arguments["current_session_id"] = context.session_id
if "workspace" not in arguments and hasattr(self.backend, "workspace"):
arguments["workspace"] = context.workspace
@staticmethod
def _normalize_output(content: Any) -> dict[str, Any]:

View File

@ -1,13 +1,17 @@
"""Built-in Beaver tools."""
from .echo import EchoTool, echo_tool
from .filesystem import ListDirectoryTool, ReadFileTool, SearchFilesTool
from .memory import MemoryTool, memory_tool
from .skill_view import SkillViewTool, skill_view
from .session_search import SessionSearchTool, session_search
__all__ = [
"EchoTool",
"ListDirectoryTool",
"MemoryTool",
"ReadFileTool",
"SearchFilesTool",
"SkillViewTool",
"SessionSearchTool",
"echo_tool",

View File

@ -34,6 +34,8 @@ class EchoTool:
name: str = "echo"
description: str = ECHO_TOOL_DESCRIPTION
toolset: str = "debug"
always_available: bool = False
parameters: dict[str, Any] = field(default_factory=lambda: dict(ECHO_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:

View File

@ -0,0 +1,442 @@
"""Workspace-scoped read-only filesystem tools.
这些工具是 Beaver 第一批真实本地工具,只做只读能力:
- list_directory
- read_file
- search_files
安全边界先保持非常明确:所有用户传入路径都必须解析到当前
`ToolContext.workspace` 内部。即使 workspace 里有指向外部的符号链接,
读取时也会因为真实路径越界而被拒绝。
"""
from __future__ import annotations
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import Any, Iterable
MAX_LIST_ENTRIES = 1_000
MAX_READ_LINES = 1_000
MAX_READ_CHARS = 120_000
MAX_SEARCH_RESULTS = 200
MAX_SEARCH_FILE_BYTES = 2_000_000
MAX_SEARCH_FILES = 5_000
SKIP_DIR_NAMES = {
".git",
".hg",
".svn",
".venv",
"venv",
"__pycache__",
".pytest_cache",
".mypy_cache",
".ruff_cache",
"node_modules",
"dist",
"build",
}
LIST_DIRECTORY_PARAMETERS: dict[str, Any] = {
"type": "object",
"properties": {
"path": {
"type": "string",
"default": ".",
"description": "Directory path relative to the current workspace. Absolute paths are allowed only if they stay inside the workspace.",
},
"recursive": {
"type": "boolean",
"default": False,
"description": "Whether to recursively list child entries. Symlink directories are not followed.",
},
"max_entries": {
"type": "integer",
"default": 200,
"minimum": 1,
"maximum": MAX_LIST_ENTRIES,
"description": "Maximum number of entries to return.",
},
},
"required": [],
}
READ_FILE_PARAMETERS: dict[str, Any] = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path relative to the current workspace. Absolute paths are allowed only if they stay inside the workspace.",
},
"start_line": {
"type": "integer",
"default": 1,
"minimum": 1,
"description": "1-based line number to start reading from.",
},
"max_lines": {
"type": "integer",
"default": 200,
"minimum": 1,
"maximum": MAX_READ_LINES,
"description": "Maximum number of lines to read.",
},
},
"required": ["path"],
}
SEARCH_FILES_PARAMETERS: dict[str, Any] = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Plain text query to search in file paths and UTF-8 text files.",
},
"path": {
"type": "string",
"default": ".",
"description": "Directory or file path relative to the current workspace.",
},
"max_results": {
"type": "integer",
"default": 50,
"minimum": 1,
"maximum": MAX_SEARCH_RESULTS,
"description": "Maximum number of matches to return.",
},
"case_sensitive": {
"type": "boolean",
"default": False,
"description": "Whether search should be case-sensitive.",
},
},
"required": ["query"],
}
class WorkspacePathError(ValueError):
"""Raised when a requested path escapes the configured workspace."""
def _json_result(success: bool, **payload: Any) -> str:
return json.dumps({"success": success, **payload}, ensure_ascii=False, indent=2)
def _clamp_int(value: Any, *, default: int, minimum: int, maximum: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = default
return max(minimum, min(parsed, maximum))
def _workspace_root(workspace: str | None) -> Path:
if not workspace:
raise WorkspacePathError("workspace is not configured for filesystem tools")
root = Path(workspace).expanduser().resolve(strict=True)
if not root.is_dir():
raise WorkspacePathError(f"workspace is not a directory: {root}")
return root
def _resolve_existing_path(workspace: str | None, user_path: str | None) -> tuple[Path, Path]:
"""Resolve a user path and ensure the real target stays inside workspace."""
root = _workspace_root(workspace)
raw_path = Path(user_path or ".").expanduser()
candidate = raw_path if raw_path.is_absolute() else root / raw_path
resolved = candidate.resolve(strict=True)
try:
resolved.relative_to(root)
except ValueError as exc:
raise WorkspacePathError(
f"path escapes workspace: {user_path or '.'}"
) from exc
return root, resolved
def _relative_path(root: Path, path: Path) -> str:
try:
return str(path.relative_to(root)) or "."
except ValueError:
return str(path)
def _entry_type(path: Path) -> str:
if path.is_symlink():
return "symlink"
if path.is_dir():
return "directory"
if path.is_file():
return "file"
return "other"
def _entry_payload(root: Path, path: Path) -> dict[str, Any]:
try:
stat = path.lstat() if path.is_symlink() else path.stat()
size = stat.st_size
except OSError:
size = None
return {
"name": path.name,
"path": _relative_path(root, path),
"type": _entry_type(path),
"size": size,
}
def _iter_directory(root: Path, directory: Path, *, recursive: bool) -> Iterable[Path]:
def sort_key(item: Path) -> tuple[bool, str]:
is_real_directory = not item.is_symlink() and item.is_dir()
return (not is_real_directory, item.name.lower())
entries = sorted(directory.iterdir(), key=sort_key)
for entry in entries:
yield entry
if not recursive or entry.is_symlink() or not entry.is_dir():
continue
yield from _iter_directory(root, entry, recursive=True)
def _looks_binary(path: Path) -> bool:
try:
with path.open("rb") as handle:
sample = handle.read(4096)
except OSError:
return True
return b"\0" in sample
def _read_text_file(path: Path) -> str:
if _looks_binary(path):
raise ValueError("binary files cannot be read by read_file/search_files")
return path.read_text(encoding="utf-8")
def _iter_search_files(root: Path, start: Path) -> Iterable[Path]:
if start.is_file():
yield start
return
stack = [start]
visited = 0
while stack and visited < MAX_SEARCH_FILES:
current = stack.pop()
try:
children = sorted(current.iterdir(), key=lambda item: item.name.lower())
except OSError:
continue
for child in children:
if child.is_symlink():
continue
if child.is_dir():
if child.name in SKIP_DIR_NAMES:
continue
stack.append(child)
continue
if child.is_file():
visited += 1
yield child
if visited >= MAX_SEARCH_FILES:
break
@dataclass(slots=True)
class ListDirectoryTool:
"""List files and directories inside the current workspace."""
name: str = "list_directory"
description: str = (
"List files and directories inside the current workspace. "
"Use this before reading files when you need to inspect project structure. "
"This tool never follows paths outside the workspace."
)
toolset: str = "filesystem"
always_available: bool = True
workspace: str | None = None
parameters: dict[str, Any] = field(default_factory=lambda: dict(LIST_DIRECTORY_PARAMETERS))
async def execute(
self,
*,
path: str = ".",
recursive: bool = False,
max_entries: int = 200,
workspace: str | None = None,
) -> str:
try:
root, resolved = _resolve_existing_path(workspace, path)
if not resolved.is_dir():
return _json_result(False, error="not_a_directory", path=path)
limit = _clamp_int(max_entries, default=200, minimum=1, maximum=MAX_LIST_ENTRIES)
entries: list[dict[str, Any]] = []
truncated = False
for entry in _iter_directory(root, resolved, recursive=bool(recursive)):
entries.append(_entry_payload(root, entry))
if len(entries) >= limit:
truncated = True
break
return _json_result(
True,
path=_relative_path(root, resolved),
recursive=bool(recursive),
entries=entries,
truncated=truncated,
)
except (OSError, WorkspacePathError, ValueError) as exc:
return _json_result(False, error=str(exc), path=path)
@dataclass(slots=True)
class ReadFileTool:
"""Read a UTF-8 text file inside the current workspace."""
name: str = "read_file"
description: str = (
"Read a UTF-8 text file inside the current workspace with line limits. "
"Use this to inspect source code, docs, config, or logs. "
"This tool rejects binary files and paths outside the workspace."
)
toolset: str = "filesystem"
always_available: bool = True
workspace: str | None = None
parameters: dict[str, Any] = field(default_factory=lambda: dict(READ_FILE_PARAMETERS))
async def execute(
self,
*,
path: str,
start_line: int = 1,
max_lines: int = 200,
workspace: str | None = None,
) -> str:
try:
root, resolved = _resolve_existing_path(workspace, path)
if not resolved.is_file():
return _json_result(False, error="not_a_file", path=path)
start = _clamp_int(start_line, default=1, minimum=1, maximum=10_000_000)
limit = _clamp_int(max_lines, default=200, minimum=1, maximum=MAX_READ_LINES)
content = _read_text_file(resolved)
lines = content.splitlines()
selected = lines[start - 1 : start - 1 + limit]
selected_text = "\n".join(selected)
char_truncated = False
if len(selected_text) > MAX_READ_CHARS:
selected_text = selected_text[:MAX_READ_CHARS]
char_truncated = True
end_line = start + len(selected) - 1 if selected else start - 1
return _json_result(
True,
path=_relative_path(root, resolved),
start_line=start,
end_line=end_line,
total_lines=len(lines),
truncated=end_line < len(lines) or char_truncated,
content=selected_text,
)
except UnicodeDecodeError:
return _json_result(False, error="file is not valid UTF-8 text", path=path)
except (OSError, WorkspacePathError, ValueError) as exc:
return _json_result(False, error=str(exc), path=path)
@dataclass(slots=True)
class SearchFilesTool:
"""Search filenames and UTF-8 text file contents inside the workspace."""
name: str = "search_files"
description: str = (
"Search file paths and UTF-8 text file contents inside the current workspace. "
"Use this to find relevant source files, docs, config keys, or log lines. "
"This tool skips large/binary files and never searches outside the workspace."
)
toolset: str = "filesystem"
always_available: bool = True
workspace: str | None = None
parameters: dict[str, Any] = field(default_factory=lambda: dict(SEARCH_FILES_PARAMETERS))
async def execute(
self,
*,
query: str,
path: str = ".",
max_results: int = 50,
case_sensitive: bool = False,
workspace: str | None = None,
) -> str:
try:
if not isinstance(query, str) or not query.strip():
return _json_result(False, error="query must be a non-empty string")
root, resolved = _resolve_existing_path(workspace, path)
if not resolved.is_dir() and not resolved.is_file():
return _json_result(False, error="path must be a file or directory", path=path)
limit = _clamp_int(max_results, default=50, minimum=1, maximum=MAX_SEARCH_RESULTS)
needle = query if case_sensitive else query.lower()
results: list[dict[str, Any]] = []
searched_files = 0
skipped_files = 0
for file_path in _iter_search_files(root, resolved):
relative = _relative_path(root, file_path)
haystack_path = relative if case_sensitive else relative.lower()
if needle in haystack_path:
results.append(
{
"path": relative,
"line": None,
"match_type": "path",
"preview": relative,
}
)
if len(results) >= limit:
break
try:
if file_path.stat().st_size > MAX_SEARCH_FILE_BYTES or _looks_binary(file_path):
skipped_files += 1
continue
text = file_path.read_text(encoding="utf-8")
except (OSError, UnicodeDecodeError):
skipped_files += 1
continue
searched_files += 1
lines = text.splitlines()
for index, line in enumerate(lines, start=1):
haystack_line = line if case_sensitive else line.lower()
if needle not in haystack_line:
continue
results.append(
{
"path": relative,
"line": index,
"match_type": "content",
"preview": line[:500],
}
)
if len(results) >= limit:
break
if len(results) >= limit:
break
return _json_result(
True,
query=query,
path=_relative_path(root, resolved),
results=results,
truncated=len(results) >= limit,
searched_files=searched_files,
skipped_files=skipped_files,
)
except (OSError, WorkspacePathError, ValueError) as exc:
return _json_result(False, error=str(exc), path=path)

View File

@ -123,6 +123,8 @@ class MemoryTool:
store: MemoryStore
name: str = "memory"
description: str = MEMORY_TOOL_DESCRIPTION
toolset: str = "memory"
always_available: bool = True
parameters: dict[str, Any] = field(default_factory=lambda: dict(MEMORY_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:

View File

@ -406,6 +406,8 @@ class SessionSearchTool:
summarizer: SessionSummarizer | None = None
name: str = "session_search"
description: str = SESSION_SEARCH_TOOL_DESCRIPTION
toolset: str = "session"
always_available: bool = True
parameters: dict[str, Any] = field(default_factory=lambda: dict(SESSION_SEARCH_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:

View File

@ -76,6 +76,8 @@ class SkillViewTool:
loader: SkillsLoader
name: str = "skill_view"
description: str = SKILL_VIEW_TOOL_DESCRIPTION
toolset: str = "skills"
always_available: bool = True
parameters: dict[str, Any] = field(default_factory=lambda: dict(SKILL_VIEW_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:

View File

@ -11,6 +11,7 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Iterable
from beaver.tools.base import BaseTool, ToolSpec
@ -49,7 +50,30 @@ class ToolRegistry:
def list_specs(self) -> list[ToolSpec]:
return [tool.spec for tool in self._tools.values()]
def list_always_specs(self) -> list[ToolSpec]:
"""列出每轮 run 都应该暴露给模型的基础工具。"""
return [spec for spec in self.list_specs() if spec.always_available]
def get_specs(self, names: Sequence[str]) -> list[ToolSpec]:
"""按名称顺序返回已注册工具 spec忽略未知工具。"""
specs: list[ToolSpec] = []
seen: set[str] = set()
for name in names:
tool = self.get(name)
if tool is None or name in seen:
continue
specs.append(tool.spec)
seen.add(name)
return specs
def export_provider_schemas(self) -> list[dict]:
"""导出给 provider 的函数工具 schema 列表。"""
return [spec.to_provider_schema() for spec in self.list_specs()]
def export_selected_provider_schemas(self, specs: Sequence[ToolSpec]) -> list[dict]:
"""导出一组已选择工具的 provider schema。"""
return [spec.to_provider_schema() for spec in specs]

View File

@ -12,12 +12,14 @@
from __future__ import annotations
import json
from typing import Any
from typing import TYPE_CHECKING, Any
from beaver.engine.providers.base import ToolCallRequest
from beaver.tools.base import ToolContext, ToolResult
from beaver.tools.registry.tool_registry import ToolRegistry
if TYPE_CHECKING:
from beaver.engine.providers.base import ToolCallRequest
class ToolExecutor:
"""统一执行单个 tool call。"""
@ -80,16 +82,17 @@ class ToolExecutor:
@staticmethod
def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]:
if isinstance(tool_call, ToolCallRequest):
return tool_call.name, dict(tool_call.arguments)
function = tool_call.get("function")
if isinstance(function, dict):
name = function.get("name")
arguments = function.get("arguments", {})
if not isinstance(tool_call, dict):
name = getattr(tool_call, "name", None)
arguments = getattr(tool_call, "arguments", {})
else:
name = tool_call.get("name")
arguments = tool_call.get("arguments", {})
function = tool_call.get("function")
if isinstance(function, dict):
name = function.get("name")
arguments = function.get("arguments", {})
else:
name = tool_call.get("name")
arguments = tool_call.get("arguments", {})
if not name:
raise ValueError("Tool call is missing a tool name")
@ -104,8 +107,8 @@ class ToolExecutor:
@staticmethod
def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str:
if isinstance(tool_call, ToolCallRequest):
return str(tool_call.name or "unknown")
if not isinstance(tool_call, dict):
return str(getattr(tool_call, "name", None) or "unknown")
function = tool_call.get("function")
if isinstance(function, dict) and function.get("name"):
return str(function["name"])