feat: streaming api
All checks were successful
Build container / build-docker (push) Successful in 17m34s

This commit is contained in:
vera
2026-04-22 18:33:08 +08:00
parent c17a131fe0
commit 42eb035f4b
8 changed files with 7025 additions and 2 deletions

View File

@ -0,0 +1,128 @@
# coding=utf-8
import argparse
import json
import logging
import uuid
from typing import Optional
import numpy as np
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from qwen_asr import Qwen3ASRModel
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = FastAPI(title="Qwen3-ASR Streaming API")
# Global ASR model instance
asr_model: Optional[Qwen3ASRModel] = None
server_args = None
@app.on_event("startup")
async def startup_event():
global asr_model
logger.info(f"Loading ASR model from {server_args.asr_model_path}...")
# Using vLLM backend for streaming
asr_model = Qwen3ASRModel.LLM(
model=server_args.asr_model_path,
gpu_memory_utilization=server_args.gpu_memory_utilization,
max_new_tokens=server_args.max_new_tokens,
)
logger.info("Model loaded successfully.")
@app.websocket("/asr/stream")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
session_id = uuid.uuid4().hex
logger.info(f"New session started: {session_id}")
# Initialize streaming state for this session
state = asr_model.init_streaming_state(
unfixed_chunk_num=server_args.unfixed_chunk_num,
unfixed_token_num=server_args.unfixed_token_num,
chunk_size_sec=server_args.chunk_size_sec,
)
try:
while True:
# Receive message from client
message = await websocket.receive()
if "bytes" in message:
# Binary audio data (Float32, 16kHz)
raw_bytes = message["bytes"]
if len(raw_bytes) % 4 != 0:
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (Float32)"})
continue
# Convert bytes to numpy array
wav = np.frombuffer(raw_bytes, dtype=np.float32)
# Perform streaming transcription
asr_model.streaming_transcribe(wav, state)
# Send back current intermediate transcription
await websocket.send_json({
"session_id": session_id,
"language": getattr(state, "language", "") or "",
"text": getattr(state, "text", "") or "",
"is_final": False
})
elif "text" in message:
# Command message
try:
text_data = json.loads(message["text"])
if text_data.get("command") == "finish":
logger.info(f"Finish command received for session: {session_id}")
break
except json.JSONDecodeError:
logger.warning(f"Received invalid JSON text: {message['text']}")
except WebSocketDisconnect:
logger.info(f"Client disconnected: {session_id}")
except Exception as e:
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
try:
print(e)
await websocket.send_json({"error": str(e)})
except:
pass
finally:
# Finish transcription and send final results
try:
asr_model.finish_streaming_transcribe(state)
await websocket.send_json({
"session_id": session_id,
"language": getattr(state, "language", "") or "",
"text": getattr(state, "text", "") or "",
"is_final": True
})
logger.info(f"Sent final result for session: {session_id}")
except Exception as e:
logger.error(f"Error while finishing session {session_id}: {e}")
try:
await websocket.close()
except:
pass
logger.info(f"Session closed: {session_id}")
def parse_args():
p = argparse.ArgumentParser(description="Qwen3-ASR Streaming API (vLLM backend)")
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
p.add_argument("--host", default="0.0.0.0", help="Bind host")
p.add_argument("--port", type=int, default=8000, help="Bind port")
p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization")
p.add_argument("--max-new-tokens", type=int, default=32, help="Max new tokens to generate per streaming call. Small value is recommended for low latency.")
p.add_argument("--unfixed-chunk-num", type=int, default=4, help="Number of unfixed chunks in streaming")
p.add_argument("--unfixed-token-num", type=int, default=5, help="Number of unfixed tokens in streaming")
p.add_argument("--chunk-size-sec", type=float, default=1.0, help="Size of each chunk in seconds")
return p.parse_args()
if __name__ == "__main__":
server_args = parse_args()
# Note: Use uvicorn to run the FastAPI app
uvicorn.run(app, host=server_args.host, port=server_args.port)

View File

