Initial commit

This commit is contained in:
Xiong Wang
2026-01-29 20:23:50 +08:00
commit 9567667698
32 changed files with 30029 additions and 0 deletions

25
qwen_asr/__init__.py Normal file
View File

@ -0,0 +1,25 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
qwen_asr: Qwen3-ASR package.
"""
from .inference.qwen3_asr import Qwen3ASRModel
from .inference.qwen3_forced_aligner import Qwen3ForcedAligner
from .inference.utils import parse_asr_output
__all__ = ["__version__"]

26
qwen_asr/__main__.py Normal file
View File

@ -0,0 +1,26 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def main():
print(
"qwen_asr package.\n"
"Use CLI entrypoints:\n"
" - qwen-asr-demo\n"
" - qwen-asr-demo-streaming\n"
" - qwen-asr-serve\n"
)
if __name__ == "__main__":
main()

536
qwen_asr/cli/demo.py Normal file
View File

@ -0,0 +1,536 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A gradio demo for Qwen3 ASR models.
"""
import argparse
import base64
import io
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import gradio as gr
import numpy as np
import torch
from qwen_asr import Qwen3ASRModel
from scipy.io.wavfile import write as wav_write
def _title_case_display(s: str) -> str:
s = (s or "").strip()
s = s.replace("_", " ")
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
if not items:
return [], {}
display = [_title_case_display(x) for x in items]
mapping = {d: r for d, r in zip(display, items)}
return display, mapping
def _dtype_from_str(s: str) -> torch.dtype:
s = (s or "").strip().lower()
if s in ("bf16", "bfloat16"):
return torch.bfloat16
if s in ("fp16", "float16", "half"):
return torch.float16
if s in ("fp32", "float32"):
return torch.float32
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
def _normalize_audio(wav, eps=1e-12, clip=True):
x = np.asarray(wav)
if np.issubdtype(x.dtype, np.integer):
info = np.iinfo(x.dtype)
if info.min < 0:
y = x.astype(np.float32) / max(abs(info.min), info.max)
else:
mid = (info.max + 1) / 2.0
y = (x.astype(np.float32) - mid) / mid
elif np.issubdtype(x.dtype, np.floating):
y = x.astype(np.float32)
m = np.max(np.abs(y)) if y.size else 0.0
if m > 1.0 + 1e-6:
y = y / (m + eps)
else:
raise TypeError(f"Unsupported dtype: {x.dtype}")
if clip:
y = np.clip(y, -1.0, 1.0)
if y.ndim > 1:
y = np.mean(y, axis=-1).astype(np.float32)
return y
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
"""
Accept gradio audio:
- {"sampling_rate": int, "data": np.ndarray}
- (sr, np.ndarray) [some gradio versions]
Return: (wav_float32_mono, sr)
"""
if audio is None:
return None
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
sr = int(audio["sampling_rate"])
wav = _normalize_audio(audio["data"])
return wav, sr
if isinstance(audio, tuple) and len(audio) == 2:
a0, a1 = audio
if isinstance(a0, int):
sr = int(a0)
wav = _normalize_audio(a1)
return wav, sr
if isinstance(a1, int):
wav = _normalize_audio(a0)
sr = int(a1)
return wav, sr
return None
def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]:
if audio is None:
raise ValueError("Audio is required.")
at = _audio_to_tuple(audio)
if at is not None:
return at
raise ValueError("Unsupported audio input format.")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="qwen-asr-demo",
description=(
"Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n"
"Examples:\n"
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n"
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n"
" qwen-asr-demo --backend vllm --cuda-visible-devices 0\n"
" qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n"
" qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n"
),
formatter_class=argparse.RawTextHelpFormatter,
add_help=True,
)
parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.")
parser.add_argument(
"--aligner-checkpoint",
default=None,
help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).",
)
parser.add_argument(
"--backend",
default="transformers",
choices=["transformers", "vllm"],
help="Backend for ASR model loading (default: transformers).",
)
parser.add_argument(
"--cuda-visible-devices",
default="0",
help=(
"Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). "
"Use e.g. '0' or '1'"
),
)
parser.add_argument(
"--backend-kwargs",
default=None,
help=(
"JSON dict for backend-specific kwargs excluding checkpoints.\n"
"Examples:\n"
" transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n"
" vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n"
),
)
parser.add_argument(
"--aligner-kwargs",
default=None,
help=(
"JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n"
"Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n"
),
)
# Gradio server args
parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).")
parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).")
parser.add_argument(
"--share/--no-share",
dest="share",
default=False,
action=argparse.BooleanOptionalAction,
help="Whether to create a public Gradio link (default: disabled).",
)
parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).")
# HTTPS args
parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).")
parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).")
parser.add_argument(
"--ssl-verify/--no-ssl-verify",
dest="ssl_verify",
default=True,
action=argparse.BooleanOptionalAction,
help="Whether to verify SSL certificate (default: enabled).",
)
return parser
def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]:
if s is None or not str(s).strip():
return {}
try:
obj = json.loads(s)
except Exception as e:
raise ValueError(f"Invalid JSON for {name}: {e}")
if not isinstance(obj, dict):
raise ValueError(f"{name} must be a JSON object (dict).")
return obj
def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None:
v = (cuda_visible_devices or "").strip()
if not v:
return
os.environ["CUDA_VISIBLE_DEVICES"] = v
def _default_backend_kwargs(backend: str) -> Dict[str, Any]:
if backend == "transformers":
return dict(
dtype=torch.bfloat16,
device_map="cuda:0",
max_inference_batch_size=4,
max_new_tokens=512,
)
else:
return dict(
gpu_memory_utilization=0.8,
max_inference_batch_size=4,
max_new_tokens=4096,
)
def _default_aligner_kwargs() -> Dict[str, Any]:
return dict(
dtype=torch.bfloat16,
device_map="cuda:0",
)
def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
out = dict(base)
out.update(override)
return out
def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for k, v in d.items():
if k == "dtype" and isinstance(v, str):
out[k] = _dtype_from_str(v)
else:
out[k] = v
return out
def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str:
"""
Build HTML with per-token audio slices, using base64 data URLs.
Expect timestamps as list[dict] with keys: text, start_time, end_time (ms).
"""
at = _audio_to_tuple(audio_upload)
if at is None:
raise ValueError("Audio input is required for visualization.")
audio, sr = at
if not timestamps:
return "<div style='color:#666'>No timestamps to visualize.</div>"
if not isinstance(timestamps, list):
raise ValueError("Timestamps must be a list (JSON array).")
html_content = """
<style>
.word-alignment-container { display: flex; flex-wrap: wrap; gap: 10px; }
.word-box {
border: 1px solid #ddd; border-radius: 8px; padding: 10px;
background-color: #f9f9f9; box-shadow: 0 2px 4px rgba(0,0,0,0.06);
text-align: center;
}
.word-text { font-size: 18px; font-weight: 700; margin-bottom: 5px; }
.word-time { font-size: 12px; color: #666; margin-bottom: 8px; }
.word-audio audio { width: 140px; height: 30px; }
details { border: 1px solid #ddd; border-radius: 6px; padding: 10px; background-color: #f7f7f7; }
summary { font-weight: 700; cursor: pointer; }
</style>
"""
html_content += """
<details open>
<summary>Timestamps Visualization (时间戳可视化结果)</summary>
<div class="word-alignment-container" style="margin-top: 14px;">
"""
for item in timestamps:
if not isinstance(item, dict):
continue
word = str(item.get("text", "") or "")
start = item.get("start_time", None)
end = item.get("end_time", None)
if start is None or end is None:
continue
start = float(start)
end = float(end)
if end <= start:
continue
start_sample = max(0, int(start * sr))
end_sample = min(len(audio), int(end * sr))
if end_sample <= start_sample:
continue
seg = audio[start_sample:end_sample]
seg_i16 = (np.clip(seg, -1.0, 1.0) * 32767.0).astype(np.int16)
mem = io.BytesIO()
wav_write(mem, sr, seg_i16)
mem.seek(0)
b64 = base64.b64encode(mem.read()).decode("utf-8")
audio_src = f"data:audio/wav;base64,{b64}"
html_content += f"""
<div class="word-box">
<div class="word-text">{word}</div>
<div class="word-time">{start} - {end} s</div>
<div class="word-audio">
<audio controls preload="none" src="{audio_src}"></audio>
</div>
</div>
"""
html_content += "</div></details>"
return html_content
def build_demo(
asr: Qwen3ASRModel,
asr_ckpt: str,
backend: str,
aligner_ckpt: Optional[str] = None,
) -> gr.Blocks:
supported_langs_raw = asr.get_supported_languages()
lang_choices_disp, lang_map = _build_choices_and_map([x for x in supported_langs_raw])
lang_choices = ["Auto"] + lang_choices_disp
has_aligner = bool(aligner_ckpt)
theme = gr.themes.Soft(
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
)
css = ".gradio-container {max-width: none !important;}"
with gr.Blocks(theme=theme, css=css) as demo:
gr.Markdown(
f"""
# Qwen3 ASR Demo
**Backend:** `{backend}`
**ASR Checkpoint:** `{asr_ckpt}`
**Forced Aligner:** `{aligner_ckpt if aligner_ckpt else "(none)"}`
"""
)
with gr.Row():
with gr.Column(scale=2):
audio_in = gr.Audio(label="Audio Input (上传音频)", type="numpy")
lang_in = gr.Dropdown(
label="Language (语种)",
choices=lang_choices,
value="Auto",
interactive=True,
)
if has_aligner:
ts_in = gr.Checkbox(
label="Return Timestamps (是否返回时间戳)",
value=True,
)
else:
ts_in = gr.State(False)
btn = gr.Button("Transcribe (识别)", variant="primary")
with gr.Column(scale=2):
out_lang = gr.Textbox(label="Detected Language", lines=1)
out_text = gr.Textbox(label="Result Text", lines=12)
if has_aligner:
with gr.Column(scale=3):
out_ts = gr.JSON(label="Timestamps时间戳结果")
viz_btn = gr.Button("Visualize Timestamps (可视化时间戳)", variant="secondary")
else:
with gr.Column(scale=3):
out_ts = gr.State(None)
viz_btn = gr.State(None)
# Put the visualization panel below the three columns
if has_aligner:
with gr.Row():
out_ts_html = gr.HTML(label="Timestamps Visualization (时间戳可视化结果)")
else:
out_ts_html = gr.State("")
def run(audio_upload: Any, lang_disp: str, return_ts: bool):
audio_obj = _parse_audio_any(audio_upload)
language = None
if lang_disp and lang_disp != "Auto":
language = lang_map.get(lang_disp, lang_disp)
return_ts = bool(return_ts) and has_aligner
results = asr.transcribe(
audio=audio_obj,
language=language,
return_time_stamps=return_ts,
)
if not isinstance(results, list) or len(results) != 1:
raise RuntimeError(
f"Unexpected result size: {type(results)} "
f"len={len(results) if isinstance(results, list) else 'N/A'}"
)
r = results[0]
if has_aligner:
ts_payload = None
if return_ts:
ts_payload = [
dict(
text=getattr(t, "text", None),
start_time=getattr(t, "start_time", None),
end_time=getattr(t, "end_time", None),
)
for t in (getattr(r, "time_stamps", None) or [])
]
return (
getattr(r, "language", "") or "",
getattr(r, "text", "") or "",
gr.update(value=ts_payload) if return_ts else gr.update(value=None),
gr.update(value=""), # clear html on each transcribe
)
else:
return (
getattr(r, "language", "") or "",
getattr(r, "text", "") or "",
)
def visualize(audio_upload: Any, timestamps_json: Any):
return _make_timestamp_html(audio_upload, timestamps_json)
if has_aligner:
btn.click(
run,
inputs=[audio_in, lang_in, ts_in],
outputs=[out_lang, out_text, out_ts, out_ts_html],
)
viz_btn.click(
visualize,
inputs=[audio_in, out_ts],
outputs=[out_ts_html],
)
else:
btn.click(
run,
inputs=[audio_in, lang_in, ts_in],
outputs=[out_lang, out_text],
)
return demo
def main(argv=None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
_apply_cuda_visible_devices(args.cuda_visible_devices)
backend = args.backend
asr_ckpt = args.asr_checkpoint
aligner_ckpt = args.aligner_checkpoint
user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
backend_kwargs = _coerce_special_types(backend_kwargs)
forced_aligner = None
forced_aligner_kwargs = None
if aligner_ckpt:
forced_aligner = aligner_ckpt
aligner_kwargs = _merge_dicts(_default_aligner_kwargs(), user_aligner_kwargs)
forced_aligner_kwargs = _coerce_special_types(aligner_kwargs)
if backend == "transformers":
asr = Qwen3ASRModel.from_pretrained(
asr_ckpt,
forced_aligner=forced_aligner,
forced_aligner_kwargs=forced_aligner_kwargs,
**backend_kwargs,
)
else:
asr = Qwen3ASRModel.LLM(
model=asr_ckpt,
forced_aligner=forced_aligner,
forced_aligner_kwargs=forced_aligner_kwargs,
**backend_kwargs,
)
demo = build_demo(asr, asr_ckpt, backend, aligner_ckpt=aligner_ckpt)
launch_kwargs: Dict[str, Any] = dict(
server_name=args.ip,
server_port=args.port,
share=args.share,
ssl_verify=True if args.ssl_verify else False,
)
if args.ssl_certfile is not None:
launch_kwargs["ssl_certfile"] = args.ssl_certfile
if args.ssl_keyfile is not None:
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,507 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Minimal web demo for Qwen3ASRModel Streaming Inference (vLLM backend).
Install:
pip install qwen-asr[vllm]
Run:
python streaming/demo_qwen3_asr_vllm_streaming.py
Open:
http://127.0.0.1:7860
"""
import argparse
import time
import uuid
from dataclasses import dataclass
from typing import Dict, Optional
import numpy as np
from flask import Flask, Response, jsonify, request
from qwen_asr import Qwen3ASRModel
@dataclass
class Session:
state: object
created_at: float
last_seen: float
app = Flask(__name__)
global asr
global UNFIXED_CHUNK_NUM
global UNFIXED_TOKEN_NUM
global CHUNK_SIZE_SEC
SESSIONS: Dict[str, Session] = {}
SESSION_TTL_SEC = 10 * 60
def _gc_sessions():
now = time.time()
dead = [sid for sid, s in SESSIONS.items() if now - s.last_seen > SESSION_TTL_SEC]
for sid in dead:
try:
asr.finish_streaming_transcribe(SESSIONS[sid].state)
except Exception:
pass
SESSIONS.pop(sid, None)
def _get_session(session_id: str) -> Optional[Session]:
_gc_sessions()
s = SESSIONS.get(session_id)
if s:
s.last_seen = time.time()
return s
INDEX_HTML = r"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<title>Qwen3-ASR Streaming</title>
<style>
:root{
--bg:#ffffff;
--card:#ffffff;
--muted:#5b6472;
--text:#0f172a;
--border:#e5e7eb;
--ok:#059669;
--warn:#d97706;
--danger:#e11d48;
}
html, body { height: 100%; }
body{
margin:0;
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Noto Sans";
background: var(--bg);
color:var(--text);
}
.wrap{
height: 100vh;
max-width: none;
margin: 0;
padding: 16px;
box-sizing: border-box;
display: flex;
}
.card{
width: 100%;
height: 100%;
background: var(--card);
border:1px solid var(--border);
border-radius: 14px;
padding: 16px;
box-sizing: border-box;
box-shadow: 0 10px 30px rgba(0,0,0,.06);
display: flex;
flex-direction: column;
gap: 12px;
min-height: 0;
}
h1{ font-size: 16px; margin: 0; letter-spacing:.2px;}
.row{ display:flex; gap:12px; align-items:center; flex-wrap: wrap; }
button{
border:1px solid var(--border); border-radius: 12px;
padding: 10px 14px; cursor:pointer; color:var(--text);
background: #f8fafc;
transition: transform .05s ease, background .15s ease, border-color .15s ease;
font-weight: 700;
}
button:hover{ background: #f1f5f9; border-color:#cbd5e1; }
button:active{ transform: translateY(1px); }
button.primary{ border-color: rgba(5,150,105,.35); background: rgba(5,150,105,.10); }
button.danger{ border-color: rgba(225,29,72,.35); background: rgba(225,29,72,.10); }
button:disabled{ opacity:.5; cursor:not-allowed; }
.pill{
font-size: 12px; padding: 6px 10px; border-radius: 999px;
border:1px solid var(--border); color: var(--muted);
background: #f8fafc;
user-select:none;
}
.pill.ok{ color: #065f46; border-color: rgba(5,150,105,.35); background: rgba(5,150,105,.10); }
.pill.warn{ color: #92400e; border-color: rgba(217,119,6,.35); background: rgba(217,119,6,.10); }
.pill.err{ color: #9f1239; border-color: rgba(225,29,72,.35); background: rgba(225,29,72,.10); }
.panel{
border:1px solid var(--border);
border-radius: 12px;
background: #ffffff;
padding: 12px;
}
.panel.textpanel{
flex: 1;
display: flex;
flex-direction: column;
min-height: 0;
}
.label{ color:var(--muted); font-size: 12px; margin-bottom: 6px; }
.mono{ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New"; }
#text{
flex: 1;
min-height: 0;
white-space: pre-wrap;
line-height: 1.6;
font-size: 15px;
padding: 12px;
border-radius: 12px;
border: 1px solid var(--border);
background: #f8fafc;
overflow: auto;
}
a{ color: #2563eb; text-decoration:none; }
</style>
</head>
<body>
<div class="wrap">
<div class="card">
<h1>Qwen3-ASR Streaming</h1>
<div class="row">
<button id="btnStart" class="primary">Start / 开始</button>
<button id="btnStop" class="danger" disabled>Stop / 停止</button>
<span id="status" class="pill warn">Idle / 未开始</span>
<a href="javascript:void(0)" id="btnClear" class="mono" style="margin-left:auto;">Clear / 清空</a>
</div>
<div class="panel">
<div class="label">Language / 语言</div>
<div id="lang" class="mono">—</div>
</div>
<div class="panel textpanel">
<div class="label">Text / 文本</div>
<div id="text"></div>
</div>
</div>
</div>
<script>
(() => {
const $ = (id) => document.getElementById(id);
const btnStart = $("btnStart");
const btnStop = $("btnStop");
const btnClear = $("btnClear");
const statusEl = $("status");
const langEl = $("lang");
const textEl = $("text");
const CHUNK_MS = 500;
const TARGET_SR = 16000;
let audioCtx = null;
let processor = null;
let source = null;
let mediaStream = null;
let sessionId = null;
let running = false;
let buf = new Float32Array(0);
let pushing = false;
function setStatus(text, cls){
statusEl.textContent = text;
statusEl.className = "pill " + (cls || "");
}
function lockUI(on){
btnStart.disabled = on;
btnStop.disabled = !on;
}
function concatFloat32(a, b){
const out = new Float32Array(a.length + b.length);
out.set(a, 0);
out.set(b, a.length);
return out;
}
function resampleLinear(input, srcSr, dstSr){
if (srcSr === dstSr) return input;
const ratio = dstSr / srcSr;
const outLen = Math.max(0, Math.round(input.length * ratio));
const out = new Float32Array(outLen);
for (let i = 0; i < outLen; i++){
const x = i / ratio;
const x0 = Math.floor(x);
const x1 = Math.min(x0 + 1, input.length - 1);
const t = x - x0;
out[i] = input[x0] * (1 - t) + input[x1] * t;
}
return out;
}
async function apiStart(){
const r = await fetch("/api/start", {method:"POST"});
if(!r.ok) throw new Error(await r.text());
const j = await r.json();
sessionId = j.session_id;
}
async function apiPushChunk(float32_16k){
const r = await fetch("/api/chunk?session_id=" + encodeURIComponent(sessionId), {
method: "POST",
headers: {"Content-Type":"application/octet-stream"},
body: float32_16k.buffer
});
if(!r.ok) throw new Error(await r.text());
return await r.json();
}
async function apiFinish(){
const r = await fetch("/api/finish?session_id=" + encodeURIComponent(sessionId), {method:"POST"});
if(!r.ok) throw new Error(await r.text());
return await r.json();
}
btnClear.onclick = () => { textEl.textContent = ""; };
async function stopAudioPipeline(){
try{
if (processor){ processor.disconnect(); processor.onaudioprocess = null; }
if (source) source.disconnect();
if (audioCtx) await audioCtx.close();
if (mediaStream) mediaStream.getTracks().forEach(t => t.stop());
}catch(e){}
processor = null; source = null; audioCtx = null; mediaStream = null;
}
btnStart.onclick = async () => {
if (running) return;
textEl.textContent = "";
langEl.textContent = "";
buf = new Float32Array(0);
try{
setStatus("Starting… / 启动中…", "warn");
lockUI(true);
await apiStart();
mediaStream = await navigator.mediaDevices.getUserMedia({
audio: {
channelCount: 1,
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true
},
video: false
});
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
source = audioCtx.createMediaStreamSource(mediaStream);
processor = audioCtx.createScriptProcessor(4096, 1, 1);
const chunkSamples = Math.round(TARGET_SR * (CHUNK_MS / 1000));
processor.onaudioprocess = (e) => {
if (!running) return;
const input = e.inputBuffer.getChannelData(0);
const resampled = resampleLinear(input, audioCtx.sampleRate, TARGET_SR);
buf = concatFloat32(buf, resampled);
if (!pushing) pump();
};
source.connect(processor);
processor.connect(audioCtx.destination);
running = true;
setStatus("Listening… / 识别中…", "ok");
}catch(err){
console.error(err);
setStatus("Start failed / 启动失败: " + err.message, "err");
lockUI(false);
running = false;
sessionId = null;
await stopAudioPipeline();
}
};
async function pump(){
if (pushing) return;
pushing = true;
const chunkSamples = Math.round(TARGET_SR * (CHUNK_MS / 1000));
try{
while (running && buf.length >= chunkSamples){
const chunk = buf.slice(0, chunkSamples);
buf = buf.slice(chunkSamples);
const j = await apiPushChunk(chunk);
langEl.textContent = j.language || "";
textEl.textContent = j.text || "";
if (running) setStatus("Listening… / 识别中…", "ok");
}
}catch(err){
console.error(err);
if (running) setStatus("Backend error / 后端错误: " + err.message, "err");
}finally{
pushing = false;
}
}
btnStop.onclick = async () => {
if (!running) return;
running = false;
setStatus("Finishing… / 收尾中…", "warn");
lockUI(false);
await stopAudioPipeline();
try{
if (sessionId){
const j = await apiFinish();
langEl.textContent = j.language || "";
textEl.textContent = j.text || "";
}
setStatus("Stopped / 已停止", "");
}catch(err){
console.error(err);
setStatus("Finish failed / 收尾失败: " + err.message, "err");
}finally{
sessionId = null;
buf = new Float32Array(0);
pushing = false;
}
};
})();
</script>
</body>
</html>
"""
@app.get("/")
def index():
return Response(INDEX_HTML, mimetype="text/html; charset=utf-8")
@app.post("/api/start")
def api_start():
session_id = uuid.uuid4().hex
state = asr.init_streaming_state(
unfixed_chunk_num=UNFIXED_CHUNK_NUM,
unfixed_token_num=UNFIXED_TOKEN_NUM,
chunk_size_sec=CHUNK_SIZE_SEC,
)
now = time.time()
SESSIONS[session_id] = Session(state=state, created_at=now, last_seen=now)
return jsonify({"session_id": session_id})
@app.post("/api/chunk")
def api_chunk():
session_id = request.args.get("session_id", "")
s = _get_session(session_id)
if not s:
return jsonify({"error": "invalid session_id"}), 400
if request.mimetype != "application/octet-stream":
return jsonify({"error": "expect application/octet-stream"}), 400
raw = request.get_data(cache=False)
if len(raw) % 4 != 0:
return jsonify({"error": "float32 bytes length not multiple of 4"}), 400
wav = np.frombuffer(raw, dtype=np.float32).reshape(-1)
asr.streaming_transcribe(wav, s.state)
return jsonify(
{
"language": getattr(s.state, "language", "") or "",
"text": getattr(s.state, "text", "") or "",
}
)
@app.post("/api/finish")
def api_finish():
session_id = request.args.get("session_id", "")
s = _get_session(session_id)
if not s:
return jsonify({"error": "invalid session_id"}), 400
asr.finish_streaming_transcribe(s.state)
out = {
"language": getattr(s.state, "language", "") or "",
"text": getattr(s.state, "text", "") or "",
}
SESSIONS.pop(session_id, None)
return jsonify(out)
def parse_args():
p = argparse.ArgumentParser(description="Qwen3-ASR Streaming Web Demo (vLLM backend)")
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
p.add_argument("--host", default="0.0.0.0", help="Bind host")
p.add_argument("--port", type=int, default=8000, help="Bind port")
p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization")
p.add_argument("--unfixed-chunk-num", type=int, default=4)
p.add_argument("--unfixed-token-num", type=int, default=5)
p.add_argument("--chunk-size-sec", type=float, default=1.0)
return p.parse_args()
def main():
args = parse_args()
global asr
global UNFIXED_CHUNK_NUM
global UNFIXED_TOKEN_NUM
global CHUNK_SIZE_SEC
UNFIXED_CHUNK_NUM = args.unfixed_chunk_num
UNFIXED_TOKEN_NUM = args.unfixed_token_num
CHUNK_SIZE_SEC = args.chunk_size_sec
asr = Qwen3ASRModel.LLM(
model=args.asr_model_path,
gpu_memory_utilization=args.gpu_memory_utilization,
max_new_tokens=32,
)
print("Model loaded.")
app.run(host=args.host, port=args.port, debug=False, use_reloader=False, threaded=True)
if __name__ == "__main__":
main()

46
qwen_asr/cli/serve.py Normal file
View File

@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
try:
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
from vllm import ModelRegistry
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
except Exception as e:
raise ImportError(
"vLLM is not available, to use qwen-asr-serve, please install with: pip install qwen-asr[vllm]"
) from e
from vllm.entrypoints.cli.main import main as vllm_main
def main():
sys.argv.insert(1, "serve")
vllm_main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,18 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_qwen3_asr import Qwen3ASRConfig
from .modeling_qwen3_asr import Qwen3ASRForConditionalGeneration
from .processing_qwen3_asr import Qwen3ASRProcessor

View File

@ -0,0 +1,425 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen3ASRAudioEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3ASRAudioEncoder`]. It is used to instantiate a
Qwen3-ASR audio encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
architecture.
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_mel_bins (`int`, *optional*, defaults to 128):
Number of mel features used per input features. Should correspond to the value used in the
`Qwen3ASRProcessor` class.
encoder_layers (`int`, *optional*, defaults to 32):
Number of encoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 20):
Number of attention heads for each attention layer in the Transformer encoder.
encoder_ffn_dim (`int`, *optional*, defaults to 5120):
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
d_model (`int`, *optional*, defaults to 1280):
Dimensionality of the layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_function (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
max_source_positions (`int`, *optional*, defaults to 1500):
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
n_window (`int`, *optional*, defaults to 100):
The chunk for conv and flash attn in AudioEncoder.
output_dim (`int`, *optional*, defaults to 3584):
The output dimension of AudioEncoder.
Example:
```python
>>> from transformers import Qwen3ASRAudioEncoderConfig, Qwen3ASRAudioEncoder
>>> # Initializing a Qwen3ASRAudioEncoderConfig
>>> configuration = Qwen3ASRAudioEncoderConfig()
>>> # Initializing a Qwen3ASRAudioEncoder (with random weights)
>>> model = Qwen3ASRAudioEncoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_asr_audio_encoder"
def __init__(
self,
num_mel_bins=128,
encoder_layers=32,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
d_model=1280,
dropout=0,
attention_dropout=0,
activation_function="gelu",
activation_dropout=0,
scale_embedding=False,
initializer_range=0.02,
max_source_positions=1500,
n_window=100,
output_dim=3584,
n_window_infer=400,
conv_chunksize=500,
downsample_hidden_size=480,
**kwargs,
):
super().__init__(**kwargs)
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.num_hidden_layers = encoder_layers
self.initializer_range = initializer_range
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions
self.n_window = n_window
self.output_dim = output_dim
self.n_window_infer = n_window_infer
self.conv_chunksize = conv_chunksize
self.downsample_hidden_size = downsample_hidden_size
class Qwen3ASRTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3ASRTextModel`]. It is used to instantiate a
Qwen3-ASR model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-ASR-1.7B [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3ASR model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3ASRModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen3ASRTextModel, Qwen3ASRTextConfig
>>> # Initializing a Qwen3ASR style configuration
>>> configuration = Qwen3ASRTextConfig()
>>> # Initializing a model from the Qwen3-VL-7B style configuration
>>> model = Qwen3ASRTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_asr_text"
base_config_key = "text_config"
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen3ASRThinkerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3ASRThinker`]. It is used to instantiate a
Qwen3-ASR-Thinker model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the thinker component of the Qwen3-Omni
architecture.
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
audio_config (`dict`, *optional*):
The config dictionary of the audio backbone.
text_config (`dict`, *optional*):
The config dictionary of the text backbone.
audio_token_id (`int`, *optional*, defaults to 151646):
The audio token id to encode the audio prompt.
audio_start_token_id (`int`, *optional*, defaults to 151647):
The audio start token id to encode the audio prompt.
user_token_id (`int`, *optional*, defaults to 872):
The user token id to encode the user token.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Qwen3ASRThinkerModel, Qwen3ASRThinkerConfig
>>> # Initializing a default Qwen3ASRThinkerConfig
>>> configuration = Qwen3ASRThinkerConfig()
>>> # Initializing a model (with random weights) from the default configuration
>>> model = Qwen3ASRThinkerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_asr_thinker"
attribute_map = {}
sub_configs = {
"audio_config": Qwen3ASRAudioEncoderConfig,
"text_config": Qwen3ASRTextConfig,
}
def __init__(
self,
audio_config=None,
text_config=None,
audio_token_id=151646,
audio_start_token_id=151647,
user_token_id=872,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.user_token_id = user_token_id
self.audio_start_token_id = audio_start_token_id
self.initializer_range = initializer_range
if isinstance(audio_config, dict):
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = Qwen3ASRAudioEncoderConfig()
self.audio_config = audio_config
if isinstance(text_config, dict):
text_config = Qwen3ASRTextConfig(**text_config)
elif text_config is None:
text_config = Qwen3ASRTextConfig()
self.text_config = text_config
self.audio_token_id = audio_token_id
class Qwen3ASRConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`Qwen3ASRForConditionalGeneration`]. It is used to instantiate a Qwen3ASR
model according to the specified sub-models configurations, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the
[Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
support_languages (`List[str]`, *optional*): The languages supported by the model.
Example:
```python
>>> from transformers import (
... Qwen3ASRThinkerConfig,
... Qwen3ASRForConditionalGeneration,
... Qwen3ASRConfig,
... )
>>> # Initializing a Qwen3ASR style configuration
>>> configuration = Qwen3ASRConfig()
>>> # Initializing a model from the configuration
>>> model = Qwen3ASRForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_asr"
sub_configs = {
"thinker_config": Qwen3ASRThinkerConfig,
}
def __init__(
self,
thinker_config=None,
support_languages=None,
**kwargs,
):
super().__init__(**kwargs)
if thinker_config is None:
thinker_config = {}
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
self.support_languages = support_languages
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overridden for deeply nested config like Qwen2.5-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
__all__ = ["Qwen3ASRConfig", "Qwen3ASRThinkerConfig", "Qwen3ASRAudioEncoderConfig"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,209 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
from transformers.audio_utils import AudioInput
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
from transformers.tokenization_utils_base import TextInput
class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"padding_side": "left",
},
"audio_kwargs": {
"sampling_rate": 16000,
"padding": True,
"return_attention_mask": True,
},
}
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
class Qwen3ASRProcessor(ProcessorMixin):
r"""
Constructs a Qwen3ASR processor.
[`Qwen3ASRProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
[`~Qwen3ASRProcessor.__call__`] and [`~Qwen3ASRProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The text tokenizer.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
"""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self, feature_extractor=None, tokenizer=None, chat_template=None
):
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
self.audio_token = self.tokenizer.audio_token
self.audio_bos_token = self.tokenizer.audio_bos_token
self.audio_eos_token = self.tokenizer.audio_eos_token
def __call__(
self,
text: TextInput = None,
audio: AudioInput = None,
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
output_kwargs = self._merge_kwargs(
Qwen3ASRProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if audio is not None:
output_kwargs["audio_kwargs"]["padding"] = True
output_kwargs["audio_kwargs"]["truncation"] = False
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename feature_attention_mask to prevent conflicts later on
audio_inputs["input_features"] = audio_inputs.pop(
"input_features"
) # rename input_features to prevent conflicts later on
audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1)))
else:
audio_inputs = {}
audio_lengths = iter([])
if not isinstance(text, list):
text = [text]
text = self.replace_multimodal_special_tokens(
text,
audio_lengths,
)
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(
data={**texts_inputs, **audio_inputs},
tensor_type=kwargs.get("return_tensors"),
)
def replace_multimodal_special_tokens(
self,
text,
audio_lengths,
):
processed_text = []
for sample in text:
positions = []
special_tokens = [re.escape(tok) for tok in [self.audio_token]]
pattern = "|".join(special_tokens)
positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)])
positions.sort(key=lambda x: x[0])
for _, special_token in positions:
if special_token == self.audio_token:
sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
sample = sample.replace("<|audio_placeholder|>", self.audio_token)
processed_text.append(sample)
return processed_text
def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]:
"""
Splits token index list into chunks based on token value ranges.
Given a list of token indices, returns a list of (start, end) index tuples representing
slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
- the first chunk contains token values < 1000,
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
Returns:
`list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
and end (exclusive) indices of a chunk in `token_indices`.
"""
def _iter():
i, start_idx = 0, 0 # skip bos token
current_chunk = 1
while i < len(token_indices): # skip eos token
if token_indices[i] >= current_chunk * tokens_per_chunk:
yield (start_idx, i)
start_idx = i
current_chunk += 1
i += 1
yield (start_idx, len(token_indices))
return list(_iter())
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
return super().apply_chat_template(conversations, chat_template, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ feature_extractor_input_names
+ ["feature_attention_mask"]
)
)
__all__ = ["Qwen3ASRProcessor"]

View File

@ -0,0 +1,16 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .qwen3_asr import Qwen3ASRForConditionalGeneration

View File

@ -0,0 +1,997 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2026 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, cast
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import MultiModalConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
Qwen2_5OmniAudioFeatureInputs,
Qwen3OmniMoeThinkerMultiModalProcessor,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
AudioItem,
ModalityData,
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.model_executor.models.vision import (
get_vit_attn_backend,
)
from ..transformers_backend.configuration_qwen3_asr import (
Qwen3ASRConfig,
Qwen3ASRThinkerConfig,
Qwen3ASRAudioEncoderConfig
)
from ..transformers_backend.processing_qwen3_asr import (
Qwen3ASRProcessor,
)
try:
from vllm.multimodal.profiling import BaseDummyInputsBuilder
except:
from vllm.multimodal.processing import BaseDummyInputsBuilder
logger = init_logger(__name__)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
# ============= Audio Encoder Components =============
class SinusoidsPositionEmbedding(nn.Module):
"""Sinusoidal position embedding for audio encoder."""
def __init__(self, length: int, channels: int, max_timescale: int = 10000):
super().__init__()
self.length = length
self.channels = channels
self.max_timescale = max_timescale
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2).float()
)
scaled_time = (
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
)
positional_embedding = torch.cat(
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=1
)
self.register_buffer(
"positional_embedding", positional_embedding, persistent=False
)
def forward(self, seqlen: int) -> torch.Tensor:
return self.positional_embedding[:seqlen, :]
class Qwen3ASRAudioAttention(nn.Module):
"""Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""
def __init__(
self,
config: Qwen3ASRAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
self.num_heads = config.encoder_attention_heads
self.head_dim = self.embed_dim // self.num_heads
tp_size = get_tensor_model_parallel_world_size()
self.num_local_heads = self.num_heads // tp_size
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
total_num_kv_heads=self.num_heads,
bias=True,
prefix=f"{prefix}.qkv",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
bias=True,
prefix=f"{prefix}.out_proj",
)
self.attn = MMEncoderAttention(
num_heads=self.num_local_heads,
head_size=self.head_dim,
scale=self.scaling,
multimodal_config=multimodal_config,
)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor | None,
) -> torch.Tensor:
seq_length, _ = hidden_states.size()
qkv, _ = self.qkv(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(1, seq_length, -1, self.head_dim)
k = k.view(1, seq_length, -1, self.head_dim)
v = v.view(1, seq_length, -1, self.head_dim)
attn_output = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.view(seq_length, -1)
output, _ = self.out_proj(attn_output)
return output
class Qwen3ASRAudioEncoderLayer(nn.Module):
"""Transformer encoder layer for Qwen3-Omni Audio Encoder."""
def __init__(
self,
config: Qwen3ASRAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Qwen3ASRAudioAttention(
config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn"
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.encoder_ffn_dim,
bias=True,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.encoder_ffn_dim,
self.embed_dim,
bias=True,
prefix=f"{prefix}.fc2",
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor | None,
) -> torch.Tensor:
"""
Args:
hidden_states: Input tensor of shape (seq_len, hidden_size)
cu_seqlens: Cumulative sequence lengths
max_seqlen: Maximum sequence length in the batch
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# Clamp for numerical stability with fp16
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
return hidden_states
class Qwen3ASRAudioEncoder(nn.Module):
"""vLLM-native Qwen3-ASR Audio Encoder."""
def __init__(
self,
config: Qwen3ASRAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.n_window = config.n_window
self.n_window_infer = config.n_window_infer
self.conv_chunksize = config.conv_chunksize
# Position embedding
self.positional_embedding = SinusoidsPositionEmbedding(
self.max_source_positions, embed_dim
)
# Convolutional layers for mel-spectrogram processing
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
self.conv2d3 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
conv_out_dim = config.downsample_hidden_size * (
(((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
)
self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False)
# Transformer encoder layers
self.layers = nn.ModuleList(
[
Qwen3ASRAudioEncoderLayer(
config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{i}",
)
for i in range(config.encoder_layers)
]
)
# Output layers
self.ln_post = nn.LayerNorm(config.d_model)
self.proj1 = nn.Linear(config.d_model, config.d_model)
self.act = _ACTIVATION_REGISTRY[config.activation_function]
self.proj2 = nn.Linear(config.d_model, config.output_dim)
# Get attention backend
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=config.d_model // config.encoder_attention_heads,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
"""Compute max_seqlen only for flash attention backends."""
max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
@property
def dtype(self) -> torch.dtype:
return self.conv2d1.weight.dtype
@property
def device(self) -> torch.device:
return self.conv2d1.weight.device
def forward(
self,
input_features: torch.Tensor,
feature_lens: torch.Tensor,
aftercnn_lens: torch.Tensor,
):
# Compute chunk information
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
# Split input features into chunks and pad
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(
chunk_list, batch_first=True
).transpose(1, 2)
# Compute feature lengths after CNN
feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths)
# Vectorized mask creation: avoid creating many small tensors
max_len_after_cnn = feature_lens_after_cnn.max().item()
indices = torch.arange(max_len_after_cnn, device=padded_feature.device)
padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze(
1
)
# Add channel dimension for conv2d
padded_feature = padded_feature.unsqueeze(1)
# Apply convolutional layers (chunk if needed to avoid OOM)
if padded_feature.size(0) <= self.conv_chunksize:
# Fast path: no chunking needed
padded_embed = F.gelu(self.conv2d1(padded_feature))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
else:
# Chunked processing to avoid OOM
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
# (batch, channels, freq, time) -> (batch, time, channels*freq)
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
)
# Add positional embedding
positional_embedding = (
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
# Extract valid hidden states and compute cu_seqlens
hidden_states = padded_embed[padded_mask_after_cnn]
# Compute cumulative sequence lengths for chunked attention
cu_chunk_lens = [0]
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
self.n_window_infer // (self.n_window * 2)
)
# Use tolist() for efficient batch conversion from tensor to Python
for cnn_len in aftercnn_lens.tolist():
num_full_chunks = cnn_len // window_aftercnn
remainder = cnn_len % window_aftercnn
cu_chunk_lens.extend([window_aftercnn] * num_full_chunks)
if remainder:
cu_chunk_lens.append(remainder)
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
-1, dtype=torch.int32
)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
# Apply transformer layers
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
cu_seqlens,
max_seqlen,
)
# Apply output layers
hidden_states = self.ln_post(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.proj2(hidden_states)
return hidden_states
def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after the three conv2d layers."""
lengths = input_lengths
for _ in range(3):
lengths = (lengths - 1) // 2 + 1
return lengths
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights with mapping from HuggingFace format."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("self_attn.qkv.", "self_attn.q_proj.", "q"),
("self_attn.qkv.", "self_attn.k_proj.", "k"),
("self_attn.qkv.", "self_attn.v_proj.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3ASRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3ASRConfig).thinker_config
def get_hf_processor(self, **kwargs: object) -> Qwen3ASRProcessor:
processor = self.ctx.get_hf_processor(
Qwen3ASRProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
if not hasattr(processor, "audio_token"):
processor.audio_token = "<|audio_pad|>"
return processor
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
return audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
feature_extractor = self.info.get_feature_extractor()
target_audio_length = (
min(
feature_extractor.chunk_length,
30,
)
* feature_extractor.sampling_rate
)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=target_audio_length,
num_audios=num_audios,
overrides=audio_overrides,
),
}
def _qwen3asr_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,)))
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1
),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
)
class Qwen3ASRMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"input_audio_features", "audio_feature_lengths"},
fields_factory=_qwen3asr_field_config,
)
return super()._parse_audio_data(data)
class Qwen3ASRMultiModalProcessor(
Qwen3OmniMoeThinkerMultiModalProcessor,
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen3ASRMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen3asr_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = processor.audio_token
audio_token_id = vocab[audio_token]
out_mm_data = out_mm_kwargs.get_data()
audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
feature_attention_mask = out_mm_data.get("feature_attention_mask")
if audio_feature_lengths is None and feature_attention_mask is None:
audio_output_lengths = []
elif audio_feature_lengths is not None:
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor)
audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)
)
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int):
num_features = audio_output_lengths[item_idx]
if num_features == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio = audios.get(item_idx)
raise ValueError(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model"
)
return [audio_token_id] * num_features
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_qwen2_audio,
),
]
@MULTIMODAL_REGISTRY.register_processor(
Qwen3ASRMultiModalProcessor,
info=Qwen3ASRProcessingInfo,
dummy_inputs=Qwen3ASRDummyInputsBuilder,
)
class Qwen3ASRForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsMRoPE,
SupportsTranscription,
):
supported_languages = ISO639_1_SUPPORTED_LANGS
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "language_model.lm_head.",
"thinker.model.": "language_model.model.",
"thinker.": "",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):
return "<|audio_start|><|audio_pad|><|audio_end|>"
raise ValueError("Only audio modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config # needed for torch compile forward context
thinker_config: Qwen3ASRThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config
)
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
self.multimodal_config = multimodal_config
self.audio_tower = Qwen3ASRAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.quant_config = quant_config
self.language_model = Qwen3ForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config, architectures=["Qwen3ForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Qwen2_5OmniAudioFeatureInputs | None:
input_audio_features = kwargs.pop("input_audio_features", None)
audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
if input_audio_features is None:
return None
return Qwen2_5OmniAudioFeatureInputs(
type="audio_features",
input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (
input_key in ("input_audio_features")
and "audio" not in mm_input_by_modality
):
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
**kwargs
)
return mm_input_by_modality
def _process_audio_input(
self,
audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] | None = None,
cached_audio_features: torch.Tensor | None = None,
) -> torch.Tensor:
input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"]
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_features = self.audio_tower(
input_features.to(self.audio_tower.dtype),
feature_lens=audio_feature_lengths,
aftercnn_lens=audio_output_lengths,
)
return audio_features.split(audio_output_lengths.tolist())
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "audio":
audio_embeddings = self._process_audio_input(multimodal_input)
multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["talker.", "code2wav."],
)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
seq_len = len(input_tokens)
if not mm_features:
# No audio features, just return linear positions
llm_positions = (
torch.arange(seq_len, dtype=torch.long).view(1, -1).expand(3, -1)
)
return llm_positions.clone(), 0
llm_pos_ids_list: list[torch.Tensor] = []
st = 0
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
# Get audio feature length from mm_feature data
audio_feature_length = mm_feature.data["audio_feature_lengths"].data
if isinstance(audio_feature_length, torch.Tensor):
audio_feature_length = audio_feature_length.item()
audio_len = _get_feat_extract_output_lengths(
torch.tensor(audio_feature_length)
).item()
# Text segment before audio (includes audio_start token)
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_positions = (
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(text_positions)
st_idx = st_idx + text_len
# Audio token segment
audio_positions = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(audio_positions)
st = offset + audio_len
# Handle remaining text (includes audio_end and any trailing text)
if st < seq_len:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = seq_len - st
final_text_positions = (
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(final_text_positions)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = (llm_positions.max() + 1 - seq_len).item()
return llm_positions, mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
tower_model=["audio_tower."],
)
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
return SpeechToTextConfig(
max_audio_clip_s=feature_extractor.chunk_length,
sample_rate=feature_extractor.sampling_rate,
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
tokenizer = cached_tokenizer_from_config(model_config)
audio_placeholder = cls.get_placeholder_str("audio", 0)
if task_type not in ("transcribe", "translate"):
raise ValueError(
f"Unsupported task_type '{task_type}'. "
"Supported task types are 'transcribe' and 'translate'."
)
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
if to_language is None:
prompt = (
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
else:
prompt = (
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
f"<|im_start|>assistant\nlanguage {full_lang_name_to}<asr_text>"
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,821 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
from .qwen3_forced_aligner import Qwen3ForcedAligner
from .utils import (
MAX_ASR_INPUT_SECONDS,
MAX_FORCE_ALIGN_INPUT_SECONDS,
SAMPLE_RATE,
SUPPORTED_LANGUAGES,
AudioChunk,
AudioLike,
chunk_list,
merge_languages,
normalize_audios,
normalize_language_name,
parse_asr_output,
split_audio_into_chunks,
validate_language,
)
try:
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
from vllm import ModelRegistry
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
except:
pass
@dataclass
class ASRTranscription:
"""
One transcription result.
Attributes:
language (str):
Merged language string for the sample, e.g. "Chinese" or "Chinese,English".
Empty string if unknown or silent audio.
text (str):
Transcribed text.
time_stamps (Optional[Any]):
Forced aligner output (ForcedAlignResult).
Present only when return_time_stamps=True.
"""
language: str
text: str
time_stamps: Optional[Any] = None
@dataclass
class ASRStreamingState:
"""
Streaming ASR state for one audio stream (single utterance).
Attributes:
unfixed_chunk_num (int):
For the first N chunks, do not use previous ASR result as prefix prompt (reset prefix to "").
unfixed_token_num (int):
When chunk_id >= unfixed_chunk_num, rollback the last K tokens from the accumulated text
before using it as prefix prompt, to reduce boundary jitter.
chunk_size_sec (float):
Chunk size in seconds. Audio will be fed to the model in increments of this length.
chunk_size_samples (int):
Chunk size in samples at 16kHz (derived from chunk_size_sec).
chunk_id (int):
Current chunk index (0-based).
buffer (np.ndarray):
Buffered PCM samples that are not yet consumed into a full chunk.
audio_accum (np.ndarray):
Accumulated audio from the beginning of the stream up to current time (no padding).
prompt_raw (str):
Base prompt generated by chat template (with generation prompt), without appended prefix text.
context (str):
Context string.
force_language (Optional[str]):
If provided, force output to be text-only by appending "language X<asr_text>" in prompt_raw,
consistent with non-streaming transcribe().
language (str):
Latest parsed language (updated after each chunk decode). Empty if unknown/silent.
text (str):
Latest parsed transcription text (updated after each chunk decode).
_raw_decoded (str):
Internal accumulated decoded raw text (before parse_asr_output normalization).
Used for rollback/token trimming and as prefix for prompting.
"""
unfixed_chunk_num: int
unfixed_token_num: int
chunk_size_sec: float
chunk_size_samples: int
chunk_id: int
buffer: np.ndarray
audio_accum: np.ndarray
prompt_raw: str
context: str
force_language: Optional[str]
language: str
text: str
_raw_decoded: str
class Qwen3ASRModel:
"""
Unified inference wrapper for Qwen3-ASR with two backends:
- Transformers backend
- vLLM backend
It optionally supports time stamp output via Qwen3-ForcedAligner.
Notes:
- Each request uses a context text and exactly one audio.
- If language is provided, the prompt will force the output to be text-only by appending
"language {Language}<asr_text>" to the assistant prompt.
"""
def __init__(
self,
backend: str,
model: Any,
processor: Any,
sampling_params: Optional[Any] = None,
forced_aligner: Optional[Qwen3ForcedAligner] = None,
max_inference_batch_size: int = -1,
max_new_tokens: int = 512,
):
self.backend = backend # "transformers" | "vllm"
self.model = model
self.processor = processor
self.sampling_params = sampling_params
self.forced_aligner = forced_aligner
self.max_inference_batch_size = int(max_inference_batch_size)
self.max_new_tokens = max_new_tokens
if backend == "transformers":
self.device = getattr(model, "device", None)
if self.device is None:
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
self.dtype = getattr(model, "dtype", torch.float32)
else:
self.device = None
self.dtype = None
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
forced_aligner: Optional[str] = None,
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
max_inference_batch_size: int = 32,
max_new_tokens: Optional[int] = 512,
**kwargs,
) -> "Qwen3ASRModel":
"""
Initialize using Transformers backend.
Args:
pretrained_model_name_or_path:
HuggingFace repo id or local directory.
forced_aligner:
Optional forced aligner model path/repo id.
forced_aligner_kwargs:
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
max_inference_batch_size:
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
max_new_tokens:
Maximum number of tokens to generate.
**kwargs:
Forwarded to AutoModel.from_pretrained(...).
Returns:
Qwen3ASRModel
"""
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
forced_aligner_model = None
if forced_aligner is not None:
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
forced_aligner, **(forced_aligner_kwargs or {})
)
return cls(
backend="transformers",
model=model,
processor=processor,
sampling_params=None,
forced_aligner=forced_aligner_model,
max_inference_batch_size=max_inference_batch_size,
max_new_tokens=max_new_tokens,
)
@classmethod
def LLM(
cls,
model: str,
forced_aligner: Optional[str] = None,
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
max_inference_batch_size: int = -1,
max_new_tokens: Optional[int] = 4096,
**kwargs,
) -> "Qwen3ASRModel":
"""
Initialize using vLLM backend.
Import is isolated to keep vLLM optional.
Args:
model:
Model path/repo for vLLM.
forced_aligner:
Optional forced aligner model path/repo id.
forced_aligner_kwargs:
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
max_inference_batch_size:
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
max_new_tokens:
Maximum number of tokens to generate.
**kwargs:
Forwarded to vllm.LLM(...).
Returns:
Qwen3ASRModel
Raises:
ImportError: If vLLM is not installed.
"""
try:
from vllm import LLM as vLLM
from vllm import SamplingParams
except Exception as e:
raise ImportError(
"vLLM is not available. Install with: pip install qwen-asr[vllm]"
) from e
llm = vLLM(model=model, **kwargs)
processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True)
sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens}))
forced_aligner_model = None
if forced_aligner is not None:
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
forced_aligner, **(forced_aligner_kwargs or {})
)
return cls(
backend="vllm",
model=llm,
processor=processor,
sampling_params=sampling_params,
forced_aligner=forced_aligner_model,
max_inference_batch_size=max_inference_batch_size,
max_new_tokens=None,
)
def get_supported_languages(self) -> List[str]:
"""
Returns the supported language list.
Returns:
List[str]: Canonical language names.
"""
return list(SUPPORTED_LANGUAGES)
@torch.no_grad()
def transcribe(
self,
audio: Union[AudioLike, List[AudioLike]],
context: Union[str, List[str]] = "",
language: Optional[Union[str, List[Optional[str]]]] = None,
return_time_stamps: bool = False,
) -> List[ASRTranscription]:
"""
Transcribe audio with optional context and optional forced alignment timestamps.
Args:
audio:
Audio input(s). Supported:
- str: local path / URL / base64 data url
- (np.ndarray, sr)
- list of above
context:
Context string(s). If scalar, it will be broadcast to batch size.
language:
Optional language(s). If provided, it must be in supported languages.
If scalar, it will be broadcast to batch size.
If provided, the prompt will force output to be transcription text only.
return_time_stamps:
If True, timestamps are produced via forced aligner and merged across chunks.
This requires forced_aligner initialized.
Returns:
List[ASRTranscription]: One result per input audio.
Raises:
ValueError:
- If return_time_stamps=True but forced_aligner is not provided.
- If language is unsupported.
- If batch sizes mismatch for context/language.
"""
if return_time_stamps and self.forced_aligner is None:
raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.")
wavs = normalize_audios(audio)
n = len(wavs)
ctxs = context if isinstance(context, list) else [context]
if len(ctxs) == 1 and n > 1:
ctxs = ctxs * n
if len(ctxs) != n:
raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}")
langs_in: List[Optional[str]]
if language is None:
langs_in = [None] * n
else:
langs_in = language if isinstance(language, list) else [language]
if len(langs_in) == 1 and n > 1:
langs_in = langs_in * n
if len(langs_in) != n:
raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}")
langs_norm: List[Optional[str]] = []
for l in langs_in:
if l is None or str(l).strip() == "":
langs_norm.append(None)
else:
ln = normalize_language_name(str(l))
validate_language(ln)
langs_norm.append(ln)
max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS
# chunk audios and record mapping
chunks: List[AudioChunk] = []
for i, wav in enumerate(wavs):
parts = split_audio_into_chunks(
wav=wav,
sr=SAMPLE_RATE,
max_chunk_sec=max_chunk_sec,
)
for j, (cwav, offset_sec) in enumerate(parts):
chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec))
# run ASR on chunks
chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks]
chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks]
chunk_wavs: List[np.ndarray] = [c.wav for c in chunks]
raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang)
# parse outputs, prepare for optional alignment
per_chunk_lang: List[str] = []
per_chunk_text: List[str] = []
for out, forced_lang in zip(raw_outputs, chunk_lang):
lang, txt = parse_asr_output(out, user_language=forced_lang)
per_chunk_lang.append(lang)
per_chunk_text.append(txt)
# forced alignment (optional)
per_chunk_align: List[Optional[Any]] = [None] * len(chunks)
if return_time_stamps:
to_align_audio = []
to_align_text = []
to_align_lang = []
to_align_idx = []
for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)):
if txt.strip() == "":
continue
to_align_audio.append((c.wav, c.sr))
to_align_text.append(txt)
to_align_lang.append(lang_pred)
to_align_idx.append(idx)
# batch align with max_inference_batch_size
aligned_results: List[Any] = []
for a_chunk, t_chunk, l_chunk in zip(
chunk_list(to_align_audio, self.max_inference_batch_size),
chunk_list(to_align_text, self.max_inference_batch_size),
chunk_list(to_align_lang, self.max_inference_batch_size),
):
aligned_results.extend(
self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk)
)
# offset fix
for k, idx in enumerate(to_align_idx):
c = chunks[idx]
r = aligned_results[k]
per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec)
# merge chunks back to original samples
out_langs: List[List[str]] = [[] for _ in range(n)]
out_texts: List[List[str]] = [[] for _ in range(n)]
out_aligns: List[List[Any]] = [[] for _ in range(n)]
for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align):
out_langs[c.orig_index].append(lang)
out_texts[c.orig_index].append(txt)
if return_time_stamps and al is not None:
out_aligns[c.orig_index].append(al)
results: List[ASRTranscription] = []
for i in range(n):
merged_text = "".join([t for t in out_texts[i] if t is not None])
merged_language = merge_languages(out_langs[i])
merged_align = None
if return_time_stamps:
merged_align = self._merge_align_results(out_aligns[i])
results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align))
return results
def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]:
return [
{"role": "system", "content": context or ""},
{"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
]
def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str:
"""
Build the string prompt for one request.
If force_language is provided, "language X<asr_text>" is appended after the generation prompt
to request text-only output.
"""
msgs = self._build_messages(context=context, audio_payload="")
base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
if force_language:
base = base + f"language {force_language}{'<asr_text>'}"
return base
def _infer_asr(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
"""
Run backend inference for chunk-level items.
Args:
contexts: List of context strings.
wavs: List of mono waveforms (np.ndarray).
languages: List of forced languages or None.
Returns:
List[str]: Raw decoded strings (one per chunk).
"""
if self.backend == "transformers":
return self._infer_asr_transformers(contexts, wavs, languages)
if self.backend == "vllm":
return self._infer_asr_vllm(contexts, wavs, languages)
raise RuntimeError(f"Unknown backend: {self.backend}")
def _infer_asr_transformers(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
outs: List[str] = []
texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)]
batch_size = self.max_inference_batch_size
if batch_size is None or batch_size < 0:
batch_size = len(texts)
for i in range(0, len(texts), batch_size):
sub_text = texts[i : i + batch_size]
sub_wavs = wavs[i : i + batch_size]
inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True)
inputs = inputs.to(self.model.device).to(self.model.dtype)
text_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
decoded = self.processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
outs.extend(list(decoded))
return outs
def _infer_asr_vllm(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
inputs: List[Dict[str, Any]] = []
for c, w, fl in zip(contexts, wavs, languages):
prompt = self._build_text_prompt(context=c, force_language=fl)
inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}})
outs: List[str] = []
for batch in chunk_list(inputs, self.max_inference_batch_size):
outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False)
for o in outputs:
outs.append(o.outputs[0].text)
return outs
def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
"""
Apply time offset to a ForcedAlignResult-like object.
This function assumes:
- result has attribute `.items` which is a list of items with start_time/end_time in seconds.
- dataclasses are frozen in upstream implementation, so we reconstruct by type.
Args:
result: ForcedAlignResult
offset_sec: Offset in seconds
Returns:
ForcedAlignResult: New object with shifted timestamps.
"""
if result is None:
return None
items = []
for it in result.items:
items.append(type(it)(text=it.text,
start_time=round(it.start_time + offset_sec, 3),
end_time=round(it.end_time + offset_sec, 3)))
return type(result)(items=items)
def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
"""
Merge multiple ForcedAlignResult objects into a single one by concatenating items.
Args:
results: List of ForcedAlignResult
Returns:
ForcedAlignResult or None
"""
if not results:
return None
all_items = []
for r in results:
if r is None:
continue
all_items.extend(list(r.items))
if not all_items:
return None
return type(results[0])(items=all_items)
def init_streaming_state(
self,
context: str = "",
language: Optional[str] = None,
unfixed_chunk_num: int = 2,
unfixed_token_num: int = 5,
chunk_size_sec: float = 2.0,
) -> ASRStreamingState:
"""
Initialize streaming ASR state for a single stream.
Notes:
- Streaming ASR is supported ONLY for vLLM backend.
- Streaming ASR does NOT support timestamps (forced aligner is not used).
- Batch inference is NOT supported.
Args:
context:
Context string.
language:
Optional forced language. If provided, it must be in supported languages.
Same behavior as transcribe(): forces text-only output via prompt suffix.
unfixed_chunk_num:
For the first N chunks, do not use previous output as prefix prompt (reset prefix to "").
unfixed_token_num:
Roll back the last K tokens from accumulated output when using it as prefix prompt
after unfixed_chunk_num.
chunk_size_sec:
Chunk size in seconds (audio is 16k PCM). The function will internally convert it
to sample count at 16kHz.
Returns:
ASRStreamingState: Mutable state object to be passed to streaming_transcribe() and
finish_streaming_transcribe().
Raises:
ValueError:
- If backend is not "vllm".
- If chunk_size_sec <= 0.
- If forced language is invalid (same validation rules as transcribe()).
"""
if self.backend != "vllm":
raise ValueError("Streaming ASR is supported only for vLLM backend (backend='vllm').")
if chunk_size_sec is None or float(chunk_size_sec) <= 0:
raise ValueError(f"chunk_size_sec must be > 0, got: {chunk_size_sec}")
force_language = None
if language is not None and str(language).strip() != "":
ln = normalize_language_name(str(language))
validate_language(ln)
force_language = ln
chunk_size_samples = int(round(float(chunk_size_sec) * SAMPLE_RATE))
chunk_size_samples = max(1, chunk_size_samples)
prompt_raw = self._build_text_prompt(context=context, force_language=force_language)
return ASRStreamingState(
unfixed_chunk_num=int(unfixed_chunk_num),
unfixed_token_num=int(unfixed_token_num),
chunk_size_sec=float(chunk_size_sec),
chunk_size_samples=int(chunk_size_samples),
chunk_id=0,
buffer=np.zeros((0,), dtype=np.float32),
audio_accum=np.zeros((0,), dtype=np.float32),
prompt_raw=prompt_raw,
context=context or "",
force_language=force_language,
language="",
text="",
_raw_decoded="",
)
def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) -> ASRStreamingState:
"""
Streaming ASR decode step.
This function accepts an arbitrary-length 16k PCM float numpy array (mono).
It buffers incoming samples, and whenever enough samples are accumulated to form one
full chunk (chunk_size_sec), it runs one incremental decode step and updates:
- state.language
- state.text
The caller only needs to keep passing audio to this function and read state.language/state.text.
Implementation details:
- Each time a new chunk is ready, we append it to audio_accum and re-feed *all* audio seen
so far to the model (no padding).
- We update the prompt as: state.prompt_raw + prefix_text
- Prefix rollback strategy:
* If chunk_id < unfixed_chunk_num: prefix_text = ""
* Else: rollback last unfixed_token_num tokens from previously accumulated decoded text.
Notes:
- vLLM backend only.
- No timestamps.
- Single stream only (no batching).
Args:
pcm16k:
16kHz mono PCM waveform (np.ndarray). Length can be any non-negative integer.
dtype can be float32/float64/int16; it will be converted to float32.
state:
Streaming state returned by init_streaming_state().
Returns:
ASRStreamingState: The same state object (mutated) for convenience.
Raises:
ValueError:
If backend is not "vllm" or state is invalid.
"""
if self.backend != "vllm":
raise ValueError("streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
if state is None:
raise ValueError("state must not be None. Call init_streaming_state() first.")
if pcm16k is None:
raise ValueError("pcm16k must not be None.")
# Ensure 1D mono
x = np.asarray(pcm16k)
if x.ndim != 1:
x = x.reshape(-1)
# Convert to float32 PCM in [-1, 1] if int16 provided
if x.dtype == np.int16:
x = (x.astype(np.float32) / 32768.0)
else:
x = x.astype(np.float32, copy=False)
# Append to buffer
if x.shape[0] > 0:
state.buffer = np.concatenate([state.buffer, x], axis=0)
# Consume full chunks
while state.buffer.shape[0] >= state.chunk_size_samples:
chunk = state.buffer[: state.chunk_size_samples]
state.buffer = state.buffer[state.chunk_size_samples :]
# Accumulate audio (re-feed from start, no padding)
if state.audio_accum.shape[0] == 0:
state.audio_accum = chunk
else:
state.audio_accum = np.concatenate([state.audio_accum, chunk], axis=0)
# Build prefix with rollback strategy
prefix = ""
if state.chunk_id < state.unfixed_chunk_num:
prefix = ""
else:
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
prompt = state.prompt_raw + prefix
# vLLM input: single item
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
gen_text = outputs[0].outputs[0].text
# Accumulate raw decoded (then parse to lang/text)
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
state.language = lang
state.text = txt
state.chunk_id += 1
return state
def finish_streaming_transcribe(self, state: ASRStreamingState) -> ASRStreamingState:
"""
Finish streaming ASR.
This function flushes the remaining buffered audio in state.buffer (tail audio).
It sends the remaining samples to the model even if shorter than chunk_size_sec,
without padding. Then it updates state.language/state.text one last time.
Notes:
- vLLM backend only.
- No timestamps.
- Single stream only.
Args:
state:
Streaming state.
Returns:
ASRStreamingState: Updated state (mutated).
Raises:
ValueError:
If backend is not "vllm" or state is invalid.
"""
if self.backend != "vllm":
raise ValueError("finish_streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
if state is None:
raise ValueError("state must not be None.")
# If no remaining buffer, still return state as-is.
if state.buffer is None or state.buffer.shape[0] == 0:
return state
tail = state.buffer
state.buffer = np.zeros((0,), dtype=np.float32)
# Append tail to accumulated audio
if state.audio_accum.shape[0] == 0:
state.audio_accum = tail
else:
state.audio_accum = np.concatenate([state.audio_accum, tail], axis=0)
# Prefix rollback strategy (same as per-chunk)
prefix = ""
if state.chunk_id < state.unfixed_chunk_num:
prefix = ""
else:
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
prompt = state.prompt_raw + prefix
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
gen_text = outputs[0].outputs[0].text
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
state.language = lang
state.text = txt
state.chunk_id += 1
return state

View File

@ -0,0 +1,483 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unicodedata
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import nagisa
import torch
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
from .utils import (
AudioLike,
ensure_list,
normalize_audios,
)
class Qwen3ForceAlignProcessor():
def __init__(self):
ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict")
ko_scores = {}
with open(ko_dict_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
word = line.split()[0]
ko_scores[word] = 1.0
self.ko_score = ko_scores
self.ko_tokenizer = None
def is_kept_char(self, ch: str) -> bool:
if ch == "'":
return True
cat = unicodedata.category(ch)
if cat.startswith("L") or cat.startswith("N"):
return True
return False
def clean_token(self, token: str) -> str:
return "".join(ch for ch in token if self.is_kept_char(ch))
def is_cjk_char(self, ch: str) -> bool:
code = ord(ch)
return (
0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs
or 0x3400 <= code <= 0x4DBF # Extension A
or 0x20000 <= code <= 0x2A6DF # Extension B
or 0x2A700 <= code <= 0x2B73F # Extension C
or 0x2B740 <= code <= 0x2B81F # Extension D
or 0x2B820 <= code <= 0x2CEAF # Extension E
or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs
)
def tokenize_chinese_mixed(self, text: str) -> List[str]:
tokens: List[str] = []
current_latin: List[str] = []
def flush_latin():
nonlocal current_latin
if current_latin:
token = "".join(current_latin)
cleaned = self.clean_token(token)
if cleaned:
tokens.append(cleaned)
current_latin = []
for ch in text:
if self.is_cjk_char(ch):
flush_latin()
tokens.append(ch)
else:
if self.is_kept_char(ch):
current_latin.append(ch)
else:
flush_latin()
flush_latin()
return tokens
def tokenize_japanese(self, text: str) -> List[str]:
words = nagisa.tagging(text).words
tokens: List[str] = []
for w in words:
cleaned = self.clean_token(w)
if cleaned:
tokens.append(cleaned)
return tokens
def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]:
raw_tokens = ko_tokenizer.tokenize(text)
tokens: List[str] = []
for w in raw_tokens:
w_clean = self.clean_token(w)
if w_clean:
tokens.append(w_clean)
return tokens
def split_segment_with_chinese(self, seg: str) -> List[str]:
tokens: List[str] = []
buf: List[str] = []
def flush_buf():
nonlocal buf
if buf:
tokens.append("".join(buf))
buf = []
for ch in seg:
if self.is_cjk_char(ch):
flush_buf()
tokens.append(ch)
else:
buf.append(ch)
flush_buf()
return tokens
def tokenize_space_lang(self, text: str) -> List[str]:
tokens: List[str] = []
for seg in text.split():
cleaned = self.clean_token(seg)
if cleaned:
tokens.extend(self.split_segment_with_chinese(cleaned))
return tokens
def fix_timestamp(self, data) -> List[int]:
data = data.tolist()
n = len(data)
dp = [1] * n
parent = [-1] * n
for i in range(1, n):
for j in range(i):
if data[j] <= data[i] and dp[j] + 1 > dp[i]:
dp[i] = dp[j] + 1
parent[i] = j
max_length = max(dp)
max_idx = dp.index(max_length)
lis_indices = []
idx = max_idx
while idx != -1:
lis_indices.append(idx)
idx = parent[idx]
lis_indices.reverse()
is_normal = [False] * n
for idx in lis_indices:
is_normal[idx] = True
result = data.copy()
i = 0
while i < n:
if not is_normal[i]:
j = i
while j < n and not is_normal[j]:
j += 1
anomaly_count = j - i
if anomaly_count <= 2:
left_val = None
for k in range(i - 1, -1, -1):
if is_normal[k]:
left_val = result[k]
break
right_val = None
for k in range(j, n):
if is_normal[k]:
right_val = result[k]
break
for k in range(i, j):
if left_val is None:
result[k] = right_val
elif right_val is None:
result[k] = left_val
else:
result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val
else:
left_val = None
for k in range(i - 1, -1, -1):
if is_normal[k]:
left_val = result[k]
break
right_val = None
for k in range(j, n):
if is_normal[k]:
right_val = result[k]
break
if left_val is not None and right_val is not None:
step = (right_val - left_val) / (anomaly_count + 1)
for k in range(i, j):
result[k] = left_val + step * (k - i + 1)
elif left_val is not None:
for k in range(i, j):
result[k] = left_val
elif right_val is not None:
for k in range(i, j):
result[k] = right_val
i = j
else:
i += 1
return [int(res) for res in result]
def encode_timestamp(self, text: str, language: str) -> List[str]:
language = language.lower()
if language.lower() == "japanese":
word_list = self.tokenize_japanese(text)
elif language.lower() == "korean":
if self.ko_tokenizer is None:
from soynlp.tokenizer import LTokenizer
self.ko_tokenizer = LTokenizer(scores=self.ko_score)
word_list = self.tokenize_korean(self.ko_tokenizer, text)
else:
word_list = self.tokenize_space_lang(text)
input_text = "<timestamp><timestamp>".join(word_list) + "<timestamp><timestamp>"
input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text
return word_list, input_text
def parse_timestamp(self, word_list, timestamp):
timestamp_output = []
timestamp_fixed = self.fix_timestamp(timestamp)
for i, word in enumerate(word_list):
start_time = timestamp_fixed[i * 2]
end_time = timestamp_fixed[i * 2 + 1]
timestamp_output.append({
"text": word,
"start_time": start_time,
"end_time": end_time
})
return timestamp_output
@dataclass(frozen=True)
class ForcedAlignItem:
"""
One aligned item span.
Attributes:
text (str):
The aligned unit (cjk character or word) produced by the forced aligner processor.
start_time (float):
Start time in seconds.
end_time (float):
End time in seconds.
"""
text: str
start_time: int
end_time: int
@dataclass(frozen=True)
class ForcedAlignResult:
"""
Forced alignment output for one sample.
Attributes:
items (List[ForcedAlignItem]):
Aligned token spans.
"""
items: List[ForcedAlignItem]
def __iter__(self):
return iter(self.items)
def __len__(self):
return len(self.items)
def __getitem__(self, idx: int) -> ForcedAlignItem:
return self.items[idx]
class Qwen3ForcedAligner:
"""
A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.
This wrapper provides:
- `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
- audio input normalization (path/URL/base64/(np.ndarray, sr))
- batch and single-sample forced alignment
- structured output with attribute access (`.text`, `.start_time`, `.end_time`)
"""
def __init__(
self,
model: Qwen3ASRForConditionalGeneration,
processor: Qwen3ASRProcessor,
aligner_processor: Qwen3ForceAlignProcessor,
):
self.model = model
self.processor = processor
self.aligner_processor = aligner_processor
self.device = getattr(model, "device", None)
if self.device is None:
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
self.timestamp_token_id = int(model.config.timestamp_token_id)
self.timestamp_segment_time = float(model.config.timestamp_segment_time)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs,
) -> "Qwen3ForcedAligner":
"""
Load Qwen3-ForcedAligner model and initialize processors.
This method:
1) Registers config/model/processor for HF auto classes.
2) Loads the model using `AutoModel.from_pretrained(...)`.
3) Initializes:
- HF processor (`AutoProcessor.from_pretrained(...)`)
- forced alignment text processor (`Qwen3ForceAlignProcessor()`)
Args:
pretrained_model_name_or_path (str):
HuggingFace repo id or local directory.
**kwargs:
Forwarded to `AutoModel.from_pretrained(...)`.
Typical examples: device_map="cuda:0", dtype=torch.bfloat16.
Returns:
Qwen3ForcedAligner:
Initialized wrapper instance.
"""
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
if not isinstance(model, Qwen3ASRForConditionalGeneration):
raise TypeError(
f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration."
)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
aligner_processor = Qwen3ForceAlignProcessor()
return cls(model=model, processor=processor, aligner_processor=aligner_processor)
def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
items: List[ForcedAlignItem] = []
for it in timestamp_output:
items.append(
ForcedAlignItem(
text=str(it.get("text", "")),
start_time=float(it.get("start_time", 0)),
end_time=float(it.get("end_time", 0)),
)
)
return ForcedAlignResult(items=items)
@torch.inference_mode()
def align(
self,
audio: Union[AudioLike, List[AudioLike]],
text: Union[str, List[str]],
language: Union[str, List[str]],
) -> List[ForcedAlignResult]:
"""
Run forced alignment for batch or single sample.
Args:
audio:
Audio input(s). Each item supports:
- local path / https URL / base64 string
- (np.ndarray, sr)
All audios will be converted into mono 16k float32 arrays in [-1, 1].
text:
Transcript(s) for alignment.
language:
Language(s) for each sample (e.g., "Chinese", "English").
Returns:
List[ForcedAlignResult]:
One result per sample. Each result contains `items`, and each token can be accessed via
`.text`, `.start_time`, `.end_time`.
"""
texts = ensure_list(text)
languages = ensure_list(language)
audios = normalize_audios(audio)
if len(languages) == 1 and len(audios) > 1:
languages = languages * len(audios)
if not (len(audios) == len(texts) == len(languages)):
raise ValueError(
f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
)
word_lists = []
aligner_input_texts = []
for t, lang in zip(texts, languages):
word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang)
word_lists.append(word_list)
aligner_input_texts.append(aligner_input_text)
inputs = self.processor(
text=aligner_input_texts,
audio=audios,
return_tensors="pt",
padding=True,
)
inputs = inputs.to(self.model.device).to(self.model.dtype)
logits = self.model.thinker(**inputs).logits
output_ids = logits.argmax(dim=-1)
results: List[ForcedAlignResult] = []
for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists):
masked_output_id = output_id[input_id == self.timestamp_token_id]
timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
for it in timestamp_output:
it['start_time'] = round(it['start_time'] / 1000.0, 3)
it['end_time'] = round(it['end_time'] / 1000.0, 3)
results.append(self._to_structured_items(timestamp_output))
return results
def get_supported_languages(self) -> Optional[List[str]]:
"""
List supported language names for the current model.
This is a thin wrapper around `self.model.get_support_languages()`.
If the underlying model does not expose language constraints (returns None),
this method also returns None.
Returns:
Optional[List[str]]:
- A sorted list of supported language names (lowercased), if available.
- None if the model does not provide supported languages.
"""
fn = getattr(self.model, "get_support_languages", None)
if not callable(fn):
return None
langs = fn()
if langs is None:
return None
return sorted({str(x).lower() for x in langs})

