537 lines
18 KiB
Python
537 lines
18 KiB
Python
# coding=utf-8
|
||
# Copyright 2026 The Alibaba Qwen team.
|
||
# SPDX-License-Identifier: Apache-2.0
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""
|
||
A gradio demo for Qwen3 ASR models.
|
||
"""
|
||
|
||
import argparse
|
||
import base64
|
||
import io
|
||
import json
|
||
import os
|
||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
||
import gradio as gr
|
||
import numpy as np
|
||
import torch
|
||
from qwen_asr import Qwen3ASRModel
|
||
from scipy.io.wavfile import write as wav_write
|
||
|
||
|
||
def _title_case_display(s: str) -> str:
|
||
s = (s or "").strip()
|
||
s = s.replace("_", " ")
|
||
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
||
|
||
|
||
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
||
if not items:
|
||
return [], {}
|
||
display = [_title_case_display(x) for x in items]
|
||
mapping = {d: r for d, r in zip(display, items)}
|
||
return display, mapping
|
||
|
||
|
||
def _dtype_from_str(s: str) -> torch.dtype:
|
||
s = (s or "").strip().lower()
|
||
if s in ("bf16", "bfloat16"):
|
||
return torch.bfloat16
|
||
if s in ("fp16", "float16", "half"):
|
||
return torch.float16
|
||
if s in ("fp32", "float32"):
|
||
return torch.float32
|
||
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
|
||
|
||
|
||
def _normalize_audio(wav, eps=1e-12, clip=True):
|
||
x = np.asarray(wav)
|
||
|
||
if np.issubdtype(x.dtype, np.integer):
|
||
info = np.iinfo(x.dtype)
|
||
if info.min < 0:
|
||
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
||
else:
|
||
mid = (info.max + 1) / 2.0
|
||
y = (x.astype(np.float32) - mid) / mid
|
||
elif np.issubdtype(x.dtype, np.floating):
|
||
y = x.astype(np.float32)
|
||
m = np.max(np.abs(y)) if y.size else 0.0
|
||
if m > 1.0 + 1e-6:
|
||
y = y / (m + eps)
|
||
else:
|
||
raise TypeError(f"Unsupported dtype: {x.dtype}")
|
||
|
||
if clip:
|
||
y = np.clip(y, -1.0, 1.0)
|
||
|
||
if y.ndim > 1:
|
||
y = np.mean(y, axis=-1).astype(np.float32)
|
||
|
||
return y
|
||
|
||
|
||
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
|
||
"""
|
||
Accept gradio audio:
|
||
- {"sampling_rate": int, "data": np.ndarray}
|
||
- (sr, np.ndarray) [some gradio versions]
|
||
Return: (wav_float32_mono, sr)
|
||
"""
|
||
if audio is None:
|
||
return None
|
||
|
||
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
|
||
sr = int(audio["sampling_rate"])
|
||
wav = _normalize_audio(audio["data"])
|
||
return wav, sr
|
||
|
||
if isinstance(audio, tuple) and len(audio) == 2:
|
||
a0, a1 = audio
|
||
if isinstance(a0, int):
|
||
sr = int(a0)
|
||
wav = _normalize_audio(a1)
|
||
return wav, sr
|
||
if isinstance(a1, int):
|
||
wav = _normalize_audio(a0)
|
||
sr = int(a1)
|
||
return wav, sr
|
||
|
||
return None
|
||
|
||
|
||
def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]:
|
||
if audio is None:
|
||
raise ValueError("Audio is required.")
|
||
at = _audio_to_tuple(audio)
|
||
if at is not None:
|
||
return at
|
||
raise ValueError("Unsupported audio input format.")
|
||
|
||
|
||
def build_parser() -> argparse.ArgumentParser:
|
||
parser = argparse.ArgumentParser(
|
||
prog="qwen-asr-demo",
|
||
description=(
|
||
"Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n"
|
||
"Examples:\n"
|
||
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n"
|
||
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n"
|
||
" qwen-asr-demo --backend vllm --cuda-visible-devices 0\n"
|
||
" qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n"
|
||
" qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n"
|
||
),
|
||
formatter_class=argparse.RawTextHelpFormatter,
|
||
add_help=True,
|
||
)
|
||
|
||
parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.")
|
||
parser.add_argument(
|
||
"--aligner-checkpoint",
|
||
default=None,
|
||
help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--backend",
|
||
default="transformers",
|
||
choices=["transformers", "vllm"],
|
||
help="Backend for ASR model loading (default: transformers).",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--cuda-visible-devices",
|
||
default="0",
|
||
help=(
|
||
"Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). "
|
||
"Use e.g. '0' or '1'"
|
||
),
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--backend-kwargs",
|
||
default=None,
|
||
help=(
|
||
"JSON dict for backend-specific kwargs excluding checkpoints.\n"
|
||
"Examples:\n"
|
||
" transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n"
|
||
" vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n"
|
||
),
|
||
)
|
||
parser.add_argument(
|
||
"--aligner-kwargs",
|
||
default=None,
|
||
help=(
|
||
"JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n"
|
||
"Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n"
|
||
),
|
||
)
|
||
|
||
# Gradio server args
|
||
parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).")
|
||
parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).")
|
||
parser.add_argument(
|
||
"--share/--no-share",
|
||
dest="share",
|
||
default=False,
|
||
action=argparse.BooleanOptionalAction,
|
||
help="Whether to create a public Gradio link (default: disabled).",
|
||
)
|
||
parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).")
|
||
|
||
# HTTPS args
|
||
parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).")
|
||
parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).")
|
||
parser.add_argument(
|
||
"--ssl-verify/--no-ssl-verify",
|
||
dest="ssl_verify",
|
||
default=True,
|
||
action=argparse.BooleanOptionalAction,
|
||
help="Whether to verify SSL certificate (default: enabled).",
|
||
)
|
||
|
||
return parser
|
||
|
||
|
||
def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]:
|
||
if s is None or not str(s).strip():
|
||
return {}
|
||
try:
|
||
obj = json.loads(s)
|
||
except Exception as e:
|
||
raise ValueError(f"Invalid JSON for {name}: {e}")
|
||
if not isinstance(obj, dict):
|
||
raise ValueError(f"{name} must be a JSON object (dict).")
|
||
return obj
|
||
|
||
|
||
def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None:
|
||
v = (cuda_visible_devices or "").strip()
|
||
if not v:
|
||
return
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = v
|
||
|
||
|
||
def _default_backend_kwargs(backend: str) -> Dict[str, Any]:
|
||
if backend == "transformers":
|
||
return dict(
|
||
dtype=torch.bfloat16,
|
||
device_map="cuda:0",
|
||
max_inference_batch_size=4,
|
||
max_new_tokens=512,
|
||
)
|
||
else:
|
||
return dict(
|
||
gpu_memory_utilization=0.8,
|
||
max_inference_batch_size=4,
|
||
max_new_tokens=4096,
|
||
)
|
||
|
||
|
||
def _default_aligner_kwargs() -> Dict[str, Any]:
|
||
return dict(
|
||
dtype=torch.bfloat16,
|
||
device_map="cuda:0",
|
||
)
|
||
|
||
|
||
def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||
out = dict(base)
|
||
out.update(override)
|
||
return out
|
||
|
||
|
||
def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]:
|
||
out: Dict[str, Any] = {}
|
||
for k, v in d.items():
|
||
if k == "dtype" and isinstance(v, str):
|
||
out[k] = _dtype_from_str(v)
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str:
|
||
"""
|
||
Build HTML with per-token audio slices, using base64 data URLs.
|
||
Expect timestamps as list[dict] with keys: text, start_time, end_time (ms).
|
||
"""
|
||
at = _audio_to_tuple(audio_upload)
|
||
if at is None:
|
||
raise ValueError("Audio input is required for visualization.")
|
||
audio, sr = at
|
||
|
||
if not timestamps:
|
||
return "<div style='color:#666'>No timestamps to visualize.</div>"
|
||
if not isinstance(timestamps, list):
|
||
raise ValueError("Timestamps must be a list (JSON array).")
|
||
|
||
html_content = """
|
||
<style>
|
||
.word-alignment-container { display: flex; flex-wrap: wrap; gap: 10px; }
|
||
.word-box {
|
||
border: 1px solid #ddd; border-radius: 8px; padding: 10px;
|
||
background-color: #f9f9f9; box-shadow: 0 2px 4px rgba(0,0,0,0.06);
|
||
text-align: center;
|
||
}
|
||
.word-text { font-size: 18px; font-weight: 700; margin-bottom: 5px; }
|
||
.word-time { font-size: 12px; color: #666; margin-bottom: 8px; }
|
||
.word-audio audio { width: 140px; height: 30px; }
|
||
details { border: 1px solid #ddd; border-radius: 6px; padding: 10px; background-color: #f7f7f7; }
|
||
summary { font-weight: 700; cursor: pointer; }
|
||
</style>
|
||
"""
|
||
|
||
html_content += """
|
||
<details open>
|
||
<summary>Timestamps Visualization (时间戳可视化结果)</summary>
|
||
<div class="word-alignment-container" style="margin-top: 14px;">
|
||
"""
|
||
|
||
for item in timestamps:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
word = str(item.get("text", "") or "")
|
||
start = item.get("start_time", None)
|
||
end = item.get("end_time", None)
|
||
if start is None or end is None:
|
||
continue
|
||
|
||
start = float(start)
|
||
end = float(end)
|
||
if end <= start:
|
||
continue
|
||
|
||
start_sample = max(0, int(start * sr))
|
||
end_sample = min(len(audio), int(end * sr))
|
||
if end_sample <= start_sample:
|
||
continue
|
||
|
||
seg = audio[start_sample:end_sample]
|
||
seg_i16 = (np.clip(seg, -1.0, 1.0) * 32767.0).astype(np.int16)
|
||
|
||
mem = io.BytesIO()
|
||
wav_write(mem, sr, seg_i16)
|
||
mem.seek(0)
|
||
b64 = base64.b64encode(mem.read()).decode("utf-8")
|
||
audio_src = f"data:audio/wav;base64,{b64}"
|
||
|
||
html_content += f"""
|
||
<div class="word-box">
|
||
<div class="word-text">{word}</div>
|
||
<div class="word-time">{start} - {end} s</div>
|
||
<div class="word-audio">
|
||
<audio controls preload="none" src="{audio_src}"></audio>
|
||
</div>
|
||
</div>
|
||
"""
|
||
|
||
html_content += "</div></details>"
|
||
return html_content
|
||
|
||
|
||
def build_demo(
|
||
asr: Qwen3ASRModel,
|
||
asr_ckpt: str,
|
||
backend: str,
|
||
aligner_ckpt: Optional[str] = None,
|
||
) -> gr.Blocks:
|
||
supported_langs_raw = asr.get_supported_languages()
|
||
lang_choices_disp, lang_map = _build_choices_and_map([x for x in supported_langs_raw])
|
||
lang_choices = ["Auto"] + lang_choices_disp
|
||
|
||
has_aligner = bool(aligner_ckpt)
|
||
|
||
theme = gr.themes.Soft(
|
||
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
|
||
)
|
||
css = ".gradio-container {max-width: none !important;}"
|
||
|
||
with gr.Blocks(theme=theme, css=css) as demo:
|
||
gr.Markdown(
|
||
f"""
|
||
# Qwen3 ASR Demo
|
||
**Backend:** `{backend}`
|
||
**ASR Checkpoint:** `{asr_ckpt}`
|
||
**Forced Aligner:** `{aligner_ckpt if aligner_ckpt else "(none)"}`
|
||
"""
|
||
)
|
||
|
||
with gr.Row():
|
||
with gr.Column(scale=2):
|
||
audio_in = gr.Audio(label="Audio Input (上传音频)", type="numpy")
|
||
lang_in = gr.Dropdown(
|
||
label="Language (语种)",
|
||
choices=lang_choices,
|
||
value="Auto",
|
||
interactive=True,
|
||
)
|
||
if has_aligner:
|
||
ts_in = gr.Checkbox(
|
||
label="Return Timestamps (是否返回时间戳)",
|
||
value=True,
|
||
)
|
||
else:
|
||
ts_in = gr.State(False)
|
||
|
||
btn = gr.Button("Transcribe (识别)", variant="primary")
|
||
|
||
with gr.Column(scale=2):
|
||
out_lang = gr.Textbox(label="Detected Language", lines=1)
|
||
out_text = gr.Textbox(label="Result Text", lines=12)
|
||
|
||
if has_aligner:
|
||
with gr.Column(scale=3):
|
||
out_ts = gr.JSON(label="Timestamps(时间戳结果)")
|
||
viz_btn = gr.Button("Visualize Timestamps (可视化时间戳)", variant="secondary")
|
||
else:
|
||
with gr.Column(scale=3):
|
||
out_ts = gr.State(None)
|
||
viz_btn = gr.State(None)
|
||
|
||
# Put the visualization panel below the three columns
|
||
if has_aligner:
|
||
with gr.Row():
|
||
out_ts_html = gr.HTML(label="Timestamps Visualization (时间戳可视化结果)")
|
||
else:
|
||
out_ts_html = gr.State("")
|
||
|
||
def run(audio_upload: Any, lang_disp: str, return_ts: bool):
|
||
audio_obj = _parse_audio_any(audio_upload)
|
||
|
||
language = None
|
||
if lang_disp and lang_disp != "Auto":
|
||
language = lang_map.get(lang_disp, lang_disp)
|
||
|
||
return_ts = bool(return_ts) and has_aligner
|
||
|
||
results = asr.transcribe(
|
||
audio=audio_obj,
|
||
language=language,
|
||
return_time_stamps=return_ts,
|
||
)
|
||
if not isinstance(results, list) or len(results) != 1:
|
||
raise RuntimeError(
|
||
f"Unexpected result size: {type(results)} "
|
||
f"len={len(results) if isinstance(results, list) else 'N/A'}"
|
||
)
|
||
|
||
r = results[0]
|
||
|
||
if has_aligner:
|
||
ts_payload = None
|
||
if return_ts:
|
||
ts_payload = [
|
||
dict(
|
||
text=getattr(t, "text", None),
|
||
start_time=getattr(t, "start_time", None),
|
||
end_time=getattr(t, "end_time", None),
|
||
)
|
||
for t in (getattr(r, "time_stamps", None) or [])
|
||
]
|
||
return (
|
||
getattr(r, "language", "") or "",
|
||
getattr(r, "text", "") or "",
|
||
gr.update(value=ts_payload) if return_ts else gr.update(value=None),
|
||
gr.update(value=""), # clear html on each transcribe
|
||
)
|
||
else:
|
||
return (
|
||
getattr(r, "language", "") or "",
|
||
getattr(r, "text", "") or "",
|
||
)
|
||
|
||
def visualize(audio_upload: Any, timestamps_json: Any):
|
||
return _make_timestamp_html(audio_upload, timestamps_json)
|
||
|
||
if has_aligner:
|
||
btn.click(
|
||
run,
|
||
inputs=[audio_in, lang_in, ts_in],
|
||
outputs=[out_lang, out_text, out_ts, out_ts_html],
|
||
)
|
||
viz_btn.click(
|
||
visualize,
|
||
inputs=[audio_in, out_ts],
|
||
outputs=[out_ts_html],
|
||
)
|
||
else:
|
||
btn.click(
|
||
run,
|
||
inputs=[audio_in, lang_in, ts_in],
|
||
outputs=[out_lang, out_text],
|
||
)
|
||
|
||
return demo
|
||
|
||
|
||
def main(argv=None) -> int:
|
||
parser = build_parser()
|
||
args = parser.parse_args(argv)
|
||
|
||
_apply_cuda_visible_devices(args.cuda_visible_devices)
|
||
|
||
backend = args.backend
|
||
asr_ckpt = args.asr_checkpoint
|
||
aligner_ckpt = args.aligner_checkpoint
|
||
|
||
user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
|
||
user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
|
||
|
||
backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
|
||
backend_kwargs = _coerce_special_types(backend_kwargs)
|
||
|
||
forced_aligner = None
|
||
forced_aligner_kwargs = None
|
||
if aligner_ckpt:
|
||
forced_aligner = aligner_ckpt
|
||
aligner_kwargs = _merge_dicts(_default_aligner_kwargs(), user_aligner_kwargs)
|
||
forced_aligner_kwargs = _coerce_special_types(aligner_kwargs)
|
||
|
||
if backend == "transformers":
|
||
asr = Qwen3ASRModel.from_pretrained(
|
||
asr_ckpt,
|
||
forced_aligner=forced_aligner,
|
||
forced_aligner_kwargs=forced_aligner_kwargs,
|
||
**backend_kwargs,
|
||
)
|
||
else:
|
||
asr = Qwen3ASRModel.LLM(
|
||
model=asr_ckpt,
|
||
forced_aligner=forced_aligner,
|
||
forced_aligner_kwargs=forced_aligner_kwargs,
|
||
**backend_kwargs,
|
||
)
|
||
|
||
demo = build_demo(asr, asr_ckpt, backend, aligner_ckpt=aligner_ckpt)
|
||
|
||
launch_kwargs: Dict[str, Any] = dict(
|
||
server_name=args.ip,
|
||
server_port=args.port,
|
||
share=args.share,
|
||
ssl_verify=True if args.ssl_verify else False,
|
||
)
|
||
if args.ssl_certfile is not None:
|
||
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
||
if args.ssl_keyfile is not None:
|
||
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
||
|
||
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|