@ -0,0 +1,188 @@
# coding=utf-8
"""
Unified Qwen3-ASR API Server
============================
模型只加载一次,同时提供:
- POST /asr/transcribe (非流式,整段音频转写)
- WS /asr/stream (流式,实时增量转写)
启动示例:
uv run examples/api_unified_fastapi.py --asr-model-path Qwen/Qwen3-ASR-1.7B
非流式调用示例Python
import requests
with open("audio.wav", "rb") as f:
resp = requests.post("http://localhost:8000/asr/transcribe",
files={"file": ("audio.wav", f, "audio/wav")},
data={"context": "", "language": ""})
print(resp.json())
"""
import argparse
import io
import json
import logging
import uuid
from typing import Optional
import numpy as np
import soundfile as sf
import uvicorn
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from qwen_asr import Qwen3ASRModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Qwen3-ASR Unified API")
# ── 全局单例 ──────────────────────────────────────────────────────────────────
asr_model: Optional[Qwen3ASRModel] = None
server_args = None
@app.on_event("startup")
async def startup_event():
global asr_model
logger.info(f"Loading ASR model from {server_args.asr_model_path} ...")
asr_model = Qwen3ASRModel.LLM(
model=server_args.asr_model_path,
gpu_memory_utilization=server_args.gpu_memory_utilization,
max_new_tokens=server_args.max_new_tokens,
)
logger.info("Model loaded successfully.")
# ── 非流式端点HTTP POST multipart─────────────────────────────────────────
@app.post("/asr/transcribe")
async def transcribe_endpoint(
file: UploadFile = File(..., description="音频文件wav/mp3/flac 等 soundfile 支持的格式)"),
context: str = Form(default="", description="可选上下文"),
language: str = Form(default="", description="可选强制语言,如 Chinese / English"),
):
"""
非流式整段转写。
以 multipart/form-data 上传音频文件,返回最终转写文本。
curl 示例:
curl -X POST http://localhost:8000/asr/transcribe \\
-F "file=@audio.wav" -F "context=" -F "language="
"""
try:
raw = await file.read()
with io.BytesIO(raw) as buf:
wav, sr = sf.read(buf, dtype="float32", always_2d=False)
wav = np.asarray(wav, dtype=np.float32)
results = asr_model.transcribe(
audio=(wav, sr),
context=context,
language=language.strip() or None,
)
r = results[0]
return {"language": r.language, "text": r.text}
except Exception as e:
logger.error(f"Transcribe error: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
# ── 流式端点WebSocket──────────────────────────────────────────────────────
@app.websocket("/asr/stream")
async def websocket_endpoint(websocket: WebSocket):
"""
流式增量转写。
协议:
客户端 → bytes : float32 PCM 16kHz 音频块
客户端 → text : {"command": "finish"} 结束会话
服务端 → text : {"session_id": ..., "language": ..., "text": ..., "is_final": bool}
"""
await websocket.accept()
session_id = uuid.uuid4().hex
logger.info(f"Stream session started: {session_id}")
state = asr_model.init_streaming_state(
unfixed_chunk_num=server_args.unfixed_chunk_num,
unfixed_token_num=server_args.unfixed_token_num,
chunk_size_sec=server_args.chunk_size_sec,
)
try:
while True:
message = await websocket.receive()
if "bytes" in message:
raw = message["bytes"]
if len(raw) % 4 != 0:
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (float32)"})
continue
wav = np.frombuffer(raw, dtype=np.float32)
asr_model.streaming_transcribe(wav, state)
await websocket.send_json({
"session_id": session_id,
"language": state.language or "",
"text": state.text or "",
"is_final": False,
})
elif "text" in message:
try:
cmd = json.loads(message["text"])
if cmd.get("command") == "finish":
logger.info(f"Finish command received: {session_id}")
break
except json.JSONDecodeError:
logger.warning(f"Invalid JSON: {message['text']}")
except WebSocketDisconnect:
logger.info(f"Client disconnected: {session_id}")
except Exception as e:
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
try:
await websocket.send_json({"error": str(e)})
except Exception:
pass
finally:
try:
asr_model.finish_streaming_transcribe(state)
await websocket.send_json({
"session_id": session_id,
"language": state.language or "",
"text": state.text or "",
"is_final": True,
})
logger.info(f"Final result sent: {session_id}")
except Exception as e:
logger.error(f"Error finishing session {session_id}: {e}")
try:
await websocket.close()
except Exception:
pass
logger.info(f"Session closed: {session_id}")
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description="Qwen3-ASR Unified API (streaming + non-streaming)")
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--gpu-memory-utilization", type=float, default=0.8)
p.add_argument("--max-new-tokens", type=int, default=32,
help="Max new tokens per call (streaming). Use larger value for non-streaming.")
p.add_argument("--unfixed-chunk-num", type=int, default=4)
p.add_argument("--unfixed-token-num", type=int, default=5)
p.add_argument("--chunk-size-sec", type=float, default=1.0)
return p.parse_args()
if __name__ == "__main__":
server_args = parse_args()
uvicorn.run(app, host=server_args.host, port=server_args.port)