import os import shutil import uuid import torch import numpy as np import soundfile as sf from fastapi import FastAPI, UploadFile, File, Form, HTTPException # 导入两种模式需要的库 from funasr import AutoModel from model import FunASRNano from tools.utils import load_audio app = FastAPI(title="FunASR Dual-Mode API") # --- 环境配置 --- device = ( "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) MODEL_DIR = os.getenv("MODEL_DIR", "/models/Fun-ASR-Nano-2512") TEMP_DIR = "./temp_audio" os.makedirs(TEMP_DIR, exist_ok=True) # --- 模型全局初始化 --- print(f"正在加载 AutoModel (Mode 1)...") model_auto = AutoModel( model=MODEL_DIR, trust_remote_code=True, vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 30000}, device=device, hub="ms" ) print(f"正在加载 Direct Model (Mode 2)...") model_direct, direct_kwargs = FunASRNano.from_pretrained(model=MODEL_DIR, device=device) tokenizer = direct_kwargs.get("tokenizer", None) model_direct.eval() # --- 接口 1: Using FunASR for Inference --- @app.post("/inference/funasr") async def inference_funasr( file: UploadFile = File(...), language: str = Form("中文"), itn: str = Form("true"), hotwords: str = Form("") ): temp_path = save_temp_file(file) try: is_itn = True if itn.lower() in ["true", "1", "t"] else False clean_lang = language.strip().strip('"') clean_hw = hotwords.strip().strip('"') # 核心修复点:不传 cache,且处理 hotwords res = model_auto.generate( input=temp_path, batch_size=1, hotwords=clean_hw if clean_hw else None, language=clean_lang, itn=is_itn, ) return {"status": "success", "text": res[0]["text"]} except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) finally: remove_temp_file(temp_path) # --- 接口 2: Direct Inference --- @app.post("/inference/direct") async def inference_direct( file: UploadFile = File(...), chunk_mode: bool = Form(False) # 是否开启你脚本2中的分片逻辑 ): """直接调用 model.py 中的 FunASRNano 进行推理""" temp_path = save_temp_file(file) try: if not chunk_mode: # 模式 A: 标准直接推理 res = model_direct.inference(data_in=[temp_path], **direct_kwargs) text = res[0][0] else: # 模式 B: 模拟脚本 2 中的分片循环逻辑 duration = sf.info(temp_path).duration chunk_size = 0.72 cum_durations = np.arange(chunk_size, duration + chunk_size, chunk_size) prev_text = "" for idx, cum_duration in enumerate(cum_durations): audio, rate = load_audio(temp_path, 16000, duration=round(cum_duration, 3)) # 注意:这里调用的是模型内部的推理逻辑 step_res = model_direct.inference( [torch.tensor(audio).to(device)], prev_text=prev_text, **direct_kwargs ) prev_text = step_res[0][0]["text"] # 脚本 2 中的特殊解码逻辑 if idx != len(cum_durations) - 1 and tokenizer: prev_text = tokenizer.decode(tokenizer.encode(prev_text)[:-5]).replace("", "") text = prev_text return {"status": "success", "mode": "direct", "text": text} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: remove_temp_file(temp_path) # --- 工具函数 --- def save_temp_file(upload_file): ext = os.path.splitext(upload_file.filename)[1] path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}{ext}") with open(path, "wb") as buffer: shutil.copyfileobj(upload_file.file, buffer) return path def remove_temp_file(path): if os.path.exists(path): os.remove(path) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5000)