129 lines
5.0 KiB
Python
129 lines
5.0 KiB
Python
# coding=utf-8
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import uvicorn
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from qwen_asr import Qwen3ASRModel
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(title="Qwen3-ASR Streaming API")
|
|
|
|
# Global ASR model instance
|
|
asr_model: Optional[Qwen3ASRModel] = None
|
|
server_args = None
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
global asr_model
|
|
logger.info(f"Loading ASR model from {server_args.asr_model_path}...")
|
|
# Using vLLM backend for streaming
|
|
asr_model = Qwen3ASRModel.LLM(
|
|
model=server_args.asr_model_path,
|
|
gpu_memory_utilization=server_args.gpu_memory_utilization,
|
|
max_new_tokens=server_args.max_new_tokens,
|
|
)
|
|
logger.info("Model loaded successfully.")
|
|
|
|
@app.websocket("/asr/stream")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
session_id = uuid.uuid4().hex
|
|
logger.info(f"New session started: {session_id}")
|
|
|
|
# Initialize streaming state for this session
|
|
state = asr_model.init_streaming_state(
|
|
unfixed_chunk_num=server_args.unfixed_chunk_num,
|
|
unfixed_token_num=server_args.unfixed_token_num,
|
|
chunk_size_sec=server_args.chunk_size_sec,
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
# Receive message from client
|
|
message = await websocket.receive()
|
|
|
|
if "bytes" in message:
|
|
# Binary audio data (Float32, 16kHz)
|
|
raw_bytes = message["bytes"]
|
|
if len(raw_bytes) % 4 != 0:
|
|
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (Float32)"})
|
|
continue
|
|
|
|
# Convert bytes to numpy array
|
|
wav = np.frombuffer(raw_bytes, dtype=np.float32)
|
|
|
|
# Perform streaming transcription
|
|
asr_model.streaming_transcribe(wav, state)
|
|
|
|
# Send back current intermediate transcription
|
|
await websocket.send_json({
|
|
"session_id": session_id,
|
|
"language": getattr(state, "language", "") or "",
|
|
"text": getattr(state, "text", "") or "",
|
|
"is_final": False
|
|
})
|
|
|
|
elif "text" in message:
|
|
# Command message
|
|
try:
|
|
text_data = json.loads(message["text"])
|
|
if text_data.get("command") == "finish":
|
|
logger.info(f"Finish command received for session: {session_id}")
|
|
break
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Received invalid JSON text: {message['text']}")
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"Client disconnected: {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
|
|
try:
|
|
print(e)
|
|
await websocket.send_json({"error": str(e)})
|
|
except:
|
|
pass
|
|
finally:
|
|
# Finish transcription and send final results
|
|
try:
|
|
asr_model.finish_streaming_transcribe(state)
|
|
await websocket.send_json({
|
|
"session_id": session_id,
|
|
"language": getattr(state, "language", "") or "",
|
|
"text": getattr(state, "text", "") or "",
|
|
"is_final": True
|
|
})
|
|
logger.info(f"Sent final result for session: {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error while finishing session {session_id}: {e}")
|
|
|
|
try:
|
|
await websocket.close()
|
|
except:
|
|
pass
|
|
logger.info(f"Session closed: {session_id}")
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Qwen3-ASR Streaming API (vLLM backend)")
|
|
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
|
|
p.add_argument("--host", default="0.0.0.0", help="Bind host")
|
|
p.add_argument("--port", type=int, default=8000, help="Bind port")
|
|
p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization")
|
|
p.add_argument("--max-new-tokens", type=int, default=32, help="Max new tokens to generate per streaming call. Small value is recommended for low latency.")
|
|
p.add_argument("--unfixed-chunk-num", type=int, default=4, help="Number of unfixed chunks in streaming")
|
|
p.add_argument("--unfixed-token-num", type=int, default=5, help="Number of unfixed tokens in streaming")
|
|
p.add_argument("--chunk-size-sec", type=float, default=1.0, help="Size of each chunk in seconds")
|
|
return p.parse_args()
|
|
|
|
if __name__ == "__main__":
|
|
server_args = parse_args()
|
|
# Note: Use uvicorn to run the FastAPI app
|
|
uvicorn.run(app, host=server_args.host, port=server_args.port)
|