# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ A gradio demo for Qwen3 ASR models. """ import argparse import base64 import io import json import os from typing import Any, Dict, List, Optional, Tuple, Union import gradio as gr import numpy as np import torch from qwen_asr import Qwen3ASRModel from scipy.io.wavfile import write as wav_write def _title_case_display(s: str) -> str: s = (s or "").strip() s = s.replace("_", " ") return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()]) def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]: if not items: return [], {} display = [_title_case_display(x) for x in items] mapping = {d: r for d, r in zip(display, items)} return display, mapping def _dtype_from_str(s: str) -> torch.dtype: s = (s or "").strip().lower() if s in ("bf16", "bfloat16"): return torch.bfloat16 if s in ("fp16", "float16", "half"): return torch.float16 if s in ("fp32", "float32"): return torch.float32 raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.") def _normalize_audio(wav, eps=1e-12, clip=True): x = np.asarray(wav) if np.issubdtype(x.dtype, np.integer): info = np.iinfo(x.dtype) if info.min < 0: y = x.astype(np.float32) / max(abs(info.min), info.max) else: mid = (info.max + 1) / 2.0 y = (x.astype(np.float32) - mid) / mid elif np.issubdtype(x.dtype, np.floating): y = x.astype(np.float32) m = np.max(np.abs(y)) if y.size else 0.0 if m > 1.0 + 1e-6: y = y / (m + eps) else: raise TypeError(f"Unsupported dtype: {x.dtype}") if clip: y = np.clip(y, -1.0, 1.0) if y.ndim > 1: y = np.mean(y, axis=-1).astype(np.float32) return y def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]: """ Accept gradio audio: - {"sampling_rate": int, "data": np.ndarray} - (sr, np.ndarray) [some gradio versions] Return: (wav_float32_mono, sr) """ if audio is None: return None if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio: sr = int(audio["sampling_rate"]) wav = _normalize_audio(audio["data"]) return wav, sr if isinstance(audio, tuple) and len(audio) == 2: a0, a1 = audio if isinstance(a0, int): sr = int(a0) wav = _normalize_audio(a1) return wav, sr if isinstance(a1, int): wav = _normalize_audio(a0) sr = int(a1) return wav, sr return None def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]: if audio is None: raise ValueError("Audio is required.") at = _audio_to_tuple(audio) if at is not None: return at raise ValueError("Unsupported audio input format.") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="qwen-asr-demo", description=( "Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n" "Examples:\n" " qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n" " qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n" " qwen-asr-demo --backend vllm --cuda-visible-devices 0\n" " qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n" " qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n" ), formatter_class=argparse.RawTextHelpFormatter, add_help=True, ) parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.") parser.add_argument( "--aligner-checkpoint", default=None, help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).", ) parser.add_argument( "--backend", default="transformers", choices=["transformers", "vllm"], help="Backend for ASR model loading (default: transformers).", ) parser.add_argument( "--cuda-visible-devices", default="0", help=( "Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). " "Use e.g. '0' or '1'" ), ) parser.add_argument( "--backend-kwargs", default=None, help=( "JSON dict for backend-specific kwargs excluding checkpoints.\n" "Examples:\n" " transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n" " vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n" ), ) parser.add_argument( "--aligner-kwargs", default=None, help=( "JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n" "Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n" ), ) # Gradio server args parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).") parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).") parser.add_argument( "--share/--no-share", dest="share", default=False, action=argparse.BooleanOptionalAction, help="Whether to create a public Gradio link (default: disabled).", ) parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).") # HTTPS args parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).") parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).") parser.add_argument( "--ssl-verify/--no-ssl-verify", dest="ssl_verify", default=True, action=argparse.BooleanOptionalAction, help="Whether to verify SSL certificate (default: enabled).", ) return parser def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]: if s is None or not str(s).strip(): return {} try: obj = json.loads(s) except Exception as e: raise ValueError(f"Invalid JSON for {name}: {e}") if not isinstance(obj, dict): raise ValueError(f"{name} must be a JSON object (dict).") return obj def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None: v = (cuda_visible_devices or "").strip() if not v: return os.environ["CUDA_VISIBLE_DEVICES"] = v def _default_backend_kwargs(backend: str) -> Dict[str, Any]: if backend == "transformers": return dict( dtype=torch.bfloat16, device_map="cuda:0", max_inference_batch_size=4, max_new_tokens=512, ) else: return dict( gpu_memory_utilization=0.8, max_inference_batch_size=4, max_new_tokens=4096, ) def _default_aligner_kwargs() -> Dict[str, Any]: return dict( dtype=torch.bfloat16, device_map="cuda:0", ) def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: out = dict(base) out.update(override) return out def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]: out: Dict[str, Any] = {} for k, v in d.items(): if k == "dtype" and isinstance(v, str): out[k] = _dtype_from_str(v) else: out[k] = v return out def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str: """ Build HTML with per-token audio slices, using base64 data URLs. Expect timestamps as list[dict] with keys: text, start_time, end_time (ms). """ at = _audio_to_tuple(audio_upload) if at is None: raise ValueError("Audio input is required for visualization.") audio, sr = at if not timestamps: return "