497
qwen_asr/inference/utils.py Normal file
View File

@ -0,0 +1,497 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import urllib.request
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple, Union
from urllib.parse import urlparse
import librosa
import numpy as np
import soundfile as sf
AudioLike = Union[
str, # wav path / URL / base64
Tuple[np.ndarray, int], # (waveform, sr)
]
MaybeList = Union[Any, List[Any]]
SAMPLE_RATE = 16000
MAX_ASR_INPUT_SECONDS = 1200
MAX_FORCE_ALIGN_INPUT_SECONDS = 180
MIN_ASR_INPUT_SECONDS = 0.5
SUPPORTED_LANGUAGES: List[str] = [
"Chinese",
"English",
"Cantonese",
"Arabic",
"German",
"French",
"Spanish",
"Portuguese",
"Indonesian",
"Italian",
"Korean",
"Russian",
"Thai",
"Vietnamese",
"Japanese",
"Turkish",
"Hindi",
"Malay",
"Dutch",
"Swedish",
"Danish",
"Finnish",
"Polish",
"Czech",
"Filipino",
"Persian",
"Greek",
"Romanian",
"Hungarian",
"Macedonian"
]
_ASR_TEXT_TAG = "<asr_text>"
_LANG_PREFIX = "language "
def normalize_language_name(language: str) -> str:
"""
Normalize language name to the canonical format used by Qwen3-ASR:
first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese').
Args:
language (str): Input language name.
Returns:
str: Normalized language name.
Raises:
ValueError: If language is empty.
"""
if language is None:
raise ValueError("language is None")
s = str(language).strip()
if not s:
raise ValueError("language is empty")
return s[:1].upper() + s[1:].lower()
def validate_language(language: str) -> None:
"""
Validate the language is supported.
Args:
language (str): Canonical language name.
Raises:
ValueError: If unsupported.
"""
if language not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}")
def ensure_list(x: MaybeList) -> List[Any]:
return x if isinstance(x, list) else [x]
def is_url(s: str) -> bool:
try:
u = urlparse(s)
return u.scheme in ("http", "https") and bool(u.netloc)
except Exception:
return False
def is_probably_base64(s: str) -> bool:
if s.startswith("data:audio"):
return True
if ("/" not in s and "\\" not in s) and len(s) > 256:
return True
return False
def decode_base64_bytes(b64: str) -> bytes:
if "," in b64 and b64.strip().startswith("data:"):
b64 = b64.split(",", 1)[1]
return base64.b64decode(b64)
def load_audio_any(x: str) -> Tuple[np.ndarray, int]:
if is_url(x):
with urllib.request.urlopen(x) as resp:
audio_bytes = resp.read()
with io.BytesIO(audio_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
elif is_probably_base64(x):
audio_bytes = decode_base64_bytes(x)
with io.BytesIO(audio_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
audio, sr = librosa.load(x, sr=None, mono=False)
audio = np.asarray(audio, dtype=np.float32)
sr = int(sr)
return audio, sr
def to_mono(audio: np.ndarray) -> np.ndarray:
if audio.ndim == 1:
return audio
# soundfile can return shape (T, C); some pipelines use (C, T)
if audio.ndim == 2:
if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]:
audio = audio.T
return np.mean(audio, axis=-1).astype(np.float32)
raise ValueError(f"Unsupported audio ndim={audio.ndim}")
def float_range_normalize(audio: np.ndarray) -> np.ndarray:
audio = audio.astype(np.float32)
if audio.size == 0:
return audio
peak = float(np.max(np.abs(audio)))
if peak == 0.0:
return audio
# If decoded audio is int-like scaled or out-of-range, normalize conservatively.
if peak > 1.0:
audio = audio / peak
audio = np.clip(audio, -1.0, 1.0)
return audio
def normalize_audio_input(a: AudioLike) -> np.ndarray:
"""
Normalize one audio input to mono 16k float32 waveform in [-1, 1].
Supported inputs:
- str: local file path / https URL / base64 audio string
- (np.ndarray, sr): waveform and sampling rate
Returns:
np.ndarray:
Mono 16k float32 waveform in [-1, 1].
"""
if isinstance(a, str):
audio, sr = load_audio_any(a)
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
audio, sr = a[0], int(a[1])
else:
raise TypeError(f"Unsupported audio input type: {type(a)}")
audio = to_mono(np.asarray(audio))
if sr != SAMPLE_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
audio = float_range_normalize(audio)
return audio
def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]:
items = ensure_list(audios)
return [normalize_audio_input(a) for a in items]
def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]:
"""
Yield chunks of a list.
Args:
xs (List[Any]): Input list.
chunk_size (int): Chunk size.
Yields:
List[Any]: Slices of xs.
"""
if chunk_size <= 0:
yield xs
return
for i in range(0, len(xs), chunk_size):
yield xs[i : i + chunk_size]
@dataclass(frozen=True)
class AudioChunk:
"""
One chunk cut from an original audio.
Attributes:
orig_index: Index of the original sample in the input batch.
chunk_index: Index of this chunk within the original sample.
wav: Mono float32 waveform.
sr: Sampling rate.
offset_sec: Start offset of this chunk in the original audio, in seconds.
"""
orig_index: int
chunk_index: int
wav: np.ndarray
sr: int
offset_sec: float
def split_audio_into_chunks(
wav: np.ndarray,
sr: int,
max_chunk_sec: float,
search_expand_sec: float = 5.0,
min_window_ms: float = 100.0,
) -> List[Tuple[np.ndarray, float]]:
"""
Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary.
This implementation guarantees:
- Concatenating all returned chunks reproduces the original audio exactly
(total number of samples is identical, no overlaps, no gaps).
Args:
wav: Mono waveform float32.
sr: Sampling rate.
max_chunk_sec: Target max chunk duration in seconds.
search_expand_sec: Boundary search half-window in seconds.
min_window_ms: Sliding window in milliseconds for energy estimation.
Returns:
List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec).
"""
wav = np.asarray(wav, dtype=np.float32)
if wav.ndim > 1:
wav = np.mean(wav, axis=-1).astype(np.float32)
total_len = int(wav.shape[0])
total_sec = total_len / float(sr)
if total_sec <= max_chunk_sec:
return [(wav, 0.0)]
max_len = int(max_chunk_sec * sr)
expand = int(search_expand_sec * sr)
win = max(4, int((min_window_ms / 1000.0) * sr))
chunks: List[Tuple[np.ndarray, float]] = []
start = 0
offset_sec = 0.0
while (total_len - start) > max_len:
cut = start + max_len
left = max(start, cut - expand)
right = min(total_len, cut + expand)
if right - left <= win:
boundary = cut
else:
seg = wav[left:right]
seg_abs = np.abs(seg)
window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid")
min_pos = int(np.argmin(window_sums))
wstart = min_pos
wend = min_pos + win
local = seg_abs[wstart:wend]
inner = int(np.argmin(local))
boundary = left + wstart + inner
boundary = int(max(boundary, start + 1))
boundary = int(min(boundary, total_len))
chunk = wav[start:boundary]
chunks.append((chunk, offset_sec))
offset_sec += (boundary - start) / float(sr)
start = boundary
tail = wav[start:total_len]
chunks.append((tail, offset_sec))
# Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail)
min_len = int(MIN_ASR_INPUT_SECONDS * sr)
padded: List[Tuple[np.ndarray, float]] = []
for c, off in chunks:
if c.shape[0] < min_len:
pad = min_len - int(c.shape[0])
c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32)
padded.append((c, off))
chunks = padded
return chunks
def detect_and_fix_repetitions(text, threshold=20):
def fix_char_repeats(s, thresh):
res = []
i = 0
n = len(s)
while i < n:
count = 1
while i + count < n and s[i + count] == s[i]:
count += 1
if count > thresh:
res.append(s[i])
i += count
else:
res.append(s[i:i+count])
i += count
return ''.join(res)
def fix_pattern_repeats(s, thresh, max_len=20):
n = len(s)
min_repeat_chars = thresh * 2
if n < min_repeat_chars:
return s
i = 0
result = []
while i <= n - min_repeat_chars:
found = False
for k in range(1, max_len + 1):
if i + k * thresh > n:
break
pattern = s[i:i+k]
valid = True
for rep in range(1, thresh):
start_idx = i + rep * k
if s[start_idx:start_idx+k] != pattern:
valid = False
break
if valid:
total_rep = thresh
end_index = i + thresh * k
while end_index + k <= n and s[end_index:end_index+k] == pattern:
total_rep += 1
end_index += k
result.append(pattern)
result.append(fix_pattern_repeats(s[end_index:], thresh, max_len))
i = n
found = True
break
if found:
break
else:
result.append(s[i])
i += 1
if not found:
result.append(s[i:])
return ''.join(result)
text_raw = text
text = fix_char_repeats(text_raw, threshold)
text = fix_pattern_repeats(text, threshold)
return text
def parse_asr_output(
raw: str,
user_language: Optional[str] = None,
) -> Tuple[str, str]:
"""
Parse Qwen3-ASR raw output into (language, text).
Cases:
- With tag: "language Chinese<asr_text>...."
- With newlines: "language Chinese\\n...\\n<asr_text>...."
- No tag: treat whole string as text.
- "language None<asr_text>": treat as empty audio -> ("", "")
If user_language is provided, language is forced to user_language and raw is treated as text-only
(the model is expected to output plain transcription without metadata).
Args:
raw: Raw decoded string.
user_language: Canonical language name if user forced language.
Returns:
Tuple[str, str]: (language, text)
"""
if raw is None:
return "", ""
s = str(raw).strip()
if not s:
return "", ""
s = detect_and_fix_repetitions(s)
if user_language:
# user explicitly forced language => model output is treated as pure text
return user_language, s
meta_part = s
text_part = ""
has_tag = _ASR_TEXT_TAG in s
if has_tag:
meta_part, text_part = s.split(_ASR_TEXT_TAG, 1)
else:
# no tag => pure text
return "", s.strip()
meta_lower = meta_part.lower()
# empty audio heuristic
if "language none" in meta_lower:
t = text_part.strip()
if not t:
return "", ""
# if model still returned something, keep it but language unknown
return "", t
# extract "language xxx" from meta
lang = ""
for line in meta_part.splitlines():
line = line.strip()
if not line:
continue
low = line.lower()
if low.startswith(_LANG_PREFIX):
val = line[len(_LANG_PREFIX):].strip()
if val:
lang = normalize_language_name(val)
break
return lang, text_part.strip()
def merge_languages(langs: List[str]) -> str:
"""
Merge per-chunk languages into a compact comma-separated string,
keeping order and removing consecutive duplicates and empty entries.
Example:
["Chinese", "English", "English"] -> "Chinese,English"
Args:
langs: List of canonical language names.
Returns:
str: Merged language string.
"""
out: List[str] = []
prev = None
for x in langs:
x = (x or "").strip()
if not x:
continue
if x == prev:
continue
out.append(x)
prev = x
return ",".join(out)