Initial commit
This commit is contained in:
21009
qwen_asr/inference/assets/korean_dict_jieba.dict
Normal file
21009
qwen_asr/inference/assets/korean_dict_jieba.dict
Normal file
File diff suppressed because it is too large
Load Diff
821
qwen_asr/inference/qwen3_asr.py
Normal file
821
qwen_asr/inference/qwen3_asr.py
Normal file
@ -0,0 +1,821 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
||||
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
||||
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
||||
|
||||
from .qwen3_forced_aligner import Qwen3ForcedAligner
|
||||
from .utils import (
|
||||
MAX_ASR_INPUT_SECONDS,
|
||||
MAX_FORCE_ALIGN_INPUT_SECONDS,
|
||||
SAMPLE_RATE,
|
||||
SUPPORTED_LANGUAGES,
|
||||
AudioChunk,
|
||||
AudioLike,
|
||||
chunk_list,
|
||||
merge_languages,
|
||||
normalize_audios,
|
||||
normalize_language_name,
|
||||
parse_asr_output,
|
||||
split_audio_into_chunks,
|
||||
validate_language,
|
||||
)
|
||||
|
||||
try:
|
||||
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
|
||||
from vllm import ModelRegistry
|
||||
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRTranscription:
|
||||
"""
|
||||
One transcription result.
|
||||
|
||||
Attributes:
|
||||
language (str):
|
||||
Merged language string for the sample, e.g. "Chinese" or "Chinese,English".
|
||||
Empty string if unknown or silent audio.
|
||||
text (str):
|
||||
Transcribed text.
|
||||
time_stamps (Optional[Any]):
|
||||
Forced aligner output (ForcedAlignResult).
|
||||
Present only when return_time_stamps=True.
|
||||
"""
|
||||
language: str
|
||||
text: str
|
||||
time_stamps: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRStreamingState:
|
||||
"""
|
||||
Streaming ASR state for one audio stream (single utterance).
|
||||
|
||||
Attributes:
|
||||
unfixed_chunk_num (int):
|
||||
For the first N chunks, do not use previous ASR result as prefix prompt (reset prefix to "").
|
||||
unfixed_token_num (int):
|
||||
When chunk_id >= unfixed_chunk_num, rollback the last K tokens from the accumulated text
|
||||
before using it as prefix prompt, to reduce boundary jitter.
|
||||
chunk_size_sec (float):
|
||||
Chunk size in seconds. Audio will be fed to the model in increments of this length.
|
||||
chunk_size_samples (int):
|
||||
Chunk size in samples at 16kHz (derived from chunk_size_sec).
|
||||
chunk_id (int):
|
||||
Current chunk index (0-based).
|
||||
buffer (np.ndarray):
|
||||
Buffered PCM samples that are not yet consumed into a full chunk.
|
||||
audio_accum (np.ndarray):
|
||||
Accumulated audio from the beginning of the stream up to current time (no padding).
|
||||
prompt_raw (str):
|
||||
Base prompt generated by chat template (with generation prompt), without appended prefix text.
|
||||
context (str):
|
||||
Context string.
|
||||
force_language (Optional[str]):
|
||||
If provided, force output to be text-only by appending "language X<asr_text>" in prompt_raw,
|
||||
consistent with non-streaming transcribe().
|
||||
language (str):
|
||||
Latest parsed language (updated after each chunk decode). Empty if unknown/silent.
|
||||
text (str):
|
||||
Latest parsed transcription text (updated after each chunk decode).
|
||||
_raw_decoded (str):
|
||||
Internal accumulated decoded raw text (before parse_asr_output normalization).
|
||||
Used for rollback/token trimming and as prefix for prompting.
|
||||
"""
|
||||
unfixed_chunk_num: int
|
||||
unfixed_token_num: int
|
||||
chunk_size_sec: float
|
||||
chunk_size_samples: int
|
||||
|
||||
chunk_id: int
|
||||
buffer: np.ndarray
|
||||
audio_accum: np.ndarray
|
||||
|
||||
prompt_raw: str
|
||||
context: str
|
||||
force_language: Optional[str]
|
||||
|
||||
language: str
|
||||
text: str
|
||||
_raw_decoded: str
|
||||
|
||||
|
||||
class Qwen3ASRModel:
|
||||
"""
|
||||
Unified inference wrapper for Qwen3-ASR with two backends:
|
||||
- Transformers backend
|
||||
- vLLM backend
|
||||
|
||||
It optionally supports time stamp output via Qwen3-ForcedAligner.
|
||||
|
||||
Notes:
|
||||
- Each request uses a context text and exactly one audio.
|
||||
- If language is provided, the prompt will force the output to be text-only by appending
|
||||
"language {Language}<asr_text>" to the assistant prompt.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: str,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
sampling_params: Optional[Any] = None,
|
||||
forced_aligner: Optional[Qwen3ForcedAligner] = None,
|
||||
max_inference_batch_size: int = -1,
|
||||
max_new_tokens: int = 512,
|
||||
):
|
||||
self.backend = backend # "transformers" | "vllm"
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.sampling_params = sampling_params
|
||||
self.forced_aligner = forced_aligner
|
||||
self.max_inference_batch_size = int(max_inference_batch_size)
|
||||
self.max_new_tokens = max_new_tokens
|
||||
|
||||
if backend == "transformers":
|
||||
self.device = getattr(model, "device", None)
|
||||
if self.device is None:
|
||||
try:
|
||||
self.device = next(model.parameters()).device
|
||||
except StopIteration:
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = getattr(model, "dtype", torch.float32)
|
||||
else:
|
||||
self.device = None
|
||||
self.dtype = None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
forced_aligner: Optional[str] = None,
|
||||
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_inference_batch_size: int = 32,
|
||||
max_new_tokens: Optional[int] = 512,
|
||||
**kwargs,
|
||||
) -> "Qwen3ASRModel":
|
||||
"""
|
||||
Initialize using Transformers backend.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path:
|
||||
HuggingFace repo id or local directory.
|
||||
forced_aligner:
|
||||
Optional forced aligner model path/repo id.
|
||||
forced_aligner_kwargs:
|
||||
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
|
||||
max_inference_batch_size:
|
||||
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
|
||||
max_new_tokens:
|
||||
Maximum number of tokens to generate.
|
||||
**kwargs:
|
||||
Forwarded to AutoModel.from_pretrained(...).
|
||||
|
||||
Returns:
|
||||
Qwen3ASRModel
|
||||
"""
|
||||
|
||||
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
|
||||
|
||||
forced_aligner_model = None
|
||||
if forced_aligner is not None:
|
||||
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
|
||||
forced_aligner, **(forced_aligner_kwargs or {})
|
||||
)
|
||||
|
||||
return cls(
|
||||
backend="transformers",
|
||||
model=model,
|
||||
processor=processor,
|
||||
sampling_params=None,
|
||||
forced_aligner=forced_aligner_model,
|
||||
max_inference_batch_size=max_inference_batch_size,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def LLM(
|
||||
cls,
|
||||
model: str,
|
||||
forced_aligner: Optional[str] = None,
|
||||
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_inference_batch_size: int = -1,
|
||||
max_new_tokens: Optional[int] = 4096,
|
||||
**kwargs,
|
||||
) -> "Qwen3ASRModel":
|
||||
"""
|
||||
Initialize using vLLM backend.
|
||||
|
||||
Import is isolated to keep vLLM optional.
|
||||
|
||||
Args:
|
||||
model:
|
||||
Model path/repo for vLLM.
|
||||
forced_aligner:
|
||||
Optional forced aligner model path/repo id.
|
||||
forced_aligner_kwargs:
|
||||
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
|
||||
max_inference_batch_size:
|
||||
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
|
||||
max_new_tokens:
|
||||
Maximum number of tokens to generate.
|
||||
**kwargs:
|
||||
Forwarded to vllm.LLM(...).
|
||||
|
||||
Returns:
|
||||
Qwen3ASRModel
|
||||
|
||||
Raises:
|
||||
ImportError: If vLLM is not installed.
|
||||
"""
|
||||
try:
|
||||
from vllm import LLM as vLLM
|
||||
from vllm import SamplingParams
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
"vLLM is not available. Install with: pip install qwen-asr[vllm]"
|
||||
) from e
|
||||
|
||||
llm = vLLM(model=model, **kwargs)
|
||||
|
||||
processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True)
|
||||
sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens}))
|
||||
|
||||
forced_aligner_model = None
|
||||
if forced_aligner is not None:
|
||||
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
|
||||
forced_aligner, **(forced_aligner_kwargs or {})
|
||||
)
|
||||
|
||||
return cls(
|
||||
backend="vllm",
|
||||
model=llm,
|
||||
processor=processor,
|
||||
sampling_params=sampling_params,
|
||||
forced_aligner=forced_aligner_model,
|
||||
max_inference_batch_size=max_inference_batch_size,
|
||||
max_new_tokens=None,
|
||||
)
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""
|
||||
Returns the supported language list.
|
||||
|
||||
Returns:
|
||||
List[str]: Canonical language names.
|
||||
"""
|
||||
return list(SUPPORTED_LANGUAGES)
|
||||
|
||||
@torch.no_grad()
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[AudioLike, List[AudioLike]],
|
||||
context: Union[str, List[str]] = "",
|
||||
language: Optional[Union[str, List[Optional[str]]]] = None,
|
||||
return_time_stamps: bool = False,
|
||||
) -> List[ASRTranscription]:
|
||||
"""
|
||||
Transcribe audio with optional context and optional forced alignment timestamps.
|
||||
|
||||
Args:
|
||||
audio:
|
||||
Audio input(s). Supported:
|
||||
- str: local path / URL / base64 data url
|
||||
- (np.ndarray, sr)
|
||||
- list of above
|
||||
context:
|
||||
Context string(s). If scalar, it will be broadcast to batch size.
|
||||
language:
|
||||
Optional language(s). If provided, it must be in supported languages.
|
||||
If scalar, it will be broadcast to batch size.
|
||||
If provided, the prompt will force output to be transcription text only.
|
||||
return_time_stamps:
|
||||
If True, timestamps are produced via forced aligner and merged across chunks.
|
||||
This requires forced_aligner initialized.
|
||||
|
||||
Returns:
|
||||
List[ASRTranscription]: One result per input audio.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- If return_time_stamps=True but forced_aligner is not provided.
|
||||
- If language is unsupported.
|
||||
- If batch sizes mismatch for context/language.
|
||||
"""
|
||||
if return_time_stamps and self.forced_aligner is None:
|
||||
raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.")
|
||||
|
||||
wavs = normalize_audios(audio)
|
||||
n = len(wavs)
|
||||
|
||||
ctxs = context if isinstance(context, list) else [context]
|
||||
if len(ctxs) == 1 and n > 1:
|
||||
ctxs = ctxs * n
|
||||
if len(ctxs) != n:
|
||||
raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}")
|
||||
|
||||
langs_in: List[Optional[str]]
|
||||
if language is None:
|
||||
langs_in = [None] * n
|
||||
else:
|
||||
langs_in = language if isinstance(language, list) else [language]
|
||||
if len(langs_in) == 1 and n > 1:
|
||||
langs_in = langs_in * n
|
||||
if len(langs_in) != n:
|
||||
raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}")
|
||||
|
||||
langs_norm: List[Optional[str]] = []
|
||||
for l in langs_in:
|
||||
if l is None or str(l).strip() == "":
|
||||
langs_norm.append(None)
|
||||
else:
|
||||
ln = normalize_language_name(str(l))
|
||||
validate_language(ln)
|
||||
langs_norm.append(ln)
|
||||
|
||||
max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS
|
||||
|
||||
# chunk audios and record mapping
|
||||
chunks: List[AudioChunk] = []
|
||||
for i, wav in enumerate(wavs):
|
||||
parts = split_audio_into_chunks(
|
||||
wav=wav,
|
||||
sr=SAMPLE_RATE,
|
||||
max_chunk_sec=max_chunk_sec,
|
||||
)
|
||||
for j, (cwav, offset_sec) in enumerate(parts):
|
||||
chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec))
|
||||
|
||||
# run ASR on chunks
|
||||
chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks]
|
||||
chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks]
|
||||
chunk_wavs: List[np.ndarray] = [c.wav for c in chunks]
|
||||
raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang)
|
||||
|
||||
# parse outputs, prepare for optional alignment
|
||||
per_chunk_lang: List[str] = []
|
||||
per_chunk_text: List[str] = []
|
||||
for out, forced_lang in zip(raw_outputs, chunk_lang):
|
||||
lang, txt = parse_asr_output(out, user_language=forced_lang)
|
||||
per_chunk_lang.append(lang)
|
||||
per_chunk_text.append(txt)
|
||||
|
||||
# forced alignment (optional)
|
||||
per_chunk_align: List[Optional[Any]] = [None] * len(chunks)
|
||||
if return_time_stamps:
|
||||
to_align_audio = []
|
||||
to_align_text = []
|
||||
to_align_lang = []
|
||||
to_align_idx = []
|
||||
|
||||
for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)):
|
||||
if txt.strip() == "":
|
||||
continue
|
||||
to_align_audio.append((c.wav, c.sr))
|
||||
to_align_text.append(txt)
|
||||
to_align_lang.append(lang_pred)
|
||||
to_align_idx.append(idx)
|
||||
|
||||
# batch align with max_inference_batch_size
|
||||
aligned_results: List[Any] = []
|
||||
for a_chunk, t_chunk, l_chunk in zip(
|
||||
chunk_list(to_align_audio, self.max_inference_batch_size),
|
||||
chunk_list(to_align_text, self.max_inference_batch_size),
|
||||
chunk_list(to_align_lang, self.max_inference_batch_size),
|
||||
):
|
||||
aligned_results.extend(
|
||||
self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk)
|
||||
)
|
||||
|
||||
# offset fix
|
||||
for k, idx in enumerate(to_align_idx):
|
||||
c = chunks[idx]
|
||||
r = aligned_results[k]
|
||||
per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec)
|
||||
|
||||
# merge chunks back to original samples
|
||||
out_langs: List[List[str]] = [[] for _ in range(n)]
|
||||
out_texts: List[List[str]] = [[] for _ in range(n)]
|
||||
out_aligns: List[List[Any]] = [[] for _ in range(n)]
|
||||
|
||||
for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align):
|
||||
out_langs[c.orig_index].append(lang)
|
||||
out_texts[c.orig_index].append(txt)
|
||||
if return_time_stamps and al is not None:
|
||||
out_aligns[c.orig_index].append(al)
|
||||
|
||||
results: List[ASRTranscription] = []
|
||||
for i in range(n):
|
||||
merged_text = "".join([t for t in out_texts[i] if t is not None])
|
||||
merged_language = merge_languages(out_langs[i])
|
||||
merged_align = None
|
||||
if return_time_stamps:
|
||||
merged_align = self._merge_align_results(out_aligns[i])
|
||||
results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align))
|
||||
|
||||
return results
|
||||
|
||||
def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": context or ""},
|
||||
{"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
|
||||
]
|
||||
|
||||
def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str:
|
||||
"""
|
||||
Build the string prompt for one request.
|
||||
|
||||
If force_language is provided, "language X<asr_text>" is appended after the generation prompt
|
||||
to request text-only output.
|
||||
"""
|
||||
msgs = self._build_messages(context=context, audio_payload="")
|
||||
base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
|
||||
if force_language:
|
||||
base = base + f"language {force_language}{'<asr_text>'}"
|
||||
return base
|
||||
|
||||
def _infer_asr(
|
||||
self,
|
||||
contexts: List[str],
|
||||
wavs: List[np.ndarray],
|
||||
languages: List[Optional[str]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Run backend inference for chunk-level items.
|
||||
|
||||
Args:
|
||||
contexts: List of context strings.
|
||||
wavs: List of mono waveforms (np.ndarray).
|
||||
languages: List of forced languages or None.
|
||||
|
||||
Returns:
|
||||
List[str]: Raw decoded strings (one per chunk).
|
||||
"""
|
||||
if self.backend == "transformers":
|
||||
return self._infer_asr_transformers(contexts, wavs, languages)
|
||||
if self.backend == "vllm":
|
||||
return self._infer_asr_vllm(contexts, wavs, languages)
|
||||
raise RuntimeError(f"Unknown backend: {self.backend}")
|
||||
|
||||
def _infer_asr_transformers(
|
||||
self,
|
||||
contexts: List[str],
|
||||
wavs: List[np.ndarray],
|
||||
languages: List[Optional[str]],
|
||||
) -> List[str]:
|
||||
outs: List[str] = []
|
||||
|
||||
texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)]
|
||||
|
||||
batch_size = self.max_inference_batch_size
|
||||
if batch_size is None or batch_size < 0:
|
||||
batch_size = len(texts)
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
sub_text = texts[i : i + batch_size]
|
||||
sub_wavs = wavs[i : i + batch_size]
|
||||
inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(self.model.device).to(self.model.dtype)
|
||||
|
||||
text_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
|
||||
|
||||
decoded = self.processor.batch_decode(
|
||||
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
outs.extend(list(decoded))
|
||||
|
||||
return outs
|
||||
|
||||
def _infer_asr_vllm(
|
||||
self,
|
||||
contexts: List[str],
|
||||
wavs: List[np.ndarray],
|
||||
languages: List[Optional[str]],
|
||||
) -> List[str]:
|
||||
inputs: List[Dict[str, Any]] = []
|
||||
for c, w, fl in zip(contexts, wavs, languages):
|
||||
prompt = self._build_text_prompt(context=c, force_language=fl)
|
||||
inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}})
|
||||
|
||||
outs: List[str] = []
|
||||
for batch in chunk_list(inputs, self.max_inference_batch_size):
|
||||
outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False)
|
||||
for o in outputs:
|
||||
outs.append(o.outputs[0].text)
|
||||
return outs
|
||||
|
||||
def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
|
||||
"""
|
||||
Apply time offset to a ForcedAlignResult-like object.
|
||||
|
||||
This function assumes:
|
||||
- result has attribute `.items` which is a list of items with start_time/end_time in seconds.
|
||||
- dataclasses are frozen in upstream implementation, so we reconstruct by type.
|
||||
|
||||
Args:
|
||||
result: ForcedAlignResult
|
||||
offset_sec: Offset in seconds
|
||||
|
||||
Returns:
|
||||
ForcedAlignResult: New object with shifted timestamps.
|
||||
"""
|
||||
if result is None:
|
||||
return None
|
||||
items = []
|
||||
for it in result.items:
|
||||
items.append(type(it)(text=it.text,
|
||||
start_time=round(it.start_time + offset_sec, 3),
|
||||
end_time=round(it.end_time + offset_sec, 3)))
|
||||
return type(result)(items=items)
|
||||
|
||||
def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
|
||||
"""
|
||||
Merge multiple ForcedAlignResult objects into a single one by concatenating items.
|
||||
|
||||
Args:
|
||||
results: List of ForcedAlignResult
|
||||
|
||||
Returns:
|
||||
ForcedAlignResult or None
|
||||
"""
|
||||
if not results:
|
||||
return None
|
||||
all_items = []
|
||||
for r in results:
|
||||
if r is None:
|
||||
continue
|
||||
all_items.extend(list(r.items))
|
||||
if not all_items:
|
||||
return None
|
||||
return type(results[0])(items=all_items)
|
||||
|
||||
def init_streaming_state(
|
||||
self,
|
||||
context: str = "",
|
||||
language: Optional[str] = None,
|
||||
unfixed_chunk_num: int = 2,
|
||||
unfixed_token_num: int = 5,
|
||||
chunk_size_sec: float = 2.0,
|
||||
) -> ASRStreamingState:
|
||||
"""
|
||||
Initialize streaming ASR state for a single stream.
|
||||
|
||||
Notes:
|
||||
- Streaming ASR is supported ONLY for vLLM backend.
|
||||
- Streaming ASR does NOT support timestamps (forced aligner is not used).
|
||||
- Batch inference is NOT supported.
|
||||
|
||||
Args:
|
||||
context:
|
||||
Context string.
|
||||
language:
|
||||
Optional forced language. If provided, it must be in supported languages.
|
||||
Same behavior as transcribe(): forces text-only output via prompt suffix.
|
||||
unfixed_chunk_num:
|
||||
For the first N chunks, do not use previous output as prefix prompt (reset prefix to "").
|
||||
unfixed_token_num:
|
||||
Roll back the last K tokens from accumulated output when using it as prefix prompt
|
||||
after unfixed_chunk_num.
|
||||
chunk_size_sec:
|
||||
Chunk size in seconds (audio is 16k PCM). The function will internally convert it
|
||||
to sample count at 16kHz.
|
||||
|
||||
Returns:
|
||||
ASRStreamingState: Mutable state object to be passed to streaming_transcribe() and
|
||||
finish_streaming_transcribe().
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- If backend is not "vllm".
|
||||
- If chunk_size_sec <= 0.
|
||||
- If forced language is invalid (same validation rules as transcribe()).
|
||||
"""
|
||||
if self.backend != "vllm":
|
||||
raise ValueError("Streaming ASR is supported only for vLLM backend (backend='vllm').")
|
||||
if chunk_size_sec is None or float(chunk_size_sec) <= 0:
|
||||
raise ValueError(f"chunk_size_sec must be > 0, got: {chunk_size_sec}")
|
||||
|
||||
force_language = None
|
||||
if language is not None and str(language).strip() != "":
|
||||
ln = normalize_language_name(str(language))
|
||||
validate_language(ln)
|
||||
force_language = ln
|
||||
|
||||
chunk_size_samples = int(round(float(chunk_size_sec) * SAMPLE_RATE))
|
||||
chunk_size_samples = max(1, chunk_size_samples)
|
||||
|
||||
prompt_raw = self._build_text_prompt(context=context, force_language=force_language)
|
||||
|
||||
return ASRStreamingState(
|
||||
unfixed_chunk_num=int(unfixed_chunk_num),
|
||||
unfixed_token_num=int(unfixed_token_num),
|
||||
chunk_size_sec=float(chunk_size_sec),
|
||||
chunk_size_samples=int(chunk_size_samples),
|
||||
chunk_id=0,
|
||||
buffer=np.zeros((0,), dtype=np.float32),
|
||||
audio_accum=np.zeros((0,), dtype=np.float32),
|
||||
prompt_raw=prompt_raw,
|
||||
context=context or "",
|
||||
force_language=force_language,
|
||||
language="",
|
||||
text="",
|
||||
_raw_decoded="",
|
||||
)
|
||||
|
||||
def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) -> ASRStreamingState:
|
||||
"""
|
||||
Streaming ASR decode step.
|
||||
|
||||
This function accepts an arbitrary-length 16k PCM float numpy array (mono).
|
||||
It buffers incoming samples, and whenever enough samples are accumulated to form one
|
||||
full chunk (chunk_size_sec), it runs one incremental decode step and updates:
|
||||
|
||||
- state.language
|
||||
- state.text
|
||||
|
||||
The caller only needs to keep passing audio to this function and read state.language/state.text.
|
||||
|
||||
Implementation details:
|
||||
- Each time a new chunk is ready, we append it to audio_accum and re-feed *all* audio seen
|
||||
so far to the model (no padding).
|
||||
- We update the prompt as: state.prompt_raw + prefix_text
|
||||
- Prefix rollback strategy:
|
||||
* If chunk_id < unfixed_chunk_num: prefix_text = ""
|
||||
* Else: rollback last unfixed_token_num tokens from previously accumulated decoded text.
|
||||
|
||||
Notes:
|
||||
- vLLM backend only.
|
||||
- No timestamps.
|
||||
- Single stream only (no batching).
|
||||
|
||||
Args:
|
||||
pcm16k:
|
||||
16kHz mono PCM waveform (np.ndarray). Length can be any non-negative integer.
|
||||
dtype can be float32/float64/int16; it will be converted to float32.
|
||||
state:
|
||||
Streaming state returned by init_streaming_state().
|
||||
|
||||
Returns:
|
||||
ASRStreamingState: The same state object (mutated) for convenience.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If backend is not "vllm" or state is invalid.
|
||||
"""
|
||||
if self.backend != "vllm":
|
||||
raise ValueError("streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
|
||||
if state is None:
|
||||
raise ValueError("state must not be None. Call init_streaming_state() first.")
|
||||
if pcm16k is None:
|
||||
raise ValueError("pcm16k must not be None.")
|
||||
|
||||
# Ensure 1D mono
|
||||
x = np.asarray(pcm16k)
|
||||
if x.ndim != 1:
|
||||
x = x.reshape(-1)
|
||||
|
||||
# Convert to float32 PCM in [-1, 1] if int16 provided
|
||||
if x.dtype == np.int16:
|
||||
x = (x.astype(np.float32) / 32768.0)
|
||||
else:
|
||||
x = x.astype(np.float32, copy=False)
|
||||
|
||||
# Append to buffer
|
||||
if x.shape[0] > 0:
|
||||
state.buffer = np.concatenate([state.buffer, x], axis=0)
|
||||
|
||||
# Consume full chunks
|
||||
while state.buffer.shape[0] >= state.chunk_size_samples:
|
||||
chunk = state.buffer[: state.chunk_size_samples]
|
||||
state.buffer = state.buffer[state.chunk_size_samples :]
|
||||
|
||||
# Accumulate audio (re-feed from start, no padding)
|
||||
if state.audio_accum.shape[0] == 0:
|
||||
state.audio_accum = chunk
|
||||
else:
|
||||
state.audio_accum = np.concatenate([state.audio_accum, chunk], axis=0)
|
||||
|
||||
# Build prefix with rollback strategy
|
||||
prefix = ""
|
||||
if state.chunk_id < state.unfixed_chunk_num:
|
||||
prefix = ""
|
||||
else:
|
||||
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
|
||||
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
|
||||
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
|
||||
|
||||
prompt = state.prompt_raw + prefix
|
||||
|
||||
# vLLM input: single item
|
||||
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
|
||||
|
||||
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
|
||||
gen_text = outputs[0].outputs[0].text
|
||||
|
||||
# Accumulate raw decoded (then parse to lang/text)
|
||||
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
|
||||
|
||||
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
|
||||
state.language = lang
|
||||
state.text = txt
|
||||
|
||||
state.chunk_id += 1
|
||||
|
||||
return state
|
||||
|
||||
def finish_streaming_transcribe(self, state: ASRStreamingState) -> ASRStreamingState:
|
||||
"""
|
||||
Finish streaming ASR.
|
||||
|
||||
This function flushes the remaining buffered audio in state.buffer (tail audio).
|
||||
It sends the remaining samples to the model even if shorter than chunk_size_sec,
|
||||
without padding. Then it updates state.language/state.text one last time.
|
||||
|
||||
Notes:
|
||||
- vLLM backend only.
|
||||
- No timestamps.
|
||||
- Single stream only.
|
||||
|
||||
Args:
|
||||
state:
|
||||
Streaming state.
|
||||
|
||||
Returns:
|
||||
ASRStreamingState: Updated state (mutated).
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If backend is not "vllm" or state is invalid.
|
||||
"""
|
||||
if self.backend != "vllm":
|
||||
raise ValueError("finish_streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
|
||||
if state is None:
|
||||
raise ValueError("state must not be None.")
|
||||
|
||||
# If no remaining buffer, still return state as-is.
|
||||
if state.buffer is None or state.buffer.shape[0] == 0:
|
||||
return state
|
||||
|
||||
tail = state.buffer
|
||||
state.buffer = np.zeros((0,), dtype=np.float32)
|
||||
|
||||
# Append tail to accumulated audio
|
||||
if state.audio_accum.shape[0] == 0:
|
||||
state.audio_accum = tail
|
||||
else:
|
||||
state.audio_accum = np.concatenate([state.audio_accum, tail], axis=0)
|
||||
|
||||
# Prefix rollback strategy (same as per-chunk)
|
||||
prefix = ""
|
||||
if state.chunk_id < state.unfixed_chunk_num:
|
||||
prefix = ""
|
||||
else:
|
||||
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
|
||||
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
|
||||
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
|
||||
|
||||
prompt = state.prompt_raw + prefix
|
||||
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
|
||||
|
||||
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
|
||||
gen_text = outputs[0].outputs[0].text
|
||||
|
||||
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
|
||||
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
|
||||
state.language = lang
|
||||
state.text = txt
|
||||
|
||||
state.chunk_id += 1
|
||||
return state
|
||||
483
qwen_asr/inference/qwen3_forced_aligner.py
Normal file
483
qwen_asr/inference/qwen3_forced_aligner.py
Normal file
@ -0,0 +1,483 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import unicodedata
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import nagisa
|
||||
import torch
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
from .utils import (
|
||||
AudioLike,
|
||||
ensure_list,
|
||||
normalize_audios,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3ForceAlignProcessor():
|
||||
def __init__(self):
|
||||
ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict")
|
||||
ko_scores = {}
|
||||
with open(ko_dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
word = line.split()[0]
|
||||
ko_scores[word] = 1.0
|
||||
self.ko_score = ko_scores
|
||||
self.ko_tokenizer = None
|
||||
|
||||
def is_kept_char(self, ch: str) -> bool:
|
||||
if ch == "'":
|
||||
return True
|
||||
cat = unicodedata.category(ch)
|
||||
if cat.startswith("L") or cat.startswith("N"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def clean_token(self, token: str) -> str:
|
||||
return "".join(ch for ch in token if self.is_kept_char(ch))
|
||||
|
||||
def is_cjk_char(self, ch: str) -> bool:
|
||||
code = ord(ch)
|
||||
return (
|
||||
0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs
|
||||
or 0x3400 <= code <= 0x4DBF # Extension A
|
||||
or 0x20000 <= code <= 0x2A6DF # Extension B
|
||||
or 0x2A700 <= code <= 0x2B73F # Extension C
|
||||
or 0x2B740 <= code <= 0x2B81F # Extension D
|
||||
or 0x2B820 <= code <= 0x2CEAF # Extension E
|
||||
or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs
|
||||
)
|
||||
|
||||
def tokenize_chinese_mixed(self, text: str) -> List[str]:
|
||||
tokens: List[str] = []
|
||||
current_latin: List[str] = []
|
||||
|
||||
def flush_latin():
|
||||
nonlocal current_latin
|
||||
if current_latin:
|
||||
token = "".join(current_latin)
|
||||
cleaned = self.clean_token(token)
|
||||
if cleaned:
|
||||
tokens.append(cleaned)
|
||||
current_latin = []
|
||||
|
||||
for ch in text:
|
||||
if self.is_cjk_char(ch):
|
||||
flush_latin()
|
||||
tokens.append(ch)
|
||||
else:
|
||||
if self.is_kept_char(ch):
|
||||
current_latin.append(ch)
|
||||
else:
|
||||
flush_latin()
|
||||
|
||||
flush_latin()
|
||||
|
||||
return tokens
|
||||
|
||||
def tokenize_japanese(self, text: str) -> List[str]:
|
||||
words = nagisa.tagging(text).words
|
||||
tokens: List[str] = []
|
||||
for w in words:
|
||||
cleaned = self.clean_token(w)
|
||||
if cleaned:
|
||||
tokens.append(cleaned)
|
||||
return tokens
|
||||
|
||||
def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]:
|
||||
raw_tokens = ko_tokenizer.tokenize(text)
|
||||
tokens: List[str] = []
|
||||
for w in raw_tokens:
|
||||
w_clean = self.clean_token(w)
|
||||
if w_clean:
|
||||
tokens.append(w_clean)
|
||||
return tokens
|
||||
|
||||
def split_segment_with_chinese(self, seg: str) -> List[str]:
|
||||
tokens: List[str] = []
|
||||
buf: List[str] = []
|
||||
|
||||
def flush_buf():
|
||||
nonlocal buf
|
||||
if buf:
|
||||
tokens.append("".join(buf))
|
||||
buf = []
|
||||
|
||||
for ch in seg:
|
||||
if self.is_cjk_char(ch):
|
||||
flush_buf()
|
||||
tokens.append(ch)
|
||||
else:
|
||||
buf.append(ch)
|
||||
|
||||
flush_buf()
|
||||
return tokens
|
||||
|
||||
def tokenize_space_lang(self, text: str) -> List[str]:
|
||||
tokens: List[str] = []
|
||||
for seg in text.split():
|
||||
cleaned = self.clean_token(seg)
|
||||
if cleaned:
|
||||
tokens.extend(self.split_segment_with_chinese(cleaned))
|
||||
return tokens
|
||||
|
||||
def fix_timestamp(self, data) -> List[int]:
|
||||
data = data.tolist()
|
||||
n = len(data)
|
||||
|
||||
dp = [1] * n
|
||||
parent = [-1] * n
|
||||
|
||||
for i in range(1, n):
|
||||
for j in range(i):
|
||||
if data[j] <= data[i] and dp[j] + 1 > dp[i]:
|
||||
dp[i] = dp[j] + 1
|
||||
parent[i] = j
|
||||
|
||||
max_length = max(dp)
|
||||
max_idx = dp.index(max_length)
|
||||
|
||||
lis_indices = []
|
||||
idx = max_idx
|
||||
while idx != -1:
|
||||
lis_indices.append(idx)
|
||||
idx = parent[idx]
|
||||
lis_indices.reverse()
|
||||
|
||||
is_normal = [False] * n
|
||||
for idx in lis_indices:
|
||||
is_normal[idx] = True
|
||||
|
||||
result = data.copy()
|
||||
i = 0
|
||||
|
||||
while i < n:
|
||||
if not is_normal[i]:
|
||||
j = i
|
||||
while j < n and not is_normal[j]:
|
||||
j += 1
|
||||
|
||||
anomaly_count = j - i
|
||||
|
||||
if anomaly_count <= 2:
|
||||
left_val = None
|
||||
for k in range(i - 1, -1, -1):
|
||||
if is_normal[k]:
|
||||
left_val = result[k]
|
||||
break
|
||||
|
||||
right_val = None
|
||||
for k in range(j, n):
|
||||
if is_normal[k]:
|
||||
right_val = result[k]
|
||||
break
|
||||
|
||||
for k in range(i, j):
|
||||
if left_val is None:
|
||||
result[k] = right_val
|
||||
elif right_val is None:
|
||||
result[k] = left_val
|
||||
else:
|
||||
result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val
|
||||
|
||||
else:
|
||||
left_val = None
|
||||
for k in range(i - 1, -1, -1):
|
||||
if is_normal[k]:
|
||||
left_val = result[k]
|
||||
break
|
||||
|
||||
right_val = None
|
||||
for k in range(j, n):
|
||||
if is_normal[k]:
|
||||
right_val = result[k]
|
||||
break
|
||||
|
||||
if left_val is not None and right_val is not None:
|
||||
step = (right_val - left_val) / (anomaly_count + 1)
|
||||
for k in range(i, j):
|
||||
result[k] = left_val + step * (k - i + 1)
|
||||
elif left_val is not None:
|
||||
for k in range(i, j):
|
||||
result[k] = left_val
|
||||
elif right_val is not None:
|
||||
for k in range(i, j):
|
||||
result[k] = right_val
|
||||
|
||||
i = j
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return [int(res) for res in result]
|
||||
|
||||
def encode_timestamp(self, text: str, language: str) -> List[str]:
|
||||
language = language.lower()
|
||||
|
||||
if language.lower() == "japanese":
|
||||
word_list = self.tokenize_japanese(text)
|
||||
elif language.lower() == "korean":
|
||||
if self.ko_tokenizer is None:
|
||||
from soynlp.tokenizer import LTokenizer
|
||||
self.ko_tokenizer = LTokenizer(scores=self.ko_score)
|
||||
word_list = self.tokenize_korean(self.ko_tokenizer, text)
|
||||
else:
|
||||
word_list = self.tokenize_space_lang(text)
|
||||
|
||||
input_text = "<timestamp><timestamp>".join(word_list) + "<timestamp><timestamp>"
|
||||
input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text
|
||||
|
||||
return word_list, input_text
|
||||
|
||||
def parse_timestamp(self, word_list, timestamp):
|
||||
timestamp_output = []
|
||||
|
||||
timestamp_fixed = self.fix_timestamp(timestamp)
|
||||
for i, word in enumerate(word_list):
|
||||
start_time = timestamp_fixed[i * 2]
|
||||
end_time = timestamp_fixed[i * 2 + 1]
|
||||
timestamp_output.append({
|
||||
"text": word,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
})
|
||||
|
||||
return timestamp_output
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ForcedAlignItem:
|
||||
"""
|
||||
One aligned item span.
|
||||
|
||||
Attributes:
|
||||
text (str):
|
||||
The aligned unit (cjk character or word) produced by the forced aligner processor.
|
||||
start_time (float):
|
||||
Start time in seconds.
|
||||
end_time (float):
|
||||
End time in seconds.
|
||||
"""
|
||||
text: str
|
||||
start_time: int
|
||||
end_time: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ForcedAlignResult:
|
||||
"""
|
||||
Forced alignment output for one sample.
|
||||
|
||||
Attributes:
|
||||
items (List[ForcedAlignItem]):
|
||||
Aligned token spans.
|
||||
"""
|
||||
items: List[ForcedAlignItem]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx: int) -> ForcedAlignItem:
|
||||
return self.items[idx]
|
||||
|
||||
|
||||
class Qwen3ForcedAligner:
|
||||
"""
|
||||
A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.
|
||||
|
||||
This wrapper provides:
|
||||
- `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
|
||||
- audio input normalization (path/URL/base64/(np.ndarray, sr))
|
||||
- batch and single-sample forced alignment
|
||||
- structured output with attribute access (`.text`, `.start_time`, `.end_time`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Qwen3ASRForConditionalGeneration,
|
||||
processor: Qwen3ASRProcessor,
|
||||
aligner_processor: Qwen3ForceAlignProcessor,
|
||||
):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.aligner_processor = aligner_processor
|
||||
|
||||
self.device = getattr(model, "device", None)
|
||||
if self.device is None:
|
||||
try:
|
||||
self.device = next(model.parameters()).device
|
||||
except StopIteration:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.timestamp_token_id = int(model.config.timestamp_token_id)
|
||||
self.timestamp_segment_time = float(model.config.timestamp_segment_time)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
**kwargs,
|
||||
) -> "Qwen3ForcedAligner":
|
||||
"""
|
||||
Load Qwen3-ForcedAligner model and initialize processors.
|
||||
|
||||
This method:
|
||||
1) Registers config/model/processor for HF auto classes.
|
||||
2) Loads the model using `AutoModel.from_pretrained(...)`.
|
||||
3) Initializes:
|
||||
- HF processor (`AutoProcessor.from_pretrained(...)`)
|
||||
- forced alignment text processor (`Qwen3ForceAlignProcessor()`)
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str):
|
||||
HuggingFace repo id or local directory.
|
||||
**kwargs:
|
||||
Forwarded to `AutoModel.from_pretrained(...)`.
|
||||
Typical examples: device_map="cuda:0", dtype=torch.bfloat16.
|
||||
|
||||
Returns:
|
||||
Qwen3ForcedAligner:
|
||||
Initialized wrapper instance.
|
||||
"""
|
||||
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
||||
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
||||
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
||||
|
||||
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
if not isinstance(model, Qwen3ASRForConditionalGeneration):
|
||||
raise TypeError(
|
||||
f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration."
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
|
||||
aligner_processor = Qwen3ForceAlignProcessor()
|
||||
|
||||
return cls(model=model, processor=processor, aligner_processor=aligner_processor)
|
||||
|
||||
def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
|
||||
items: List[ForcedAlignItem] = []
|
||||
for it in timestamp_output:
|
||||
items.append(
|
||||
ForcedAlignItem(
|
||||
text=str(it.get("text", "")),
|
||||
start_time=float(it.get("start_time", 0)),
|
||||
end_time=float(it.get("end_time", 0)),
|
||||
)
|
||||
)
|
||||
return ForcedAlignResult(items=items)
|
||||
|
||||
@torch.inference_mode()
|
||||
def align(
|
||||
self,
|
||||
audio: Union[AudioLike, List[AudioLike]],
|
||||
text: Union[str, List[str]],
|
||||
language: Union[str, List[str]],
|
||||
) -> List[ForcedAlignResult]:
|
||||
"""
|
||||
Run forced alignment for batch or single sample.
|
||||
|
||||
Args:
|
||||
audio:
|
||||
Audio input(s). Each item supports:
|
||||
- local path / https URL / base64 string
|
||||
- (np.ndarray, sr)
|
||||
All audios will be converted into mono 16k float32 arrays in [-1, 1].
|
||||
text:
|
||||
Transcript(s) for alignment.
|
||||
language:
|
||||
Language(s) for each sample (e.g., "Chinese", "English").
|
||||
|
||||
Returns:
|
||||
List[ForcedAlignResult]:
|
||||
One result per sample. Each result contains `items`, and each token can be accessed via
|
||||
`.text`, `.start_time`, `.end_time`.
|
||||
"""
|
||||
texts = ensure_list(text)
|
||||
languages = ensure_list(language)
|
||||
audios = normalize_audios(audio)
|
||||
|
||||
if len(languages) == 1 and len(audios) > 1:
|
||||
languages = languages * len(audios)
|
||||
|
||||
if not (len(audios) == len(texts) == len(languages)):
|
||||
raise ValueError(
|
||||
f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
|
||||
)
|
||||
|
||||
word_lists = []
|
||||
aligner_input_texts = []
|
||||
for t, lang in zip(texts, languages):
|
||||
word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang)
|
||||
word_lists.append(word_list)
|
||||
aligner_input_texts.append(aligner_input_text)
|
||||
|
||||
inputs = self.processor(
|
||||
text=aligner_input_texts,
|
||||
audio=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
inputs = inputs.to(self.model.device).to(self.model.dtype)
|
||||
|
||||
logits = self.model.thinker(**inputs).logits
|
||||
output_ids = logits.argmax(dim=-1)
|
||||
|
||||
results: List[ForcedAlignResult] = []
|
||||
for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists):
|
||||
masked_output_id = output_id[input_id == self.timestamp_token_id]
|
||||
timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
|
||||
timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
|
||||
for it in timestamp_output:
|
||||
it['start_time'] = round(it['start_time'] / 1000.0, 3)
|
||||
it['end_time'] = round(it['end_time'] / 1000.0, 3)
|
||||
results.append(self._to_structured_items(timestamp_output))
|
||||
|
||||
return results
|
||||
|
||||
def get_supported_languages(self) -> Optional[List[str]]:
|
||||
"""
|
||||
List supported language names for the current model.
|
||||
|
||||
This is a thin wrapper around `self.model.get_support_languages()`.
|
||||
If the underlying model does not expose language constraints (returns None),
|
||||
this method also returns None.
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]:
|
||||
- A sorted list of supported language names (lowercased), if available.
|
||||
- None if the model does not provide supported languages.
|
||||
"""
|
||||
fn = getattr(self.model, "get_support_languages", None)
|
||||
if not callable(fn):
|
||||
return None
|
||||
|
||||
langs = fn()
|
||||
if langs is None:
|
||||
return None
|
||||
|
||||
return sorted({str(x).lower() for x in langs})
|
||||
497
qwen_asr/inference/utils.py
Normal file
497
qwen_asr/inference/utils.py
Normal file
@ -0,0 +1,497 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
AudioLike = Union[
|
||||
str, # wav path / URL / base64
|
||||
Tuple[np.ndarray, int], # (waveform, sr)
|
||||
]
|
||||
MaybeList = Union[Any, List[Any]]
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_ASR_INPUT_SECONDS = 1200
|
||||
MAX_FORCE_ALIGN_INPUT_SECONDS = 180
|
||||
MIN_ASR_INPUT_SECONDS = 0.5
|
||||
SUPPORTED_LANGUAGES: List[str] = [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Cantonese",
|
||||
"Arabic",
|
||||
"German",
|
||||
"French",
|
||||
"Spanish",
|
||||
"Portuguese",
|
||||
"Indonesian",
|
||||
"Italian",
|
||||
"Korean",
|
||||
"Russian",
|
||||
"Thai",
|
||||
"Vietnamese",
|
||||
"Japanese",
|
||||
"Turkish",
|
||||
"Hindi",
|
||||
"Malay",
|
||||
"Dutch",
|
||||
"Swedish",
|
||||
"Danish",
|
||||
"Finnish",
|
||||
"Polish",
|
||||
"Czech",
|
||||
"Filipino",
|
||||
"Persian",
|
||||
"Greek",
|
||||
"Romanian",
|
||||
"Hungarian",
|
||||
"Macedonian"
|
||||
]
|
||||
_ASR_TEXT_TAG = "<asr_text>"
|
||||
_LANG_PREFIX = "language "
|
||||
|
||||
|
||||
def normalize_language_name(language: str) -> str:
|
||||
"""
|
||||
Normalize language name to the canonical format used by Qwen3-ASR:
|
||||
first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese').
|
||||
|
||||
Args:
|
||||
language (str): Input language name.
|
||||
|
||||
Returns:
|
||||
str: Normalized language name.
|
||||
|
||||
Raises:
|
||||
ValueError: If language is empty.
|
||||
"""
|
||||
if language is None:
|
||||
raise ValueError("language is None")
|
||||
s = str(language).strip()
|
||||
if not s:
|
||||
raise ValueError("language is empty")
|
||||
return s[:1].upper() + s[1:].lower()
|
||||
|
||||
|
||||
def validate_language(language: str) -> None:
|
||||
"""
|
||||
Validate the language is supported.
|
||||
|
||||
Args:
|
||||
language (str): Canonical language name.
|
||||
|
||||
Raises:
|
||||
ValueError: If unsupported.
|
||||
"""
|
||||
if language not in SUPPORTED_LANGUAGES:
|
||||
raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}")
|
||||
|
||||
|
||||
def ensure_list(x: MaybeList) -> List[Any]:
|
||||
return x if isinstance(x, list) else [x]
|
||||
|
||||
|
||||
def is_url(s: str) -> bool:
|
||||
try:
|
||||
u = urlparse(s)
|
||||
return u.scheme in ("http", "https") and bool(u.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_probably_base64(s: str) -> bool:
|
||||
if s.startswith("data:audio"):
|
||||
return True
|
||||
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def decode_base64_bytes(b64: str) -> bytes:
|
||||
if "," in b64 and b64.strip().startswith("data:"):
|
||||
b64 = b64.split(",", 1)[1]
|
||||
return base64.b64decode(b64)
|
||||
|
||||
|
||||
def load_audio_any(x: str) -> Tuple[np.ndarray, int]:
|
||||
if is_url(x):
|
||||
with urllib.request.urlopen(x) as resp:
|
||||
audio_bytes = resp.read()
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
elif is_probably_base64(x):
|
||||
audio_bytes = decode_base64_bytes(x)
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
else:
|
||||
audio, sr = librosa.load(x, sr=None, mono=False)
|
||||
|
||||
audio = np.asarray(audio, dtype=np.float32)
|
||||
sr = int(sr)
|
||||
return audio, sr
|
||||
|
||||
|
||||
def to_mono(audio: np.ndarray) -> np.ndarray:
|
||||
if audio.ndim == 1:
|
||||
return audio
|
||||
# soundfile can return shape (T, C); some pipelines use (C, T)
|
||||
if audio.ndim == 2:
|
||||
if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]:
|
||||
audio = audio.T
|
||||
return np.mean(audio, axis=-1).astype(np.float32)
|
||||
raise ValueError(f"Unsupported audio ndim={audio.ndim}")
|
||||
|
||||
|
||||
def float_range_normalize(audio: np.ndarray) -> np.ndarray:
|
||||
audio = audio.astype(np.float32)
|
||||
if audio.size == 0:
|
||||
return audio
|
||||
peak = float(np.max(np.abs(audio)))
|
||||
if peak == 0.0:
|
||||
return audio
|
||||
# If decoded audio is int-like scaled or out-of-range, normalize conservatively.
|
||||
if peak > 1.0:
|
||||
audio = audio / peak
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
return audio
|
||||
|
||||
|
||||
def normalize_audio_input(a: AudioLike) -> np.ndarray:
|
||||
"""
|
||||
Normalize one audio input to mono 16k float32 waveform in [-1, 1].
|
||||
|
||||
Supported inputs:
|
||||
- str: local file path / https URL / base64 audio string
|
||||
- (np.ndarray, sr): waveform and sampling rate
|
||||
|
||||
Returns:
|
||||
np.ndarray:
|
||||
Mono 16k float32 waveform in [-1, 1].
|
||||
"""
|
||||
if isinstance(a, str):
|
||||
audio, sr = load_audio_any(a)
|
||||
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
||||
audio, sr = a[0], int(a[1])
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
||||
|
||||
audio = to_mono(np.asarray(audio))
|
||||
if sr != SAMPLE_RATE:
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
|
||||
audio = float_range_normalize(audio)
|
||||
return audio
|
||||
|
||||
|
||||
def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]:
|
||||
items = ensure_list(audios)
|
||||
return [normalize_audio_input(a) for a in items]
|
||||
|
||||
|
||||
def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]:
|
||||
"""
|
||||
Yield chunks of a list.
|
||||
|
||||
Args:
|
||||
xs (List[Any]): Input list.
|
||||
chunk_size (int): Chunk size.
|
||||
|
||||
Yields:
|
||||
List[Any]: Slices of xs.
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
yield xs
|
||||
return
|
||||
for i in range(0, len(xs), chunk_size):
|
||||
yield xs[i : i + chunk_size]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioChunk:
|
||||
"""
|
||||
One chunk cut from an original audio.
|
||||
|
||||
Attributes:
|
||||
orig_index: Index of the original sample in the input batch.
|
||||
chunk_index: Index of this chunk within the original sample.
|
||||
wav: Mono float32 waveform.
|
||||
sr: Sampling rate.
|
||||
offset_sec: Start offset of this chunk in the original audio, in seconds.
|
||||
"""
|
||||
orig_index: int
|
||||
chunk_index: int
|
||||
wav: np.ndarray
|
||||
sr: int
|
||||
offset_sec: float
|
||||
|
||||
|
||||
def split_audio_into_chunks(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
max_chunk_sec: float,
|
||||
search_expand_sec: float = 5.0,
|
||||
min_window_ms: float = 100.0,
|
||||
) -> List[Tuple[np.ndarray, float]]:
|
||||
"""
|
||||
Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary.
|
||||
|
||||
This implementation guarantees:
|
||||
- Concatenating all returned chunks reproduces the original audio exactly
|
||||
(total number of samples is identical, no overlaps, no gaps).
|
||||
|
||||
Args:
|
||||
wav: Mono waveform float32.
|
||||
sr: Sampling rate.
|
||||
max_chunk_sec: Target max chunk duration in seconds.
|
||||
search_expand_sec: Boundary search half-window in seconds.
|
||||
min_window_ms: Sliding window in milliseconds for energy estimation.
|
||||
|
||||
Returns:
|
||||
List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec).
|
||||
"""
|
||||
wav = np.asarray(wav, dtype=np.float32)
|
||||
if wav.ndim > 1:
|
||||
wav = np.mean(wav, axis=-1).astype(np.float32)
|
||||
|
||||
total_len = int(wav.shape[0])
|
||||
total_sec = total_len / float(sr)
|
||||
if total_sec <= max_chunk_sec:
|
||||
return [(wav, 0.0)]
|
||||
|
||||
max_len = int(max_chunk_sec * sr)
|
||||
expand = int(search_expand_sec * sr)
|
||||
win = max(4, int((min_window_ms / 1000.0) * sr))
|
||||
|
||||
chunks: List[Tuple[np.ndarray, float]] = []
|
||||
|
||||
start = 0
|
||||
offset_sec = 0.0
|
||||
|
||||
while (total_len - start) > max_len:
|
||||
cut = start + max_len
|
||||
|
||||
left = max(start, cut - expand)
|
||||
right = min(total_len, cut + expand)
|
||||
|
||||
if right - left <= win:
|
||||
boundary = cut
|
||||
else:
|
||||
seg = wav[left:right]
|
||||
seg_abs = np.abs(seg)
|
||||
|
||||
window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid")
|
||||
|
||||
min_pos = int(np.argmin(window_sums))
|
||||
|
||||
wstart = min_pos
|
||||
wend = min_pos + win
|
||||
local = seg_abs[wstart:wend]
|
||||
inner = int(np.argmin(local))
|
||||
boundary = left + wstart + inner
|
||||
|
||||
boundary = int(max(boundary, start + 1))
|
||||
boundary = int(min(boundary, total_len))
|
||||
|
||||
chunk = wav[start:boundary]
|
||||
chunks.append((chunk, offset_sec))
|
||||
|
||||
offset_sec += (boundary - start) / float(sr)
|
||||
start = boundary
|
||||
|
||||
tail = wav[start:total_len]
|
||||
chunks.append((tail, offset_sec))
|
||||
|
||||
# Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail)
|
||||
min_len = int(MIN_ASR_INPUT_SECONDS * sr)
|
||||
padded: List[Tuple[np.ndarray, float]] = []
|
||||
for c, off in chunks:
|
||||
if c.shape[0] < min_len:
|
||||
pad = min_len - int(c.shape[0])
|
||||
c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32)
|
||||
padded.append((c, off))
|
||||
chunks = padded
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def detect_and_fix_repetitions(text, threshold=20):
|
||||
def fix_char_repeats(s, thresh):
|
||||
res = []
|
||||
i = 0
|
||||
n = len(s)
|
||||
while i < n:
|
||||
count = 1
|
||||
while i + count < n and s[i + count] == s[i]:
|
||||
count += 1
|
||||
|
||||
if count > thresh:
|
||||
res.append(s[i])
|
||||
i += count
|
||||
else:
|
||||
res.append(s[i:i+count])
|
||||
i += count
|
||||
return ''.join(res)
|
||||
|
||||
def fix_pattern_repeats(s, thresh, max_len=20):
|
||||
n = len(s)
|
||||
min_repeat_chars = thresh * 2
|
||||
if n < min_repeat_chars:
|
||||
return s
|
||||
|
||||
i = 0
|
||||
result = []
|
||||
while i <= n - min_repeat_chars:
|
||||
found = False
|
||||
for k in range(1, max_len + 1):
|
||||
if i + k * thresh > n:
|
||||
break
|
||||
|
||||
pattern = s[i:i+k]
|
||||
valid = True
|
||||
for rep in range(1, thresh):
|
||||
start_idx = i + rep * k
|
||||
if s[start_idx:start_idx+k] != pattern:
|
||||
valid = False
|
||||
break
|
||||
|
||||
if valid:
|
||||
total_rep = thresh
|
||||
end_index = i + thresh * k
|
||||
while end_index + k <= n and s[end_index:end_index+k] == pattern:
|
||||
total_rep += 1
|
||||
end_index += k
|
||||
result.append(pattern)
|
||||
result.append(fix_pattern_repeats(s[end_index:], thresh, max_len))
|
||||
i = n
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
|
||||
if not found:
|
||||
result.append(s[i:])
|
||||
return ''.join(result)
|
||||
|
||||
text_raw = text
|
||||
text = fix_char_repeats(text_raw, threshold)
|
||||
text = fix_pattern_repeats(text, threshold)
|
||||
return text
|
||||
|
||||
|
||||
def parse_asr_output(
|
||||
raw: str,
|
||||
user_language: Optional[str] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse Qwen3-ASR raw output into (language, text).
|
||||
|
||||
Cases:
|
||||
- With tag: "language Chinese<asr_text>...."
|
||||
- With newlines: "language Chinese\\n...\\n<asr_text>...."
|
||||
- No tag: treat whole string as text.
|
||||
- "language None<asr_text>": treat as empty audio -> ("", "")
|
||||
|
||||
If user_language is provided, language is forced to user_language and raw is treated as text-only
|
||||
(the model is expected to output plain transcription without metadata).
|
||||
|
||||
Args:
|
||||
raw: Raw decoded string.
|
||||
user_language: Canonical language name if user forced language.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (language, text)
|
||||
"""
|
||||
if raw is None:
|
||||
return "", ""
|
||||
s = str(raw).strip()
|
||||
if not s:
|
||||
return "", ""
|
||||
|
||||
s = detect_and_fix_repetitions(s)
|
||||
|
||||
if user_language:
|
||||
# user explicitly forced language => model output is treated as pure text
|
||||
return user_language, s
|
||||
|
||||
meta_part = s
|
||||
text_part = ""
|
||||
has_tag = _ASR_TEXT_TAG in s
|
||||
if has_tag:
|
||||
meta_part, text_part = s.split(_ASR_TEXT_TAG, 1)
|
||||
else:
|
||||
# no tag => pure text
|
||||
return "", s.strip()
|
||||
|
||||
meta_lower = meta_part.lower()
|
||||
|
||||
# empty audio heuristic
|
||||
if "language none" in meta_lower:
|
||||
t = text_part.strip()
|
||||
if not t:
|
||||
return "", ""
|
||||
# if model still returned something, keep it but language unknown
|
||||
return "", t
|
||||
|
||||
# extract "language xxx" from meta
|
||||
lang = ""
|
||||
for line in meta_part.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
low = line.lower()
|
||||
if low.startswith(_LANG_PREFIX):
|
||||
val = line[len(_LANG_PREFIX):].strip()
|
||||
if val:
|
||||
lang = normalize_language_name(val)
|
||||
break
|
||||
|
||||
return lang, text_part.strip()
|
||||
|
||||
|
||||
def merge_languages(langs: List[str]) -> str:
|
||||
"""
|
||||
Merge per-chunk languages into a compact comma-separated string,
|
||||
keeping order and removing consecutive duplicates and empty entries.
|
||||
|
||||
Example:
|
||||
["Chinese", "English", "English"] -> "Chinese,English"
|
||||
|
||||
Args:
|
||||
langs: List of canonical language names.
|
||||
|
||||
Returns:
|
||||
str: Merged language string.
|
||||
"""
|
||||
out: List[str] = []
|
||||
prev = None
|
||||
for x in langs:
|
||||
x = (x or "").strip()
|
||||
if not x:
|
||||
continue
|
||||
if x == prev:
|
||||
continue
|
||||
out.append(x)
|
||||
prev = x
|
||||
return ",".join(out)
|
||||
Reference in New Issue
Block a user