Initial commit
This commit is contained in:
536
qwen_asr/cli/demo.py
Normal file
536
qwen_asr/cli/demo.py
Normal file
@ -0,0 +1,536 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A gradio demo for Qwen3 ASR models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
from scipy.io.wavfile import write as wav_write
|
||||
|
||||
|
||||
def _title_case_display(s: str) -> str:
|
||||
s = (s or "").strip()
|
||||
s = s.replace("_", " ")
|
||||
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
||||
|
||||
|
||||
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
||||
if not items:
|
||||
return [], {}
|
||||
display = [_title_case_display(x) for x in items]
|
||||
mapping = {d: r for d, r in zip(display, items)}
|
||||
return display, mapping
|
||||
|
||||
|
||||
def _dtype_from_str(s: str) -> torch.dtype:
|
||||
s = (s or "").strip().lower()
|
||||
if s in ("bf16", "bfloat16"):
|
||||
return torch.bfloat16
|
||||
if s in ("fp16", "float16", "half"):
|
||||
return torch.float16
|
||||
if s in ("fp32", "float32"):
|
||||
return torch.float32
|
||||
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
|
||||
|
||||
|
||||
def _normalize_audio(wav, eps=1e-12, clip=True):
|
||||
x = np.asarray(wav)
|
||||
|
||||
if np.issubdtype(x.dtype, np.integer):
|
||||
info = np.iinfo(x.dtype)
|
||||
if info.min < 0:
|
||||
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
||||
else:
|
||||
mid = (info.max + 1) / 2.0
|
||||
y = (x.astype(np.float32) - mid) / mid
|
||||
elif np.issubdtype(x.dtype, np.floating):
|
||||
y = x.astype(np.float32)
|
||||
m = np.max(np.abs(y)) if y.size else 0.0
|
||||
if m > 1.0 + 1e-6:
|
||||
y = y / (m + eps)
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype: {x.dtype}")
|
||||
|
||||
if clip:
|
||||
y = np.clip(y, -1.0, 1.0)
|
||||
|
||||
if y.ndim > 1:
|
||||
y = np.mean(y, axis=-1).astype(np.float32)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
|
||||
"""
|
||||
Accept gradio audio:
|
||||
- {"sampling_rate": int, "data": np.ndarray}
|
||||
- (sr, np.ndarray) [some gradio versions]
|
||||
Return: (wav_float32_mono, sr)
|
||||
"""
|
||||
if audio is None:
|
||||
return None
|
||||
|
||||
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
|
||||
sr = int(audio["sampling_rate"])
|
||||
wav = _normalize_audio(audio["data"])
|
||||
return wav, sr
|
||||
|
||||
if isinstance(audio, tuple) and len(audio) == 2:
|
||||
a0, a1 = audio
|
||||
if isinstance(a0, int):
|
||||
sr = int(a0)
|
||||
wav = _normalize_audio(a1)
|
||||
return wav, sr
|
||||
if isinstance(a1, int):
|
||||
wav = _normalize_audio(a0)
|
||||
sr = int(a1)
|
||||
return wav, sr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]:
|
||||
if audio is None:
|
||||
raise ValueError("Audio is required.")
|
||||
at = _audio_to_tuple(audio)
|
||||
if at is not None:
|
||||
return at
|
||||
raise ValueError("Unsupported audio input format.")
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="qwen-asr-demo",
|
||||
description=(
|
||||
"Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n"
|
||||
"Examples:\n"
|
||||
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n"
|
||||
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n"
|
||||
" qwen-asr-demo --backend vllm --cuda-visible-devices 0\n"
|
||||
" qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n"
|
||||
" qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n"
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
add_help=True,
|
||||
)
|
||||
|
||||
parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.")
|
||||
parser.add_argument(
|
||||
"--aligner-checkpoint",
|
||||
default=None,
|
||||
help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
default="transformers",
|
||||
choices=["transformers", "vllm"],
|
||||
help="Backend for ASR model loading (default: transformers).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuda-visible-devices",
|
||||
default="0",
|
||||
help=(
|
||||
"Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). "
|
||||
"Use e.g. '0' or '1'"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend-kwargs",
|
||||
default=None,
|
||||
help=(
|
||||
"JSON dict for backend-specific kwargs excluding checkpoints.\n"
|
||||
"Examples:\n"
|
||||
" transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n"
|
||||
" vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aligner-kwargs",
|
||||
default=None,
|
||||
help=(
|
||||
"JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n"
|
||||
"Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n"
|
||||
),
|
||||
)
|
||||
|
||||
# Gradio server args
|
||||
parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).")
|
||||
parser.add_argument(
|
||||
"--share/--no-share",
|
||||
dest="share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether to create a public Gradio link (default: disabled).",
|
||||
)
|
||||
parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).")
|
||||
|
||||
# HTTPS args
|
||||
parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).")
|
||||
parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).")
|
||||
parser.add_argument(
|
||||
"--ssl-verify/--no-ssl-verify",
|
||||
dest="ssl_verify",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether to verify SSL certificate (default: enabled).",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]:
|
||||
if s is None or not str(s).strip():
|
||||
return {}
|
||||
try:
|
||||
obj = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid JSON for {name}: {e}")
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError(f"{name} must be a JSON object (dict).")
|
||||
return obj
|
||||
|
||||
|
||||
def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None:
|
||||
v = (cuda_visible_devices or "").strip()
|
||||
if not v:
|
||||
return
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = v
|
||||
|
||||
|
||||
def _default_backend_kwargs(backend: str) -> Dict[str, Any]:
|
||||
if backend == "transformers":
|
||||
return dict(
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
max_inference_batch_size=4,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
gpu_memory_utilization=0.8,
|
||||
max_inference_batch_size=4,
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def _default_aligner_kwargs() -> Dict[str, Any]:
|
||||
return dict(
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
|
||||
|
||||
def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
out = dict(base)
|
||||
out.update(override)
|
||||
return out
|
||||
|
||||
|
||||
def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
for k, v in d.items():
|
||||
if k == "dtype" and isinstance(v, str):
|
||||
out[k] = _dtype_from_str(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str:
|
||||
"""
|
||||
Build HTML with per-token audio slices, using base64 data URLs.
|
||||
Expect timestamps as list[dict] with keys: text, start_time, end_time (ms).
|
||||
"""
|
||||
at = _audio_to_tuple(audio_upload)
|
||||
if at is None:
|
||||
raise ValueError("Audio input is required for visualization.")
|
||||
audio, sr = at
|
||||
|
||||
if not timestamps:
|
||||
return "<div style='color:#666'>No timestamps to visualize.</div>"
|
||||
if not isinstance(timestamps, list):
|
||||
raise ValueError("Timestamps must be a list (JSON array).")
|
||||
|
||||
html_content = """
|
||||
<style>
|
||||
.word-alignment-container { display: flex; flex-wrap: wrap; gap: 10px; }
|
||||
.word-box {
|
||||
border: 1px solid #ddd; border-radius: 8px; padding: 10px;
|
||||
background-color: #f9f9f9; box-shadow: 0 2px 4px rgba(0,0,0,0.06);
|
||||
text-align: center;
|
||||
}
|
||||
.word-text { font-size: 18px; font-weight: 700; margin-bottom: 5px; }
|
||||
.word-time { font-size: 12px; color: #666; margin-bottom: 8px; }
|
||||
.word-audio audio { width: 140px; height: 30px; }
|
||||
details { border: 1px solid #ddd; border-radius: 6px; padding: 10px; background-color: #f7f7f7; }
|
||||
summary { font-weight: 700; cursor: pointer; }
|
||||
</style>
|
||||
"""
|
||||
|
||||
html_content += """
|
||||
<details open>
|
||||
<summary>Timestamps Visualization (时间戳可视化结果)</summary>
|
||||
<div class="word-alignment-container" style="margin-top: 14px;">
|
||||
"""
|
||||
|
||||
for item in timestamps:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
word = str(item.get("text", "") or "")
|
||||
start = item.get("start_time", None)
|
||||
end = item.get("end_time", None)
|
||||
if start is None or end is None:
|
||||
continue
|
||||
|
||||
start = float(start)
|
||||
end = float(end)
|
||||
if end <= start:
|
||||
continue
|
||||
|
||||
start_sample = max(0, int(start * sr))
|
||||
end_sample = min(len(audio), int(end * sr))
|
||||
if end_sample <= start_sample:
|
||||
continue
|
||||
|
||||
seg = audio[start_sample:end_sample]
|
||||
seg_i16 = (np.clip(seg, -1.0, 1.0) * 32767.0).astype(np.int16)
|
||||
|
||||
mem = io.BytesIO()
|
||||
wav_write(mem, sr, seg_i16)
|
||||
mem.seek(0)
|
||||
b64 = base64.b64encode(mem.read()).decode("utf-8")
|
||||
audio_src = f"data:audio/wav;base64,{b64}"
|
||||
|
||||
html_content += f"""
|
||||
<div class="word-box">
|
||||
<div class="word-text">{word}</div>
|
||||
<div class="word-time">{start} - {end} s</div>
|
||||
<div class="word-audio">
|
||||
<audio controls preload="none" src="{audio_src}"></audio>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += "</div></details>"
|
||||
return html_content
|
||||
|
||||
|
||||
def build_demo(
|
||||
asr: Qwen3ASRModel,
|
||||
asr_ckpt: str,
|
||||
backend: str,
|
||||
aligner_ckpt: Optional[str] = None,
|
||||
) -> gr.Blocks:
|
||||
supported_langs_raw = asr.get_supported_languages()
|
||||
lang_choices_disp, lang_map = _build_choices_and_map([x for x in supported_langs_raw])
|
||||
lang_choices = ["Auto"] + lang_choices_disp
|
||||
|
||||
has_aligner = bool(aligner_ckpt)
|
||||
|
||||
theme = gr.themes.Soft(
|
||||
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
|
||||
)
|
||||
css = ".gradio-container {max-width: none !important;}"
|
||||
|
||||
with gr.Blocks(theme=theme, css=css) as demo:
|
||||
gr.Markdown(
|
||||
f"""
|
||||
# Qwen3 ASR Demo
|
||||
**Backend:** `{backend}`
|
||||
**ASR Checkpoint:** `{asr_ckpt}`
|
||||
**Forced Aligner:** `{aligner_ckpt if aligner_ckpt else "(none)"}`
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
audio_in = gr.Audio(label="Audio Input (上传音频)", type="numpy")
|
||||
lang_in = gr.Dropdown(
|
||||
label="Language (语种)",
|
||||
choices=lang_choices,
|
||||
value="Auto",
|
||||
interactive=True,
|
||||
)
|
||||
if has_aligner:
|
||||
ts_in = gr.Checkbox(
|
||||
label="Return Timestamps (是否返回时间戳)",
|
||||
value=True,
|
||||
)
|
||||
else:
|
||||
ts_in = gr.State(False)
|
||||
|
||||
btn = gr.Button("Transcribe (识别)", variant="primary")
|
||||
|
||||
with gr.Column(scale=2):
|
||||
out_lang = gr.Textbox(label="Detected Language", lines=1)
|
||||
out_text = gr.Textbox(label="Result Text", lines=12)
|
||||
|
||||
if has_aligner:
|
||||
with gr.Column(scale=3):
|
||||
out_ts = gr.JSON(label="Timestamps(时间戳结果)")
|
||||
viz_btn = gr.Button("Visualize Timestamps (可视化时间戳)", variant="secondary")
|
||||
else:
|
||||
with gr.Column(scale=3):
|
||||
out_ts = gr.State(None)
|
||||
viz_btn = gr.State(None)
|
||||
|
||||
# Put the visualization panel below the three columns
|
||||
if has_aligner:
|
||||
with gr.Row():
|
||||
out_ts_html = gr.HTML(label="Timestamps Visualization (时间戳可视化结果)")
|
||||
else:
|
||||
out_ts_html = gr.State("")
|
||||
|
||||
def run(audio_upload: Any, lang_disp: str, return_ts: bool):
|
||||
audio_obj = _parse_audio_any(audio_upload)
|
||||
|
||||
language = None
|
||||
if lang_disp and lang_disp != "Auto":
|
||||
language = lang_map.get(lang_disp, lang_disp)
|
||||
|
||||
return_ts = bool(return_ts) and has_aligner
|
||||
|
||||
results = asr.transcribe(
|
||||
audio=audio_obj,
|
||||
language=language,
|
||||
return_time_stamps=return_ts,
|
||||
)
|
||||
if not isinstance(results, list) or len(results) != 1:
|
||||
raise RuntimeError(
|
||||
f"Unexpected result size: {type(results)} "
|
||||
f"len={len(results) if isinstance(results, list) else 'N/A'}"
|
||||
)
|
||||
|
||||
r = results[0]
|
||||
|
||||
if has_aligner:
|
||||
ts_payload = None
|
||||
if return_ts:
|
||||
ts_payload = [
|
||||
dict(
|
||||
text=getattr(t, "text", None),
|
||||
start_time=getattr(t, "start_time", None),
|
||||
end_time=getattr(t, "end_time", None),
|
||||
)
|
||||
for t in (getattr(r, "time_stamps", None) or [])
|
||||
]
|
||||
return (
|
||||
getattr(r, "language", "") or "",
|
||||
getattr(r, "text", "") or "",
|
||||
gr.update(value=ts_payload) if return_ts else gr.update(value=None),
|
||||
gr.update(value=""), # clear html on each transcribe
|
||||
)
|
||||
else:
|
||||
return (
|
||||
getattr(r, "language", "") or "",
|
||||
getattr(r, "text", "") or "",
|
||||
)
|
||||
|
||||
def visualize(audio_upload: Any, timestamps_json: Any):
|
||||
return _make_timestamp_html(audio_upload, timestamps_json)
|
||||
|
||||
if has_aligner:
|
||||
btn.click(
|
||||
run,
|
||||
inputs=[audio_in, lang_in, ts_in],
|
||||
outputs=[out_lang, out_text, out_ts, out_ts_html],
|
||||
)
|
||||
viz_btn.click(
|
||||
visualize,
|
||||
inputs=[audio_in, out_ts],
|
||||
outputs=[out_ts_html],
|
||||
)
|
||||
else:
|
||||
btn.click(
|
||||
run,
|
||||
inputs=[audio_in, lang_in, ts_in],
|
||||
outputs=[out_lang, out_text],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_apply_cuda_visible_devices(args.cuda_visible_devices)
|
||||
|
||||
backend = args.backend
|
||||
asr_ckpt = args.asr_checkpoint
|
||||
aligner_ckpt = args.aligner_checkpoint
|
||||
|
||||
user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
|
||||
user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
|
||||
|
||||
backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
|
||||
backend_kwargs = _coerce_special_types(backend_kwargs)
|
||||
|
||||
forced_aligner = None
|
||||
forced_aligner_kwargs = None
|
||||
if aligner_ckpt:
|
||||
forced_aligner = aligner_ckpt
|
||||
aligner_kwargs = _merge_dicts(_default_aligner_kwargs(), user_aligner_kwargs)
|
||||
forced_aligner_kwargs = _coerce_special_types(aligner_kwargs)
|
||||
|
||||
if backend == "transformers":
|
||||
asr = Qwen3ASRModel.from_pretrained(
|
||||
asr_ckpt,
|
||||
forced_aligner=forced_aligner,
|
||||
forced_aligner_kwargs=forced_aligner_kwargs,
|
||||
**backend_kwargs,
|
||||
)
|
||||
else:
|
||||
asr = Qwen3ASRModel.LLM(
|
||||
model=asr_ckpt,
|
||||
forced_aligner=forced_aligner,
|
||||
forced_aligner_kwargs=forced_aligner_kwargs,
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
demo = build_demo(asr, asr_ckpt, backend, aligner_ckpt=aligner_ckpt)
|
||||
|
||||
launch_kwargs: Dict[str, Any] = dict(
|
||||
server_name=args.ip,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
ssl_verify=True if args.ssl_verify else False,
|
||||
)
|
||||
if args.ssl_certfile is not None:
|
||||
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
||||
if args.ssl_keyfile is not None:
|
||||
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
||||
|
||||
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
507
qwen_asr/cli/demo_streaming.py
Normal file
507
qwen_asr/cli/demo_streaming.py
Normal file
@ -0,0 +1,507 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Minimal web demo for Qwen3ASRModel Streaming Inference (vLLM backend).
|
||||
|
||||
Install:
|
||||
pip install qwen-asr[vllm]
|
||||
|
||||
Run:
|
||||
python streaming/demo_qwen3_asr_vllm_streaming.py
|
||||
Open:
|
||||
http://127.0.0.1:7860
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from flask import Flask, Response, jsonify, request
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
state: object
|
||||
created_at: float
|
||||
last_seen: float
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
global asr
|
||||
global UNFIXED_CHUNK_NUM
|
||||
global UNFIXED_TOKEN_NUM
|
||||
global CHUNK_SIZE_SEC
|
||||
|
||||
SESSIONS: Dict[str, Session] = {}
|
||||
SESSION_TTL_SEC = 10 * 60
|
||||
|
||||
|
||||
def _gc_sessions():
|
||||
now = time.time()
|
||||
dead = [sid for sid, s in SESSIONS.items() if now - s.last_seen > SESSION_TTL_SEC]
|
||||
for sid in dead:
|
||||
try:
|
||||
asr.finish_streaming_transcribe(SESSIONS[sid].state)
|
||||
except Exception:
|
||||
pass
|
||||
SESSIONS.pop(sid, None)
|
||||
|
||||
|
||||
def _get_session(session_id: str) -> Optional[Session]:
|
||||
_gc_sessions()
|
||||
s = SESSIONS.get(session_id)
|
||||
if s:
|
||||
s.last_seen = time.time()
|
||||
return s
|
||||
|
||||
|
||||
INDEX_HTML = r"""<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||
<title>Qwen3-ASR Streaming</title>
|
||||
<style>
|
||||
:root{
|
||||
--bg:#ffffff;
|
||||
--card:#ffffff;
|
||||
--muted:#5b6472;
|
||||
--text:#0f172a;
|
||||
--border:#e5e7eb;
|
||||
--ok:#059669;
|
||||
--warn:#d97706;
|
||||
--danger:#e11d48;
|
||||
}
|
||||
|
||||
html, body { height: 100%; }
|
||||
|
||||
body{
|
||||
margin:0;
|
||||
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Noto Sans";
|
||||
background: var(--bg);
|
||||
color:var(--text);
|
||||
}
|
||||
|
||||
.wrap{
|
||||
height: 100vh;
|
||||
max-width: none;
|
||||
margin: 0;
|
||||
padding: 16px;
|
||||
box-sizing: border-box;
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.card{
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: var(--card);
|
||||
border:1px solid var(--border);
|
||||
border-radius: 14px;
|
||||
padding: 16px;
|
||||
box-sizing: border-box;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,.06);
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
h1{ font-size: 16px; margin: 0; letter-spacing:.2px;}
|
||||
|
||||
.row{ display:flex; gap:12px; align-items:center; flex-wrap: wrap; }
|
||||
|
||||
button{
|
||||
border:1px solid var(--border); border-radius: 12px;
|
||||
padding: 10px 14px; cursor:pointer; color:var(--text);
|
||||
background: #f8fafc;
|
||||
transition: transform .05s ease, background .15s ease, border-color .15s ease;
|
||||
font-weight: 700;
|
||||
}
|
||||
button:hover{ background: #f1f5f9; border-color:#cbd5e1; }
|
||||
button:active{ transform: translateY(1px); }
|
||||
button.primary{ border-color: rgba(5,150,105,.35); background: rgba(5,150,105,.10); }
|
||||
button.danger{ border-color: rgba(225,29,72,.35); background: rgba(225,29,72,.10); }
|
||||
button:disabled{ opacity:.5; cursor:not-allowed; }
|
||||
|
||||
.pill{
|
||||
font-size: 12px; padding: 6px 10px; border-radius: 999px;
|
||||
border:1px solid var(--border); color: var(--muted);
|
||||
background: #f8fafc;
|
||||
user-select:none;
|
||||
}
|
||||
.pill.ok{ color: #065f46; border-color: rgba(5,150,105,.35); background: rgba(5,150,105,.10); }
|
||||
.pill.warn{ color: #92400e; border-color: rgba(217,119,6,.35); background: rgba(217,119,6,.10); }
|
||||
.pill.err{ color: #9f1239; border-color: rgba(225,29,72,.35); background: rgba(225,29,72,.10); }
|
||||
|
||||
.panel{
|
||||
border:1px solid var(--border);
|
||||
border-radius: 12px;
|
||||
background: #ffffff;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.panel.textpanel{
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.label{ color:var(--muted); font-size: 12px; margin-bottom: 6px; }
|
||||
.mono{ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New"; }
|
||||
|
||||
#text{
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
white-space: pre-wrap;
|
||||
line-height: 1.6;
|
||||
font-size: 15px;
|
||||
padding: 12px;
|
||||
border-radius: 12px;
|
||||
border: 1px solid var(--border);
|
||||
background: #f8fafc;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
a{ color: #2563eb; text-decoration:none; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<div class="card">
|
||||
<h1>Qwen3-ASR Streaming</h1>
|
||||
|
||||
<div class="row">
|
||||
<button id="btnStart" class="primary">Start / 开始</button>
|
||||
<button id="btnStop" class="danger" disabled>Stop / 停止</button>
|
||||
<span id="status" class="pill warn">Idle / 未开始</span>
|
||||
<a href="javascript:void(0)" id="btnClear" class="mono" style="margin-left:auto;">Clear / 清空</a>
|
||||
</div>
|
||||
|
||||
<div class="panel">
|
||||
<div class="label">Language / 语言</div>
|
||||
<div id="lang" class="mono">—</div>
|
||||
</div>
|
||||
|
||||
<div class="panel textpanel">
|
||||
<div class="label">Text / 文本</div>
|
||||
<div id="text"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
(() => {
|
||||
const $ = (id) => document.getElementById(id);
|
||||
|
||||
const btnStart = $("btnStart");
|
||||
const btnStop = $("btnStop");
|
||||
const btnClear = $("btnClear");
|
||||
const statusEl = $("status");
|
||||
const langEl = $("lang");
|
||||
const textEl = $("text");
|
||||
|
||||
const CHUNK_MS = 500;
|
||||
const TARGET_SR = 16000;
|
||||
|
||||
let audioCtx = null;
|
||||
let processor = null;
|
||||
let source = null;
|
||||
let mediaStream = null;
|
||||
|
||||
let sessionId = null;
|
||||
let running = false;
|
||||
|
||||
let buf = new Float32Array(0);
|
||||
let pushing = false;
|
||||
|
||||
function setStatus(text, cls){
|
||||
statusEl.textContent = text;
|
||||
statusEl.className = "pill " + (cls || "");
|
||||
}
|
||||
|
||||
function lockUI(on){
|
||||
btnStart.disabled = on;
|
||||
btnStop.disabled = !on;
|
||||
}
|
||||
|
||||
function concatFloat32(a, b){
|
||||
const out = new Float32Array(a.length + b.length);
|
||||
out.set(a, 0);
|
||||
out.set(b, a.length);
|
||||
return out;
|
||||
}
|
||||
|
||||
function resampleLinear(input, srcSr, dstSr){
|
||||
if (srcSr === dstSr) return input;
|
||||
const ratio = dstSr / srcSr;
|
||||
const outLen = Math.max(0, Math.round(input.length * ratio));
|
||||
const out = new Float32Array(outLen);
|
||||
for (let i = 0; i < outLen; i++){
|
||||
const x = i / ratio;
|
||||
const x0 = Math.floor(x);
|
||||
const x1 = Math.min(x0 + 1, input.length - 1);
|
||||
const t = x - x0;
|
||||
out[i] = input[x0] * (1 - t) + input[x1] * t;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
async function apiStart(){
|
||||
const r = await fetch("/api/start", {method:"POST"});
|
||||
if(!r.ok) throw new Error(await r.text());
|
||||
const j = await r.json();
|
||||
sessionId = j.session_id;
|
||||
}
|
||||
|
||||
async function apiPushChunk(float32_16k){
|
||||
const r = await fetch("/api/chunk?session_id=" + encodeURIComponent(sessionId), {
|
||||
method: "POST",
|
||||
headers: {"Content-Type":"application/octet-stream"},
|
||||
body: float32_16k.buffer
|
||||
});
|
||||
if(!r.ok) throw new Error(await r.text());
|
||||
return await r.json();
|
||||
}
|
||||
|
||||
async function apiFinish(){
|
||||
const r = await fetch("/api/finish?session_id=" + encodeURIComponent(sessionId), {method:"POST"});
|
||||
if(!r.ok) throw new Error(await r.text());
|
||||
return await r.json();
|
||||
}
|
||||
|
||||
btnClear.onclick = () => { textEl.textContent = ""; };
|
||||
|
||||
async function stopAudioPipeline(){
|
||||
try{
|
||||
if (processor){ processor.disconnect(); processor.onaudioprocess = null; }
|
||||
if (source) source.disconnect();
|
||||
if (audioCtx) await audioCtx.close();
|
||||
if (mediaStream) mediaStream.getTracks().forEach(t => t.stop());
|
||||
}catch(e){}
|
||||
processor = null; source = null; audioCtx = null; mediaStream = null;
|
||||
}
|
||||
|
||||
btnStart.onclick = async () => {
|
||||
if (running) return;
|
||||
|
||||
textEl.textContent = "";
|
||||
langEl.textContent = "—";
|
||||
buf = new Float32Array(0);
|
||||
|
||||
try{
|
||||
setStatus("Starting… / 启动中…", "warn");
|
||||
lockUI(true);
|
||||
|
||||
await apiStart();
|
||||
|
||||
mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
channelCount: 1,
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
autoGainControl: true
|
||||
},
|
||||
video: false
|
||||
});
|
||||
|
||||
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
|
||||
source = audioCtx.createMediaStreamSource(mediaStream);
|
||||
|
||||
processor = audioCtx.createScriptProcessor(4096, 1, 1);
|
||||
const chunkSamples = Math.round(TARGET_SR * (CHUNK_MS / 1000));
|
||||
|
||||
processor.onaudioprocess = (e) => {
|
||||
if (!running) return;
|
||||
const input = e.inputBuffer.getChannelData(0);
|
||||
const resampled = resampleLinear(input, audioCtx.sampleRate, TARGET_SR);
|
||||
buf = concatFloat32(buf, resampled);
|
||||
if (!pushing) pump();
|
||||
};
|
||||
|
||||
source.connect(processor);
|
||||
processor.connect(audioCtx.destination);
|
||||
|
||||
running = true;
|
||||
setStatus("Listening… / 识别中…", "ok");
|
||||
|
||||
}catch(err){
|
||||
console.error(err);
|
||||
setStatus("Start failed / 启动失败: " + err.message, "err");
|
||||
lockUI(false);
|
||||
running = false;
|
||||
sessionId = null;
|
||||
await stopAudioPipeline();
|
||||
}
|
||||
};
|
||||
|
||||
async function pump(){
|
||||
if (pushing) return;
|
||||
pushing = true;
|
||||
|
||||
const chunkSamples = Math.round(TARGET_SR * (CHUNK_MS / 1000));
|
||||
|
||||
try{
|
||||
while (running && buf.length >= chunkSamples){
|
||||
const chunk = buf.slice(0, chunkSamples);
|
||||
buf = buf.slice(chunkSamples);
|
||||
|
||||
const j = await apiPushChunk(chunk);
|
||||
langEl.textContent = j.language || "—";
|
||||
textEl.textContent = j.text || "";
|
||||
if (running) setStatus("Listening… / 识别中…", "ok");
|
||||
}
|
||||
}catch(err){
|
||||
console.error(err);
|
||||
if (running) setStatus("Backend error / 后端错误: " + err.message, "err");
|
||||
}finally{
|
||||
pushing = false;
|
||||
}
|
||||
}
|
||||
|
||||
btnStop.onclick = async () => {
|
||||
if (!running) return;
|
||||
|
||||
running = false;
|
||||
setStatus("Finishing… / 收尾中…", "warn");
|
||||
lockUI(false);
|
||||
|
||||
await stopAudioPipeline();
|
||||
|
||||
try{
|
||||
if (sessionId){
|
||||
const j = await apiFinish();
|
||||
langEl.textContent = j.language || "—";
|
||||
textEl.textContent = j.text || "";
|
||||
}
|
||||
setStatus("Stopped / 已停止", "");
|
||||
}catch(err){
|
||||
console.error(err);
|
||||
setStatus("Finish failed / 收尾失败: " + err.message, "err");
|
||||
}finally{
|
||||
sessionId = null;
|
||||
buf = new Float32Array(0);
|
||||
pushing = false;
|
||||
}
|
||||
};
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def index():
|
||||
return Response(INDEX_HTML, mimetype="text/html; charset=utf-8")
|
||||
|
||||
|
||||
@app.post("/api/start")
|
||||
def api_start():
|
||||
session_id = uuid.uuid4().hex
|
||||
state = asr.init_streaming_state(
|
||||
unfixed_chunk_num=UNFIXED_CHUNK_NUM,
|
||||
unfixed_token_num=UNFIXED_TOKEN_NUM,
|
||||
chunk_size_sec=CHUNK_SIZE_SEC,
|
||||
)
|
||||
now = time.time()
|
||||
SESSIONS[session_id] = Session(state=state, created_at=now, last_seen=now)
|
||||
return jsonify({"session_id": session_id})
|
||||
|
||||
|
||||
@app.post("/api/chunk")
|
||||
def api_chunk():
|
||||
session_id = request.args.get("session_id", "")
|
||||
s = _get_session(session_id)
|
||||
if not s:
|
||||
return jsonify({"error": "invalid session_id"}), 400
|
||||
|
||||
if request.mimetype != "application/octet-stream":
|
||||
return jsonify({"error": "expect application/octet-stream"}), 400
|
||||
|
||||
raw = request.get_data(cache=False)
|
||||
if len(raw) % 4 != 0:
|
||||
return jsonify({"error": "float32 bytes length not multiple of 4"}), 400
|
||||
|
||||
wav = np.frombuffer(raw, dtype=np.float32).reshape(-1)
|
||||
|
||||
asr.streaming_transcribe(wav, s.state)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"language": getattr(s.state, "language", "") or "",
|
||||
"text": getattr(s.state, "text", "") or "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/finish")
|
||||
def api_finish():
|
||||
session_id = request.args.get("session_id", "")
|
||||
s = _get_session(session_id)
|
||||
if not s:
|
||||
return jsonify({"error": "invalid session_id"}), 400
|
||||
|
||||
asr.finish_streaming_transcribe(s.state)
|
||||
out = {
|
||||
"language": getattr(s.state, "language", "") or "",
|
||||
"text": getattr(s.state, "text", "") or "",
|
||||
}
|
||||
SESSIONS.pop(session_id, None)
|
||||
return jsonify(out)
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description="Qwen3-ASR Streaming Web Demo (vLLM backend)")
|
||||
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
|
||||
p.add_argument("--host", default="0.0.0.0", help="Bind host")
|
||||
p.add_argument("--port", type=int, default=8000, help="Bind port")
|
||||
p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization")
|
||||
|
||||
p.add_argument("--unfixed-chunk-num", type=int, default=4)
|
||||
p.add_argument("--unfixed-token-num", type=int, default=5)
|
||||
p.add_argument("--chunk-size-sec", type=float, default=1.0)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
global asr
|
||||
global UNFIXED_CHUNK_NUM
|
||||
global UNFIXED_TOKEN_NUM
|
||||
global CHUNK_SIZE_SEC
|
||||
|
||||
UNFIXED_CHUNK_NUM = args.unfixed_chunk_num
|
||||
UNFIXED_TOKEN_NUM = args.unfixed_token_num
|
||||
CHUNK_SIZE_SEC = args.chunk_size_sec
|
||||
|
||||
asr = Qwen3ASRModel.LLM(
|
||||
model=args.asr_model_path,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
max_new_tokens=32,
|
||||
)
|
||||
print("Model loaded.")
|
||||
app.run(host=args.host, port=args.port, debug=False, use_reloader=False, threaded=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
qwen_asr/cli/serve.py
Normal file
46
qwen_asr/cli/serve.py
Normal file
@ -0,0 +1,46 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
||||
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
||||
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
||||
|
||||
try:
|
||||
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
|
||||
from vllm import ModelRegistry
|
||||
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
"vLLM is not available, to use qwen-asr-serve, please install with: pip install qwen-asr[vllm]"
|
||||
) from e
|
||||
|
||||
from vllm.entrypoints.cli.main import main as vllm_main
|
||||
|
||||
def main():
|
||||
sys.argv.insert(1, "serve")
|
||||
vllm_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user