Files
Fun-ASR/api.py
vera 7879751126
Some checks failed
Build container / build-docker (push) Failing after 28s
feat: api
2026-02-10 17:56:37 +08:00

134 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import uuid
import torch
import numpy as np
import soundfile as sf
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
# 导入两种模式需要的库
from funasr import AutoModel
from model import FunASRNano
from tools.utils import load_audio
app = FastAPI(title="FunASR Dual-Mode API")
# --- 环境配置 ---
device = (
"cuda:0" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
MODEL_DIR = os.getenv("MODEL_DIR", "/models/Fun-ASR-Nano-2512")
TEMP_DIR = "./temp_audio"
os.makedirs(TEMP_DIR, exist_ok=True)
# --- 模型全局初始化 ---
print(f"正在加载 AutoModel (Mode 1)...")
model_auto = AutoModel(
model=MODEL_DIR,
trust_remote_code=True,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device=device,
hub="ms"
)
print(f"正在加载 Direct Model (Mode 2)...")
model_direct, direct_kwargs = FunASRNano.from_pretrained(model=MODEL_DIR, device=device)
tokenizer = direct_kwargs.get("tokenizer", None)
model_direct.eval()
# --- 接口 1: Using FunASR for Inference ---
@app.post("/inference/funasr")
async def inference_funasr(
file: UploadFile = File(...),
language: str = Form("中文"),
itn: str = Form("true"),
hotwords: str = Form("")
):
temp_path = save_temp_file(file)
try:
is_itn = True if itn.lower() in ["true", "1", "t"] else False
clean_lang = language.strip().strip('"')
clean_hw = hotwords.strip().strip('"')
# 核心修复点:不传 cache且处理 hotwords
res = model_auto.generate(
input=temp_path,
batch_size=1,
hotwords=clean_hw if clean_hw else None,
language=clean_lang,
itn=is_itn,
)
return {"status": "success", "text": res[0]["text"]}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
finally:
remove_temp_file(temp_path)
# --- 接口 2: Direct Inference ---
@app.post("/inference/direct")
async def inference_direct(
file: UploadFile = File(...),
chunk_mode: bool = Form(False) # 是否开启你脚本2中的分片逻辑
):
"""直接调用 model.py 中的 FunASRNano 进行推理"""
temp_path = save_temp_file(file)
try:
if not chunk_mode:
# 模式 A: 标准直接推理
res = model_direct.inference(data_in=[temp_path], **direct_kwargs)
text = res[0][0]
else:
# 模式 B: 模拟脚本 2 中的分片循环逻辑
duration = sf.info(temp_path).duration
chunk_size = 0.72
cum_durations = np.arange(chunk_size, duration + chunk_size, chunk_size)
prev_text = ""
for idx, cum_duration in enumerate(cum_durations):
audio, rate = load_audio(temp_path, 16000, duration=round(cum_duration, 3))
# 注意:这里调用的是模型内部的推理逻辑
step_res = model_direct.inference(
[torch.tensor(audio).to(device)],
prev_text=prev_text,
**direct_kwargs
)
prev_text = step_res[0][0]["text"]
# 脚本 2 中的特殊解码逻辑
if idx != len(cum_durations) - 1 and tokenizer:
prev_text = tokenizer.decode(tokenizer.encode(prev_text)[:-5]).replace("", "")
text = prev_text
return {"status": "success", "mode": "direct", "text": text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
remove_temp_file(temp_path)
# --- 工具函数 ---
def save_temp_file(upload_file):
ext = os.path.splitext(upload_file.filename)[1]
path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}{ext}")
with open(path, "wb") as buffer:
shutil.copyfileobj(upload_file.file, buffer)
return path
def remove_temp_file(path):
if os.path.exists(path):
os.remove(path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)