Initial commit
This commit is contained in:
150
examples/example_qwen3_asr_transformers.py
Normal file
150
examples/example_qwen3_asr_transformers.py
Normal file
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Examples for Qwen3ASRModel (Transformers backend).
|
||||
|
||||
Covers:
|
||||
- single-sample inference (URL audio)
|
||||
- batch inference (mixed URL / base64 / (np.ndarray, sr))
|
||||
- forcing language (text-only output)
|
||||
- returning time_stamps (single + batch) via Qwen3ForcedAligner
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
|
||||
ASR_MODEL_PATH = "Qwen/Qwen3-ASR-1.7B"
|
||||
FORCED_ALIGNER_PATH = "Qwen/Qwen3-ForcedAligner-0.6B"
|
||||
|
||||
URL_ZH = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav"
|
||||
URL_EN = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
||||
|
||||
|
||||
def _download_audio_bytes(url: str, timeout: int = 30) -> bytes:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return resp.read()
|
||||
|
||||
|
||||
def _read_wav_from_bytes(audio_bytes: bytes) -> Tuple[np.ndarray, int]:
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
return np.asarray(wav, dtype=np.float32), int(sr)
|
||||
|
||||
|
||||
def _to_data_url_base64(audio_bytes: bytes, mime: str = "audio/wav") -> str:
|
||||
b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
|
||||
def _print_result(title: str, results) -> None:
|
||||
print(f"\n===== {title} =====")
|
||||
for i, r in enumerate(results):
|
||||
print(f"[sample {i}] language={r.language!r}")
|
||||
print(f"[sample {i}] text={r.text!r}")
|
||||
if r.time_stamps is not None and len(r.time_stamps) > 0:
|
||||
head = r.time_stamps[0]
|
||||
tail = r.time_stamps[-1]
|
||||
print(f"[sample {i}] ts_first: {head.text!r} {head.start_time}->{head.end_time} s")
|
||||
print(f"[sample {i}] ts_last : {tail.text!r} {tail.start_time}->{tail.end_time} s")
|
||||
|
||||
|
||||
def test_single_url(asr: Qwen3ASRModel) -> None:
|
||||
results = asr.transcribe(
|
||||
audio=URL_ZH,
|
||||
language=None,
|
||||
return_time_stamps=False,
|
||||
)
|
||||
assert isinstance(results, list) and len(results) == 1
|
||||
_print_result("single-url (no forced language, no timestamps)", results)
|
||||
|
||||
|
||||
def test_batch_mixed(asr: Qwen3ASRModel) -> None:
|
||||
zh_bytes = _download_audio_bytes(URL_ZH)
|
||||
en_bytes = _download_audio_bytes(URL_EN)
|
||||
|
||||
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
||||
en_wav, en_sr = _read_wav_from_bytes(en_bytes)
|
||||
|
||||
results = asr.transcribe(
|
||||
audio=[URL_ZH, zh_b64, (en_wav, en_sr)],
|
||||
context=["", "交易 停滞", ""],
|
||||
language=[None, "Chinese", "English"],
|
||||
return_time_stamps=False,
|
||||
)
|
||||
assert len(results) == 3
|
||||
_print_result("batch-mixed (forced language for some)", results)
|
||||
|
||||
|
||||
def test_single_with_timestamps(asr: Qwen3ASRModel) -> None:
|
||||
results = asr.transcribe(
|
||||
audio=URL_EN,
|
||||
language="English",
|
||||
return_time_stamps=True,
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].time_stamps is not None
|
||||
_print_result("single-url (forced language + timestamps)", results)
|
||||
|
||||
|
||||
def test_batch_with_timestamps(asr: Qwen3ASRModel) -> None:
|
||||
zh_bytes = _download_audio_bytes(URL_ZH)
|
||||
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
||||
|
||||
results = asr.transcribe(
|
||||
audio=[URL_ZH, zh_b64, URL_EN],
|
||||
context=["", "交易 停滞", ""],
|
||||
language=["Chinese", "Chinese", "English"],
|
||||
return_time_stamps=True,
|
||||
)
|
||||
assert len(results) == 3
|
||||
assert all(r.time_stamps is not None for r in results)
|
||||
_print_result("batch (forced language + timestamps)", results)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asr = Qwen3ASRModel.from_pretrained(
|
||||
ASR_MODEL_PATH,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
# attn_implementation="flash_attention_2",
|
||||
forced_aligner=FORCED_ALIGNER_PATH,
|
||||
forced_aligner_kwargs=dict(
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
# attn_implementation="flash_attention_2",
|
||||
),
|
||||
max_inference_batch_size=32,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
test_single_url(asr)
|
||||
test_batch_mixed(asr)
|
||||
test_single_with_timestamps(asr)
|
||||
test_batch_with_timestamps(asr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
examples/example_qwen3_asr_vllm.py
Normal file
152
examples/example_qwen3_asr_vllm.py
Normal file
@ -0,0 +1,152 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Examples for Qwen3ASRModel (vLLM backend).
|
||||
|
||||
Covers:
|
||||
- single-sample inference (URL audio)
|
||||
- batch inference (mixed URL / base64 / (np.ndarray, sr))
|
||||
- forcing language (text-only output)
|
||||
- returning timestamps (single + batch) via Qwen3ForcedAligner
|
||||
|
||||
Note:
|
||||
Requires vLLM extra:
|
||||
pip install qwen-asr[vllm]
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
|
||||
ASR_MODEL_PATH = "Qwen/Qwen3-ASR-1.7B"
|
||||
FORCED_ALIGNER_PATH = "Qwen/Qwen3-ForcedAligner-0.6B"
|
||||
|
||||
URL_ZH = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav"
|
||||
URL_EN = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
||||
|
||||
|
||||
def _download_audio_bytes(url: str, timeout: int = 30) -> bytes:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return resp.read()
|
||||
|
||||
|
||||
def _read_wav_from_bytes(audio_bytes: bytes) -> Tuple[np.ndarray, int]:
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
return np.asarray(wav, dtype=np.float32), int(sr)
|
||||
|
||||
|
||||
def _to_data_url_base64(audio_bytes: bytes, mime: str = "audio/wav") -> str:
|
||||
b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
|
||||
def _print_result(title: str, results) -> None:
|
||||
print(f"\n===== {title} =====")
|
||||
for i, r in enumerate(results):
|
||||
print(f"[sample {i}] language={r.language!r}")
|
||||
print(f"[sample {i}] text={r.text!r}")
|
||||
if r.time_stamps is not None and len(r.time_stamps) > 0:
|
||||
head = r.time_stamps[0]
|
||||
tail = r.time_stamps[-1]
|
||||
print(f"[sample {i}] ts_first: {head.text!r} {head.start_time}->{head.end_time} s")
|
||||
print(f"[sample {i}] ts_last : {tail.text!r} {tail.start_time}->{tail.end_time} s")
|
||||
|
||||
|
||||
def test_single_url(asr: Qwen3ASRModel) -> None:
|
||||
results = asr.transcribe(
|
||||
audio=URL_ZH,
|
||||
language=None,
|
||||
return_time_stamps=False,
|
||||
)
|
||||
assert isinstance(results, list) and len(results) == 1
|
||||
_print_result("single-url (no forced language, no timestamps)", results)
|
||||
|
||||
|
||||
def test_batch_mixed(asr: Qwen3ASRModel) -> None:
|
||||
zh_bytes = _download_audio_bytes(URL_ZH)
|
||||
en_bytes = _download_audio_bytes(URL_EN)
|
||||
|
||||
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
||||
en_wav, en_sr = _read_wav_from_bytes(en_bytes)
|
||||
|
||||
results = asr.transcribe(
|
||||
audio=[URL_ZH, zh_b64, (en_wav, en_sr)],
|
||||
context=["", "交易 停滞", ""],
|
||||
language=[None, "Chinese", "English"],
|
||||
return_time_stamps=False,
|
||||
)
|
||||
assert len(results) == 3
|
||||
_print_result("batch-mixed (forced language for some)", results)
|
||||
|
||||
|
||||
def test_single_with_timestamps(asr: Qwen3ASRModel) -> None:
|
||||
results = asr.transcribe(
|
||||
audio=URL_EN,
|
||||
language="English",
|
||||
return_time_stamps=True,
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].time_stamps is not None
|
||||
_print_result("single-url (forced language + timestamps)", results)
|
||||
|
||||
|
||||
def test_batch_with_timestamps(asr: Qwen3ASRModel) -> None:
|
||||
zh_bytes = _download_audio_bytes(URL_ZH)
|
||||
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
||||
|
||||
results = asr.transcribe(
|
||||
audio=[URL_ZH, zh_b64, URL_EN],
|
||||
context=["", "交易 停滞", ""],
|
||||
language=["Chinese", "Chinese", "English"],
|
||||
return_time_stamps=True,
|
||||
)
|
||||
assert len(results) == 3
|
||||
assert all(r.time_stamps is not None for r in results)
|
||||
_print_result("batch (forced language + timestamps)", results)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asr = Qwen3ASRModel.LLM(
|
||||
model=ASR_MODEL_PATH,
|
||||
gpu_memory_utilization=0.8,
|
||||
forced_aligner=FORCED_ALIGNER_PATH,
|
||||
forced_aligner_kwargs=dict(
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
# attn_implementation="flash_attention_2",
|
||||
),
|
||||
max_inference_batch_size=32,
|
||||
max_new_tokens=1024,
|
||||
)
|
||||
|
||||
test_single_url(asr)
|
||||
test_batch_mixed(asr)
|
||||
test_single_with_timestamps(asr)
|
||||
test_batch_with_timestamps(asr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
105
examples/example_qwen3_asr_vllm_streaming.py
Normal file
105
examples/example_qwen3_asr_vllm_streaming.py
Normal file
@ -0,0 +1,105 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Examples for Qwen3ASRModel Streaming Inference (vLLM backend).
|
||||
|
||||
Note:
|
||||
Requires vLLM extra:
|
||||
pip install qwen-asr[vllm]
|
||||
"""
|
||||
|
||||
import io
|
||||
import urllib.request
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
|
||||
ASR_MODEL_PATH = "Qwen/Qwen3-ASR-1.7B"
|
||||
URL_EN = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
||||
|
||||
|
||||
def _download_audio_bytes(url: str, timeout: int = 30) -> bytes:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return resp.read()
|
||||
|
||||
|
||||
def _read_wav_from_bytes(audio_bytes: bytes) -> Tuple[np.ndarray, int]:
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
return np.asarray(wav, dtype=np.float32), int(sr)
|
||||
|
||||
|
||||
def _resample_to_16k(wav: np.ndarray, sr: int) -> np.ndarray:
|
||||
"""Simple resample to 16k if needed (uses linear interpolation; good enough for a test)."""
|
||||
if sr == 16000:
|
||||
return wav.astype(np.float32, copy=False)
|
||||
wav = wav.astype(np.float32, copy=False)
|
||||
dur = wav.shape[0] / float(sr)
|
||||
n16 = int(round(dur * 16000))
|
||||
if n16 <= 0:
|
||||
return np.zeros((0,), dtype=np.float32)
|
||||
x_old = np.linspace(0.0, dur, num=wav.shape[0], endpoint=False)
|
||||
x_new = np.linspace(0.0, dur, num=n16, endpoint=False)
|
||||
return np.interp(x_new, x_old, wav).astype(np.float32)
|
||||
|
||||
|
||||
def run_streaming_case(asr: Qwen3ASRModel, wav16k: np.ndarray, step_ms: int) -> None:
|
||||
sr = 16000
|
||||
step = int(round(step_ms / 1000.0 * sr))
|
||||
|
||||
print(f"\n===== streaming step = {step_ms} ms =====")
|
||||
state = asr.init_streaming_state(
|
||||
unfixed_chunk_num=2,
|
||||
unfixed_token_num=5,
|
||||
chunk_size_sec=2.0,
|
||||
)
|
||||
|
||||
pos = 0
|
||||
call_id = 0
|
||||
while pos < wav16k.shape[0]:
|
||||
seg = wav16k[pos : pos + step]
|
||||
pos += seg.shape[0]
|
||||
call_id += 1
|
||||
asr.streaming_transcribe(seg, state)
|
||||
print(f"[call {call_id:03d}] language={state.language!r} text={state.text!r}")
|
||||
|
||||
asr.finish_streaming_transcribe(state)
|
||||
print(f"[final] language={state.language!r} text={state.text!r}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Streaming is vLLM-only and no forced aligner supported.
|
||||
asr = Qwen3ASRModel.LLM(
|
||||
model=ASR_MODEL_PATH,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_new_tokens=32, # set a small value for streaming
|
||||
)
|
||||
|
||||
audio_bytes = _download_audio_bytes(URL_EN)
|
||||
wav, sr = _read_wav_from_bytes(audio_bytes)
|
||||
wav16k = _resample_to_16k(wav, sr)
|
||||
|
||||
for step_ms in [500, 1000, 2000, 4000]:
|
||||
run_streaming_case(asr, wav16k, step_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
examples/example_qwen3_forced_aligner.py
Normal file
214
examples/example_qwen3_forced_aligner.py
Normal file
@ -0,0 +1,214 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Examples for Qwen3ForcedAligner.
|
||||
|
||||
Covers:
|
||||
- single-sample inference (URL audio)
|
||||
- batch inference (URL audio)
|
||||
- base64 audio input (data:audio/wav;base64,...)
|
||||
- numpy waveform input as (np.ndarray, sr) using urllib request
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from qwen_asr import Qwen3ForcedAligner
|
||||
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen3-ForcedAligner-0.6B"
|
||||
|
||||
URL_ZH = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav"
|
||||
URL_EN = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
||||
|
||||
TEXT_ZH = "甚至出现交易几乎停滞的情况。"
|
||||
TEXT_EN = (
|
||||
"Mm. Oh, yeah, yeah. He wasn't even that big when I started listening to him, "
|
||||
"but and his solo music didn't do overly well, but he did very well when he "
|
||||
"started writing for other people."
|
||||
)
|
||||
|
||||
|
||||
def _download_audio_bytes(url: str, timeout: int = 30) -> bytes:
|
||||
"""
|
||||
Download audio bytes from a URL.
|
||||
|
||||
Args:
|
||||
url (str): Audio URL.
|
||||
timeout (int): Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
bytes: Raw response bytes.
|
||||
"""
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return resp.read()
|
||||
|
||||
|
||||
def _read_wav_from_bytes(audio_bytes: bytes) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Decode audio bytes into waveform and sampling rate.
|
||||
|
||||
Args:
|
||||
audio_bytes (bytes): Encoded audio bytes (wav/flac/ogg supported by libsndfile).
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: (waveform, sr). Waveform may be mono or multi-channel.
|
||||
"""
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
return np.asarray(wav, dtype=np.float32), int(sr)
|
||||
|
||||
|
||||
def _to_data_url_base64(audio_bytes: bytes, mime: str = "audio/wav") -> str:
|
||||
"""
|
||||
Convert audio bytes into a base64 data URL string.
|
||||
|
||||
Args:
|
||||
audio_bytes (bytes): Encoded audio bytes.
|
||||
mime (str): MIME type.
|
||||
|
||||
Returns:
|
||||
str: data:{mime};base64,... string.
|
||||
"""
|
||||
b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
|
||||
def _print_result(title: str, results) -> None:
|
||||
"""
|
||||
Print a compact summary for debugging.
|
||||
|
||||
Args:
|
||||
title (str): Case name.
|
||||
results (List[ForcedAlignResult]): Outputs from aligner.align(...).
|
||||
"""
|
||||
print(f"\n===== {title} =====")
|
||||
for i, r in enumerate(results):
|
||||
n = len(r)
|
||||
head = r[0] if n > 0 else None
|
||||
tail = r[-1] if n > 0 else None
|
||||
print(f"[sample {i}] item={n}")
|
||||
if head is not None:
|
||||
print(f" first: {head.text!r} {head.start_time}->{head.end_time} s")
|
||||
print(f" last : {tail.text!r} {tail.start_time}->{tail.end_time} s")
|
||||
|
||||
|
||||
def test_single_url(aligner: Qwen3ForcedAligner) -> None:
|
||||
"""
|
||||
Single-sample alignment using HTTPS URL audio input.
|
||||
"""
|
||||
results = aligner.align(
|
||||
audio=URL_ZH,
|
||||
text=TEXT_ZH,
|
||||
language="Chinese",
|
||||
)
|
||||
assert isinstance(results, list) and len(results) == 1
|
||||
assert len(results[0]) > 0
|
||||
_print_result("single-url", results)
|
||||
|
||||
|
||||
def test_batch_url(aligner: Qwen3ForcedAligner) -> None:
|
||||
"""
|
||||
Batch alignment using HTTPS URL audio input.
|
||||
"""
|
||||
results = aligner.align(
|
||||
audio=[URL_ZH, URL_EN],
|
||||
text=[TEXT_ZH, TEXT_EN],
|
||||
language=["Chinese", "English"],
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) > 0 and len(results[1]) > 0
|
||||
_print_result("batch-url", results)
|
||||
|
||||
|
||||
def test_base64_data_url(aligner: Qwen3ForcedAligner) -> None:
|
||||
"""
|
||||
Single-sample alignment using base64 data URL audio input.
|
||||
"""
|
||||
audio_bytes = _download_audio_bytes(URL_ZH)
|
||||
b64 = _to_data_url_base64(audio_bytes, mime="audio/wav")
|
||||
|
||||
results = aligner.align(
|
||||
audio=b64,
|
||||
text=TEXT_ZH,
|
||||
language="Chinese",
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) > 0
|
||||
_print_result("single-base64-data-url", results)
|
||||
|
||||
|
||||
def test_numpy_tuple_from_request(aligner: Qwen3ForcedAligner) -> None:
|
||||
"""
|
||||
Single-sample alignment using (np.ndarray, sr) input where waveform is obtained by HTTP request.
|
||||
"""
|
||||
audio_bytes = _download_audio_bytes(URL_EN)
|
||||
wav, sr = _read_wav_from_bytes(audio_bytes)
|
||||
|
||||
results = aligner.align(
|
||||
audio=(wav, sr),
|
||||
text=TEXT_EN,
|
||||
language="English",
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) > 0
|
||||
_print_result("single-numpy-tuple-from-request", results)
|
||||
|
||||
|
||||
def test_batch_mixed_inputs(aligner: Qwen3ForcedAligner) -> None:
|
||||
"""
|
||||
Batch alignment mixing URL, base64, and (np.ndarray, sr) inputs.
|
||||
"""
|
||||
zh_bytes = _download_audio_bytes(URL_ZH)
|
||||
en_bytes = _download_audio_bytes(URL_EN)
|
||||
|
||||
zh_b64 = _to_data_url_base64(zh_bytes, mime="audio/wav")
|
||||
en_wav, en_sr = _read_wav_from_bytes(en_bytes)
|
||||
|
||||
results = aligner.align(
|
||||
audio=[URL_ZH, zh_b64, (en_wav, en_sr)],
|
||||
text=[TEXT_ZH, TEXT_ZH, TEXT_EN],
|
||||
language=["Chinese", "Chinese", "English"],
|
||||
)
|
||||
assert len(results) == 3
|
||||
assert all(len(r) > 0 for r in results)
|
||||
_print_result("batch-mixed-inputs", results)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
aligner = Qwen3ForcedAligner.from_pretrained(
|
||||
MODEL_PATH,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
# attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
test_single_url(aligner)
|
||||
test_batch_url(aligner)
|
||||
test_base64_data_url(aligner)
|
||||
test_numpy_tuple_from_request(aligner)
|
||||
test_batch_mixed_inputs(aligner)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user