# coding=utf-8 import argparse import json import logging import uuid from typing import Optional import numpy as np import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect from qwen_asr import Qwen3ASRModel # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) app = FastAPI(title="Qwen3-ASR Streaming API") # Global ASR model instance asr_model: Optional[Qwen3ASRModel] = None server_args = None @app.on_event("startup") async def startup_event(): global asr_model logger.info(f"Loading ASR model from {server_args.asr_model_path}...") # Using vLLM backend for streaming asr_model = Qwen3ASRModel.LLM( model=server_args.asr_model_path, gpu_memory_utilization=server_args.gpu_memory_utilization, max_new_tokens=server_args.max_new_tokens, ) logger.info("Model loaded successfully.") @app.websocket("/asr/stream") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() session_id = uuid.uuid4().hex logger.info(f"New session started: {session_id}") # Initialize streaming state for this session state = asr_model.init_streaming_state( unfixed_chunk_num=server_args.unfixed_chunk_num, unfixed_token_num=server_args.unfixed_token_num, chunk_size_sec=server_args.chunk_size_sec, ) try: while True: # Receive message from client message = await websocket.receive() if "bytes" in message: # Binary audio data (Float32, 16kHz) raw_bytes = message["bytes"] if len(raw_bytes) % 4 != 0: await websocket.send_json({"error": "Data length must be multiple of 4 bytes (Float32)"}) continue # Convert bytes to numpy array wav = np.frombuffer(raw_bytes, dtype=np.float32) # Perform streaming transcription asr_model.streaming_transcribe(wav, state) # Send back current intermediate transcription await websocket.send_json({ "session_id": session_id, "language": getattr(state, "language", "") or "", "text": getattr(state, "text", "") or "", "is_final": False }) elif "text" in message: # Command message try: text_data = json.loads(message["text"]) if text_data.get("command") == "finish": logger.info(f"Finish command received for session: {session_id}") break except json.JSONDecodeError: logger.warning(f"Received invalid JSON text: {message['text']}") except WebSocketDisconnect: logger.info(f"Client disconnected: {session_id}") except Exception as e: logger.error(f"Error in session {session_id}: {e}", exc_info=True) try: print(e) await websocket.send_json({"error": str(e)}) except: pass finally: # Finish transcription and send final results try: asr_model.finish_streaming_transcribe(state) await websocket.send_json({ "session_id": session_id, "language": getattr(state, "language", "") or "", "text": getattr(state, "text", "") or "", "is_final": True }) logger.info(f"Sent final result for session: {session_id}") except Exception as e: logger.error(f"Error while finishing session {session_id}: {e}") try: await websocket.close() except: pass logger.info(f"Session closed: {session_id}") def parse_args(): p = argparse.ArgumentParser(description="Qwen3-ASR Streaming API (vLLM backend)") p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path") p.add_argument("--host", default="0.0.0.0", help="Bind host") p.add_argument("--port", type=int, default=8000, help="Bind port") p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization") p.add_argument("--max-new-tokens", type=int, default=32, help="Max new tokens to generate per streaming call. Small value is recommended for low latency.") p.add_argument("--unfixed-chunk-num", type=int, default=4, help="Number of unfixed chunks in streaming") p.add_argument("--unfixed-token-num", type=int, default=5, help="Number of unfixed tokens in streaming") p.add_argument("--chunk-size-sec", type=float, default=1.0, help="Size of each chunk in seconds") return p.parse_args() if __name__ == "__main__": server_args = parse_args() # Note: Use uvicorn to run the FastAPI app uvicorn.run(app, host=server_args.host, port=server_args.port)