initialize project with FastAPI embedding service and environment configuration
This commit is contained in:
11
.env.example
Normal file
11
.env.example
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Embedding model path (supports local path or UNC network path)
|
||||||
|
EMBEDDING_PATH=\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3
|
||||||
|
|
||||||
|
# Service bind host
|
||||||
|
EMBEDDING_HOST=0.0.0.0
|
||||||
|
|
||||||
|
# Service bind port
|
||||||
|
EMBEDDING_PORT=8000
|
||||||
|
|
||||||
|
# Uvicorn worker count
|
||||||
|
EMBEDDING_WORKERS=5
|
||||||
71
README.md
71
README.md
@ -0,0 +1,71 @@
|
|||||||
|
# openai_format_embedding
|
||||||
|
|
||||||
|
一个兼容 OpenAI Embeddings 接口格式的本地向量服务,基本支持所有 huggingface 能找到的嵌入模型。
|
||||||
|
|
||||||
|
## 功能说明
|
||||||
|
|
||||||
|
- 提供 `POST /v1/embeddings` 接口
|
||||||
|
- 支持单条文本或文本数组输入
|
||||||
|
- 返回结构与 OpenAI Embeddings 响应格式一致
|
||||||
|
- 可通过环境变量配置模型路径、监听地址、端口和 worker 数
|
||||||
|
|
||||||
|
## 环境变量
|
||||||
|
|
||||||
|
可参考 [.env.example](.env.example)。
|
||||||
|
|
||||||
|
- `EMBEDDING_PATH`:模型目录路径,支持本地路径或 UNC 网络路径
|
||||||
|
- `EMBEDDING_HOST`:服务监听地址,默认 `0.0.0.0`
|
||||||
|
- `EMBEDDING_PORT`:服务监听端口,默认 `8000`
|
||||||
|
- `EMBEDDING_WORKERS`:Uvicorn worker 数,默认 `5`
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 启动服务
|
||||||
|
|
||||||
|
```
|
||||||
|
# 编辑 .env 文件,配置 EMBEDDING_PATH 等参数
|
||||||
|
cp .env.example .env
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 调用示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://127.0.0.1:8000/v1/embeddings" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "bge-m3",
|
||||||
|
"input": ["hello world", "你好,向量服务"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
返回示例(结构示意):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": [0.0123, -0.0456],
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "bge-m3",
|
||||||
|
"object": "list",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"total_tokens": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 说明
|
||||||
|
|
||||||
|
- `prompt_tokens` 当前实现按空格分词统计
|
||||||
|
- `total_tokens` 使用 `cl100k_base` 编码统计
|
||||||
|
|||||||
98
main.py
Normal file
98
main.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 设置文本向量模型
|
||||||
|
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', r'\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3')
|
||||||
|
# 设置服务监听ip
|
||||||
|
EMBEDDING_HOST = os.environ.get('EMBEDDING_HOST', '0.0.0.0')
|
||||||
|
# 设置服务监听端口
|
||||||
|
EMBEDDING_PORT = int(os.environ.get('EMBEDDING_PORT', 8000))
|
||||||
|
# 设置服务进程数量
|
||||||
|
EMBEDDING_WORKERS = int(os.environ.get('EMBEDDING_WORKERS', 5))
|
||||||
|
|
||||||
|
# 模型全局变量
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
embedding_model = SentenceTransformer(EMBEDDING_PATH, device=device)
|
||||||
|
tokenizer = tiktoken.get_encoding('cl100k_base')
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=False,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
data: list
|
||||||
|
model: str
|
||||||
|
object: str
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
# 提取文本向量接口
|
||||||
|
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||||
|
async def get_embeddings(request: EmbeddingRequest):
|
||||||
|
input_texts = request.input if isinstance(request.input, list) else [request.input]
|
||||||
|
|
||||||
|
# 使用批处理一次编码所有文本
|
||||||
|
embeddings = embedding_model.encode(input_texts, convert_to_list=True)
|
||||||
|
|
||||||
|
def count_tokens(text: str) -> int:
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": embedding,
|
||||||
|
"index": index
|
||||||
|
}
|
||||||
|
for index, embedding in enumerate(embeddings)
|
||||||
|
],
|
||||||
|
"model": request.model,
|
||||||
|
"object": "list",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": sum(count_tokens(text) for text in input_texts),
|
||||||
|
"total_tokens": sum(count_tokens(text) for text in input_texts),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# 开启服务
|
||||||
|
def start_server():
|
||||||
|
uvicorn.run("main:app", host=EMBEDDING_HOST, port=EMBEDDING_PORT, workers=EMBEDDING_WORKERS)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server()
|
||||||
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
sentence-transformers
|
||||||
|
tiktoken
|
||||||
|
torch
|
||||||
|
python-dotenv
|
||||||
Reference in New Issue
Block a user