This commit is contained in:
133
api.py
Normal file
133
api.py
Normal file
@ -0,0 +1,133 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
|
||||
# 导入两种模式需要的库
|
||||
from funasr import AutoModel
|
||||
from model import FunASRNano
|
||||
from tools.utils import load_audio
|
||||
|
||||
app = FastAPI(title="FunASR Dual-Mode API")
|
||||
|
||||
# --- 环境配置 ---
|
||||
device = (
|
||||
"cuda:0" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
MODEL_DIR = os.getenv("MODEL_DIR", "/models/Fun-ASR-Nano-2512")
|
||||
TEMP_DIR = "./temp_audio"
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
# --- 模型全局初始化 ---
|
||||
print(f"正在加载 AutoModel (Mode 1)...")
|
||||
model_auto = AutoModel(
|
||||
model=MODEL_DIR,
|
||||
trust_remote_code=True,
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
device=device,
|
||||
hub="ms"
|
||||
)
|
||||
|
||||
print(f"正在加载 Direct Model (Mode 2)...")
|
||||
model_direct, direct_kwargs = FunASRNano.from_pretrained(model=MODEL_DIR, device=device)
|
||||
tokenizer = direct_kwargs.get("tokenizer", None)
|
||||
model_direct.eval()
|
||||
|
||||
|
||||
# --- 接口 1: Using FunASR for Inference ---
|
||||
@app.post("/inference/funasr")
|
||||
async def inference_funasr(
|
||||
file: UploadFile = File(...),
|
||||
language: str = Form("中文"),
|
||||
itn: str = Form("true"),
|
||||
hotwords: str = Form("")
|
||||
):
|
||||
temp_path = save_temp_file(file)
|
||||
try:
|
||||
is_itn = True if itn.lower() in ["true", "1", "t"] else False
|
||||
clean_lang = language.strip().strip('"')
|
||||
clean_hw = hotwords.strip().strip('"')
|
||||
|
||||
# 核心修复点:不传 cache,且处理 hotwords
|
||||
res = model_auto.generate(
|
||||
input=temp_path,
|
||||
batch_size=1,
|
||||
hotwords=clean_hw if clean_hw else None,
|
||||
language=clean_lang,
|
||||
itn=is_itn,
|
||||
)
|
||||
|
||||
return {"status": "success", "text": res[0]["text"]}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
remove_temp_file(temp_path)
|
||||
|
||||
|
||||
# --- 接口 2: Direct Inference ---
|
||||
@app.post("/inference/direct")
|
||||
async def inference_direct(
|
||||
file: UploadFile = File(...),
|
||||
chunk_mode: bool = Form(False) # 是否开启你脚本2中的分片逻辑
|
||||
):
|
||||
"""直接调用 model.py 中的 FunASRNano 进行推理"""
|
||||
temp_path = save_temp_file(file)
|
||||
try:
|
||||
if not chunk_mode:
|
||||
# 模式 A: 标准直接推理
|
||||
res = model_direct.inference(data_in=[temp_path], **direct_kwargs)
|
||||
text = res[0][0]
|
||||
else:
|
||||
# 模式 B: 模拟脚本 2 中的分片循环逻辑
|
||||
duration = sf.info(temp_path).duration
|
||||
chunk_size = 0.72
|
||||
cum_durations = np.arange(chunk_size, duration + chunk_size, chunk_size)
|
||||
prev_text = ""
|
||||
|
||||
for idx, cum_duration in enumerate(cum_durations):
|
||||
audio, rate = load_audio(temp_path, 16000, duration=round(cum_duration, 3))
|
||||
# 注意:这里调用的是模型内部的推理逻辑
|
||||
step_res = model_direct.inference(
|
||||
[torch.tensor(audio).to(device)],
|
||||
prev_text=prev_text,
|
||||
**direct_kwargs
|
||||
)
|
||||
prev_text = step_res[0][0]["text"]
|
||||
|
||||
# 脚本 2 中的特殊解码逻辑
|
||||
if idx != len(cum_durations) - 1 and tokenizer:
|
||||
prev_text = tokenizer.decode(tokenizer.encode(prev_text)[:-5]).replace("", "")
|
||||
|
||||
text = prev_text
|
||||
|
||||
return {"status": "success", "mode": "direct", "text": text}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
remove_temp_file(temp_path)
|
||||
|
||||
|
||||
# --- 工具函数 ---
|
||||
def save_temp_file(upload_file):
|
||||
ext = os.path.splitext(upload_file.filename)[1]
|
||||
path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}{ext}")
|
||||
with open(path, "wb") as buffer:
|
||||
shutil.copyfileobj(upload_file.file, buffer)
|
||||
return path
|
||||
|
||||
def remove_temp_file(path):
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
Reference in New Issue
Block a user