import os from contextlib import asynccontextmanager from typing import List, Union import tiktoken import torch import uvicorn from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer load_dotenv() # 设置文本向量模型 EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', r'\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3') # 设置服务监听ip EMBEDDING_HOST = os.environ.get('EMBEDDING_HOST', '0.0.0.0') # 设置服务监听端口 EMBEDDING_PORT = int(os.environ.get('EMBEDDING_PORT', 8000)) # 设置服务进程数量 EMBEDDING_WORKERS = int(os.environ.get('EMBEDDING_WORKERS', 5)) # 模型全局变量 device = "cuda" if torch.cuda.is_available() else "cpu" embedding_model = SentenceTransformer(EMBEDDING_PATH, device=device) tokenizer = tiktoken.get_encoding('cl100k_base') @asynccontextmanager async def lifespan(app: FastAPI): yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str class EmbeddingResponse(BaseModel): data: list model: str object: str usage: dict # 提取文本向量接口 @app.post("/v1/embeddings", response_model=EmbeddingResponse) async def get_embeddings(request: EmbeddingRequest): input_texts = request.input if isinstance(request.input, list) else [request.input] # 使用批处理一次编码所有文本 embeddings = embedding_model.encode(input_texts, convert_to_list=True) def count_tokens(text: str) -> int: return len(tokenizer.encode(text)) response = { "data": [ { "object": "embedding", "embedding": embedding, "index": index } for index, embedding in enumerate(embeddings) ], "model": request.model, "object": "list", "usage": { "prompt_tokens": sum(count_tokens(text) for text in input_texts), "total_tokens": sum(count_tokens(text) for text in input_texts), }, } return response # 开启服务 def start_server(): uvicorn.run("main:app", host=EMBEDDING_HOST, port=EMBEDDING_PORT, workers=EMBEDDING_WORKERS) if __name__ == "__main__": start_server()