feat(app-instance): 集成Beaver后端并更新配置管理
集成新的Beaver后端服务到应用实例中,替换原有的nanobot实现。 主要变更包括: - 在Dockerfile和环境配置中添加Beaver相关路径和配置变量 - 更新工作目录结构从.nanobot到.beaver - 实现Beaver引擎加载器,支持配置文件加载和工具组装 - 添加内置工具如ListDirectoryTool、ReadFileTool、SearchFilesTool - 更新消息处理流程,支持通道适配器和网关模式 - 重构技能系统,支持显式工具提示和嵌入式检索 - 改进错误处理和生命周期管理 此变更使应用实例能够使用统一的Beaver后端进行AI代理运行时管理。
This commit is contained in:
@ -2,17 +2,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from beaver.engine.context import ContextBuilder
|
||||
from beaver.engine.session import SessionManager
|
||||
from beaver.foundation.config import BeaverConfig, load_config
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.skills import SkillAssembler, SkillsLoader
|
||||
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
|
||||
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
|
||||
from beaver.tools import ObjectBackedTool, ToolAssembler, ToolExecutor, ToolRegistry
|
||||
from beaver.tools.builtins import (
|
||||
EchoTool,
|
||||
ListDirectoryTool,
|
||||
MemoryTool,
|
||||
ReadFileTool,
|
||||
SearchFilesTool,
|
||||
SessionSearchTool,
|
||||
SkillViewTool,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -27,6 +37,7 @@ class EngineLoadResult:
|
||||
"""
|
||||
|
||||
workspace: Path
|
||||
config: BeaverConfig = field(default_factory=BeaverConfig)
|
||||
tools: list[str] = field(default_factory=list)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
memory_stores: list[str] = field(default_factory=list)
|
||||
@ -35,6 +46,7 @@ class EngineLoadResult:
|
||||
curated_memory_store: MemoryStore | None = None
|
||||
memory_service: MemoryService | None = None
|
||||
tool_registry: ToolRegistry | None = None
|
||||
tool_assembler: ToolAssembler | None = None
|
||||
tool_executor: ToolExecutor | None = None
|
||||
context_builder: ContextBuilder | None = None
|
||||
skills_loader: SkillsLoader | None = None
|
||||
@ -89,19 +101,26 @@ class EngineLoader:
|
||||
self,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
config_path: str | Path | None = None,
|
||||
config: BeaverConfig | None = None,
|
||||
session_manager: SessionManager | None = None,
|
||||
curated_memory_store: MemoryStore | None = None,
|
||||
memory_service: MemoryService | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
tool_assembler: ToolAssembler | None = None,
|
||||
context_builder: ContextBuilder | None = None,
|
||||
skills_loader: SkillsLoader | None = None,
|
||||
skill_assembler: SkillAssembler | None = None,
|
||||
) -> None:
|
||||
self.workspace = Path(workspace or Path.cwd())
|
||||
self.config = config or load_config(workspace=workspace, config_path=config_path)
|
||||
configured_workspace = self.config.agents_defaults.workspace
|
||||
env_workspace = os.getenv("BEAVER_WORKSPACE")
|
||||
self.workspace = Path(workspace or configured_workspace or env_workspace or Path.cwd())
|
||||
self._session_manager = session_manager
|
||||
self._curated_memory_store = curated_memory_store
|
||||
self._memory_service = memory_service
|
||||
self._tool_registry = tool_registry
|
||||
self._tool_assembler = tool_assembler
|
||||
self._context_builder = context_builder
|
||||
self._skills_loader = skills_loader
|
||||
self._skill_assembler = skill_assembler
|
||||
@ -127,15 +146,20 @@ class EngineLoader:
|
||||
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
|
||||
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
|
||||
ObjectBackedTool(SessionSearchTool(db=session_manager)),
|
||||
ObjectBackedTool(ListDirectoryTool()),
|
||||
ObjectBackedTool(ReadFileTool()),
|
||||
ObjectBackedTool(SearchFilesTool()),
|
||||
]
|
||||
)
|
||||
|
||||
context_builder = self._context_builder or ContextBuilder()
|
||||
tool_assembler = self._tool_assembler or ToolAssembler()
|
||||
tool_executor = ToolExecutor(tool_registry)
|
||||
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
|
||||
|
||||
result = EngineLoadResult(
|
||||
workspace=workspace,
|
||||
config=self.config,
|
||||
tools=[spec.name for spec in tool_registry.list_specs()],
|
||||
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
|
||||
memory_stores=["curated"],
|
||||
@ -144,6 +168,7 @@ class EngineLoader:
|
||||
curated_memory_store=memory_service.get_store(),
|
||||
memory_service=memory_service,
|
||||
tool_registry=tool_registry,
|
||||
tool_assembler=tool_assembler,
|
||||
tool_executor=tool_executor,
|
||||
context_builder=context_builder,
|
||||
skills_loader=skills_loader,
|
||||
|
||||
@ -272,12 +272,24 @@ class AgentLoop:
|
||||
memory_service = self._require_loaded("memory_service")
|
||||
context_builder = self._require_loaded("context_builder")
|
||||
tool_registry = self._require_loaded("tool_registry")
|
||||
tool_assembler = self._require_loaded("tool_assembler")
|
||||
tool_executor = self._require_loaded("tool_executor")
|
||||
skills_loader = self._require_loaded("skills_loader")
|
||||
skill_assembler = self._require_loaded("skill_assembler")
|
||||
|
||||
config = loaded.config
|
||||
configured_provider = config.resolve_provider_target(model=model, provider_name=provider_name)
|
||||
|
||||
resolved_session_id = session_id or uuid4().hex
|
||||
resolved_run_id = uuid4().hex
|
||||
resolved_model = model or self.profile.default_model
|
||||
resolved_model = configured_provider.get("model") or self.profile.default_model
|
||||
resolved_provider_name = configured_provider.get("provider_name") or provider_name
|
||||
resolved_api_key = api_key or configured_provider.get("api_key")
|
||||
resolved_api_base = api_base or configured_provider.get("api_base")
|
||||
resolved_extra_headers = extra_headers or configured_provider.get("extra_headers")
|
||||
resolved_request_timeout_seconds = configured_provider.get("request_timeout_seconds")
|
||||
resolved_embedding_model = embedding_model or config.default_embedding_model
|
||||
resolved_embedding_target = embedding_target or config.resolve_embedding_target()
|
||||
resolved_max_tokens = max_tokens or self.profile.max_tokens
|
||||
resolved_temperature = self.profile.temperature if temperature is None else temperature
|
||||
resolved_max_tool_iterations = (
|
||||
@ -316,20 +328,21 @@ class AgentLoop:
|
||||
user_message_recorded = False
|
||||
iterations = 0
|
||||
final_usage: dict[str, Any] = {}
|
||||
final_provider_name: str | None = provider_name
|
||||
final_provider_name: str | None = resolved_provider_name
|
||||
final_model: str | None = resolved_model
|
||||
try:
|
||||
bundle = provider_bundle or make_provider_bundle(
|
||||
model=resolved_model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
provider_name=resolved_provider_name,
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
request_timeout_seconds=resolved_request_timeout_seconds,
|
||||
extra_headers=resolved_extra_headers,
|
||||
routing=routing,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
embedding_model=embedding_model or "text-embedding-v4",
|
||||
embedding_target=resolved_embedding_target,
|
||||
embedding_model=resolved_embedding_model,
|
||||
)
|
||||
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
|
||||
skill_selector_model = (
|
||||
@ -364,6 +377,32 @@ class AgentLoop:
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
selected_tool_specs = await tool_assembler.assemble(
|
||||
task_description=task,
|
||||
registry=tool_registry,
|
||||
skills_loader=skills_loader,
|
||||
activated_skills=assembled_skills.activated_skills,
|
||||
embedding_runtime=bundle.embedding_runtime,
|
||||
top_k=10,
|
||||
)
|
||||
tool_schemas = tool_registry.export_selected_provider_schemas(selected_tool_specs)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="tool_selection_snapshotted",
|
||||
event_payload={
|
||||
"tools": [spec.to_mcp_descriptor() for spec in selected_tool_specs],
|
||||
"tool_names": [spec.name for spec in selected_tool_specs],
|
||||
},
|
||||
content=", ".join(spec.name for spec in selected_tool_specs) or None,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
build_input = ContextBuildInput(
|
||||
base_system_prompt=self.profile.system_prompt,
|
||||
history=session_manager.get_history(resolved_session_id),
|
||||
@ -412,7 +451,6 @@ class AgentLoop:
|
||||
|
||||
provider = bundle.main_provider
|
||||
messages = list(context_result.messages)
|
||||
tool_schemas = tool_registry.export_provider_schemas()
|
||||
tool_context = ToolContext(
|
||||
workspace=str(loaded.workspace),
|
||||
session_id=resolved_session_id,
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_gateway
|
||||
from .registry import find_by_model, find_by_name, find_gateway
|
||||
from .runtime import ProviderRoutingConfig
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
@ -58,7 +58,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
if not api_key:
|
||||
return {}
|
||||
spec = self._gateway or find_by_model(model)
|
||||
spec = self._gateway
|
||||
if spec is None and self.provider_name:
|
||||
spec = find_by_name(self.provider_name)
|
||||
if spec is None:
|
||||
spec = find_by_model(model)
|
||||
if spec is None or not spec.env_key:
|
||||
return {}
|
||||
overrides: dict[str, str] = {spec.env_key: api_key}
|
||||
@ -97,6 +101,15 @@ class LiteLLMProvider(LLMProvider):
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
if self.provider_name:
|
||||
spec = find_by_name(self.provider_name)
|
||||
if spec is not None and not spec.is_gateway and not spec.is_local:
|
||||
resolved = model
|
||||
if spec.litellm_prefix and not any(resolved.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||||
elif spec.name == "openai" and "/" not in resolved:
|
||||
resolved = f"openai/{resolved}"
|
||||
return resolved
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
|
||||
@ -1,2 +1,13 @@
|
||||
"""Configuration models and loaders."""
|
||||
|
||||
from .loader import default_config_path, load_config
|
||||
from .schema import AgentDefaultsConfig, BeaverConfig, EmbeddingConfig, ProviderConfig
|
||||
|
||||
__all__ = [
|
||||
"AgentDefaultsConfig",
|
||||
"BeaverConfig",
|
||||
"EmbeddingConfig",
|
||||
"ProviderConfig",
|
||||
"default_config_path",
|
||||
"load_config",
|
||||
]
|
||||
|
||||
127
app-instance/backend/beaver/foundation/config/loader.py
Normal file
127
app-instance/backend/beaver/foundation/config/loader.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Config loader for per-sandbox Beaver runtime settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .schema import AgentDefaultsConfig, BeaverConfig, EmbeddingConfig, ProviderConfig
|
||||
|
||||
|
||||
def default_config_path(*, workspace: str | Path | None = None) -> Path:
|
||||
"""Resolve the default config path for a single-user sandbox instance.
|
||||
|
||||
Priority:
|
||||
1. `BEAVER_CONFIG_PATH`
|
||||
2. `NANOBOT_CONFIG_PATH` for compatibility during migration
|
||||
3. `BEAVER_HOME/config.json`
|
||||
4. `NANOBOT_HOME/config.json` for migration compatibility
|
||||
5. `<workspace>/.beaver/config.json`
|
||||
6. `./.beaver/config.json`
|
||||
"""
|
||||
|
||||
explicit = os.getenv("BEAVER_CONFIG_PATH") or os.getenv("NANOBOT_CONFIG_PATH")
|
||||
if explicit:
|
||||
return Path(explicit).expanduser()
|
||||
|
||||
beaver_home = os.getenv("BEAVER_HOME")
|
||||
if beaver_home:
|
||||
return Path(beaver_home).expanduser() / "config.json"
|
||||
|
||||
nanobot_home = os.getenv("NANOBOT_HOME")
|
||||
if nanobot_home:
|
||||
return Path(nanobot_home).expanduser() / "config.json"
|
||||
|
||||
root = Path(workspace).expanduser() if workspace is not None else Path.cwd()
|
||||
return root / ".beaver" / "config.json"
|
||||
|
||||
|
||||
def load_config(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
config_path: str | Path | None = None,
|
||||
) -> BeaverConfig:
|
||||
"""Load backend config; missing config is treated as an empty config."""
|
||||
|
||||
path = Path(config_path).expanduser() if config_path is not None else default_config_path(workspace=workspace)
|
||||
if not path.exists():
|
||||
return BeaverConfig(config_path=path)
|
||||
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Beaver config must be a JSON object: {path}")
|
||||
|
||||
return BeaverConfig(
|
||||
agents_defaults=_parse_agent_defaults(data),
|
||||
providers=_parse_providers(data.get("providers")),
|
||||
embedding=_parse_embedding(data),
|
||||
config_path=path,
|
||||
)
|
||||
|
||||
|
||||
def _parse_agent_defaults(data: dict[str, Any]) -> AgentDefaultsConfig:
|
||||
agents = _as_dict(data.get("agents"))
|
||||
defaults = _as_dict(agents.get("defaults"))
|
||||
return AgentDefaultsConfig(
|
||||
workspace=_string(defaults.get("workspace") or data.get("workspace")),
|
||||
model=_string(defaults.get("model") or data.get("model")),
|
||||
provider=_string(defaults.get("provider") or data.get("provider")),
|
||||
embedding_model=_string(defaults.get("embeddingModel") or defaults.get("embedding_model") or data.get("embeddingModel")),
|
||||
)
|
||||
|
||||
|
||||
def _parse_providers(raw: Any) -> dict[str, ProviderConfig]:
|
||||
providers: dict[str, ProviderConfig] = {}
|
||||
for name, payload in _as_dict(raw).items():
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
providers[str(name)] = ProviderConfig(
|
||||
api_key=_string(payload.get("apiKey") or payload.get("api_key")),
|
||||
api_base=_string(payload.get("apiBase") or payload.get("api_base") or payload.get("baseUrl") or payload.get("base_url")),
|
||||
extra_headers=_string_dict(payload.get("extraHeaders") or payload.get("extra_headers") or payload.get("headers")),
|
||||
request_timeout_seconds=_float(
|
||||
payload.get("requestTimeoutSeconds")
|
||||
or payload.get("request_timeout_seconds")
|
||||
or payload.get("timeout")
|
||||
),
|
||||
)
|
||||
return providers
|
||||
|
||||
|
||||
def _parse_embedding(data: dict[str, Any]) -> EmbeddingConfig:
|
||||
raw = _as_dict(data.get("embedding") or data.get("embeddings"))
|
||||
return EmbeddingConfig(
|
||||
provider=_string(raw.get("provider") or raw.get("provider_name")),
|
||||
model=_string(raw.get("model") or data.get("embeddingModel") or data.get("embedding_model")),
|
||||
api_key=_string(raw.get("apiKey") or raw.get("api_key")),
|
||||
api_base=_string(raw.get("apiBase") or raw.get("api_base") or raw.get("baseUrl") or raw.get("base_url")),
|
||||
extra_headers=_string_dict(raw.get("extraHeaders") or raw.get("extra_headers") or raw.get("headers")),
|
||||
request_timeout_seconds=_float(
|
||||
raw.get("requestTimeoutSeconds") or raw.get("request_timeout_seconds") or raw.get("timeout")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _as_dict(value: Any) -> dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _string(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = str(value).strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def _string_dict(value: Any) -> dict[str, str]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
return {str(key): str(item) for key, item in value.items() if item is not None}
|
||||
|
||||
|
||||
def _float(value: Any) -> float | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return float(value)
|
||||
136
app-instance/backend/beaver/foundation/config/schema.py
Normal file
136
app-instance/backend/beaver/foundation/config/schema.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Runtime configuration schema for Beaver sandbox instances."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderConfig:
|
||||
"""One configured LLM provider profile."""
|
||||
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentDefaultsConfig:
|
||||
"""Default agent settings for this sandbox instance."""
|
||||
|
||||
workspace: str | None = None
|
||||
model: str | None = None
|
||||
provider: str | None = None
|
||||
embedding_model: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EmbeddingConfig:
|
||||
"""Optional dedicated embedding model settings."""
|
||||
|
||||
provider: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BeaverConfig:
|
||||
"""Config loaded once per backend sandbox instance."""
|
||||
|
||||
agents_defaults: AgentDefaultsConfig = field(default_factory=AgentDefaultsConfig)
|
||||
providers: dict[str, ProviderConfig] = field(default_factory=dict)
|
||||
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
||||
config_path: Path | None = None
|
||||
|
||||
@property
|
||||
def default_model(self) -> str | None:
|
||||
return _clean(self.agents_defaults.model)
|
||||
|
||||
@property
|
||||
def default_embedding_model(self) -> str:
|
||||
return _clean(self.embedding.model) or _clean(self.agents_defaults.embedding_model) or "text-embedding-v4"
|
||||
|
||||
def resolve_provider_target(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve model/provider credentials from instance config.
|
||||
|
||||
Request-level model/provider overrides are allowed, but credentials are still
|
||||
read from backend config, not from Web/channel payloads.
|
||||
"""
|
||||
|
||||
resolved_model = _clean(model) or self.default_model
|
||||
resolved_provider = _clean(provider_name) or self._infer_provider(resolved_model)
|
||||
provider_cfg = self.providers.get(resolved_provider or "") if resolved_provider else None
|
||||
payload: dict[str, Any] = {
|
||||
"model": resolved_model,
|
||||
"provider_name": resolved_provider,
|
||||
}
|
||||
if provider_cfg is not None:
|
||||
payload.update(
|
||||
{
|
||||
"api_key": provider_cfg.api_key,
|
||||
"api_base": provider_cfg.api_base,
|
||||
"extra_headers": dict(provider_cfg.extra_headers),
|
||||
"request_timeout_seconds": provider_cfg.request_timeout_seconds,
|
||||
}
|
||||
)
|
||||
return {key: value for key, value in payload.items() if value not in (None, "", {})}
|
||||
|
||||
def resolve_embedding_target(self) -> dict[str, Any] | None:
|
||||
"""Return an explicit embedding target when configured."""
|
||||
|
||||
has_explicit_embedding = any(
|
||||
[
|
||||
_clean(self.embedding.provider),
|
||||
_clean(self.embedding.api_key),
|
||||
_clean(self.embedding.api_base),
|
||||
self.embedding.extra_headers,
|
||||
self.embedding.request_timeout_seconds is not None,
|
||||
]
|
||||
)
|
||||
if not has_explicit_embedding:
|
||||
return None
|
||||
|
||||
provider_cfg = self.providers.get(_clean(self.embedding.provider) or "")
|
||||
payload: dict[str, Any] = {
|
||||
"provider": _clean(self.embedding.provider),
|
||||
"model": self.default_embedding_model,
|
||||
"api_key": _clean(self.embedding.api_key) or (provider_cfg.api_key if provider_cfg else None),
|
||||
"api_base": _clean(self.embedding.api_base) or (provider_cfg.api_base if provider_cfg else None),
|
||||
"extra_headers": self.embedding.extra_headers or (dict(provider_cfg.extra_headers) if provider_cfg else {}),
|
||||
"request_timeout_seconds": self.embedding.request_timeout_seconds
|
||||
or (provider_cfg.request_timeout_seconds if provider_cfg else None),
|
||||
}
|
||||
return {key: value for key, value in payload.items() if value not in (None, "", {})}
|
||||
|
||||
def _infer_provider(self, model: str | None) -> str | None:
|
||||
configured_provider = _clean(self.agents_defaults.provider)
|
||||
if configured_provider:
|
||||
return configured_provider
|
||||
|
||||
if model and "/" in model:
|
||||
prefix = model.split("/", 1)[0]
|
||||
if prefix in self.providers:
|
||||
return prefix
|
||||
|
||||
if len(self.providers) == 1:
|
||||
return next(iter(self.providers))
|
||||
return None
|
||||
|
||||
|
||||
def _clean(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = str(value).strip()
|
||||
return value or None
|
||||
|
||||
205
app-instance/backend/beaver/foundation/embedding.py
Normal file
205
app-instance/backend/beaver/foundation/embedding.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Shared embedding-based semantic retrieval utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib import request
|
||||
|
||||
|
||||
class EmbeddingRetriever:
|
||||
"""Use an OpenAI-compatible embeddings API to rank lightweight candidates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key_env: str = "OPENAI_API_KEY",
|
||||
api_base_env: str = "OPENAI_API_BASE",
|
||||
model: str = "text-embedding-v4",
|
||||
timeout_seconds: float = 20.0,
|
||||
) -> None:
|
||||
self.api_key_env = api_key_env
|
||||
self.api_base_env = api_base_env
|
||||
self.model = model
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
candidates: list[dict[str, str]],
|
||||
top_k: int,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
fallback_top_k: int | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Return candidates ordered by embedding similarity.
|
||||
|
||||
If embedding config is missing or the request fails, return the original
|
||||
candidate order. This keeps retrieval non-blocking for the main run.
|
||||
"""
|
||||
|
||||
if not candidates or top_k <= 0:
|
||||
return []
|
||||
|
||||
fallback = self._fallback_candidates(candidates, fallback_top_k=fallback_top_k)
|
||||
resolved_api_key = api_key or os.getenv(self.api_key_env)
|
||||
resolved_api_base = api_base or os.getenv(self.api_base_env)
|
||||
if not resolved_api_key or not resolved_api_base:
|
||||
return fallback
|
||||
|
||||
try:
|
||||
query_embedding = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=[query],
|
||||
model=model or self.model,
|
||||
extra_headers=extra_headers,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
candidate_embeddings = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=[self._candidate_text(item) for item in candidates],
|
||||
model=model or self.model,
|
||||
extra_headers=extra_headers,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
|
||||
return fallback
|
||||
|
||||
query_vector = query_embedding[0]
|
||||
scored: list[tuple[float, dict[str, str]]] = []
|
||||
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
|
||||
if vector:
|
||||
scored.append((self._cosine_similarity(query_vector, vector), candidate))
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
return [item[1] for item in scored[:top_k]]
|
||||
|
||||
async def _embed_texts(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> list[list[float]]:
|
||||
all_vectors: list[list[float]] = []
|
||||
endpoint = self._normalize_embeddings_endpoint(api_base)
|
||||
for start in range(0, len(texts), 10):
|
||||
batch = texts[start:start + 10]
|
||||
payload = await self._post_embeddings(
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=batch,
|
||||
extra_headers=extra_headers,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
embeddings = payload.get("data") or []
|
||||
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
|
||||
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
|
||||
return all_vectors
|
||||
|
||||
async def _post_embeddings(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await asyncio.to_thread(
|
||||
self._post_embeddings_sync,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=texts,
|
||||
extra_headers=extra_headers,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def _post_embeddings_sync(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
body = json.dumps(
|
||||
{
|
||||
"model": model,
|
||||
"input": texts if len(texts) > 1 else texts[0],
|
||||
"encoding_format": "float",
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = request.Request(
|
||||
endpoint,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**(extra_headers or {}),
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=timeout_seconds or self.timeout_seconds) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def _fallback_candidates(
|
||||
candidates: list[dict[str, str]],
|
||||
*,
|
||||
fallback_top_k: int | None,
|
||||
) -> list[dict[str, str]]:
|
||||
if fallback_top_k is None:
|
||||
return list(candidates)
|
||||
if fallback_top_k <= 0:
|
||||
return []
|
||||
return candidates[:fallback_top_k]
|
||||
|
||||
@staticmethod
|
||||
def _candidate_text(candidate: dict[str, str]) -> str:
|
||||
parts = [
|
||||
(candidate.get("name") or "").strip(),
|
||||
(candidate.get("description") or "").strip(),
|
||||
(candidate.get("input_schema") or "").strip(),
|
||||
]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_embeddings_endpoint(api_base: str) -> str:
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/embeddings"):
|
||||
return base
|
||||
if base.endswith("/v1"):
|
||||
return f"{base}/embeddings"
|
||||
return f"{base}/v1/embeddings"
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||||
if not left or not right or len(left) != len(right):
|
||||
return -1.0
|
||||
dot = sum(a * b for a, b in zip(left, right, strict=False))
|
||||
left_norm = math.sqrt(sum(a * a for a in left))
|
||||
right_norm = math.sqrt(sum(b * b for b in right))
|
||||
if left_norm == 0 or right_norm == 0:
|
||||
return -1.0
|
||||
return dot / (left_norm * right_norm)
|
||||
@ -1,2 +1,7 @@
|
||||
"""Channel interfaces."""
|
||||
|
||||
from .base import ChannelAdapter
|
||||
from .manager import ChannelManager
|
||||
from .memory import MemoryChannelAdapter
|
||||
|
||||
__all__ = ["ChannelAdapter", "ChannelManager", "MemoryChannelAdapter"]
|
||||
|
||||
24
app-instance/backend/beaver/interfaces/channels/base.py
Normal file
24
app-instance/backend/beaver/interfaces/channels/base.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Channel adapter contracts for gateway-facing integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||
|
||||
|
||||
class ChannelAdapter(Protocol):
|
||||
"""Minimal contract every gateway channel must implement."""
|
||||
|
||||
name: str
|
||||
bus: MessageBus
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Prepare the channel before messages are routed."""
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop accepting/routing channel messages."""
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
"""Deliver an outbound message to the concrete channel."""
|
||||
|
||||
76
app-instance/backend/beaver/interfaces/channels/manager.py
Normal file
76
app-instance/backend/beaver/interfaces/channels/manager.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""Channel manager for routing gateway outbound messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
|
||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||
|
||||
from .base import ChannelAdapter
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""Start/stop channel adapters and dispatch outbound messages to them."""
|
||||
|
||||
def __init__(self, bus: MessageBus) -> None:
|
||||
self.bus = bus
|
||||
self.channels: dict[str, ChannelAdapter] = {}
|
||||
self.undeliverable: list[OutboundMessage] = []
|
||||
self.started = False
|
||||
|
||||
def register(self, channel: ChannelAdapter) -> None:
|
||||
if self.started:
|
||||
raise RuntimeError("Cannot register channels after ChannelManager.start()")
|
||||
if channel.name in self.channels:
|
||||
raise ValueError(f"Channel already registered: {channel.name}")
|
||||
if channel.bus is not self.bus:
|
||||
raise ValueError("Channel must share the same MessageBus as ChannelManager")
|
||||
self.channels[channel.name] = channel
|
||||
|
||||
async def start(self) -> None:
|
||||
started: list[ChannelAdapter] = []
|
||||
try:
|
||||
for channel in self.channels.values():
|
||||
await channel.start()
|
||||
started.append(channel)
|
||||
except BaseException:
|
||||
for channel in reversed(started):
|
||||
with suppress(BaseException):
|
||||
await channel.stop()
|
||||
raise
|
||||
else:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
errors: list[BaseException] = []
|
||||
for channel in reversed(tuple(self.channels.values())):
|
||||
try:
|
||||
await channel.stop()
|
||||
except Exception as exc: # pragma: no cover - defensive cleanup path
|
||||
errors.append(exc)
|
||||
self.started = False
|
||||
if errors:
|
||||
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
|
||||
|
||||
async def dispatch_outbound(self, stop_event: asyncio.Event) -> None:
|
||||
"""Route bus outbound messages until stopped and the queue is drained."""
|
||||
|
||||
while True:
|
||||
if stop_event.is_set() and self.bus.outbound_size == 0:
|
||||
break
|
||||
|
||||
try:
|
||||
message = await asyncio.wait_for(self.bus.consume_outbound(), timeout=0.25)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
channel = self.channels.get(message.channel)
|
||||
if channel is None:
|
||||
self.undeliverable.append(message)
|
||||
continue
|
||||
|
||||
try:
|
||||
await channel.send(message)
|
||||
except Exception: # pragma: no cover - defensive channel isolation
|
||||
self.undeliverable.append(message)
|
||||
57
app-instance/backend/beaver/interfaces/channels/memory.py
Normal file
57
app-instance/backend/beaver/interfaces/channels/memory.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""In-memory channel adapter for tests and local gateway embedding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
|
||||
|
||||
class MemoryChannelAdapter:
|
||||
"""A local channel that stores outbound messages in memory."""
|
||||
|
||||
def __init__(self, bus: MessageBus, *, name: str = "memory") -> None:
|
||||
self.name = name
|
||||
self.bus = bus
|
||||
self.started = False
|
||||
self.sent_messages: list[OutboundMessage] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.started = False
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
self.sent_messages.append(message)
|
||||
|
||||
async def publish_text(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> InboundMessage:
|
||||
"""Publish a text message from this channel into the shared bus."""
|
||||
|
||||
message = InboundMessage(
|
||||
channel=self.name,
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
execution_context=execution_context,
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
embedding_model=embedding_model,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self.bus.publish_inbound(message)
|
||||
return message
|
||||
|
||||
@ -35,6 +35,7 @@ app = typer.Typer(help="Beaver backend CLI") if hasattr(typer, "Typer") else typ
|
||||
def run(
|
||||
message: str | None = typer.Option(None, "--message", "-m", help="Run one direct Beaver request."),
|
||||
workspace: str | None = typer.Option(None, "--workspace", help="Workspace root for this run."),
|
||||
config: str | None = typer.Option(None, "--config", help="Backend config path for this run."),
|
||||
) -> None:
|
||||
"""Thin CLI wrapper around AgentService.
|
||||
|
||||
@ -44,7 +45,7 @@ def run(
|
||||
3. 打印结果
|
||||
"""
|
||||
|
||||
service = AgentService(workspace=workspace)
|
||||
service = AgentService(workspace=workspace, config_path=config)
|
||||
if not message:
|
||||
service.create_loop()
|
||||
typer.echo("Beaver engine booted.")
|
||||
|
||||
@ -1,41 +1,39 @@
|
||||
"""Gateway entrypoint for Beaver.
|
||||
|
||||
当前阶段先不扩 bus / channels adapter,只做最小消息桥接:
|
||||
当前阶段只做最小 gateway 宿主与 channel adapter 桥接:
|
||||
1. 启动时托管 `AgentService.start()`
|
||||
2. 常驻消费 `MessageBus.inbound`
|
||||
3. 调 `service.submit_direct(...)`
|
||||
3. 调 `service.handle_inbound_message(...)`
|
||||
4. 将结果写回 `MessageBus.outbound`
|
||||
5. 退出时走 `AgentService.shutdown()`
|
||||
5. 如果配置了 channel adapters,则由 `ChannelManager` 分发 outbound
|
||||
6. 退出时走 `AgentService.shutdown()`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
from beaver.foundation.events import InboundMessage, MessageBus
|
||||
from beaver.interfaces.channels import ChannelAdapter, ChannelManager
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
async def _publish_bridge_error(
|
||||
bus: MessageBus,
|
||||
inbound: InboundMessage,
|
||||
async def _cleanup_owned_service(
|
||||
service: AgentService,
|
||||
*,
|
||||
detail: str,
|
||||
finish_reason: str = "error",
|
||||
timeout_seconds: float | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
"""把 bridge 处理失败转换成结构化 outbound 错误消息。"""
|
||||
"""Best-effort cleanup for service startup failures or cancellations."""
|
||||
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=inbound.session_id,
|
||||
content=detail,
|
||||
finish_reason=finish_reason,
|
||||
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
)
|
||||
with suppress(BaseException):
|
||||
if service.is_running:
|
||||
await service.shutdown(timeout_seconds=timeout_seconds, force=force)
|
||||
else:
|
||||
service.close()
|
||||
|
||||
|
||||
async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
|
||||
@ -46,11 +44,17 @@ async def _flush_pending_inbound(bus: MessageBus, *, reason: str) -> None:
|
||||
pending = bus.inbound.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await _publish_bridge_error(bus, pending, detail=reason, finish_reason="stopped")
|
||||
await bus.publish_outbound(
|
||||
AgentService.build_outbound_error(
|
||||
pending,
|
||||
detail=reason,
|
||||
finish_reason="stopped",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _await_bridge_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
|
||||
"""等待 bridge 退出;超时则取消,避免 shutdown 被桥接层反向卡死。"""
|
||||
async def _await_task_shutdown(task: asyncio.Task[None], *, timeout_seconds: float = 1.0) -> None:
|
||||
"""等待后台任务退出;超时则取消,避免 shutdown 被反向卡死。"""
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=timeout_seconds)
|
||||
@ -85,53 +89,28 @@ async def _bridge_inbound_to_runtime(
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await service.submit_direct(
|
||||
inbound.content,
|
||||
session_id=inbound.session_id,
|
||||
source=f"gateway:{inbound.channel}",
|
||||
user_id=inbound.user_id,
|
||||
title=inbound.title,
|
||||
execution_context=inbound.execution_context,
|
||||
model=inbound.model,
|
||||
provider_name=inbound.provider_name,
|
||||
embedding_model=inbound.embedding_model,
|
||||
)
|
||||
outbound = await service.handle_inbound_message(inbound)
|
||||
except asyncio.CancelledError:
|
||||
await _publish_bridge_error(
|
||||
bus,
|
||||
inbound,
|
||||
detail="Gateway stopped before completing the inbound message",
|
||||
finish_reason="cancelled",
|
||||
)
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive bridge path
|
||||
await _publish_bridge_error(
|
||||
bus,
|
||||
inbound,
|
||||
detail=str(exc),
|
||||
)
|
||||
else:
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=result.session_id,
|
||||
run_id=result.run_id,
|
||||
content=result.output_text,
|
||||
finish_reason=result.finish_reason,
|
||||
provider_name=result.provider_name,
|
||||
model=result.model,
|
||||
usage=dict(result.usage),
|
||||
metadata={"inbound_metadata": dict(inbound.metadata)},
|
||||
AgentService.build_outbound_error(
|
||||
inbound,
|
||||
detail="Gateway stopped before completing the inbound message",
|
||||
finish_reason="cancelled",
|
||||
)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
await bus.publish_outbound(outbound)
|
||||
|
||||
|
||||
async def run_gateway(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
config_path: str | Path | None = None,
|
||||
service: AgentService | None = None,
|
||||
bus: MessageBus | None = None,
|
||||
channels: Sequence[ChannelAdapter] | None = None,
|
||||
channel_manager: ChannelManager | None = None,
|
||||
manage_service_lifecycle: bool | None = None,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
shutdown_timeout_seconds: float | None = 5.0,
|
||||
@ -142,19 +121,41 @@ async def run_gateway(
|
||||
默认 ownership 语义:
|
||||
- 未传 `service`:gateway 自己创建并接管其 lifecycle
|
||||
- 传入外部 `service`:默认只使用,不自动 start/shutdown
|
||||
- `channel_manager` 和 `channels` 二选一,避免隐式修改外部 manager
|
||||
"""
|
||||
|
||||
attached_service = service or AgentService(workspace=workspace)
|
||||
attached_bus = bus or MessageBus()
|
||||
attached_service = service or AgentService(workspace=workspace, config_path=config_path)
|
||||
if channel_manager is not None and channels is not None:
|
||||
raise ValueError("Pass either channel_manager or channels, not both")
|
||||
if bus is not None:
|
||||
attached_bus = bus
|
||||
elif channel_manager is not None:
|
||||
attached_bus = channel_manager.bus
|
||||
else:
|
||||
attached_bus = MessageBus()
|
||||
attached_channel_manager = channel_manager
|
||||
if attached_channel_manager is not None and attached_channel_manager.bus is not attached_bus:
|
||||
raise ValueError("Injected channel_manager must share the gateway MessageBus")
|
||||
if attached_channel_manager is None and channels is not None:
|
||||
attached_channel_manager = ChannelManager(attached_bus)
|
||||
if attached_channel_manager is not None and channels is not None:
|
||||
for channel in channels:
|
||||
attached_channel_manager.register(channel)
|
||||
|
||||
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
|
||||
owned_stop_event = stop_event or asyncio.Event()
|
||||
started = False
|
||||
channels_started = False
|
||||
if owns_service:
|
||||
try:
|
||||
await attached_service.start()
|
||||
started = True
|
||||
except Exception:
|
||||
attached_service.close()
|
||||
except BaseException:
|
||||
await _cleanup_owned_service(
|
||||
attached_service,
|
||||
timeout_seconds=shutdown_timeout_seconds,
|
||||
force=shutdown_force,
|
||||
)
|
||||
raise
|
||||
|
||||
if not attached_service.is_running:
|
||||
@ -163,7 +164,25 @@ async def run_gateway(
|
||||
"or allow the gateway to manage its lifecycle."
|
||||
)
|
||||
|
||||
if attached_channel_manager is not None:
|
||||
try:
|
||||
await attached_channel_manager.start()
|
||||
channels_started = True
|
||||
except BaseException:
|
||||
if owns_service and started:
|
||||
await _cleanup_owned_service(
|
||||
attached_service,
|
||||
timeout_seconds=shutdown_timeout_seconds,
|
||||
force=shutdown_force,
|
||||
)
|
||||
raise
|
||||
|
||||
bridge_task = asyncio.create_task(_bridge_inbound_to_runtime(attached_service, attached_bus, owned_stop_event))
|
||||
dispatch_task: asyncio.Task[None] | None = None
|
||||
dispatch_stop_event = asyncio.Event()
|
||||
if attached_channel_manager is not None:
|
||||
dispatch_task = asyncio.create_task(attached_channel_manager.dispatch_outbound(dispatch_stop_event))
|
||||
|
||||
try:
|
||||
await owned_stop_event.wait()
|
||||
finally:
|
||||
@ -175,9 +194,14 @@ async def run_gateway(
|
||||
force=shutdown_force,
|
||||
)
|
||||
finally:
|
||||
await _await_bridge_shutdown(bridge_task)
|
||||
await _await_task_shutdown(bridge_task)
|
||||
else:
|
||||
await _await_bridge_shutdown(bridge_task)
|
||||
await _await_task_shutdown(bridge_task)
|
||||
if dispatch_task is not None:
|
||||
dispatch_stop_event.set()
|
||||
await _await_task_shutdown(dispatch_task)
|
||||
if attached_channel_manager is not None and channels_started:
|
||||
await attached_channel_manager.stop()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
@ -56,6 +56,7 @@ async def _app_lifespan(
|
||||
app: FastAPI,
|
||||
*,
|
||||
workspace: str | Path | None,
|
||||
config_path: str | Path | None,
|
||||
service: AgentService | None,
|
||||
manage_service_lifecycle: bool | None,
|
||||
shutdown_timeout_seconds: float | None,
|
||||
@ -63,7 +64,7 @@ async def _app_lifespan(
|
||||
) -> AsyncIterator[None]:
|
||||
"""把 Web app 接到 AgentService lifecycle 上。"""
|
||||
|
||||
attached_service = service or AgentService(workspace=workspace)
|
||||
attached_service = service or AgentService(workspace=workspace, config_path=config_path)
|
||||
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
|
||||
app.state.agent_service = attached_service
|
||||
started = False
|
||||
@ -71,8 +72,15 @@ async def _app_lifespan(
|
||||
try:
|
||||
await attached_service.start()
|
||||
started = True
|
||||
except Exception:
|
||||
attached_service.close()
|
||||
except BaseException:
|
||||
with suppress(BaseException):
|
||||
if attached_service.is_running:
|
||||
await attached_service.shutdown(
|
||||
timeout_seconds=shutdown_timeout_seconds,
|
||||
force=shutdown_force,
|
||||
)
|
||||
else:
|
||||
attached_service.close()
|
||||
raise
|
||||
try:
|
||||
yield
|
||||
@ -87,6 +95,7 @@ async def _app_lifespan(
|
||||
def create_app(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
config_path: str | Path | None = None,
|
||||
service: AgentService | None = None,
|
||||
manage_service_lifecycle: bool | None = None,
|
||||
shutdown_timeout_seconds: float | None = 5.0,
|
||||
@ -106,6 +115,7 @@ def create_app(
|
||||
lifespan=lambda fastapi_app: _app_lifespan(
|
||||
fastapi_app,
|
||||
workspace=workspace,
|
||||
config_path=config_path,
|
||||
service=service,
|
||||
manage_service_lifecycle=manage_service_lifecycle,
|
||||
shutdown_timeout_seconds=shutdown_timeout_seconds,
|
||||
|
||||
@ -1,6 +1,15 @@
|
||||
"""Application services for Beaver."""
|
||||
|
||||
from .agent_service import AgentService
|
||||
from .memory_service import MemoryService
|
||||
|
||||
__all__ = ["AgentService", "MemoryService"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "AgentService":
|
||||
from .agent_service import AgentService
|
||||
|
||||
return AgentService
|
||||
if name == "MemoryService":
|
||||
from .memory_service import MemoryService
|
||||
|
||||
return MemoryService
|
||||
raise AttributeError(name)
|
||||
|
||||
@ -17,6 +17,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
class AgentService:
|
||||
@ -36,11 +37,12 @@ class AgentService:
|
||||
self,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
config_path: str | Path | None = None,
|
||||
profile: AgentProfile | None = None,
|
||||
loader: EngineLoader | None = None,
|
||||
) -> None:
|
||||
self.profile = profile or AgentProfile()
|
||||
self.loader = loader or EngineLoader(workspace=workspace)
|
||||
self.loader = loader or EngineLoader(workspace=workspace, config_path=config_path)
|
||||
self._loop: AgentLoop | None = None
|
||||
self._run_task: asyncio.Task[None] | None = None
|
||||
|
||||
@ -189,6 +191,60 @@ class AgentService:
|
||||
loop = self.create_loop()
|
||||
return await loop.submit_direct(message, **kwargs)
|
||||
|
||||
async def handle_inbound_message(self, inbound: InboundMessage) -> OutboundMessage:
|
||||
"""把 bus inbound 映射成标准 runtime 调用,并返回结构化 outbound。"""
|
||||
|
||||
try:
|
||||
result = await self.submit_direct(
|
||||
inbound.content,
|
||||
session_id=inbound.session_id,
|
||||
source=f"gateway:{inbound.channel}",
|
||||
user_id=inbound.user_id,
|
||||
title=inbound.title,
|
||||
execution_context=inbound.execution_context,
|
||||
model=inbound.model,
|
||||
provider_name=inbound.provider_name,
|
||||
embedding_model=inbound.embedding_model,
|
||||
)
|
||||
except Exception as exc:
|
||||
return self.build_outbound_error(inbound, detail=str(exc))
|
||||
return self.build_outbound_message(inbound, result)
|
||||
|
||||
@staticmethod
|
||||
def build_outbound_message(inbound: InboundMessage, result: AgentRunResult) -> OutboundMessage:
|
||||
"""把一次 runtime 正常结果转成 bus outbound。"""
|
||||
|
||||
return OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=result.session_id,
|
||||
run_id=result.run_id,
|
||||
content=result.output_text,
|
||||
finish_reason=result.finish_reason,
|
||||
provider_name=result.provider_name,
|
||||
model=result.model,
|
||||
usage=dict(result.usage),
|
||||
metadata={"inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_outbound_error(
|
||||
inbound: InboundMessage,
|
||||
*,
|
||||
detail: str,
|
||||
finish_reason: str = "error",
|
||||
) -> OutboundMessage:
|
||||
"""把 inbound 处理失败转换成结构化 outbound 错误消息。"""
|
||||
|
||||
return OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
session_id=inbound.session_id,
|
||||
content=detail,
|
||||
finish_reason=finish_reason,
|
||||
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
|
||||
def run_direct(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
"""Skill system for Beaver."""
|
||||
"""Skill system for Beaver.
|
||||
|
||||
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
|
||||
from .catalog import SkillRecord, SkillsLoader
|
||||
顶层包保持 lazy export,避免只导入 catalog/loader 时顺带拉起
|
||||
SkillAssembler -> provider -> litellm 这条重依赖链。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"SkillAssembler",
|
||||
@ -10,3 +15,22 @@ __all__ = [
|
||||
"SkillRecord",
|
||||
"SkillsLoader",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in {"SkillAssembler", "SkillAssemblyResult", "SkillEmbeddingRetriever"}:
|
||||
from .assembler import SkillAssembler, SkillAssemblyResult, SkillEmbeddingRetriever
|
||||
|
||||
return {
|
||||
"SkillAssembler": SkillAssembler,
|
||||
"SkillAssemblyResult": SkillAssemblyResult,
|
||||
"SkillEmbeddingRetriever": SkillEmbeddingRetriever,
|
||||
}[name]
|
||||
if name in {"SkillRecord", "SkillsLoader"}:
|
||||
from .catalog import SkillRecord, SkillsLoader
|
||||
|
||||
return {
|
||||
"SkillRecord": SkillRecord,
|
||||
"SkillsLoader": SkillsLoader,
|
||||
}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -1,188 +1,9 @@
|
||||
"""Embedding-based skill candidate retrieval.
|
||||
|
||||
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
|
||||
阿里云百炼 `text-embedding-v4` 做最小语义召回:
|
||||
1. 复用当前 provider 的 `api_key/api_base`
|
||||
2. 先用 embedding 相似度召回一小批候选
|
||||
3. 再交给上层 LLM selector 做最终技能选择
|
||||
"""
|
||||
"""Embedding-based skill candidate retrieval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import json
|
||||
from urllib import request
|
||||
from typing import Any
|
||||
from beaver.foundation.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
class SkillEmbeddingRetriever:
|
||||
class SkillEmbeddingRetriever(EmbeddingRetriever):
|
||||
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key_env: str = "OPENAI_API_KEY",
|
||||
api_base_env: str = "OPENAI_API_BASE",
|
||||
model: str = "text-embedding-v4",
|
||||
timeout_seconds: float = 20.0,
|
||||
) -> None:
|
||||
self.api_key_env = api_key_env
|
||||
self.api_base_env = api_base_env
|
||||
self.model = model
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
candidates: list[dict[str, str]],
|
||||
top_k: int = 12,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""按 embedding 相似度召回 top-k 候选。
|
||||
|
||||
如果没有可用的 API Key / base URL,或者 embedding 调用失败,
|
||||
当前阶段先退回到“全部候选交给 LLM selector”。
|
||||
"""
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
resolved_api_key = api_key or os.getenv(self.api_key_env)
|
||||
resolved_api_base = api_base or os.getenv(self.api_base_env)
|
||||
if not resolved_api_key or not resolved_api_base:
|
||||
return candidates
|
||||
|
||||
try:
|
||||
query_embedding = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=[query],
|
||||
model=model or self.model,
|
||||
)
|
||||
candidate_texts = [self._candidate_text(item) for item in candidates]
|
||||
candidate_embeddings = await self._embed_texts(
|
||||
api_key=resolved_api_key,
|
||||
api_base=resolved_api_base,
|
||||
texts=candidate_texts,
|
||||
model=model or self.model,
|
||||
)
|
||||
except Exception:
|
||||
return candidates
|
||||
|
||||
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
|
||||
return candidates
|
||||
|
||||
query_vector = query_embedding[0]
|
||||
scored: list[tuple[float, dict[str, str]]] = []
|
||||
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
|
||||
if not vector:
|
||||
continue
|
||||
scored.append((self._cosine_similarity(query_vector, vector), candidate))
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
return [item[1] for item in scored[:top_k]]
|
||||
|
||||
async def _embed_texts(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
) -> list[list[float]]:
|
||||
"""调用 OpenAI-compatible embeddings 接口。
|
||||
|
||||
当前对齐的是你们实际在用的网关配置:
|
||||
- `POST {api_base}/embeddings`
|
||||
- `model=text-embedding-v4`
|
||||
- `encoding_format=float`
|
||||
"""
|
||||
|
||||
all_vectors: list[list[float]] = []
|
||||
endpoint = self._normalize_embeddings_endpoint(api_base)
|
||||
for start in range(0, len(texts), 10):
|
||||
batch = texts[start:start + 10]
|
||||
payload = await self._post_embeddings(
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=batch,
|
||||
)
|
||||
embeddings = payload.get("data") or []
|
||||
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
|
||||
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
|
||||
return all_vectors
|
||||
|
||||
async def _post_embeddings(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
) -> dict[str, Any]:
|
||||
return await asyncio.to_thread(
|
||||
self._post_embeddings_sync,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def _post_embeddings_sync(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
texts: list[str],
|
||||
) -> dict[str, Any]:
|
||||
body = json.dumps(
|
||||
{
|
||||
"model": model,
|
||||
"input": texts if len(texts) > 1 else texts[0],
|
||||
"encoding_format": "float",
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = request.Request(
|
||||
endpoint,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=self.timeout_seconds) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def _candidate_text(candidate: dict[str, str]) -> str:
|
||||
name = (candidate.get("name") or "").strip()
|
||||
description = (candidate.get("description") or "").strip()
|
||||
return f"{name}\n{description}".strip()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_embeddings_endpoint(api_base: str) -> str:
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/embeddings"):
|
||||
return base
|
||||
if base.endswith("/v1"):
|
||||
return f"{base}/embeddings"
|
||||
return f"{base}/v1/embeddings"
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||||
if not left or not right or len(left) != len(right):
|
||||
return -1.0
|
||||
dot = sum(a * b for a, b in zip(left, right, strict=False))
|
||||
left_norm = math.sqrt(sum(a * a for a in left))
|
||||
right_norm = math.sqrt(sum(b * b for b in right))
|
||||
if left_norm == 0 or right_norm == 0:
|
||||
return -1.0
|
||||
return dot / (left_norm * right_norm)
|
||||
|
||||
@ -63,6 +63,11 @@ class SkillAssembler:
|
||||
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
|
||||
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
|
||||
model=embedding_runtime.model if embedding_runtime is not None else None,
|
||||
extra_headers=embedding_runtime.extra_headers if embedding_runtime is not None else None,
|
||||
timeout_seconds=(
|
||||
embedding_runtime.request_timeout_seconds if embedding_runtime is not None else None
|
||||
),
|
||||
fallback_top_k=None,
|
||||
)
|
||||
if not candidates:
|
||||
return SkillAssemblyResult()
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -111,6 +112,32 @@ class SkillsLoader:
|
||||
metadata, _ = parse_frontmatter(content)
|
||||
return metadata
|
||||
|
||||
def get_skill_tool_hints(self, name: str) -> list[str]:
|
||||
"""读取 skill 显式声明的推荐工具。
|
||||
|
||||
第一版只信任显式 metadata,不从正文里猜:
|
||||
- `tools: read_file, search_files`
|
||||
- `tools: ["read_file", "search_files"]`
|
||||
- YAML-like list:
|
||||
tools:
|
||||
- read_file
|
||||
- search_files
|
||||
- 兼容 metadata JSON blob 里的 `tools`
|
||||
"""
|
||||
|
||||
frontmatter = self.get_skill_metadata(name) or {}
|
||||
meta_blob = parse_skill_metadata_blob(frontmatter.get("metadata", ""))
|
||||
names = [
|
||||
*self._coerce_tool_names(frontmatter.get("tools")),
|
||||
*self._coerce_tool_names(meta_blob.get("tools")),
|
||||
*self._coerce_tool_names(meta_blob.get("required_tools")),
|
||||
]
|
||||
result: list[str] = []
|
||||
for item in names:
|
||||
if item and item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
"""加载指定 skills 的正文,并整理成上下文块。"""
|
||||
|
||||
@ -253,6 +280,26 @@ class SkillsLoader:
|
||||
result.append(record.name)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_names(value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return []
|
||||
if raw.startswith("["):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
parsed = None
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
return [item.strip() for item in raw.split(",") if item.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return []
|
||||
|
||||
def _find_record(self, name: str) -> SkillRecord | None:
|
||||
for record in self.list_skills(filter_unavailable=False):
|
||||
if record.name == name:
|
||||
|
||||
@ -20,7 +20,7 @@ import shutil
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
|
||||
def parse_frontmatter(content: str) -> tuple[dict[str, Any], str]:
|
||||
"""解析 Markdown 文件顶部的极简 frontmatter。
|
||||
|
||||
当前先只支持最常见的:
|
||||
@ -43,12 +43,36 @@ def parse_frontmatter(content: str) -> tuple[dict[str, str], str]:
|
||||
if match is None:
|
||||
return {}, content
|
||||
|
||||
metadata: dict[str, str] = {}
|
||||
for line in match.group(1).splitlines():
|
||||
metadata: dict[str, Any] = {}
|
||||
lines = match.group(1).splitlines()
|
||||
index = 0
|
||||
while index < len(lines):
|
||||
line = lines[index]
|
||||
if ":" not in line:
|
||||
index += 1
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not value:
|
||||
items: list[str] = []
|
||||
lookahead = index + 1
|
||||
while lookahead < len(lines):
|
||||
candidate = lines[lookahead]
|
||||
stripped = candidate.strip()
|
||||
if not stripped:
|
||||
lookahead += 1
|
||||
continue
|
||||
if not stripped.startswith("- "):
|
||||
break
|
||||
items.append(stripped[2:].strip().strip('"\''))
|
||||
lookahead += 1
|
||||
if items:
|
||||
metadata[key] = items
|
||||
index = lookahead
|
||||
continue
|
||||
metadata[key] = value.strip('"\'')
|
||||
index += 1
|
||||
body = content[match.end():].strip()
|
||||
return metadata, body
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Tool system for Beaver."""
|
||||
|
||||
from .base import BaseTool, ObjectBackedTool, ToolContext, ToolResult, ToolSpec
|
||||
from .assembler import ToolAssembler
|
||||
from .registry import ToolRegistry
|
||||
from .runtime import ToolExecutor
|
||||
|
||||
@ -8,6 +9,7 @@ __all__ = [
|
||||
"BaseTool",
|
||||
"ObjectBackedTool",
|
||||
"ToolContext",
|
||||
"ToolAssembler",
|
||||
"ToolExecutor",
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
|
||||
5
app-instance/backend/beaver/tools/assembler/__init__.py
Normal file
5
app-instance/backend/beaver/tools/assembler/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Tool selection for a single Beaver run."""
|
||||
|
||||
from .task_assembler import ToolAssembler
|
||||
|
||||
__all__ = ["ToolAssembler"]
|
||||
106
app-instance/backend/beaver/tools/assembler/task_assembler.py
Normal file
106
app-instance/backend/beaver/tools/assembler/task_assembler.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Task-driven tool assembler.
|
||||
|
||||
这层和 SkillAssembler 的位置类似:它不执行工具,只决定本轮 run 应该把哪些
|
||||
tool schema 暴露给模型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from beaver.engine.context import SkillContext
|
||||
from beaver.foundation.embedding import EmbeddingRetriever
|
||||
from beaver.tools.base import ToolSpec
|
||||
from beaver.tools.registry import ToolRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from beaver.engine.providers.runtime import ProviderRuntime
|
||||
from beaver.skills.catalog.loader import SkillsLoader
|
||||
|
||||
|
||||
class ToolAssembler:
|
||||
"""Use skill hints and embedding retrieval to select run-scoped tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retriever: EmbeddingRetriever | None = None,
|
||||
always_tool_names: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
self.retriever = retriever or EmbeddingRetriever()
|
||||
self.always_tool_names = tuple(always_tool_names or ("memory", "session_search", "skill_view"))
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
*,
|
||||
task_description: str,
|
||||
registry: ToolRegistry,
|
||||
skills_loader: SkillsLoader | None = None,
|
||||
activated_skills: Sequence[SkillContext] | None = None,
|
||||
embedding_runtime: ProviderRuntime | None = None,
|
||||
top_k: int = 10,
|
||||
) -> list[ToolSpec]:
|
||||
"""Return selected tool specs for the current run.
|
||||
|
||||
Selection order is intentionally deterministic:
|
||||
1. always tools from config/spec
|
||||
2. tools explicitly declared by activated skills
|
||||
3. embedding top-k tools for the task
|
||||
"""
|
||||
|
||||
selected: list[ToolSpec] = []
|
||||
selected_names: set[str] = set()
|
||||
|
||||
def add_specs(specs: Sequence[ToolSpec]) -> None:
|
||||
for spec in specs:
|
||||
if spec.name in selected_names:
|
||||
continue
|
||||
selected.append(spec)
|
||||
selected_names.add(spec.name)
|
||||
|
||||
add_specs(registry.list_always_specs())
|
||||
add_specs(registry.get_specs(self.always_tool_names))
|
||||
|
||||
skill_tool_names = self._collect_skill_tool_names(
|
||||
skills_loader=skills_loader,
|
||||
activated_skills=activated_skills or (),
|
||||
)
|
||||
add_specs(registry.get_specs(skill_tool_names))
|
||||
|
||||
candidates = [
|
||||
spec.to_embedding_candidate()
|
||||
for spec in registry.list_specs()
|
||||
if spec.name not in selected_names
|
||||
]
|
||||
retrieved = await self.retriever.retrieve(
|
||||
query=task_description,
|
||||
candidates=candidates,
|
||||
top_k=top_k,
|
||||
api_key=embedding_runtime.api_key if embedding_runtime is not None else None,
|
||||
api_base=embedding_runtime.api_base if embedding_runtime is not None else None,
|
||||
model=embedding_runtime.model if embedding_runtime is not None else None,
|
||||
extra_headers=embedding_runtime.extra_headers if embedding_runtime is not None else None,
|
||||
timeout_seconds=(
|
||||
embedding_runtime.request_timeout_seconds if embedding_runtime is not None else None
|
||||
),
|
||||
fallback_top_k=top_k,
|
||||
)
|
||||
add_specs(registry.get_specs([item["name"] for item in retrieved]))
|
||||
return selected
|
||||
|
||||
@staticmethod
|
||||
def _collect_skill_tool_names(
|
||||
*,
|
||||
skills_loader: SkillsLoader | None,
|
||||
activated_skills: Sequence[SkillContext],
|
||||
) -> list[str]:
|
||||
if skills_loader is None or not activated_skills:
|
||||
return []
|
||||
|
||||
result: list[str] = []
|
||||
for skill in activated_skills:
|
||||
for name in skills_loader.get_skill_tool_hints(skill.name):
|
||||
if name not in result:
|
||||
result.append(name)
|
||||
return result
|
||||
@ -29,13 +29,30 @@ class ToolSpec:
|
||||
"""单个工具对外暴露的描述信息。
|
||||
|
||||
这份信息主要服务两个场景:
|
||||
1. 导出给 provider 的 function schema
|
||||
2. 在 registry 中做列出、查找、调试
|
||||
1. 以 MCP-style descriptor 作为统一事实来源
|
||||
2. 导出给 provider 的 function schema
|
||||
3. 在 registry 中做列出、查找、调试与 embedding 召回
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
toolset: str = "core"
|
||||
always_available: bool = False
|
||||
|
||||
def to_mcp_descriptor(self) -> dict[str, Any]:
|
||||
"""导出 MCP ListTools 风格的工具描述。
|
||||
|
||||
MCP 的基础字段是 `name`、`description`、`inputSchema`。
|
||||
Beaver 内部额外的 toolset/always_available 不塞进这个对象,
|
||||
避免未来对接真实 MCP server 时出现格式偏差。
|
||||
"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"inputSchema": self.input_schema,
|
||||
}
|
||||
|
||||
def to_provider_schema(self) -> dict[str, Any]:
|
||||
"""导出为 OpenAI-compatible function tool schema。"""
|
||||
@ -49,6 +66,15 @@ class ToolSpec:
|
||||
},
|
||||
}
|
||||
|
||||
def to_embedding_candidate(self) -> dict[str, str]:
|
||||
"""导出给语义召回使用的轻量文本候选。"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": json.dumps(self.input_schema, ensure_ascii=False, sort_keys=True),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolContext:
|
||||
@ -113,6 +139,8 @@ class ObjectBackedTool(BaseTool):
|
||||
name=str(getattr(backend, "name")),
|
||||
description=str(getattr(backend, "description", "")),
|
||||
input_schema=dict(getattr(backend, "parameters", {"type": "object", "properties": {}})),
|
||||
toolset=str(getattr(backend, "toolset", "core")),
|
||||
always_available=bool(getattr(backend, "always_available", False)),
|
||||
)
|
||||
|
||||
@property
|
||||
@ -150,6 +178,8 @@ class ObjectBackedTool(BaseTool):
|
||||
|
||||
if "current_session_id" not in arguments and hasattr(self.backend, "current_session_id"):
|
||||
arguments["current_session_id"] = context.session_id
|
||||
if "workspace" not in arguments and hasattr(self.backend, "workspace"):
|
||||
arguments["workspace"] = context.workspace
|
||||
|
||||
@staticmethod
|
||||
def _normalize_output(content: Any) -> dict[str, Any]:
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
"""Built-in Beaver tools."""
|
||||
|
||||
from .echo import EchoTool, echo_tool
|
||||
from .filesystem import ListDirectoryTool, ReadFileTool, SearchFilesTool
|
||||
from .memory import MemoryTool, memory_tool
|
||||
from .skill_view import SkillViewTool, skill_view
|
||||
from .session_search import SessionSearchTool, session_search
|
||||
|
||||
__all__ = [
|
||||
"EchoTool",
|
||||
"ListDirectoryTool",
|
||||
"MemoryTool",
|
||||
"ReadFileTool",
|
||||
"SearchFilesTool",
|
||||
"SkillViewTool",
|
||||
"SessionSearchTool",
|
||||
"echo_tool",
|
||||
|
||||
@ -34,6 +34,8 @@ class EchoTool:
|
||||
|
||||
name: str = "echo"
|
||||
description: str = ECHO_TOOL_DESCRIPTION
|
||||
toolset: str = "debug"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(ECHO_TOOL_PARAMETERS))
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
|
||||
442
app-instance/backend/beaver/tools/builtins/filesystem.py
Normal file
442
app-instance/backend/beaver/tools/builtins/filesystem.py
Normal file
@ -0,0 +1,442 @@
|
||||
"""Workspace-scoped read-only filesystem tools.
|
||||
|
||||
这些工具是 Beaver 第一批真实本地工具,只做只读能力:
|
||||
- list_directory
|
||||
- read_file
|
||||
- search_files
|
||||
|
||||
安全边界先保持非常明确:所有用户传入路径都必须解析到当前
|
||||
`ToolContext.workspace` 内部。即使 workspace 里有指向外部的符号链接,
|
||||
读取时也会因为真实路径越界而被拒绝。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
MAX_LIST_ENTRIES = 1_000
|
||||
MAX_READ_LINES = 1_000
|
||||
MAX_READ_CHARS = 120_000
|
||||
MAX_SEARCH_RESULTS = 200
|
||||
MAX_SEARCH_FILE_BYTES = 2_000_000
|
||||
MAX_SEARCH_FILES = 5_000
|
||||
SKIP_DIR_NAMES = {
|
||||
".git",
|
||||
".hg",
|
||||
".svn",
|
||||
".venv",
|
||||
"venv",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
}
|
||||
|
||||
|
||||
LIST_DIRECTORY_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"default": ".",
|
||||
"description": "Directory path relative to the current workspace. Absolute paths are allowed only if they stay inside the workspace.",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Whether to recursively list child entries. Symlink directories are not followed.",
|
||||
},
|
||||
"max_entries": {
|
||||
"type": "integer",
|
||||
"default": 200,
|
||||
"minimum": 1,
|
||||
"maximum": MAX_LIST_ENTRIES,
|
||||
"description": "Maximum number of entries to return.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
READ_FILE_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path relative to the current workspace. Absolute paths are allowed only if they stay inside the workspace.",
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
"description": "1-based line number to start reading from.",
|
||||
},
|
||||
"max_lines": {
|
||||
"type": "integer",
|
||||
"default": 200,
|
||||
"minimum": 1,
|
||||
"maximum": MAX_READ_LINES,
|
||||
"description": "Maximum number of lines to read.",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
SEARCH_FILES_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Plain text query to search in file paths and UTF-8 text files.",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"default": ".",
|
||||
"description": "Directory or file path relative to the current workspace.",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"default": 50,
|
||||
"minimum": 1,
|
||||
"maximum": MAX_SEARCH_RESULTS,
|
||||
"description": "Maximum number of matches to return.",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Whether search should be case-sensitive.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
|
||||
class WorkspacePathError(ValueError):
|
||||
"""Raised when a requested path escapes the configured workspace."""
|
||||
|
||||
|
||||
def _json_result(success: bool, **payload: Any) -> str:
|
||||
return json.dumps({"success": success, **payload}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def _clamp_int(value: Any, *, default: int, minimum: int, maximum: int) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
parsed = default
|
||||
return max(minimum, min(parsed, maximum))
|
||||
|
||||
|
||||
def _workspace_root(workspace: str | None) -> Path:
|
||||
if not workspace:
|
||||
raise WorkspacePathError("workspace is not configured for filesystem tools")
|
||||
root = Path(workspace).expanduser().resolve(strict=True)
|
||||
if not root.is_dir():
|
||||
raise WorkspacePathError(f"workspace is not a directory: {root}")
|
||||
return root
|
||||
|
||||
|
||||
def _resolve_existing_path(workspace: str | None, user_path: str | None) -> tuple[Path, Path]:
|
||||
"""Resolve a user path and ensure the real target stays inside workspace."""
|
||||
|
||||
root = _workspace_root(workspace)
|
||||
raw_path = Path(user_path or ".").expanduser()
|
||||
candidate = raw_path if raw_path.is_absolute() else root / raw_path
|
||||
resolved = candidate.resolve(strict=True)
|
||||
try:
|
||||
resolved.relative_to(root)
|
||||
except ValueError as exc:
|
||||
raise WorkspacePathError(
|
||||
f"path escapes workspace: {user_path or '.'}"
|
||||
) from exc
|
||||
return root, resolved
|
||||
|
||||
|
||||
def _relative_path(root: Path, path: Path) -> str:
|
||||
try:
|
||||
return str(path.relative_to(root)) or "."
|
||||
except ValueError:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _entry_type(path: Path) -> str:
|
||||
if path.is_symlink():
|
||||
return "symlink"
|
||||
if path.is_dir():
|
||||
return "directory"
|
||||
if path.is_file():
|
||||
return "file"
|
||||
return "other"
|
||||
|
||||
|
||||
def _entry_payload(root: Path, path: Path) -> dict[str, Any]:
|
||||
try:
|
||||
stat = path.lstat() if path.is_symlink() else path.stat()
|
||||
size = stat.st_size
|
||||
except OSError:
|
||||
size = None
|
||||
return {
|
||||
"name": path.name,
|
||||
"path": _relative_path(root, path),
|
||||
"type": _entry_type(path),
|
||||
"size": size,
|
||||
}
|
||||
|
||||
|
||||
def _iter_directory(root: Path, directory: Path, *, recursive: bool) -> Iterable[Path]:
|
||||
def sort_key(item: Path) -> tuple[bool, str]:
|
||||
is_real_directory = not item.is_symlink() and item.is_dir()
|
||||
return (not is_real_directory, item.name.lower())
|
||||
|
||||
entries = sorted(directory.iterdir(), key=sort_key)
|
||||
for entry in entries:
|
||||
yield entry
|
||||
if not recursive or entry.is_symlink() or not entry.is_dir():
|
||||
continue
|
||||
yield from _iter_directory(root, entry, recursive=True)
|
||||
|
||||
|
||||
def _looks_binary(path: Path) -> bool:
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
sample = handle.read(4096)
|
||||
except OSError:
|
||||
return True
|
||||
return b"\0" in sample
|
||||
|
||||
|
||||
def _read_text_file(path: Path) -> str:
|
||||
if _looks_binary(path):
|
||||
raise ValueError("binary files cannot be read by read_file/search_files")
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _iter_search_files(root: Path, start: Path) -> Iterable[Path]:
|
||||
if start.is_file():
|
||||
yield start
|
||||
return
|
||||
|
||||
stack = [start]
|
||||
visited = 0
|
||||
while stack and visited < MAX_SEARCH_FILES:
|
||||
current = stack.pop()
|
||||
try:
|
||||
children = sorted(current.iterdir(), key=lambda item: item.name.lower())
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
for child in children:
|
||||
if child.is_symlink():
|
||||
continue
|
||||
if child.is_dir():
|
||||
if child.name in SKIP_DIR_NAMES:
|
||||
continue
|
||||
stack.append(child)
|
||||
continue
|
||||
if child.is_file():
|
||||
visited += 1
|
||||
yield child
|
||||
if visited >= MAX_SEARCH_FILES:
|
||||
break
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ListDirectoryTool:
|
||||
"""List files and directories inside the current workspace."""
|
||||
|
||||
name: str = "list_directory"
|
||||
description: str = (
|
||||
"List files and directories inside the current workspace. "
|
||||
"Use this before reading files when you need to inspect project structure. "
|
||||
"This tool never follows paths outside the workspace."
|
||||
)
|
||||
toolset: str = "filesystem"
|
||||
always_available: bool = True
|
||||
workspace: str | None = None
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(LIST_DIRECTORY_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str = ".",
|
||||
recursive: bool = False,
|
||||
max_entries: int = 200,
|
||||
workspace: str | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
root, resolved = _resolve_existing_path(workspace, path)
|
||||
if not resolved.is_dir():
|
||||
return _json_result(False, error="not_a_directory", path=path)
|
||||
|
||||
limit = _clamp_int(max_entries, default=200, minimum=1, maximum=MAX_LIST_ENTRIES)
|
||||
entries: list[dict[str, Any]] = []
|
||||
truncated = False
|
||||
for entry in _iter_directory(root, resolved, recursive=bool(recursive)):
|
||||
entries.append(_entry_payload(root, entry))
|
||||
if len(entries) >= limit:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
return _json_result(
|
||||
True,
|
||||
path=_relative_path(root, resolved),
|
||||
recursive=bool(recursive),
|
||||
entries=entries,
|
||||
truncated=truncated,
|
||||
)
|
||||
except (OSError, WorkspacePathError, ValueError) as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReadFileTool:
|
||||
"""Read a UTF-8 text file inside the current workspace."""
|
||||
|
||||
name: str = "read_file"
|
||||
description: str = (
|
||||
"Read a UTF-8 text file inside the current workspace with line limits. "
|
||||
"Use this to inspect source code, docs, config, or logs. "
|
||||
"This tool rejects binary files and paths outside the workspace."
|
||||
)
|
||||
toolset: str = "filesystem"
|
||||
always_available: bool = True
|
||||
workspace: str | None = None
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(READ_FILE_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
start_line: int = 1,
|
||||
max_lines: int = 200,
|
||||
workspace: str | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
root, resolved = _resolve_existing_path(workspace, path)
|
||||
if not resolved.is_file():
|
||||
return _json_result(False, error="not_a_file", path=path)
|
||||
|
||||
start = _clamp_int(start_line, default=1, minimum=1, maximum=10_000_000)
|
||||
limit = _clamp_int(max_lines, default=200, minimum=1, maximum=MAX_READ_LINES)
|
||||
content = _read_text_file(resolved)
|
||||
lines = content.splitlines()
|
||||
selected = lines[start - 1 : start - 1 + limit]
|
||||
selected_text = "\n".join(selected)
|
||||
char_truncated = False
|
||||
if len(selected_text) > MAX_READ_CHARS:
|
||||
selected_text = selected_text[:MAX_READ_CHARS]
|
||||
char_truncated = True
|
||||
|
||||
end_line = start + len(selected) - 1 if selected else start - 1
|
||||
return _json_result(
|
||||
True,
|
||||
path=_relative_path(root, resolved),
|
||||
start_line=start,
|
||||
end_line=end_line,
|
||||
total_lines=len(lines),
|
||||
truncated=end_line < len(lines) or char_truncated,
|
||||
content=selected_text,
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
return _json_result(False, error="file is not valid UTF-8 text", path=path)
|
||||
except (OSError, WorkspacePathError, ValueError) as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SearchFilesTool:
|
||||
"""Search filenames and UTF-8 text file contents inside the workspace."""
|
||||
|
||||
name: str = "search_files"
|
||||
description: str = (
|
||||
"Search file paths and UTF-8 text file contents inside the current workspace. "
|
||||
"Use this to find relevant source files, docs, config keys, or log lines. "
|
||||
"This tool skips large/binary files and never searches outside the workspace."
|
||||
)
|
||||
toolset: str = "filesystem"
|
||||
always_available: bool = True
|
||||
workspace: str | None = None
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(SEARCH_FILES_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
path: str = ".",
|
||||
max_results: int = 50,
|
||||
case_sensitive: bool = False,
|
||||
workspace: str | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
return _json_result(False, error="query must be a non-empty string")
|
||||
root, resolved = _resolve_existing_path(workspace, path)
|
||||
if not resolved.is_dir() and not resolved.is_file():
|
||||
return _json_result(False, error="path must be a file or directory", path=path)
|
||||
|
||||
limit = _clamp_int(max_results, default=50, minimum=1, maximum=MAX_SEARCH_RESULTS)
|
||||
needle = query if case_sensitive else query.lower()
|
||||
results: list[dict[str, Any]] = []
|
||||
searched_files = 0
|
||||
skipped_files = 0
|
||||
|
||||
for file_path in _iter_search_files(root, resolved):
|
||||
relative = _relative_path(root, file_path)
|
||||
haystack_path = relative if case_sensitive else relative.lower()
|
||||
if needle in haystack_path:
|
||||
results.append(
|
||||
{
|
||||
"path": relative,
|
||||
"line": None,
|
||||
"match_type": "path",
|
||||
"preview": relative,
|
||||
}
|
||||
)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
if file_path.stat().st_size > MAX_SEARCH_FILE_BYTES or _looks_binary(file_path):
|
||||
skipped_files += 1
|
||||
continue
|
||||
text = file_path.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
skipped_files += 1
|
||||
continue
|
||||
|
||||
searched_files += 1
|
||||
lines = text.splitlines()
|
||||
for index, line in enumerate(lines, start=1):
|
||||
haystack_line = line if case_sensitive else line.lower()
|
||||
if needle not in haystack_line:
|
||||
continue
|
||||
results.append(
|
||||
{
|
||||
"path": relative,
|
||||
"line": index,
|
||||
"match_type": "content",
|
||||
"preview": line[:500],
|
||||
}
|
||||
)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return _json_result(
|
||||
True,
|
||||
query=query,
|
||||
path=_relative_path(root, resolved),
|
||||
results=results,
|
||||
truncated=len(results) >= limit,
|
||||
searched_files=searched_files,
|
||||
skipped_files=skipped_files,
|
||||
)
|
||||
except (OSError, WorkspacePathError, ValueError) as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
@ -123,6 +123,8 @@ class MemoryTool:
|
||||
store: MemoryStore
|
||||
name: str = "memory"
|
||||
description: str = MEMORY_TOOL_DESCRIPTION
|
||||
toolset: str = "memory"
|
||||
always_available: bool = True
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(MEMORY_TOOL_PARAMETERS))
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
|
||||
@ -406,6 +406,8 @@ class SessionSearchTool:
|
||||
summarizer: SessionSummarizer | None = None
|
||||
name: str = "session_search"
|
||||
description: str = SESSION_SEARCH_TOOL_DESCRIPTION
|
||||
toolset: str = "session"
|
||||
always_available: bool = True
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(SESSION_SEARCH_TOOL_PARAMETERS))
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
|
||||
@ -76,6 +76,8 @@ class SkillViewTool:
|
||||
loader: SkillsLoader
|
||||
name: str = "skill_view"
|
||||
description: str = SKILL_VIEW_TOOL_DESCRIPTION
|
||||
toolset: str = "skills"
|
||||
always_available: bool = True
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(SKILL_VIEW_TOOL_PARAMETERS))
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Iterable
|
||||
|
||||
from beaver.tools.base import BaseTool, ToolSpec
|
||||
@ -49,7 +50,30 @@ class ToolRegistry:
|
||||
def list_specs(self) -> list[ToolSpec]:
|
||||
return [tool.spec for tool in self._tools.values()]
|
||||
|
||||
def list_always_specs(self) -> list[ToolSpec]:
|
||||
"""列出每轮 run 都应该暴露给模型的基础工具。"""
|
||||
|
||||
return [spec for spec in self.list_specs() if spec.always_available]
|
||||
|
||||
def get_specs(self, names: Sequence[str]) -> list[ToolSpec]:
|
||||
"""按名称顺序返回已注册工具 spec,忽略未知工具。"""
|
||||
|
||||
specs: list[ToolSpec] = []
|
||||
seen: set[str] = set()
|
||||
for name in names:
|
||||
tool = self.get(name)
|
||||
if tool is None or name in seen:
|
||||
continue
|
||||
specs.append(tool.spec)
|
||||
seen.add(name)
|
||||
return specs
|
||||
|
||||
def export_provider_schemas(self) -> list[dict]:
|
||||
"""导出给 provider 的函数工具 schema 列表。"""
|
||||
|
||||
return [spec.to_provider_schema() for spec in self.list_specs()]
|
||||
|
||||
def export_selected_provider_schemas(self, specs: Sequence[ToolSpec]) -> list[dict]:
|
||||
"""导出一组已选择工具的 provider schema。"""
|
||||
|
||||
return [spec.to_provider_schema() for spec in specs]
|
||||
|
||||
@ -12,12 +12,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from beaver.engine.providers.base import ToolCallRequest
|
||||
from beaver.tools.base import ToolContext, ToolResult
|
||||
from beaver.tools.registry.tool_registry import ToolRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from beaver.engine.providers.base import ToolCallRequest
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""统一执行单个 tool call。"""
|
||||
@ -80,16 +82,17 @@ class ToolExecutor:
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if isinstance(tool_call, ToolCallRequest):
|
||||
return tool_call.name, dict(tool_call.arguments)
|
||||
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
arguments = function.get("arguments", {})
|
||||
if not isinstance(tool_call, dict):
|
||||
name = getattr(tool_call, "name", None)
|
||||
arguments = getattr(tool_call, "arguments", {})
|
||||
else:
|
||||
name = tool_call.get("name")
|
||||
arguments = tool_call.get("arguments", {})
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
arguments = function.get("arguments", {})
|
||||
else:
|
||||
name = tool_call.get("name")
|
||||
arguments = tool_call.get("arguments", {})
|
||||
|
||||
if not name:
|
||||
raise ValueError("Tool call is missing a tool name")
|
||||
@ -104,8 +107,8 @@ class ToolExecutor:
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str:
|
||||
if isinstance(tool_call, ToolCallRequest):
|
||||
return str(tool_call.name or "unknown")
|
||||
if not isinstance(tool_call, dict):
|
||||
return str(getattr(tool_call, "name", None) or "unknown")
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and function.get("name"):
|
||||
return str(function["name"])
|
||||
|
||||
Reference in New Issue
Block a user