189 lines
7.1 KiB
Python
189 lines
7.1 KiB
Python
# coding=utf-8
|
||
"""
|
||
Unified Qwen3-ASR API Server
|
||
============================
|
||
模型只加载一次,同时提供:
|
||
- POST /asr/transcribe (非流式,整段音频转写)
|
||
- WS /asr/stream (流式,实时增量转写)
|
||
|
||
启动示例:
|
||
uv run examples/api_unified_fastapi.py --asr-model-path Qwen/Qwen3-ASR-1.7B
|
||
|
||
非流式调用示例(Python):
|
||
import requests
|
||
with open("audio.wav", "rb") as f:
|
||
resp = requests.post("http://localhost:8000/asr/transcribe",
|
||
files={"file": ("audio.wav", f, "audio/wav")},
|
||
data={"context": "", "language": ""})
|
||
print(resp.json())
|
||
"""
|
||
import argparse
|
||
import io
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import soundfile as sf
|
||
import uvicorn
|
||
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
|
||
from fastapi.responses import JSONResponse
|
||
from qwen_asr import Qwen3ASRModel
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = FastAPI(title="Qwen3-ASR Unified API")
|
||
|
||
# ── 全局单例 ──────────────────────────────────────────────────────────────────
|
||
asr_model: Optional[Qwen3ASRModel] = None
|
||
server_args = None
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
global asr_model
|
||
logger.info(f"Loading ASR model from {server_args.asr_model_path} ...")
|
||
asr_model = Qwen3ASRModel.LLM(
|
||
model=server_args.asr_model_path,
|
||
gpu_memory_utilization=server_args.gpu_memory_utilization,
|
||
max_new_tokens=server_args.max_new_tokens,
|
||
)
|
||
logger.info("Model loaded successfully.")
|
||
|
||
|
||
# ── 非流式端点(HTTP POST multipart)─────────────────────────────────────────
|
||
|
||
@app.post("/asr/transcribe")
|
||
async def transcribe_endpoint(
|
||
file: UploadFile = File(..., description="音频文件(wav/mp3/flac 等 soundfile 支持的格式)"),
|
||
context: str = Form(default="", description="可选上下文"),
|
||
language: str = Form(default="", description="可选强制语言,如 Chinese / English"),
|
||
):
|
||
"""
|
||
非流式整段转写。
|
||
以 multipart/form-data 上传音频文件,返回最终转写文本。
|
||
|
||
curl 示例:
|
||
curl -X POST http://localhost:8000/asr/transcribe \\
|
||
-F "file=@audio.wav" -F "context=" -F "language="
|
||
"""
|
||
try:
|
||
raw = await file.read()
|
||
with io.BytesIO(raw) as buf:
|
||
wav, sr = sf.read(buf, dtype="float32", always_2d=False)
|
||
wav = np.asarray(wav, dtype=np.float32)
|
||
|
||
results = asr_model.transcribe(
|
||
audio=(wav, sr),
|
||
context=context,
|
||
language=language.strip() or None,
|
||
)
|
||
r = results[0]
|
||
return {"language": r.language, "text": r.text}
|
||
except Exception as e:
|
||
logger.error(f"Transcribe error: {e}", exc_info=True)
|
||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||
|
||
|
||
# ── 流式端点(WebSocket)──────────────────────────────────────────────────────
|
||
|
||
@app.websocket("/asr/stream")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
"""
|
||
流式增量转写。
|
||
协议:
|
||
客户端 → bytes : float32 PCM 16kHz 音频块
|
||
客户端 → text : {"command": "finish"} 结束会话
|
||
服务端 → text : {"session_id": ..., "language": ..., "text": ..., "is_final": bool}
|
||
"""
|
||
await websocket.accept()
|
||
session_id = uuid.uuid4().hex
|
||
logger.info(f"Stream session started: {session_id}")
|
||
|
||
state = asr_model.init_streaming_state(
|
||
unfixed_chunk_num=server_args.unfixed_chunk_num,
|
||
unfixed_token_num=server_args.unfixed_token_num,
|
||
chunk_size_sec=server_args.chunk_size_sec,
|
||
)
|
||
|
||
try:
|
||
while True:
|
||
message = await websocket.receive()
|
||
|
||
if "bytes" in message:
|
||
raw = message["bytes"]
|
||
if len(raw) % 4 != 0:
|
||
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (float32)"})
|
||
continue
|
||
|
||
wav = np.frombuffer(raw, dtype=np.float32)
|
||
asr_model.streaming_transcribe(wav, state)
|
||
|
||
await websocket.send_json({
|
||
"session_id": session_id,
|
||
"language": state.language or "",
|
||
"text": state.text or "",
|
||
"is_final": False,
|
||
})
|
||
|
||
elif "text" in message:
|
||
try:
|
||
cmd = json.loads(message["text"])
|
||
if cmd.get("command") == "finish":
|
||
logger.info(f"Finish command received: {session_id}")
|
||
break
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"Invalid JSON: {message['text']}")
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"Client disconnected: {session_id}")
|
||
except Exception as e:
|
||
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
|
||
try:
|
||
await websocket.send_json({"error": str(e)})
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
try:
|
||
asr_model.finish_streaming_transcribe(state)
|
||
await websocket.send_json({
|
||
"session_id": session_id,
|
||
"language": state.language or "",
|
||
"text": state.text or "",
|
||
"is_final": True,
|
||
})
|
||
logger.info(f"Final result sent: {session_id}")
|
||
except Exception as e:
|
||
logger.error(f"Error finishing session {session_id}: {e}")
|
||
try:
|
||
await websocket.close()
|
||
except Exception:
|
||
pass
|
||
logger.info(f"Session closed: {session_id}")
|
||
|
||
|
||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description="Qwen3-ASR Unified API (streaming + non-streaming)")
|
||
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
|
||
p.add_argument("--host", default="0.0.0.0")
|
||
p.add_argument("--port", type=int, default=8000)
|
||
p.add_argument("--gpu-memory-utilization", type=float, default=0.8)
|
||
p.add_argument("--max-new-tokens", type=int, default=32,
|
||
help="Max new tokens per call (streaming). Use larger value for non-streaming.")
|
||
p.add_argument("--unfixed-chunk-num", type=int, default=4)
|
||
p.add_argument("--unfixed-token-num", type=int, default=5)
|
||
p.add_argument("--chunk-size-sec", type=float, default=1.0)
|
||
return p.parse_args()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
server_args = parse_args()
|
||
uvicorn.run(app, host=server_args.host, port=server_args.port)
|