# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unicodedata from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import nagisa import torch from qwen_asr.core.transformers_backend import ( Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor, ) from transformers import AutoConfig, AutoModel, AutoProcessor from .utils import ( AudioLike, ensure_list, normalize_audios, ) class Qwen3ForceAlignProcessor(): def __init__(self): ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict") ko_scores = {} with open(ko_dict_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue word = line.split()[0] ko_scores[word] = 1.0 self.ko_score = ko_scores self.ko_tokenizer = None def is_kept_char(self, ch: str) -> bool: if ch == "'": return True cat = unicodedata.category(ch) if cat.startswith("L") or cat.startswith("N"): return True return False def clean_token(self, token: str) -> str: return "".join(ch for ch in token if self.is_kept_char(ch)) def is_cjk_char(self, ch: str) -> bool: code = ord(ch) return ( 0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs or 0x3400 <= code <= 0x4DBF # Extension A or 0x20000 <= code <= 0x2A6DF # Extension B or 0x2A700 <= code <= 0x2B73F # Extension C or 0x2B740 <= code <= 0x2B81F # Extension D or 0x2B820 <= code <= 0x2CEAF # Extension E or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs ) def tokenize_chinese_mixed(self, text: str) -> List[str]: tokens: List[str] = [] current_latin: List[str] = [] def flush_latin(): nonlocal current_latin if current_latin: token = "".join(current_latin) cleaned = self.clean_token(token) if cleaned: tokens.append(cleaned) current_latin = [] for ch in text: if self.is_cjk_char(ch): flush_latin() tokens.append(ch) else: if self.is_kept_char(ch): current_latin.append(ch) else: flush_latin() flush_latin() return tokens def tokenize_japanese(self, text: str) -> List[str]: words = nagisa.tagging(text).words tokens: List[str] = [] for w in words: cleaned = self.clean_token(w) if cleaned: tokens.append(cleaned) return tokens def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]: raw_tokens = ko_tokenizer.tokenize(text) tokens: List[str] = [] for w in raw_tokens: w_clean = self.clean_token(w) if w_clean: tokens.append(w_clean) return tokens def split_segment_with_chinese(self, seg: str) -> List[str]: tokens: List[str] = [] buf: List[str] = [] def flush_buf(): nonlocal buf if buf: tokens.append("".join(buf)) buf = [] for ch in seg: if self.is_cjk_char(ch): flush_buf() tokens.append(ch) else: buf.append(ch) flush_buf() return tokens def tokenize_space_lang(self, text: str) -> List[str]: tokens: List[str] = [] for seg in text.split(): cleaned = self.clean_token(seg) if cleaned: tokens.extend(self.split_segment_with_chinese(cleaned)) return tokens def fix_timestamp(self, data) -> List[int]: data = data.tolist() n = len(data) dp = [1] * n parent = [-1] * n for i in range(1, n): for j in range(i): if data[j] <= data[i] and dp[j] + 1 > dp[i]: dp[i] = dp[j] + 1 parent[i] = j max_length = max(dp) max_idx = dp.index(max_length) lis_indices = [] idx = max_idx while idx != -1: lis_indices.append(idx) idx = parent[idx] lis_indices.reverse() is_normal = [False] * n for idx in lis_indices: is_normal[idx] = True result = data.copy() i = 0 while i < n: if not is_normal[i]: j = i while j < n and not is_normal[j]: j += 1 anomaly_count = j - i if anomaly_count <= 2: left_val = None for k in range(i - 1, -1, -1): if is_normal[k]: left_val = result[k] break right_val = None for k in range(j, n): if is_normal[k]: right_val = result[k] break for k in range(i, j): if left_val is None: result[k] = right_val elif right_val is None: result[k] = left_val else: result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val else: left_val = None for k in range(i - 1, -1, -1): if is_normal[k]: left_val = result[k] break right_val = None for k in range(j, n): if is_normal[k]: right_val = result[k] break if left_val is not None and right_val is not None: step = (right_val - left_val) / (anomaly_count + 1) for k in range(i, j): result[k] = left_val + step * (k - i + 1) elif left_val is not None: for k in range(i, j): result[k] = left_val elif right_val is not None: for k in range(i, j): result[k] = right_val i = j else: i += 1 return [int(res) for res in result] def encode_timestamp(self, text: str, language: str) -> List[str]: language = language.lower() if language.lower() == "japanese": word_list = self.tokenize_japanese(text) elif language.lower() == "korean": if self.ko_tokenizer is None: from soynlp.tokenizer import LTokenizer self.ko_tokenizer = LTokenizer(scores=self.ko_score) word_list = self.tokenize_korean(self.ko_tokenizer, text) else: word_list = self.tokenize_space_lang(text) input_text = "".join(word_list) + "" input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text return word_list, input_text def parse_timestamp(self, word_list, timestamp): timestamp_output = [] timestamp_fixed = self.fix_timestamp(timestamp) for i, word in enumerate(word_list): start_time = timestamp_fixed[i * 2] end_time = timestamp_fixed[i * 2 + 1] timestamp_output.append({ "text": word, "start_time": start_time, "end_time": end_time }) return timestamp_output @dataclass(frozen=True) class ForcedAlignItem: """ One aligned item span. Attributes: text (str): The aligned unit (cjk character or word) produced by the forced aligner processor. start_time (float): Start time in seconds. end_time (float): End time in seconds. """ text: str start_time: int end_time: int @dataclass(frozen=True) class ForcedAlignResult: """ Forced alignment output for one sample. Attributes: items (List[ForcedAlignItem]): Aligned token spans. """ items: List[ForcedAlignItem] def __iter__(self): return iter(self.items) def __len__(self): return len(self.items) def __getitem__(self, idx: int) -> ForcedAlignItem: return self.items[idx] class Qwen3ForcedAligner: """ A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference. This wrapper provides: - `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor - audio input normalization (path/URL/base64/(np.ndarray, sr)) - batch and single-sample forced alignment - structured output with attribute access (`.text`, `.start_time`, `.end_time`) """ def __init__( self, model: Qwen3ASRForConditionalGeneration, processor: Qwen3ASRProcessor, aligner_processor: Qwen3ForceAlignProcessor, ): self.model = model self.processor = processor self.aligner_processor = aligner_processor self.device = getattr(model, "device", None) if self.device is None: try: self.device = next(model.parameters()).device except StopIteration: self.device = torch.device("cpu") self.timestamp_token_id = int(model.config.timestamp_token_id) self.timestamp_segment_time = float(model.config.timestamp_segment_time) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, **kwargs, ) -> "Qwen3ForcedAligner": """ Load Qwen3-ForcedAligner model and initialize processors. This method: 1) Registers config/model/processor for HF auto classes. 2) Loads the model using `AutoModel.from_pretrained(...)`. 3) Initializes: - HF processor (`AutoProcessor.from_pretrained(...)`) - forced alignment text processor (`Qwen3ForceAlignProcessor()`) Args: pretrained_model_name_or_path (str): HuggingFace repo id or local directory. **kwargs: Forwarded to `AutoModel.from_pretrained(...)`. Typical examples: device_map="cuda:0", dtype=torch.bfloat16. Returns: Qwen3ForcedAligner: Initialized wrapper instance. """ AutoConfig.register("qwen3_asr", Qwen3ASRConfig) AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) if not isinstance(model, Qwen3ASRForConditionalGeneration): raise TypeError( f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration." ) processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True) aligner_processor = Qwen3ForceAlignProcessor() return cls(model=model, processor=processor, aligner_processor=aligner_processor) def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult: items: List[ForcedAlignItem] = [] for it in timestamp_output: items.append( ForcedAlignItem( text=str(it.get("text", "")), start_time=float(it.get("start_time", 0)), end_time=float(it.get("end_time", 0)), ) ) return ForcedAlignResult(items=items) @torch.inference_mode() def align( self, audio: Union[AudioLike, List[AudioLike]], text: Union[str, List[str]], language: Union[str, List[str]], ) -> List[ForcedAlignResult]: """ Run forced alignment for batch or single sample. Args: audio: Audio input(s). Each item supports: - local path / https URL / base64 string - (np.ndarray, sr) All audios will be converted into mono 16k float32 arrays in [-1, 1]. text: Transcript(s) for alignment. language: Language(s) for each sample (e.g., "Chinese", "English"). Returns: List[ForcedAlignResult]: One result per sample. Each result contains `items`, and each token can be accessed via `.text`, `.start_time`, `.end_time`. """ texts = ensure_list(text) languages = ensure_list(language) audios = normalize_audios(audio) if len(languages) == 1 and len(audios) > 1: languages = languages * len(audios) if not (len(audios) == len(texts) == len(languages)): raise ValueError( f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}" ) word_lists = [] aligner_input_texts = [] for t, lang in zip(texts, languages): word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang) word_lists.append(word_list) aligner_input_texts.append(aligner_input_text) inputs = self.processor( text=aligner_input_texts, audio=audios, return_tensors="pt", padding=True, ) inputs = inputs.to(self.model.device).to(self.model.dtype) logits = self.model.thinker(**inputs).logits output_ids = logits.argmax(dim=-1) results: List[ForcedAlignResult] = [] for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists): masked_output_id = output_id[input_id == self.timestamp_token_id] timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy() timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms) for it in timestamp_output: it['start_time'] = round(it['start_time'] / 1000.0, 3) it['end_time'] = round(it['end_time'] / 1000.0, 3) results.append(self._to_structured_items(timestamp_output)) return results def get_supported_languages(self) -> Optional[List[str]]: """ List supported language names for the current model. This is a thin wrapper around `self.model.get_support_languages()`. If the underlying model does not expose language constraints (returns None), this method also returns None. Returns: Optional[List[str]]: - A sorted list of supported language names (lowercased), if available. - None if the model does not provide supported languages. """ fn = getattr(self.model, "get_support_languages", None) if not callable(fn): return None langs = fn() if langs is None: return None return sorted({str(x).lower() for x in langs})