- 集成MCP连接管理器,支持MCP服务器连接 - 添加多种内置工具:ClarifyTool、CronTool、DelegateTool、ExecuteCodeTool、 PatchFileTool、ProcessTool、SendMessageTool、SpawnTool、TerminalTool、 TodoTool、WebFetchTool、WebSearchTool、WriteFileTool等 - 实现工具注册和装配功能 - 添加技能选择上下文参数 - 支持思考模式控制参数thinking_enabled feat(coordinator): 重构任务执行计划器参数命名 - 将learning_candidate_enabled重命名为allow_candidate_generation - 更新TeamGraphScheduler中的参数传递 - 修改LocalAgentRunner中的相关参数处理 - 更新README文档中的相应描述 refactor(context): 标准化工具调用参数格式 - 添加_json导入用于参数序列化 - 实现_provider_tool_calls方法标准化OpenAI兼容的工具调用载荷 - 修复工具调用中参数非字符串类型的序列化问题 refactor(session): 优化消息历史记录过滤逻辑 - 修改get_messages_as_conversation为基于运行状态过滤消息 - 排除未完成、失败或错误结束的运行记录 - 改进对话历史的可见性控制机制 fix(store): 修复FTS索引重建逻辑 - 添加异常处理防止FTS索引创建失败 - 实现_rebuild_fts_index方法重新构建全文搜索索引 - 优化索引触发器和表的维护流程
206 lines
6.8 KiB
Python
206 lines
6.8 KiB
Python
"""Shared embedding-based semantic retrieval utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import math
|
|
import os
|
|
from typing import Any
|
|
from urllib import request
|
|
|
|
|
|
class EmbeddingRetriever:
|
|
"""Use an OpenAI-compatible embeddings API to rank lightweight candidates."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key_env: str = "OPENAI_API_KEY",
|
|
api_base_env: str = "OPENAI_API_BASE",
|
|
model: str = "text-embedding-v4",
|
|
timeout_seconds: float = 3.0,
|
|
) -> None:
|
|
self.api_key_env = api_key_env
|
|
self.api_base_env = api_base_env
|
|
self.model = model
|
|
self.timeout_seconds = timeout_seconds
|
|
|
|
async def retrieve(
|
|
self,
|
|
*,
|
|
query: str,
|
|
candidates: list[dict[str, str]],
|
|
top_k: int,
|
|
api_key: str | None = None,
|
|
api_base: str | None = None,
|
|
model: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
timeout_seconds: float | None = None,
|
|
fallback_top_k: int | None = None,
|
|
) -> list[dict[str, str]]:
|
|
"""Return candidates ordered by embedding similarity.
|
|
|
|
If embedding config is missing or the request fails, return the original
|
|
candidate order. This keeps retrieval non-blocking for the main run.
|
|
"""
|
|
|
|
if not candidates or top_k <= 0:
|
|
return []
|
|
|
|
fallback = self._fallback_candidates(candidates, fallback_top_k=fallback_top_k)
|
|
resolved_api_key = api_key or os.getenv(self.api_key_env)
|
|
resolved_api_base = api_base or os.getenv(self.api_base_env)
|
|
if not resolved_api_key or not resolved_api_base:
|
|
return fallback
|
|
|
|
try:
|
|
query_embedding = await self._embed_texts(
|
|
api_key=resolved_api_key,
|
|
api_base=resolved_api_base,
|
|
texts=[query],
|
|
model=model or self.model,
|
|
extra_headers=extra_headers,
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
candidate_embeddings = await self._embed_texts(
|
|
api_key=resolved_api_key,
|
|
api_base=resolved_api_base,
|
|
texts=[self._candidate_text(item) for item in candidates],
|
|
model=model or self.model,
|
|
extra_headers=extra_headers,
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
except Exception:
|
|
return fallback
|
|
|
|
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
|
|
return fallback
|
|
|
|
query_vector = query_embedding[0]
|
|
scored: list[tuple[float, dict[str, str]]] = []
|
|
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
|
|
if vector:
|
|
scored.append((self._cosine_similarity(query_vector, vector), candidate))
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
return [item[1] for item in scored[:top_k]]
|
|
|
|
async def _embed_texts(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
api_base: str,
|
|
texts: list[str],
|
|
model: str,
|
|
extra_headers: dict[str, str] | None = None,
|
|
timeout_seconds: float | None = None,
|
|
) -> list[list[float]]:
|
|
all_vectors: list[list[float]] = []
|
|
endpoint = self._normalize_embeddings_endpoint(api_base)
|
|
for start in range(0, len(texts), 10):
|
|
batch = texts[start:start + 10]
|
|
payload = await self._post_embeddings(
|
|
endpoint=endpoint,
|
|
api_key=api_key,
|
|
model=model,
|
|
texts=batch,
|
|
extra_headers=extra_headers,
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
embeddings = payload.get("data") or []
|
|
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
|
|
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
|
|
return all_vectors
|
|
|
|
async def _post_embeddings(
|
|
self,
|
|
*,
|
|
endpoint: str,
|
|
api_key: str,
|
|
model: str,
|
|
texts: list[str],
|
|
extra_headers: dict[str, str] | None = None,
|
|
timeout_seconds: float | None = None,
|
|
) -> dict[str, Any]:
|
|
return await asyncio.to_thread(
|
|
self._post_embeddings_sync,
|
|
endpoint=endpoint,
|
|
api_key=api_key,
|
|
model=model,
|
|
texts=texts,
|
|
extra_headers=extra_headers,
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
|
|
def _post_embeddings_sync(
|
|
self,
|
|
*,
|
|
endpoint: str,
|
|
api_key: str,
|
|
model: str,
|
|
texts: list[str],
|
|
extra_headers: dict[str, str] | None = None,
|
|
timeout_seconds: float | None = None,
|
|
) -> dict[str, Any]:
|
|
body = json.dumps(
|
|
{
|
|
"model": model,
|
|
"input": texts if len(texts) > 1 else texts[0],
|
|
"encoding_format": "float",
|
|
}
|
|
).encode("utf-8")
|
|
req = request.Request(
|
|
endpoint,
|
|
data=body,
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
**(extra_headers or {}),
|
|
},
|
|
method="POST",
|
|
)
|
|
with request.urlopen(req, timeout=timeout_seconds or self.timeout_seconds) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
|
|
@staticmethod
|
|
def _fallback_candidates(
|
|
candidates: list[dict[str, str]],
|
|
*,
|
|
fallback_top_k: int | None,
|
|
) -> list[dict[str, str]]:
|
|
if fallback_top_k is None:
|
|
return list(candidates)
|
|
if fallback_top_k <= 0:
|
|
return []
|
|
return candidates[:fallback_top_k]
|
|
|
|
@staticmethod
|
|
def _candidate_text(candidate: dict[str, str]) -> str:
|
|
parts = [
|
|
(candidate.get("name") or "").strip(),
|
|
(candidate.get("description") or "").strip(),
|
|
(candidate.get("input_schema") or "").strip(),
|
|
]
|
|
return "\n".join(part for part in parts if part)
|
|
|
|
@staticmethod
|
|
def _normalize_embeddings_endpoint(api_base: str) -> str:
|
|
base = api_base.rstrip("/")
|
|
if base.endswith("/embeddings"):
|
|
return base
|
|
if base.endswith("/v1"):
|
|
return f"{base}/embeddings"
|
|
return f"{base}/v1/embeddings"
|
|
|
|
@staticmethod
|
|
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
|
if not left or not right or len(left) != len(right):
|
|
return -1.0
|
|
dot = sum(a * b for a, b in zip(left, right, strict=False))
|
|
left_norm = math.sqrt(sum(a * a for a in left))
|
|
right_norm = math.sqrt(sum(b * b for b in right))
|
|
if left_norm == 0 or right_norm == 0:
|
|
return -1.0
|
|
return dot / (left_norm * right_norm)
|