Files
Qwen3-ASR/examples/api_unified_fastapi.py
vera 7231ed2354
All checks were successful
Build container / build-docker (push) Successful in 24m10s
fix: uv dependency
2026-04-23 10:07:20 +08:00

189 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
Unified Qwen3-ASR API Server
============================
模型只加载一次,同时提供:
- POST /asr/transcribe (非流式,整段音频转写)
- WS /asr/stream (流式,实时增量转写)
启动示例:
uv run examples/api_unified_fastapi.py --asr-model-path Qwen/Qwen3-ASR-1.7B
非流式调用示例Python
import requests
with open("audio.wav", "rb") as f:
resp = requests.post("http://localhost:8000/asr/transcribe",
files={"file": ("audio.wav", f, "audio/wav")},
data={"context": "", "language": ""})
print(resp.json())
"""
import argparse
import io
import json
import logging
import uuid
from typing import Optional
import numpy as np
import soundfile as sf
import uvicorn
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from qwen_asr import Qwen3ASRModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Qwen3-ASR Unified API")
# ── 全局单例 ──────────────────────────────────────────────────────────────────
asr_model: Optional[Qwen3ASRModel] = None
server_args = None
@app.on_event("startup")
async def startup_event():
global asr_model
logger.info(f"Loading ASR model from {server_args.asr_model_path} ...")
asr_model = Qwen3ASRModel.LLM(
model=server_args.asr_model_path,
gpu_memory_utilization=server_args.gpu_memory_utilization,
max_new_tokens=server_args.max_new_tokens,
)
logger.info("Model loaded successfully.")
# ── 非流式端点HTTP POST multipart─────────────────────────────────────────
@app.post("/asr/transcribe")
async def transcribe_endpoint(
file: UploadFile = File(..., description="音频文件wav/mp3/flac 等 soundfile 支持的格式)"),
context: str = Form(default="", description="可选上下文"),
language: str = Form(default="", description="可选强制语言,如 Chinese / English"),
):
"""
非流式整段转写。
以 multipart/form-data 上传音频文件,返回最终转写文本。
curl 示例:
curl -X POST http://localhost:8000/asr/transcribe \\
-F "file=@audio.wav" -F "context=" -F "language="
"""
try:
raw = await file.read()
with io.BytesIO(raw) as buf:
wav, sr = sf.read(buf, dtype="float32", always_2d=False)
wav = np.asarray(wav, dtype=np.float32)
results = asr_model.transcribe(
audio=(wav, sr),
context=context,
language=language.strip() or None,
)
r = results[0]
return {"language": r.language, "text": r.text}
except Exception as e:
logger.error(f"Transcribe error: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
# ── 流式端点WebSocket──────────────────────────────────────────────────────
@app.websocket("/asr/stream")
async def websocket_endpoint(websocket: WebSocket):
"""
流式增量转写。
协议:
客户端 → bytes : float32 PCM 16kHz 音频块
客户端 → text : {"command": "finish"} 结束会话
服务端 → text : {"session_id": ..., "language": ..., "text": ..., "is_final": bool}
"""
await websocket.accept()
session_id = uuid.uuid4().hex
logger.info(f"Stream session started: {session_id}")
state = asr_model.init_streaming_state(
unfixed_chunk_num=server_args.unfixed_chunk_num,
unfixed_token_num=server_args.unfixed_token_num,
chunk_size_sec=server_args.chunk_size_sec,
)
try:
while True:
message = await websocket.receive()
if "bytes" in message:
raw = message["bytes"]
if len(raw) % 4 != 0:
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (float32)"})
continue
wav = np.frombuffer(raw, dtype=np.float32)
asr_model.streaming_transcribe(wav, state)
await websocket.send_json({
"session_id": session_id,
"language": state.language or "",
"text": state.text or "",
"is_final": False,
})
elif "text" in message:
try:
cmd = json.loads(message["text"])
if cmd.get("command") == "finish":
logger.info(f"Finish command received: {session_id}")
break
except json.JSONDecodeError:
logger.warning(f"Invalid JSON: {message['text']}")
except WebSocketDisconnect:
logger.info(f"Client disconnected: {session_id}")
except Exception as e:
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
try:
await websocket.send_json({"error": str(e)})
except Exception:
pass
finally:
try:
asr_model.finish_streaming_transcribe(state)
await websocket.send_json({
"session_id": session_id,
"language": state.language or "",
"text": state.text or "",
"is_final": True,
})
logger.info(f"Final result sent: {session_id}")
except Exception as e:
logger.error(f"Error finishing session {session_id}: {e}")
try:
await websocket.close()
except Exception:
pass
logger.info(f"Session closed: {session_id}")
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description="Qwen3-ASR Unified API (streaming + non-streaming)")
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=5000)
p.add_argument("--gpu-memory-utilization", type=float, default=0.8)
p.add_argument("--max-new-tokens", type=int, default=32,
help="Max new tokens per call (streaming). Use larger value for non-streaming.")
p.add_argument("--unfixed-chunk-num", type=int, default=4)
p.add_argument("--unfixed-token-num", type=int, default=5)
p.add_argument("--chunk-size-sec", type=float, default=1.0)
return p.parse_args()
if __name__ == "__main__":
server_args = parse_args()
uvicorn.run(app, host=server_args.host, port=server_args.port)