# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import numpy as np import torch from qwen_asr.core.transformers_backend import ( Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor, ) from transformers import AutoConfig, AutoModel, AutoProcessor AutoConfig.register("qwen3_asr", Qwen3ASRConfig) AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) from .qwen3_forced_aligner import Qwen3ForcedAligner from .utils import ( MAX_ASR_INPUT_SECONDS, MAX_FORCE_ALIGN_INPUT_SECONDS, SAMPLE_RATE, SUPPORTED_LANGUAGES, AudioChunk, AudioLike, chunk_list, merge_languages, normalize_audios, normalize_language_name, parse_asr_output, split_audio_into_chunks, validate_language, ) try: from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration from vllm import ModelRegistry ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration) except: pass @dataclass class ASRTranscription: """ One transcription result. Attributes: language (str): Merged language string for the sample, e.g. "Chinese" or "Chinese,English". Empty string if unknown or silent audio. text (str): Transcribed text. time_stamps (Optional[Any]): Forced aligner output (ForcedAlignResult). Present only when return_time_stamps=True. """ language: str text: str time_stamps: Optional[Any] = None @dataclass class ASRStreamingState: """ Streaming ASR state for one audio stream (single utterance). Attributes: unfixed_chunk_num (int): For the first N chunks, do not use previous ASR result as prefix prompt (reset prefix to ""). unfixed_token_num (int): When chunk_id >= unfixed_chunk_num, rollback the last K tokens from the accumulated text before using it as prefix prompt, to reduce boundary jitter. chunk_size_sec (float): Chunk size in seconds. Audio will be fed to the model in increments of this length. chunk_size_samples (int): Chunk size in samples at 16kHz (derived from chunk_size_sec). chunk_id (int): Current chunk index (0-based). buffer (np.ndarray): Buffered PCM samples that are not yet consumed into a full chunk. audio_accum (np.ndarray): Accumulated audio from the beginning of the stream up to current time (no padding). prompt_raw (str): Base prompt generated by chat template (with generation prompt), without appended prefix text. context (str): Context string. force_language (Optional[str]): If provided, force output to be text-only by appending "language X" in prompt_raw, consistent with non-streaming transcribe(). language (str): Latest parsed language (updated after each chunk decode). Empty if unknown/silent. text (str): Latest parsed transcription text (updated after each chunk decode). _raw_decoded (str): Internal accumulated decoded raw text (before parse_asr_output normalization). Used for rollback/token trimming and as prefix for prompting. """ unfixed_chunk_num: int unfixed_token_num: int chunk_size_sec: float chunk_size_samples: int chunk_id: int buffer: np.ndarray audio_accum: np.ndarray prompt_raw: str context: str force_language: Optional[str] language: str text: str _raw_decoded: str class Qwen3ASRModel: """ Unified inference wrapper for Qwen3-ASR with two backends: - Transformers backend - vLLM backend It optionally supports time stamp output via Qwen3-ForcedAligner. Notes: - Each request uses a context text and exactly one audio. - If language is provided, the prompt will force the output to be text-only by appending "language {Language}" to the assistant prompt. """ def __init__( self, backend: str, model: Any, processor: Any, sampling_params: Optional[Any] = None, forced_aligner: Optional[Qwen3ForcedAligner] = None, max_inference_batch_size: int = -1, max_new_tokens: int = 512, ): self.backend = backend # "transformers" | "vllm" self.model = model self.processor = processor self.sampling_params = sampling_params self.forced_aligner = forced_aligner self.max_inference_batch_size = int(max_inference_batch_size) self.max_new_tokens = max_new_tokens if backend == "transformers": self.device = getattr(model, "device", None) if self.device is None: try: self.device = next(model.parameters()).device except StopIteration: self.device = torch.device("cpu") self.dtype = getattr(model, "dtype", torch.float32) else: self.device = None self.dtype = None @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, forced_aligner: Optional[str] = None, forced_aligner_kwargs: Optional[Dict[str, Any]] = None, max_inference_batch_size: int = 32, max_new_tokens: Optional[int] = 512, **kwargs, ) -> "Qwen3ASRModel": """ Initialize using Transformers backend. Args: pretrained_model_name_or_path: HuggingFace repo id or local directory. forced_aligner: Optional forced aligner model path/repo id. forced_aligner_kwargs: Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...). max_inference_batch_size: Batch size limit for inference. -1 means no chunking. Small values can avoid OOM. max_new_tokens: Maximum number of tokens to generate. **kwargs: Forwarded to AutoModel.from_pretrained(...). Returns: Qwen3ASRModel """ model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True) forced_aligner_model = None if forced_aligner is not None: forced_aligner_model = Qwen3ForcedAligner.from_pretrained( forced_aligner, **(forced_aligner_kwargs or {}) ) return cls( backend="transformers", model=model, processor=processor, sampling_params=None, forced_aligner=forced_aligner_model, max_inference_batch_size=max_inference_batch_size, max_new_tokens=max_new_tokens, ) @classmethod def LLM( cls, model: str, forced_aligner: Optional[str] = None, forced_aligner_kwargs: Optional[Dict[str, Any]] = None, max_inference_batch_size: int = -1, max_new_tokens: Optional[int] = 4096, **kwargs, ) -> "Qwen3ASRModel": """ Initialize using vLLM backend. Import is isolated to keep vLLM optional. Args: model: Model path/repo for vLLM. forced_aligner: Optional forced aligner model path/repo id. forced_aligner_kwargs: Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...). max_inference_batch_size: Batch size limit for inference. -1 means no chunking. Small values can avoid OOM. max_new_tokens: Maximum number of tokens to generate. **kwargs: Forwarded to vllm.LLM(...). Returns: Qwen3ASRModel Raises: ImportError: If vLLM is not installed. """ try: from vllm import LLM as vLLM from vllm import SamplingParams except Exception as e: raise ImportError( "vLLM is not available. Install with: pip install qwen-asr[vllm]" ) from e llm = vLLM(model=model, **kwargs) processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True) sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens})) forced_aligner_model = None if forced_aligner is not None: forced_aligner_model = Qwen3ForcedAligner.from_pretrained( forced_aligner, **(forced_aligner_kwargs or {}) ) return cls( backend="vllm", model=llm, processor=processor, sampling_params=sampling_params, forced_aligner=forced_aligner_model, max_inference_batch_size=max_inference_batch_size, max_new_tokens=None, ) def get_supported_languages(self) -> List[str]: """ Returns the supported language list. Returns: List[str]: Canonical language names. """ return list(SUPPORTED_LANGUAGES) @torch.no_grad() def transcribe( self, audio: Union[AudioLike, List[AudioLike]], context: Union[str, List[str]] = "", language: Optional[Union[str, List[Optional[str]]]] = None, return_time_stamps: bool = False, ) -> List[ASRTranscription]: """ Transcribe audio with optional context and optional forced alignment timestamps. Args: audio: Audio input(s). Supported: - str: local path / URL / base64 data url - (np.ndarray, sr) - list of above context: Context string(s). If scalar, it will be broadcast to batch size. language: Optional language(s). If provided, it must be in supported languages. If scalar, it will be broadcast to batch size. If provided, the prompt will force output to be transcription text only. return_time_stamps: If True, timestamps are produced via forced aligner and merged across chunks. This requires forced_aligner initialized. Returns: List[ASRTranscription]: One result per input audio. Raises: ValueError: - If return_time_stamps=True but forced_aligner is not provided. - If language is unsupported. - If batch sizes mismatch for context/language. """ if return_time_stamps and self.forced_aligner is None: raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.") wavs = normalize_audios(audio) n = len(wavs) ctxs = context if isinstance(context, list) else [context] if len(ctxs) == 1 and n > 1: ctxs = ctxs * n if len(ctxs) != n: raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}") langs_in: List[Optional[str]] if language is None: langs_in = [None] * n else: langs_in = language if isinstance(language, list) else [language] if len(langs_in) == 1 and n > 1: langs_in = langs_in * n if len(langs_in) != n: raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}") langs_norm: List[Optional[str]] = [] for l in langs_in: if l is None or str(l).strip() == "": langs_norm.append(None) else: ln = normalize_language_name(str(l)) validate_language(ln) langs_norm.append(ln) max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS # chunk audios and record mapping chunks: List[AudioChunk] = [] for i, wav in enumerate(wavs): parts = split_audio_into_chunks( wav=wav, sr=SAMPLE_RATE, max_chunk_sec=max_chunk_sec, ) for j, (cwav, offset_sec) in enumerate(parts): chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec)) # run ASR on chunks chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks] chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks] chunk_wavs: List[np.ndarray] = [c.wav for c in chunks] raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang) # parse outputs, prepare for optional alignment per_chunk_lang: List[str] = [] per_chunk_text: List[str] = [] for out, forced_lang in zip(raw_outputs, chunk_lang): lang, txt = parse_asr_output(out, user_language=forced_lang) per_chunk_lang.append(lang) per_chunk_text.append(txt) # forced alignment (optional) per_chunk_align: List[Optional[Any]] = [None] * len(chunks) if return_time_stamps: to_align_audio = [] to_align_text = [] to_align_lang = [] to_align_idx = [] for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)): if txt.strip() == "": continue to_align_audio.append((c.wav, c.sr)) to_align_text.append(txt) to_align_lang.append(lang_pred) to_align_idx.append(idx) # batch align with max_inference_batch_size aligned_results: List[Any] = [] for a_chunk, t_chunk, l_chunk in zip( chunk_list(to_align_audio, self.max_inference_batch_size), chunk_list(to_align_text, self.max_inference_batch_size), chunk_list(to_align_lang, self.max_inference_batch_size), ): aligned_results.extend( self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk) ) # offset fix for k, idx in enumerate(to_align_idx): c = chunks[idx] r = aligned_results[k] per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec) # merge chunks back to original samples out_langs: List[List[str]] = [[] for _ in range(n)] out_texts: List[List[str]] = [[] for _ in range(n)] out_aligns: List[List[Any]] = [[] for _ in range(n)] for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align): out_langs[c.orig_index].append(lang) out_texts[c.orig_index].append(txt) if return_time_stamps and al is not None: out_aligns[c.orig_index].append(al) results: List[ASRTranscription] = [] for i in range(n): merged_text = "".join([t for t in out_texts[i] if t is not None]) merged_language = merge_languages(out_langs[i]) merged_align = None if return_time_stamps: merged_align = self._merge_align_results(out_aligns[i]) results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align)) return results def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]: return [ {"role": "system", "content": context or ""}, {"role": "user", "content": [{"type": "audio", "audio": audio_payload}]}, ] def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str: """ Build the string prompt for one request. If force_language is provided, "language X" is appended after the generation prompt to request text-only output. """ msgs = self._build_messages(context=context, audio_payload="") base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) if force_language: base = base + f"language {force_language}{''}" return base def _infer_asr( self, contexts: List[str], wavs: List[np.ndarray], languages: List[Optional[str]], ) -> List[str]: """ Run backend inference for chunk-level items. Args: contexts: List of context strings. wavs: List of mono waveforms (np.ndarray). languages: List of forced languages or None. Returns: List[str]: Raw decoded strings (one per chunk). """ if self.backend == "transformers": return self._infer_asr_transformers(contexts, wavs, languages) if self.backend == "vllm": return self._infer_asr_vllm(contexts, wavs, languages) raise RuntimeError(f"Unknown backend: {self.backend}") def _infer_asr_transformers( self, contexts: List[str], wavs: List[np.ndarray], languages: List[Optional[str]], ) -> List[str]: outs: List[str] = [] texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)] batch_size = self.max_inference_batch_size if batch_size is None or batch_size < 0: batch_size = len(texts) for i in range(0, len(texts), batch_size): sub_text = texts[i : i + batch_size] sub_wavs = wavs[i : i + batch_size] inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True) inputs = inputs.to(self.model.device).to(self.model.dtype) text_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens) decoded = self.processor.batch_decode( text_ids.sequences[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False, ) outs.extend(list(decoded)) return outs def _infer_asr_vllm( self, contexts: List[str], wavs: List[np.ndarray], languages: List[Optional[str]], ) -> List[str]: inputs: List[Dict[str, Any]] = [] for c, w, fl in zip(contexts, wavs, languages): prompt = self._build_text_prompt(context=c, force_language=fl) inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}}) outs: List[str] = [] for batch in chunk_list(inputs, self.max_inference_batch_size): outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False) for o in outputs: outs.append(o.outputs[0].text) return outs def _offset_align_result(self, result: Any, offset_sec: float) -> Any: """ Apply time offset to a ForcedAlignResult-like object. This function assumes: - result has attribute `.items` which is a list of items with start_time/end_time in seconds. - dataclasses are frozen in upstream implementation, so we reconstruct by type. Args: result: ForcedAlignResult offset_sec: Offset in seconds Returns: ForcedAlignResult: New object with shifted timestamps. """ if result is None: return None items = [] for it in result.items: items.append(type(it)(text=it.text, start_time=round(it.start_time + offset_sec, 3), end_time=round(it.end_time + offset_sec, 3))) return type(result)(items=items) def _merge_align_results(self, results: List[Any]) -> Optional[Any]: """ Merge multiple ForcedAlignResult objects into a single one by concatenating items. Args: results: List of ForcedAlignResult Returns: ForcedAlignResult or None """ if not results: return None all_items = [] for r in results: if r is None: continue all_items.extend(list(r.items)) if not all_items: return None return type(results[0])(items=all_items) def init_streaming_state( self, context: str = "", language: Optional[str] = None, unfixed_chunk_num: int = 2, unfixed_token_num: int = 5, chunk_size_sec: float = 2.0, ) -> ASRStreamingState: """ Initialize streaming ASR state for a single stream. Notes: - Streaming ASR is supported ONLY for vLLM backend. - Streaming ASR does NOT support timestamps (forced aligner is not used). - Batch inference is NOT supported. Args: context: Context string. language: Optional forced language. If provided, it must be in supported languages. Same behavior as transcribe(): forces text-only output via prompt suffix. unfixed_chunk_num: For the first N chunks, do not use previous output as prefix prompt (reset prefix to ""). unfixed_token_num: Roll back the last K tokens from accumulated output when using it as prefix prompt after unfixed_chunk_num. chunk_size_sec: Chunk size in seconds (audio is 16k PCM). The function will internally convert it to sample count at 16kHz. Returns: ASRStreamingState: Mutable state object to be passed to streaming_transcribe() and finish_streaming_transcribe(). Raises: ValueError: - If backend is not "vllm". - If chunk_size_sec <= 0. - If forced language is invalid (same validation rules as transcribe()). """ if self.backend != "vllm": raise ValueError("Streaming ASR is supported only for vLLM backend (backend='vllm').") if chunk_size_sec is None or float(chunk_size_sec) <= 0: raise ValueError(f"chunk_size_sec must be > 0, got: {chunk_size_sec}") force_language = None if language is not None and str(language).strip() != "": ln = normalize_language_name(str(language)) validate_language(ln) force_language = ln chunk_size_samples = int(round(float(chunk_size_sec) * SAMPLE_RATE)) chunk_size_samples = max(1, chunk_size_samples) prompt_raw = self._build_text_prompt(context=context, force_language=force_language) return ASRStreamingState( unfixed_chunk_num=int(unfixed_chunk_num), unfixed_token_num=int(unfixed_token_num), chunk_size_sec=float(chunk_size_sec), chunk_size_samples=int(chunk_size_samples), chunk_id=0, buffer=np.zeros((0,), dtype=np.float32), audio_accum=np.zeros((0,), dtype=np.float32), prompt_raw=prompt_raw, context=context or "", force_language=force_language, language="", text="", _raw_decoded="", ) def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) -> ASRStreamingState: """ Streaming ASR decode step. This function accepts an arbitrary-length 16k PCM float numpy array (mono). It buffers incoming samples, and whenever enough samples are accumulated to form one full chunk (chunk_size_sec), it runs one incremental decode step and updates: - state.language - state.text The caller only needs to keep passing audio to this function and read state.language/state.text. Implementation details: - Each time a new chunk is ready, we append it to audio_accum and re-feed *all* audio seen so far to the model (no padding). - We update the prompt as: state.prompt_raw + prefix_text - Prefix rollback strategy: * If chunk_id < unfixed_chunk_num: prefix_text = "" * Else: rollback last unfixed_token_num tokens from previously accumulated decoded text. Notes: - vLLM backend only. - No timestamps. - Single stream only (no batching). Args: pcm16k: 16kHz mono PCM waveform (np.ndarray). Length can be any non-negative integer. dtype can be float32/float64/int16; it will be converted to float32. state: Streaming state returned by init_streaming_state(). Returns: ASRStreamingState: The same state object (mutated) for convenience. Raises: ValueError: If backend is not "vllm" or state is invalid. """ if self.backend != "vllm": raise ValueError("streaming_transcribe() is supported only for vLLM backend (backend='vllm').") if state is None: raise ValueError("state must not be None. Call init_streaming_state() first.") if pcm16k is None: raise ValueError("pcm16k must not be None.") # Ensure 1D mono x = np.asarray(pcm16k) if x.ndim != 1: x = x.reshape(-1) # Convert to float32 PCM in [-1, 1] if int16 provided if x.dtype == np.int16: x = (x.astype(np.float32) / 32768.0) else: x = x.astype(np.float32, copy=False) # Append to buffer if x.shape[0] > 0: state.buffer = np.concatenate([state.buffer, x], axis=0) # Consume full chunks while state.buffer.shape[0] >= state.chunk_size_samples: chunk = state.buffer[: state.chunk_size_samples] state.buffer = state.buffer[state.chunk_size_samples :] # Accumulate audio (re-feed from start, no padding) if state.audio_accum.shape[0] == 0: state.audio_accum = chunk else: state.audio_accum = np.concatenate([state.audio_accum, chunk], axis=0) # Build prefix with rollback strategy prefix = "" if state.chunk_id < state.unfixed_chunk_num: prefix = "" else: cur_ids = self.processor.tokenizer.encode(state._raw_decoded) k = int(state.unfixed_token_num) while True: end_idx = max(0, len(cur_ids) - k) prefix = self.processor.tokenizer.decode(cur_ids[:end_idx]) if end_idx > 0 else "" if '\ufffd' not in prefix: break else: if end_idx == 0: prefix = "" break k += 1 prompt = state.prompt_raw + prefix # vLLM input: single item inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}} outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False) gen_text = outputs[0].outputs[0].text # Accumulate raw decoded (then parse to lang/text) state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language) state.language = lang state.text = txt state.chunk_id += 1 return state def finish_streaming_transcribe(self, state: ASRStreamingState) -> ASRStreamingState: """ Finish streaming ASR. This function flushes the remaining buffered audio in state.buffer (tail audio). It sends the remaining samples to the model even if shorter than chunk_size_sec, without padding. Then it updates state.language/state.text one last time. Notes: - vLLM backend only. - No timestamps. - Single stream only. Args: state: Streaming state. Returns: ASRStreamingState: Updated state (mutated). Raises: ValueError: If backend is not "vllm" or state is invalid. """ if self.backend != "vllm": raise ValueError("finish_streaming_transcribe() is supported only for vLLM backend (backend='vllm').") if state is None: raise ValueError("state must not be None.") # If no remaining buffer, still return state as-is. if state.buffer is None or state.buffer.shape[0] == 0: return state tail = state.buffer state.buffer = np.zeros((0,), dtype=np.float32) # Append tail to accumulated audio if state.audio_accum.shape[0] == 0: state.audio_accum = tail else: state.audio_accum = np.concatenate([state.audio_accum, tail], axis=0) # Prefix rollback strategy (same as per-chunk) prefix = "" if state.chunk_id < state.unfixed_chunk_num: prefix = "" else: cur_ids = self.processor.tokenizer.encode(state._raw_decoded) end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num)) prefix = self.processor.tokenizer.decode(cur_ids[:end_idx]) prompt = state.prompt_raw + prefix inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}} outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False) gen_text = outputs[0].outputs[0].text state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language) state.language = lang state.text = txt state.chunk_id += 1 return state