Files
beaver_project/app-instance/backend/beaver/skills/assembler/embedding_retriever.py

189 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Embedding-based skill candidate retrieval.
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
阿里云百炼 `text-embedding-v4` 做最小语义召回:
1. 复用当前 provider 的 `api_key/api_base`
2. 先用 embedding 相似度召回一小批候选
3. 再交给上层 LLM selector 做最终技能选择
"""
from __future__ import annotations
import asyncio
import math
import os
import json
from urllib import request
from typing import Any
class SkillEmbeddingRetriever:
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
def __init__(
self,
*,
api_key_env: str = "OPENAI_API_KEY",
api_base_env: str = "OPENAI_API_BASE",
model: str = "text-embedding-v4",
timeout_seconds: float = 20.0,
) -> None:
self.api_key_env = api_key_env
self.api_base_env = api_base_env
self.model = model
self.timeout_seconds = timeout_seconds
async def retrieve(
self,
*,
query: str,
candidates: list[dict[str, str]],
top_k: int = 12,
api_key: str | None = None,
api_base: str | None = None,
model: str | None = None,
) -> list[dict[str, str]]:
"""按 embedding 相似度召回 top-k 候选。
如果没有可用的 API Key / base URL或者 embedding 调用失败,
当前阶段先退回到“全部候选交给 LLM selector”。
"""
if not candidates:
return []
resolved_api_key = api_key or os.getenv(self.api_key_env)
resolved_api_base = api_base or os.getenv(self.api_base_env)
if not resolved_api_key or not resolved_api_base:
return candidates
try:
query_embedding = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=[query],
model=model or self.model,
)
candidate_texts = [self._candidate_text(item) for item in candidates]
candidate_embeddings = await self._embed_texts(
api_key=resolved_api_key,
api_base=resolved_api_base,
texts=candidate_texts,
model=model or self.model,
)
except Exception:
return candidates
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
return candidates
query_vector = query_embedding[0]
scored: list[tuple[float, dict[str, str]]] = []
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
if not vector:
continue
scored.append((self._cosine_similarity(query_vector, vector), candidate))
scored.sort(key=lambda item: item[0], reverse=True)
return [item[1] for item in scored[:top_k]]
async def _embed_texts(
self,
*,
api_key: str,
api_base: str,
texts: list[str],
model: str,
) -> list[list[float]]:
"""调用 OpenAI-compatible embeddings 接口。
当前对齐的是你们实际在用的网关配置:
- `POST {api_base}/embeddings`
- `model=text-embedding-v4`
- `encoding_format=float`
"""
all_vectors: list[list[float]] = []
endpoint = self._normalize_embeddings_endpoint(api_base)
for start in range(0, len(texts), 10):
batch = texts[start:start + 10]
payload = await self._post_embeddings(
endpoint=endpoint,
api_key=api_key,
model=model,
texts=batch,
)
embeddings = payload.get("data") or []
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
return all_vectors
async def _post_embeddings(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
return await asyncio.to_thread(
self._post_embeddings_sync,
endpoint=endpoint,
api_key=api_key,
model=model,
texts=texts,
)
def _post_embeddings_sync(
self,
*,
endpoint: str,
api_key: str,
model: str,
texts: list[str],
) -> dict[str, Any]:
body = json.dumps(
{
"model": model,
"input": texts if len(texts) > 1 else texts[0],
"encoding_format": "float",
}
).encode("utf-8")
req = request.Request(
endpoint,
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as response:
return json.loads(response.read().decode("utf-8"))
@staticmethod
def _candidate_text(candidate: dict[str, str]) -> str:
name = (candidate.get("name") or "").strip()
description = (candidate.get("description") or "").strip()
return f"{name}\n{description}".strip()
@staticmethod
def _normalize_embeddings_endpoint(api_base: str) -> str:
base = api_base.rstrip("/")
if base.endswith("/embeddings"):
return base
if base.endswith("/v1"):
return f"{base}/embeddings"
return f"{base}/v1/embeddings"
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right or len(left) != len(right):
return -1.0
dot = sum(a * b for a, b in zip(left, right, strict=False))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0 or right_norm == 0:
return -1.0
return dot / (left_norm * right_norm)