# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import io import urllib.request from dataclasses import dataclass from typing import Any, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse import librosa import numpy as np import soundfile as sf AudioLike = Union[ str, # wav path / URL / base64 Tuple[np.ndarray, int], # (waveform, sr) ] MaybeList = Union[Any, List[Any]] SAMPLE_RATE = 16000 MAX_ASR_INPUT_SECONDS = 1200 MAX_FORCE_ALIGN_INPUT_SECONDS = 180 MIN_ASR_INPUT_SECONDS = 0.5 SUPPORTED_LANGUAGES: List[str] = [ "Chinese", "English", "Cantonese", "Arabic", "German", "French", "Spanish", "Portuguese", "Indonesian", "Italian", "Korean", "Russian", "Thai", "Vietnamese", "Japanese", "Turkish", "Hindi", "Malay", "Dutch", "Swedish", "Danish", "Finnish", "Polish", "Czech", "Filipino", "Persian", "Greek", "Romanian", "Hungarian", "Macedonian" ] _ASR_TEXT_TAG = "" _LANG_PREFIX = "language " def normalize_language_name(language: str) -> str: """ Normalize language name to the canonical format used by Qwen3-ASR: first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese'). Args: language (str): Input language name. Returns: str: Normalized language name. Raises: ValueError: If language is empty. """ if language is None: raise ValueError("language is None") s = str(language).strip() if not s: raise ValueError("language is empty") return s[:1].upper() + s[1:].lower() def validate_language(language: str) -> None: """ Validate the language is supported. Args: language (str): Canonical language name. Raises: ValueError: If unsupported. """ if language not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}") def ensure_list(x: MaybeList) -> List[Any]: return x if isinstance(x, list) else [x] def is_url(s: str) -> bool: try: u = urlparse(s) return u.scheme in ("http", "https") and bool(u.netloc) except Exception: return False def is_probably_base64(s: str) -> bool: if s.startswith("data:audio"): return True if ("/" not in s and "\\" not in s) and len(s) > 256: return True return False def decode_base64_bytes(b64: str) -> bytes: if "," in b64 and b64.strip().startswith("data:"): b64 = b64.split(",", 1)[1] return base64.b64decode(b64) def load_audio_any(x: str) -> Tuple[np.ndarray, int]: if is_url(x): with urllib.request.urlopen(x) as resp: audio_bytes = resp.read() with io.BytesIO(audio_bytes) as f: audio, sr = sf.read(f, dtype="float32", always_2d=False) elif is_probably_base64(x): audio_bytes = decode_base64_bytes(x) with io.BytesIO(audio_bytes) as f: audio, sr = sf.read(f, dtype="float32", always_2d=False) else: audio, sr = librosa.load(x, sr=None, mono=False) audio = np.asarray(audio, dtype=np.float32) sr = int(sr) return audio, sr def to_mono(audio: np.ndarray) -> np.ndarray: if audio.ndim == 1: return audio # soundfile can return shape (T, C); some pipelines use (C, T) if audio.ndim == 2: if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]: audio = audio.T return np.mean(audio, axis=-1).astype(np.float32) raise ValueError(f"Unsupported audio ndim={audio.ndim}") def float_range_normalize(audio: np.ndarray) -> np.ndarray: audio = audio.astype(np.float32) if audio.size == 0: return audio peak = float(np.max(np.abs(audio))) if peak == 0.0: return audio # If decoded audio is int-like scaled or out-of-range, normalize conservatively. if peak > 1.0: audio = audio / peak audio = np.clip(audio, -1.0, 1.0) return audio def normalize_audio_input(a: AudioLike) -> np.ndarray: """ Normalize one audio input to mono 16k float32 waveform in [-1, 1]. Supported inputs: - str: local file path / https URL / base64 audio string - (np.ndarray, sr): waveform and sampling rate Returns: np.ndarray: Mono 16k float32 waveform in [-1, 1]. """ if isinstance(a, str): audio, sr = load_audio_any(a) elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray): audio, sr = a[0], int(a[1]) else: raise TypeError(f"Unsupported audio input type: {type(a)}") audio = to_mono(np.asarray(audio)) if sr != SAMPLE_RATE: audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32) audio = float_range_normalize(audio) return audio def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]: items = ensure_list(audios) return [normalize_audio_input(a) for a in items] def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: """ Yield chunks of a list. Args: xs (List[Any]): Input list. chunk_size (int): Chunk size. Yields: List[Any]: Slices of xs. """ if chunk_size <= 0: yield xs return for i in range(0, len(xs), chunk_size): yield xs[i : i + chunk_size] @dataclass(frozen=True) class AudioChunk: """ One chunk cut from an original audio. Attributes: orig_index: Index of the original sample in the input batch. chunk_index: Index of this chunk within the original sample. wav: Mono float32 waveform. sr: Sampling rate. offset_sec: Start offset of this chunk in the original audio, in seconds. """ orig_index: int chunk_index: int wav: np.ndarray sr: int offset_sec: float def split_audio_into_chunks( wav: np.ndarray, sr: int, max_chunk_sec: float, search_expand_sec: float = 5.0, min_window_ms: float = 100.0, ) -> List[Tuple[np.ndarray, float]]: """ Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary. This implementation guarantees: - Concatenating all returned chunks reproduces the original audio exactly (total number of samples is identical, no overlaps, no gaps). Args: wav: Mono waveform float32. sr: Sampling rate. max_chunk_sec: Target max chunk duration in seconds. search_expand_sec: Boundary search half-window in seconds. min_window_ms: Sliding window in milliseconds for energy estimation. Returns: List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec). """ wav = np.asarray(wav, dtype=np.float32) if wav.ndim > 1: wav = np.mean(wav, axis=-1).astype(np.float32) total_len = int(wav.shape[0]) total_sec = total_len / float(sr) if total_sec <= max_chunk_sec: return [(wav, 0.0)] max_len = int(max_chunk_sec * sr) expand = int(search_expand_sec * sr) win = max(4, int((min_window_ms / 1000.0) * sr)) chunks: List[Tuple[np.ndarray, float]] = [] start = 0 offset_sec = 0.0 while (total_len - start) > max_len: cut = start + max_len left = max(start, cut - expand) right = min(total_len, cut + expand) if right - left <= win: boundary = cut else: seg = wav[left:right] seg_abs = np.abs(seg) window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid") min_pos = int(np.argmin(window_sums)) wstart = min_pos wend = min_pos + win local = seg_abs[wstart:wend] inner = int(np.argmin(local)) boundary = left + wstart + inner boundary = int(max(boundary, start + 1)) boundary = int(min(boundary, total_len)) chunk = wav[start:boundary] chunks.append((chunk, offset_sec)) offset_sec += (boundary - start) / float(sr) start = boundary tail = wav[start:total_len] chunks.append((tail, offset_sec)) # Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail) min_len = int(MIN_ASR_INPUT_SECONDS * sr) padded: List[Tuple[np.ndarray, float]] = [] for c, off in chunks: if c.shape[0] < min_len: pad = min_len - int(c.shape[0]) c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32) padded.append((c, off)) chunks = padded return chunks def detect_and_fix_repetitions(text, threshold=20): def fix_char_repeats(s, thresh): res = [] i = 0 n = len(s) while i < n: count = 1 while i + count < n and s[i + count] == s[i]: count += 1 if count > thresh: res.append(s[i]) i += count else: res.append(s[i:i+count]) i += count return ''.join(res) def fix_pattern_repeats(s, thresh, max_len=20): n = len(s) min_repeat_chars = thresh * 2 if n < min_repeat_chars: return s i = 0 result = [] while i <= n - min_repeat_chars: found = False for k in range(1, max_len + 1): if i + k * thresh > n: break pattern = s[i:i+k] valid = True for rep in range(1, thresh): start_idx = i + rep * k if s[start_idx:start_idx+k] != pattern: valid = False break if valid: total_rep = thresh end_index = i + thresh * k while end_index + k <= n and s[end_index:end_index+k] == pattern: total_rep += 1 end_index += k result.append(pattern) result.append(fix_pattern_repeats(s[end_index:], thresh, max_len)) i = n found = True break if found: break else: result.append(s[i]) i += 1 if not found: result.append(s[i:]) return ''.join(result) text_raw = text text = fix_char_repeats(text_raw, threshold) text = fix_pattern_repeats(text, threshold) return text def parse_asr_output( raw: str, user_language: Optional[str] = None, ) -> Tuple[str, str]: """ Parse Qwen3-ASR raw output into (language, text). Cases: - With tag: "language Chinese...." - With newlines: "language Chinese\\n...\\n...." - No tag: treat whole string as text. - "language None": treat as empty audio -> ("", "") If user_language is provided, language is forced to user_language and raw is treated as text-only (the model is expected to output plain transcription without metadata). Args: raw: Raw decoded string. user_language: Canonical language name if user forced language. Returns: Tuple[str, str]: (language, text) """ if raw is None: return "", "" s = str(raw).strip() if not s: return "", "" s = detect_and_fix_repetitions(s) if user_language: # user explicitly forced language => model output is treated as pure text return user_language, s meta_part = s text_part = "" has_tag = _ASR_TEXT_TAG in s if has_tag: meta_part, text_part = s.split(_ASR_TEXT_TAG, 1) else: # no tag => pure text return "", s.strip() meta_lower = meta_part.lower() # empty audio heuristic if "language none" in meta_lower: t = text_part.strip() if not t: return "", "" # if model still returned something, keep it but language unknown return "", t # extract "language xxx" from meta lang = "" for line in meta_part.splitlines(): line = line.strip() if not line: continue low = line.lower() if low.startswith(_LANG_PREFIX): val = line[len(_LANG_PREFIX):].strip() if val: lang = normalize_language_name(val) break return lang, text_part.strip() def merge_languages(langs: List[str]) -> str: """ Merge per-chunk languages into a compact comma-separated string, keeping order and removing consecutive duplicates and empty entries. Example: ["Chinese", "English", "English"] -> "Chinese,English" Args: langs: List of canonical language names. Returns: str: Merged language string. """ out: List[str] = [] prev = None for x in langs: x = (x or "").strip() if not x: continue if x == prev: continue out.append(x) prev = x return ",".join(out)