diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..485fd99 --- /dev/null +++ b/.env.example @@ -0,0 +1,11 @@ +# Embedding model path (supports local path or UNC network path) +EMBEDDING_PATH=\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3 + +# Service bind host +EMBEDDING_HOST=0.0.0.0 + +# Service bind port +EMBEDDING_PORT=8000 + +# Uvicorn worker count +EMBEDDING_WORKERS=5 \ No newline at end of file diff --git a/README.md b/README.md index e69de29..69042b8 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,71 @@ +# openai_format_embedding + +一个兼容 OpenAI Embeddings 接口格式的本地向量服务,基本支持所有 huggingface 能找到的嵌入模型。 + +## 功能说明 + +- 提供 `POST /v1/embeddings` 接口 +- 支持单条文本或文本数组输入 +- 返回结构与 OpenAI Embeddings 响应格式一致 +- 可通过环境变量配置模型路径、监听地址、端口和 worker 数 + +## 环境变量 + +可参考 [.env.example](.env.example)。 + +- `EMBEDDING_PATH`:模型目录路径,支持本地路径或 UNC 网络路径 +- `EMBEDDING_HOST`:服务监听地址,默认 `0.0.0.0` +- `EMBEDDING_PORT`:服务监听端口,默认 `8000` +- `EMBEDDING_WORKERS`:Uvicorn worker 数,默认 `5` + +## 安装依赖 + +```bash +pip install -r requirements.txt +``` + +## 启动服务 + +``` +# 编辑 .env 文件,配置 EMBEDDING_PATH 等参数 +cp .env.example .env + +# 启动服务 +python main.py +``` + +## 调用示例 + +```bash +curl -X POST "http://127.0.0.1:8000/v1/embeddings" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "bge-m3", + "input": ["hello world", "你好,向量服务"] + }' +``` + +返回示例(结构示意): + +```json +{ + "data": [ + { + "object": "embedding", + "embedding": [0.0123, -0.0456], + "index": 0 + } + ], + "model": "bge-m3", + "object": "list", + "usage": { + "prompt_tokens": 2, + "total_tokens": 4 + } +} +``` + +## 说明 + +- `prompt_tokens` 当前实现按空格分词统计 +- `total_tokens` 使用 `cl100k_base` 编码统计 diff --git a/main.py b/main.py new file mode 100644 index 0000000..4737907 --- /dev/null +++ b/main.py @@ -0,0 +1,98 @@ +import os +from contextlib import asynccontextmanager +from typing import List, Union + +import tiktoken +import torch +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from sentence_transformers import SentenceTransformer + +load_dotenv() + +# 设置文本向量模型 +EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', r'\\10.6.80.11\Dataset\PVStore\lab-data-model-pvc-c0beeab1-6dd5-4c6a-bd2c-6ce9e114c25e\Weight\BAAI\bge-m3') +# 设置服务监听ip +EMBEDDING_HOST = os.environ.get('EMBEDDING_HOST', '0.0.0.0') +# 设置服务监听端口 +EMBEDDING_PORT = int(os.environ.get('EMBEDDING_PORT', 8000)) +# 设置服务进程数量 +EMBEDDING_WORKERS = int(os.environ.get('EMBEDDING_WORKERS', 5)) + +# 模型全局变量 +device = "cuda" if torch.cuda.is_available() else "cpu" +embedding_model = SentenceTransformer(EMBEDDING_PATH, device=device) +tokenizer = tiktoken.get_encoding('cl100k_base') + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], +) + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str + + +class EmbeddingResponse(BaseModel): + data: list + model: str + object: str + usage: dict + + +# 提取文本向量接口 +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +async def get_embeddings(request: EmbeddingRequest): + input_texts = request.input if isinstance(request.input, list) else [request.input] + + # 使用批处理一次编码所有文本 + embeddings = embedding_model.encode(input_texts, convert_to_list=True) + + def count_tokens(text: str) -> int: + return len(tokenizer.encode(text)) + + response = { + "data": [ + { + "object": "embedding", + "embedding": embedding, + "index": index + } + for index, embedding in enumerate(embeddings) + ], + "model": request.model, + "object": "list", + "usage": { + "prompt_tokens": sum(count_tokens(text) for text in input_texts), + "total_tokens": sum(count_tokens(text) for text in input_texts), + }, + } + return response + + +# 开启服务 +def start_server(): + uvicorn.run("main:app", host=EMBEDDING_HOST, port=EMBEDDING_PORT, workers=EMBEDDING_WORKERS) + + +if __name__ == "__main__": + start_server() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b0e0714 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +fastapi +uvicorn +sentence-transformers +tiktoken +torch +python-dotenv