first commit

This commit is contained in:
vera
2026-01-20 17:37:43 +08:00
parent e8dd956fc2
commit 721c53fe87
6 changed files with 614 additions and 0 deletions

159
api_concurrent.py Normal file
View File

@ -0,0 +1,159 @@
import os
import torch
import tempfile
import soundfile as sf
import asyncio
from typing import Optional
from fastapi import FastAPI, Form, UploadFile, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
from concurrent.futures import ThreadPoolExecutor
import voxcpm
import uvicorn
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "/models/Voice/VoxCPM/VoxCPM1.5/"
# ========== 模型类 ==========
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}")
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "/models/Voice/VoxCPM/VoxCPM1.5/"
def _resolve_model_dir(self) -> str:
if os.path.isdir(self.default_local_model_dir):
return self.default_local_model_dir
return "models"
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is None:
print("🔄 Loading VoxCPM model...")
model_dir = self._resolve_model_dir()
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("✅ VoxCPM model loaded.")
return self.voxcpm_model
def tts_generate(
self,
text: str,
prompt_wav_path: Optional[str] = None,
prompt_text: Optional[str] = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> str:
model = self.get_or_load_voxcpm()
wav = model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_timesteps),
normalize=do_normalize,
denoise=denoise,
)
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(tmp_wav.name, wav, 44100)
torch.cuda.empty_cache()
return tmp_wav.name
# ========== FastAPI ==========
app = FastAPI(title="VoxCPM API", version="1.0.0")
demo = VoxCPMDemo()
executor = ThreadPoolExecutor(max_workers=2) # CPU线程池执行GPU任务
gpu_queue = asyncio.Queue() # GPU并发队列
MAX_GPU_CONCURRENT = 1 # 单GPU同时最多1个任务可以调整
# GPU队列消费者协程
async def gpu_worker():
while True:
future, func, args, kwargs = await gpu_queue.get()
loop = asyncio.get_running_loop()
try:
# 在线程池中执行GPU任务
result = await loop.run_in_executor(executor, func, *args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
gpu_queue.task_done()
# 启动GPU消费者协程
async def start_gpu_workers():
for _ in range(MAX_GPU_CONCURRENT):
asyncio.create_task(gpu_worker())
@app.on_event("startup")
async def startup_event():
await start_gpu_workers()
# 封装GPU队列任务
async def submit_to_gpu(func, *args, **kwargs):
loop = asyncio.get_running_loop()
future = loop.create_future()
await gpu_queue.put((future, func, args, kwargs))
return await future
# ---------- TTS API ----------
@app.post("/generate_tts")
async def generate_tts(
background_tasks: BackgroundTasks,
text: str = Form(...),
prompt_text: Optional[str] = Form(None),
cfg_value: float = Form(2.0),
inference_timesteps: int = Form(10),
do_normalize: bool = Form(True),
denoise: bool = Form(True),
prompt_wav: Optional[UploadFile] = None,
):
try:
prompt_path = None
if prompt_wav:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await prompt_wav.read())
prompt_path = tmp.name
background_tasks.add_task(os.remove, tmp.name)
# 提交到GPU队列
output_path = await submit_to_gpu(
demo.tts_generate,
text,
prompt_path,
prompt_text,
cfg_value,
inference_timesteps,
do_normalize,
denoise
)
# 后台删除生成的文件
background_tasks.add_task(os.remove, output_path)
return StreamingResponse(
open(output_path, "rb"),
media_type="audio/wav",
headers={"Content-Disposition": 'attachment; filename="output.wav"'}
)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.get("/")
async def root():
return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]}
if __name__ == "__main__":
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=8)