153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
# coding=utf-8
|
|
# Copyright 2026 The Alibaba Qwen team.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Examples for Qwen3ASRModel (vLLM backend).
|
|
|
|
Covers:
|
|
- single-sample inference (URL audio)
|
|
- batch inference (mixed URL / base64 / (np.ndarray, sr))
|
|
- forcing language (text-only output)
|
|
- returning timestamps (single + batch) via Qwen3ForcedAligner
|
|
|
|
Note:
|
|
Requires vLLM extra:
|
|
pip install qwen-asr[vllm]
|
|
"""
|
|
|
|
import base64
|
|
import io
|
|
import urllib.request
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
|
|
from qwen_asr import Qwen3ASRModel
|
|
|
|
|
|
ASR_MODEL_PATH = "Qwen/Qwen3-ASR-1.7B"
|
|
FORCED_ALIGNER_PATH = "Qwen/Qwen3-ForcedAligner-0.6B"
|
|
|
|
URL_ZH = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav"
|
|
URL_EN = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
|
|
|
|
|
def _download_audio_bytes(url: str, timeout: int = 30) -> bytes:
|
|
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
return resp.read()
|
|
|
|
|
|
def _read_wav_from_bytes(audio_bytes: bytes) -> Tuple[np.ndarray, int]:
|
|
with io.BytesIO(audio_bytes) as f:
|
|
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
|
return np.asarray(wav, dtype=np.float32), int(sr)
|
|
|
|
|
|
def _to_data_url_base64(audio_bytes: bytes, mime: str = "audio/wav") -> str:
|
|
b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
|
return f"data:{mime};base64,{b64}"
|
|
|
|
|
|
def _print_result(title: str, results) -> None:
|
|
print(f"\n===== {title} =====")
|
|
for i, r in enumerate(results):
|
|
print(f"[sample {i}] language={r.language!r}")
|
|
print(f"[sample {i}] text={r.text!r}")
|
|
if r.time_stamps is not None and len(r.time_stamps) > 0:
|
|
head = r.time_stamps[0]
|
|
tail = r.time_stamps[-1]
|
|
print(f"[sample {i}] ts_first: {head.text!r} {head.start_time}->{head.end_time} s")
|
|
print(f"[sample {i}] ts_last : {tail.text!r} {tail.start_time}->{tail.end_time} s")
|
|
|
|
|
|
def test_single_url(asr: Qwen3ASRModel) -> None:
|
|
results = asr.transcribe(
|
|
audio=URL_ZH,
|
|
language=None,
|
|
return_time_stamps=False,
|
|
)
|
|
assert isinstance(results, list) and len(results) == 1
|
|
_print_result("single-url (no forced language, no timestamps)", results)
|
|
|
|
|
|
def test_batch_mixed(asr: Qwen3ASRModel) -> None:
|
|
zh_bytes = _download_audio_bytes(URL_ZH)
|
|
en_bytes = _download_audio_bytes(URL_EN)
|
|
|
|
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
|
en_wav, en_sr = _read_wav_from_bytes(en_bytes)
|
|
|
|
results = asr.transcribe(
|
|
audio=[URL_ZH, zh_b64, (en_wav, en_sr)],
|
|
context=["", "交易 停滞", ""],
|
|
language=[None, "Chinese", "English"],
|
|
return_time_stamps=False,
|
|
)
|
|
assert len(results) == 3
|
|
_print_result("batch-mixed (forced language for some)", results)
|
|
|
|
|
|
def test_single_with_timestamps(asr: Qwen3ASRModel) -> None:
|
|
results = asr.transcribe(
|
|
audio=URL_EN,
|
|
language="English",
|
|
return_time_stamps=True,
|
|
)
|
|
assert len(results) == 1
|
|
assert results[0].time_stamps is not None
|
|
_print_result("single-url (forced language + timestamps)", results)
|
|
|
|
|
|
def test_batch_with_timestamps(asr: Qwen3ASRModel) -> None:
|
|
zh_bytes = _download_audio_bytes(URL_ZH)
|
|
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
|
|
|
results = asr.transcribe(
|
|
audio=[URL_ZH, zh_b64, URL_EN],
|
|
context=["", "交易 停滞", ""],
|
|
language=["Chinese", "Chinese", "English"],
|
|
return_time_stamps=True,
|
|
)
|
|
assert len(results) == 3
|
|
assert all(r.time_stamps is not None for r in results)
|
|
_print_result("batch (forced language + timestamps)", results)
|
|
|
|
|
|
def main() -> None:
|
|
asr = Qwen3ASRModel.LLM(
|
|
model=ASR_MODEL_PATH,
|
|
gpu_memory_utilization=0.8,
|
|
forced_aligner=FORCED_ALIGNER_PATH,
|
|
forced_aligner_kwargs=dict(
|
|
dtype=torch.bfloat16,
|
|
device_map="cuda:0",
|
|
# attn_implementation="flash_attention_2",
|
|
),
|
|
max_inference_batch_size=32,
|
|
max_new_tokens=1024,
|
|
)
|
|
|
|
test_single_url(asr)
|
|
test_batch_mixed(asr)
|
|
test_single_with_timestamps(asr)
|
|
test_batch_with_timestamps(asr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|