# coding=utf-8 """ Unified Qwen3-ASR API Server ============================ 模型只加载一次,同时提供: - POST /asr/transcribe (非流式,整段音频转写) - WS /asr/stream (流式,实时增量转写) 启动示例: uv run examples/api_unified_fastapi.py --asr-model-path Qwen/Qwen3-ASR-1.7B 非流式调用示例(Python): import requests with open("audio.wav", "rb") as f: resp = requests.post("http://localhost:8000/asr/transcribe", files={"file": ("audio.wav", f, "audio/wav")}, data={"context": "", "language": ""}) print(resp.json()) """ import argparse import io import json import logging import uuid from typing import Optional import numpy as np import soundfile as sf import uvicorn from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from qwen_asr import Qwen3ASRModel logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) app = FastAPI(title="Qwen3-ASR Unified API") # ── 全局单例 ────────────────────────────────────────────────────────────────── asr_model: Optional[Qwen3ASRModel] = None server_args = None @app.on_event("startup") async def startup_event(): global asr_model logger.info(f"Loading ASR model from {server_args.asr_model_path} ...") asr_model = Qwen3ASRModel.LLM( model=server_args.asr_model_path, gpu_memory_utilization=server_args.gpu_memory_utilization, max_new_tokens=server_args.max_new_tokens, ) logger.info("Model loaded successfully.") # ── 非流式端点(HTTP POST multipart)───────────────────────────────────────── @app.post("/asr/transcribe") async def transcribe_endpoint( file: UploadFile = File(..., description="音频文件(wav/mp3/flac 等 soundfile 支持的格式)"), context: str = Form(default="", description="可选上下文"), language: str = Form(default="", description="可选强制语言,如 Chinese / English"), ): """ 非流式整段转写。 以 multipart/form-data 上传音频文件,返回最终转写文本。 curl 示例: curl -X POST http://localhost:8000/asr/transcribe \\ -F "file=@audio.wav" -F "context=" -F "language=" """ try: raw = await file.read() with io.BytesIO(raw) as buf: wav, sr = sf.read(buf, dtype="float32", always_2d=False) wav = np.asarray(wav, dtype=np.float32) results = asr_model.transcribe( audio=(wav, sr), context=context, language=language.strip() or None, ) r = results[0] return {"language": r.language, "text": r.text} except Exception as e: logger.error(f"Transcribe error: {e}", exc_info=True) return JSONResponse(status_code=500, content={"error": str(e)}) # ── 流式端点(WebSocket)────────────────────────────────────────────────────── @app.websocket("/asr/stream") async def websocket_endpoint(websocket: WebSocket): """ 流式增量转写。 协议: 客户端 → bytes : float32 PCM 16kHz 音频块 客户端 → text : {"command": "finish"} 结束会话 服务端 → text : {"session_id": ..., "language": ..., "text": ..., "is_final": bool} """ await websocket.accept() session_id = uuid.uuid4().hex logger.info(f"Stream session started: {session_id}") state = asr_model.init_streaming_state( unfixed_chunk_num=server_args.unfixed_chunk_num, unfixed_token_num=server_args.unfixed_token_num, chunk_size_sec=server_args.chunk_size_sec, ) try: while True: message = await websocket.receive() if "bytes" in message: raw = message["bytes"] if len(raw) % 4 != 0: await websocket.send_json({"error": "Data length must be multiple of 4 bytes (float32)"}) continue wav = np.frombuffer(raw, dtype=np.float32) asr_model.streaming_transcribe(wav, state) await websocket.send_json({ "session_id": session_id, "language": state.language or "", "text": state.text or "", "is_final": False, }) elif "text" in message: try: cmd = json.loads(message["text"]) if cmd.get("command") == "finish": logger.info(f"Finish command received: {session_id}") break except json.JSONDecodeError: logger.warning(f"Invalid JSON: {message['text']}") except WebSocketDisconnect: logger.info(f"Client disconnected: {session_id}") except Exception as e: logger.error(f"Error in session {session_id}: {e}", exc_info=True) try: await websocket.send_json({"error": str(e)}) except Exception: pass finally: try: asr_model.finish_streaming_transcribe(state) await websocket.send_json({ "session_id": session_id, "language": state.language or "", "text": state.text or "", "is_final": True, }) logger.info(f"Final result sent: {session_id}") except Exception as e: logger.error(f"Error finishing session {session_id}: {e}") try: await websocket.close() except Exception: pass logger.info(f"Session closed: {session_id}") # ── CLI ─────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="Qwen3-ASR Unified API (streaming + non-streaming)") p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path") p.add_argument("--host", default="0.0.0.0") p.add_argument("--port", type=int, default=8000) p.add_argument("--gpu-memory-utilization", type=float, default=0.8) p.add_argument("--max-new-tokens", type=int, default=32, help="Max new tokens per call (streaming). Use larger value for non-streaming.") p.add_argument("--unfixed-chunk-num", type=int, default=4) p.add_argument("--unfixed-token-num", type=int, default=5) p.add_argument("--chunk-size-sec", type=float, default=1.0) return p.parse_args() if __name__ == "__main__": server_args = parse_args() uvicorn.run(app, host=server_args.host, port=server_args.port)