189 lines
6.0 KiB
Python
189 lines
6.0 KiB
Python
"""Embedding-based skill candidate retrieval.
|
||
|
||
当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用
|
||
阿里云百炼 `text-embedding-v4` 做最小语义召回:
|
||
1. 复用当前 provider 的 `api_key/api_base`
|
||
2. 先用 embedding 相似度召回一小批候选
|
||
3. 再交给上层 LLM selector 做最终技能选择
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import math
|
||
import os
|
||
import json
|
||
from urllib import request
|
||
from typing import Any
|
||
|
||
|
||
class SkillEmbeddingRetriever:
|
||
"""用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
api_key_env: str = "OPENAI_API_KEY",
|
||
api_base_env: str = "OPENAI_API_BASE",
|
||
model: str = "text-embedding-v4",
|
||
timeout_seconds: float = 20.0,
|
||
) -> None:
|
||
self.api_key_env = api_key_env
|
||
self.api_base_env = api_base_env
|
||
self.model = model
|
||
self.timeout_seconds = timeout_seconds
|
||
|
||
async def retrieve(
|
||
self,
|
||
*,
|
||
query: str,
|
||
candidates: list[dict[str, str]],
|
||
top_k: int = 12,
|
||
api_key: str | None = None,
|
||
api_base: str | None = None,
|
||
model: str | None = None,
|
||
) -> list[dict[str, str]]:
|
||
"""按 embedding 相似度召回 top-k 候选。
|
||
|
||
如果没有可用的 API Key / base URL,或者 embedding 调用失败,
|
||
当前阶段先退回到“全部候选交给 LLM selector”。
|
||
"""
|
||
|
||
if not candidates:
|
||
return []
|
||
|
||
resolved_api_key = api_key or os.getenv(self.api_key_env)
|
||
resolved_api_base = api_base or os.getenv(self.api_base_env)
|
||
if not resolved_api_key or not resolved_api_base:
|
||
return candidates
|
||
|
||
try:
|
||
query_embedding = await self._embed_texts(
|
||
api_key=resolved_api_key,
|
||
api_base=resolved_api_base,
|
||
texts=[query],
|
||
model=model or self.model,
|
||
)
|
||
candidate_texts = [self._candidate_text(item) for item in candidates]
|
||
candidate_embeddings = await self._embed_texts(
|
||
api_key=resolved_api_key,
|
||
api_base=resolved_api_base,
|
||
texts=candidate_texts,
|
||
model=model or self.model,
|
||
)
|
||
except Exception:
|
||
return candidates
|
||
|
||
if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates):
|
||
return candidates
|
||
|
||
query_vector = query_embedding[0]
|
||
scored: list[tuple[float, dict[str, str]]] = []
|
||
for candidate, vector in zip(candidates, candidate_embeddings, strict=False):
|
||
if not vector:
|
||
continue
|
||
scored.append((self._cosine_similarity(query_vector, vector), candidate))
|
||
|
||
scored.sort(key=lambda item: item[0], reverse=True)
|
||
return [item[1] for item in scored[:top_k]]
|
||
|
||
async def _embed_texts(
|
||
self,
|
||
*,
|
||
api_key: str,
|
||
api_base: str,
|
||
texts: list[str],
|
||
model: str,
|
||
) -> list[list[float]]:
|
||
"""调用 OpenAI-compatible embeddings 接口。
|
||
|
||
当前对齐的是你们实际在用的网关配置:
|
||
- `POST {api_base}/embeddings`
|
||
- `model=text-embedding-v4`
|
||
- `encoding_format=float`
|
||
"""
|
||
|
||
all_vectors: list[list[float]] = []
|
||
endpoint = self._normalize_embeddings_endpoint(api_base)
|
||
for start in range(0, len(texts), 10):
|
||
batch = texts[start:start + 10]
|
||
payload = await self._post_embeddings(
|
||
endpoint=endpoint,
|
||
api_key=api_key,
|
||
model=model,
|
||
texts=batch,
|
||
)
|
||
embeddings = payload.get("data") or []
|
||
embeddings = sorted(embeddings, key=lambda item: item.get("index", 0))
|
||
all_vectors.extend([list(item.get("embedding") or []) for item in embeddings])
|
||
return all_vectors
|
||
|
||
async def _post_embeddings(
|
||
self,
|
||
*,
|
||
endpoint: str,
|
||
api_key: str,
|
||
model: str,
|
||
texts: list[str],
|
||
) -> dict[str, Any]:
|
||
return await asyncio.to_thread(
|
||
self._post_embeddings_sync,
|
||
endpoint=endpoint,
|
||
api_key=api_key,
|
||
model=model,
|
||
texts=texts,
|
||
)
|
||
|
||
def _post_embeddings_sync(
|
||
self,
|
||
*,
|
||
endpoint: str,
|
||
api_key: str,
|
||
model: str,
|
||
texts: list[str],
|
||
) -> dict[str, Any]:
|
||
body = json.dumps(
|
||
{
|
||
"model": model,
|
||
"input": texts if len(texts) > 1 else texts[0],
|
||
"encoding_format": "float",
|
||
}
|
||
).encode("utf-8")
|
||
req = request.Request(
|
||
endpoint,
|
||
data=body,
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
method="POST",
|
||
)
|
||
with request.urlopen(req, timeout=self.timeout_seconds) as response:
|
||
return json.loads(response.read().decode("utf-8"))
|
||
|
||
@staticmethod
|
||
def _candidate_text(candidate: dict[str, str]) -> str:
|
||
name = (candidate.get("name") or "").strip()
|
||
description = (candidate.get("description") or "").strip()
|
||
return f"{name}\n{description}".strip()
|
||
|
||
@staticmethod
|
||
def _normalize_embeddings_endpoint(api_base: str) -> str:
|
||
base = api_base.rstrip("/")
|
||
if base.endswith("/embeddings"):
|
||
return base
|
||
if base.endswith("/v1"):
|
||
return f"{base}/embeddings"
|
||
return f"{base}/v1/embeddings"
|
||
|
||
@staticmethod
|
||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||
if not left or not right or len(left) != len(right):
|
||
return -1.0
|
||
dot = sum(a * b for a, b in zip(left, right, strict=False))
|
||
left_norm = math.sqrt(sum(a * a for a in left))
|
||
right_norm = math.sqrt(sum(b * b for b in right))
|
||
if left_norm == 0 or right_norm == 0:
|
||
return -1.0
|
||
return dot / (left_norm * right_norm)
|