"""Shared embedding-based semantic retrieval utilities.""" from __future__ import annotations import asyncio import json import math import os from typing import Any from urllib import request class EmbeddingRetriever: """Use an OpenAI-compatible embeddings API to rank lightweight candidates.""" def __init__( self, *, api_key_env: str = "OPENAI_API_KEY", api_base_env: str = "OPENAI_API_BASE", model: str = "text-embedding-v4", timeout_seconds: float = 20.0, ) -> None: self.api_key_env = api_key_env self.api_base_env = api_base_env self.model = model self.timeout_seconds = timeout_seconds async def retrieve( self, *, query: str, candidates: list[dict[str, str]], top_k: int, api_key: str | None = None, api_base: str | None = None, model: str | None = None, extra_headers: dict[str, str] | None = None, timeout_seconds: float | None = None, fallback_top_k: int | None = None, ) -> list[dict[str, str]]: """Return candidates ordered by embedding similarity. If embedding config is missing or the request fails, return the original candidate order. This keeps retrieval non-blocking for the main run. """ if not candidates or top_k <= 0: return [] fallback = self._fallback_candidates(candidates, fallback_top_k=fallback_top_k) resolved_api_key = api_key or os.getenv(self.api_key_env) resolved_api_base = api_base or os.getenv(self.api_base_env) if not resolved_api_key or not resolved_api_base: return fallback try: query_embedding = await self._embed_texts( api_key=resolved_api_key, api_base=resolved_api_base, texts=[query], model=model or self.model, extra_headers=extra_headers, timeout_seconds=timeout_seconds, ) candidate_embeddings = await self._embed_texts( api_key=resolved_api_key, api_base=resolved_api_base, texts=[self._candidate_text(item) for item in candidates], model=model or self.model, extra_headers=extra_headers, timeout_seconds=timeout_seconds, ) except Exception: return fallback if not query_embedding or not query_embedding[0] or len(candidate_embeddings) != len(candidates): return fallback query_vector = query_embedding[0] scored: list[tuple[float, dict[str, str]]] = [] for candidate, vector in zip(candidates, candidate_embeddings, strict=False): if vector: scored.append((self._cosine_similarity(query_vector, vector), candidate)) scored.sort(key=lambda item: item[0], reverse=True) return [item[1] for item in scored[:top_k]] async def _embed_texts( self, *, api_key: str, api_base: str, texts: list[str], model: str, extra_headers: dict[str, str] | None = None, timeout_seconds: float | None = None, ) -> list[list[float]]: all_vectors: list[list[float]] = [] endpoint = self._normalize_embeddings_endpoint(api_base) for start in range(0, len(texts), 10): batch = texts[start:start + 10] payload = await self._post_embeddings( endpoint=endpoint, api_key=api_key, model=model, texts=batch, extra_headers=extra_headers, timeout_seconds=timeout_seconds, ) embeddings = payload.get("data") or [] embeddings = sorted(embeddings, key=lambda item: item.get("index", 0)) all_vectors.extend([list(item.get("embedding") or []) for item in embeddings]) return all_vectors async def _post_embeddings( self, *, endpoint: str, api_key: str, model: str, texts: list[str], extra_headers: dict[str, str] | None = None, timeout_seconds: float | None = None, ) -> dict[str, Any]: return await asyncio.to_thread( self._post_embeddings_sync, endpoint=endpoint, api_key=api_key, model=model, texts=texts, extra_headers=extra_headers, timeout_seconds=timeout_seconds, ) def _post_embeddings_sync( self, *, endpoint: str, api_key: str, model: str, texts: list[str], extra_headers: dict[str, str] | None = None, timeout_seconds: float | None = None, ) -> dict[str, Any]: body = json.dumps( { "model": model, "input": texts if len(texts) > 1 else texts[0], "encoding_format": "float", } ).encode("utf-8") req = request.Request( endpoint, data=body, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", **(extra_headers or {}), }, method="POST", ) with request.urlopen(req, timeout=timeout_seconds or self.timeout_seconds) as response: return json.loads(response.read().decode("utf-8")) @staticmethod def _fallback_candidates( candidates: list[dict[str, str]], *, fallback_top_k: int | None, ) -> list[dict[str, str]]: if fallback_top_k is None: return list(candidates) if fallback_top_k <= 0: return [] return candidates[:fallback_top_k] @staticmethod def _candidate_text(candidate: dict[str, str]) -> str: parts = [ (candidate.get("name") or "").strip(), (candidate.get("description") or "").strip(), (candidate.get("input_schema") or "").strip(), ] return "\n".join(part for part in parts if part) @staticmethod def _normalize_embeddings_endpoint(api_base: str) -> str: base = api_base.rstrip("/") if base.endswith("/embeddings"): return base if base.endswith("/v1"): return f"{base}/embeddings" return f"{base}/v1/embeddings" @staticmethod def _cosine_similarity(left: list[float], right: list[float]) -> float: if not left or not right or len(left) != len(right): return -1.0 dot = sum(a * b for a, b in zip(left, right, strict=False)) left_norm = math.sqrt(sum(a * a for a in left)) right_norm = math.sqrt(sum(b * b for b in right)) if left_norm == 0 or right_norm == 0: return -1.0 return dot / (left_norm * right_norm)