initialize project with FastAPI embedding service and environment configuration
This commit is contained in:
98
main.py
Normal file
98
main.py
Normal file
@ -0,0 +1,98 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Union
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 设置文本向量模型
|
||||
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', r'\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3')
|
||||
# 设置服务监听ip
|
||||
EMBEDDING_HOST = os.environ.get('EMBEDDING_HOST', '0.0.0.0')
|
||||
# 设置服务监听端口
|
||||
EMBEDDING_PORT = int(os.environ.get('EMBEDDING_PORT', 8000))
|
||||
# 设置服务进程数量
|
||||
EMBEDDING_WORKERS = int(os.environ.get('EMBEDDING_WORKERS', 5))
|
||||
|
||||
# 模型全局变量
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
embedding_model = SentenceTransformer(EMBEDDING_PATH, device=device)
|
||||
tokenizer = tiktoken.get_encoding('cl100k_base')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=False,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: list
|
||||
model: str
|
||||
object: str
|
||||
usage: dict
|
||||
|
||||
|
||||
# 提取文本向量接口
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||
async def get_embeddings(request: EmbeddingRequest):
|
||||
input_texts = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
# 使用批处理一次编码所有文本
|
||||
embeddings = embedding_model.encode(input_texts, convert_to_list=True)
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
response = {
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": index
|
||||
}
|
||||
for index, embedding in enumerate(embeddings)
|
||||
],
|
||||
"model": request.model,
|
||||
"object": "list",
|
||||
"usage": {
|
||||
"prompt_tokens": sum(count_tokens(text) for text in input_texts),
|
||||
"total_tokens": sum(count_tokens(text) for text in input_texts),
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
# 开启服务
|
||||
def start_server():
|
||||
uvicorn.run("main:app", host=EMBEDDING_HOST, port=EMBEDDING_PORT, workers=EMBEDDING_WORKERS)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server()
|
||||
Reference in New Issue
Block a user