Files
beaver_project/app-instance/backend/beaver/foundation/embedding.py
steven_li 30ab74ffb2 feat(engine): 添加MCP连接管理和工具集成功能
- 集成MCP连接管理器,支持MCP服务器连接
- 添加多种内置工具:ClarifyTool、CronTool、DelegateTool、ExecuteCodeTool、
  PatchFileTool、ProcessTool、SendMessageTool、SpawnTool、TerminalTool、
  TodoTool、WebFetchTool、WebSearchTool、WriteFileTool等
- 实现工具注册和装配功能
- 添加技能选择上下文参数
- 支持思考模式控制参数thinking_enabled

feat(coordinator): 重构任务执行计划器参数命名

- 将learning_candidate_enabled重命名为allow_candidate_generation
- 更新TeamGraphScheduler中的参数传递
- 修改LocalAgentRunner中的相关参数处理
- 更新README文档中的相应描述

refactor(context): 标准化工具调用参数格式

- 添加_json导入用于参数序列化
- 实现_provider_tool_calls方法标准化OpenAI兼容的工具调用载荷
- 修复工具调用中参数非字符串类型的序列化问题

refactor(session): 优化消息历史记录过滤逻辑

- 修改get_messages_as_conversation为基于运行状态过滤消息
- 排除未完成、失败或错误结束的运行记录
- 改进对话历史的可见性控制机制

fix(store): 修复FTS索引重建逻辑

- 添加异常处理防止FTS索引创建失败
- 实现_rebuild_fts_index方法重新构建全文搜索索引
- 优化索引触发器和表的维护流程
2026-05-14 09:43:48 +08:00

206 lines
6.8 KiB
Python

"""Shared embedding-based semantic retrieval utilities."""
from __future__ import annotations
import asyncio
import json
import math
import os
from typing import Any
from urllib import request
class EmbeddingRetriever:
"""Use an OpenAI-compatible embeddings API to rank lightweight candidates."""
def __init__(
self,
*,
api_key_env: str = "OPENAI_API_KEY",
api_base_env: str = "OPENAI_API_BASE",
model: str = "text-embedding-v4",
timeout_seconds: float = 3.0,
) -> None:
self.api_key_env = api_key_env
self.api_base_env = api_base_env
self.model = model
self.timeout_seconds = timeout_seconds
async def retrieve(
self,
*,
query: str,
candidates: list[dict[str, str]],
top_k: int,
api_key: str | None = None,
api_base: str | None = None,
model: str | None = None,
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
fallback_top_k: int | None = None,
) -> list[dict[str, str]]:
"""Return candidates ordered by embedding similarity.
If embedding config is missing or the request fails, return the original
candidate order. This keeps retrieval non-blocking for the main run.
"""
if not candidates or top_k <= 0:
return []
fallback = self._fallback_candidates(candidates, fallback_top_k=fallback_top_k)
resolved_api_key = api_key or os.getenv(self.api_key_env)
resolved_api_base = api_base or os.getenv(self.api_base_env)
if not resolved_api_key or not resolved_api_base:
return fallback
try:
query_embedding = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[query],
model=model or self.model,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
candidate_embeddings = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[self._candidate_text(item) for item in candidates],
model=model or self.model,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
except Exception:
return fallback
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
return fallback
query_vector = query_embedding[0]
scored: list[tuple[float, dict[str, str]]] = []
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
if vector:
scored.append((self._cosine_similarity(query_vector, vector), candidate))
scored.sort(key=lambda item: item[0], reverse=True)
return [item[1] for item in scored[:top_k]]
async def _embed_texts(
self,
*,
api_key: str,
api_base: str,
texts: list[str],
model: str,
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> list[list[float]]:
all_vectors: list[list[float]] = []
endpoint = self._normalize_embeddings_endpoint(api_base)
for start in range(0, len(texts), 10):
batch = texts[start:start + 10]
payload = await self._post_embeddings(
endpoint=endpoint,
api_key=api_key,
model=model,
texts=batch,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
embeddings = payload.get("data") or []
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
return all_vectors
async def _post_embeddings(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> dict[str, Any]:
return await asyncio.to_thread(
self._post_embeddings_sync,
endpoint=endpoint,
api_key=api_key,
model=model,
texts=texts,
extra_headers=extra_headers,
timeout_seconds=timeout_seconds,
)
def _post_embeddings_sync(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
extra_headers: dict[str, str] | None = None,
timeout_seconds: float | None = None,
) -> dict[str, Any]:
body = json.dumps(
{
"model": model,
"input": texts if len(texts) > 1 else texts[0],
"encoding_format": "float",
}
).encode("utf-8")
req = request.Request(
endpoint,
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
},
method="POST",
)
with request.urlopen(req, timeout=timeout_seconds or self.timeout_seconds) as response:
return json.loads(response.read().decode("utf-8"))
@staticmethod
def _fallback_candidates(
candidates: list[dict[str, str]],
*,
fallback_top_k: int | None,
) -> list[dict[str, str]]:
if fallback_top_k is None:
return list(candidates)
if fallback_top_k <= 0:
return []
return candidates[:fallback_top_k]
@staticmethod
def _candidate_text(candidate: dict[str, str]) -> str:
parts = [
(candidate.get("name") or "").strip(),
(candidate.get("description") or "").strip(),
(candidate.get("input_schema") or "").strip(),
]
return "\n".join(part for part in parts if part)
@staticmethod
def _normalize_embeddings_endpoint(api_base: str) -> str:
base = api_base.rstrip("/")
if base.endswith("/embeddings"):
return base
if base.endswith("/v1"):
return f"{base}/embeddings"
return f"{base}/v1/embeddings"
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right or len(left) != len(right):
return -1.0
dot = sum(a * b for a, b in zip(left, right, strict=False))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0 or right_norm == 0:
return -1.0
return dot / (left_norm * right_norm)