"""Embedding-based skill candidate retrieval. 当前实现使用 OpenAI-compatible `/v1/embeddings` 接口调用 阿里云百炼 `text-embedding-v4` 做最小语义召回: 1. 复用当前 provider 的 `api_key/api_base` 2. 先用 embedding 相似度召回一小批候选 3. 再交给上层 LLM selector 做最终技能选择 """ from __future__ import annotations import asyncio import math import os import json from urllib import request from typing import Any class SkillEmbeddingRetriever: """用 OpenAI-compatible embeddings API 为 skill 选择做候选召回。""" def __init__( self, *, api_key_env: str = "OPENAI_API_KEY", api_base_env: str = "OPENAI_API_BASE", model: str = "text-embedding-v4", timeout_seconds: float = 20.0, ) -> None: self.api_key_env = api_key_env self.api_base_env = api_base_env self.model = model self.timeout_seconds = timeout_seconds async def retrieve( self, *, query: str, candidates: list[dict[str, str]], top_k: int = 12, api_key: str | None = None, api_base: str | None = None, model: str | None = None, ) -> list[dict[str, str]]: """按 embedding 相似度召回 top-k 候选。 如果没有可用的 API Key / base URL,或者 embedding 调用失败, 当前阶段先退回到“全部候选交给 LLM selector”。 """ if not candidates: return [] resolved_api_key = api_key or os.getenv(self.api_key_env) resolved_api_base = api_base or os.getenv(self.api_base_env) if not resolved_api_key or not resolved_api_base: return candidates try: query_embedding = await self._embed_texts( api_key=resolved_api_key, api_base=resolved_api_base, texts=[query], model=model or self.model, ) candidate_texts = [self._candidate_text(item) for item in candidates] candidate_embeddings = await self._embed_texts( api_key=resolved_api_key, api_base=resolved_api_base, texts=candidate_texts, model=model or self.model, ) except Exception: return candidates if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates): return candidates query_vector = query_embedding[0] scored: list[tuple[float, dict[str, str]]] = [] for candidate, vector in zip(candidates, candidate_embeddings, strict=False): if not vector: continue scored.append((self._cosine_similarity(query_vector, vector), candidate)) scored.sort(key=lambda item: item[0], reverse=True) return [item[1] for item in scored[:top_k]] async def _embed_texts( self, *, api_key: str, api_base: str, texts: list[str], model: str, ) -> list[list[float]]: """调用 OpenAI-compatible embeddings 接口。 当前对齐的是你们实际在用的网关配置: - `POST {api_base}/embeddings` - `model=text-embedding-v4` - `encoding_format=float` """ all_vectors: list[list[float]] = [] endpoint = self._normalize_embeddings_endpoint(api_base) for start in range(0, len(texts), 10): batch = texts[start:start + 10] payload = await self._post_embeddings( endpoint=endpoint, api_key=api_key, model=model, texts=batch, ) embeddings = payload.get("data") or [] embeddings = sorted(embeddings, key=lambda item: item.get("index", 0)) all_vectors.extend([list(item.get("embedding") or []) for item in embeddings]) return all_vectors async def _post_embeddings( self, *, endpoint: str, api_key: str, model: str, texts: list[str], ) -> dict[str, Any]: return await asyncio.to_thread( self._post_embeddings_sync, endpoint=endpoint, api_key=api_key, model=model, texts=texts, ) def _post_embeddings_sync( self, *, endpoint: str, api_key: str, model: str, texts: list[str], ) -> dict[str, Any]: body = json.dumps( { "model": model, "input": texts if len(texts) > 1 else texts[0], "encoding_format": "float", } ).encode("utf-8") req = request.Request( endpoint, data=body, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, method="POST", ) with request.urlopen(req, timeout=self.timeout_seconds) as response: return json.loads(response.read().decode("utf-8")) @staticmethod def _candidate_text(candidate: dict[str, str]) -> str: name = (candidate.get("name") or "").strip() description = (candidate.get("description") or "").strip() return f"{name}\n{description}".strip() @staticmethod def _normalize_embeddings_endpoint(api_base: str) -> str: base = api_base.rstrip("/") if base.endswith("/embeddings"): return base if base.endswith("/v1"): return f"{base}/embeddings" return f"{base}/v1/embeddings" @staticmethod def _cosine_similarity(left: list[float], right: list[float]) -> float: if not left or not right or len(left) != len(right): return -1.0 dot = sum(a * b for a, b in zip(left, right, strict=False)) left_norm = math.sqrt(sum(a * a for a in left)) right_norm = math.sqrt(sum(b * b for b in right)) if left_norm == 0 or right_norm == 0: return -1.0 return dot / (left_norm * right_norm)