Initial commit

This commit is contained in:
Xiong Wang
2026-01-29 20:23:50 +08:00
commit 9567667698
32 changed files with 30029 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,821 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
from .qwen3_forced_aligner import Qwen3ForcedAligner
from .utils import (
MAX_ASR_INPUT_SECONDS,
MAX_FORCE_ALIGN_INPUT_SECONDS,
SAMPLE_RATE,
SUPPORTED_LANGUAGES,
AudioChunk,
AudioLike,
chunk_list,
merge_languages,
normalize_audios,
normalize_language_name,
parse_asr_output,
split_audio_into_chunks,
validate_language,
)
try:
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
from vllm import ModelRegistry
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
except:
pass
@dataclass
class ASRTranscription:
"""
One transcription result.
Attributes:
language (str):
Merged language string for the sample, e.g. "Chinese" or "Chinese,English".
Empty string if unknown or silent audio.
text (str):
Transcribed text.
time_stamps (Optional[Any]):
Forced aligner output (ForcedAlignResult).
Present only when return_time_stamps=True.
"""
language: str
text: str
time_stamps: Optional[Any] = None
@dataclass
class ASRStreamingState:
"""
Streaming ASR state for one audio stream (single utterance).
Attributes:
unfixed_chunk_num (int):
For the first N chunks, do not use previous ASR result as prefix prompt (reset prefix to "").
unfixed_token_num (int):
When chunk_id >= unfixed_chunk_num, rollback the last K tokens from the accumulated text
before using it as prefix prompt, to reduce boundary jitter.
chunk_size_sec (float):
Chunk size in seconds. Audio will be fed to the model in increments of this length.
chunk_size_samples (int):
Chunk size in samples at 16kHz (derived from chunk_size_sec).
chunk_id (int):
Current chunk index (0-based).
buffer (np.ndarray):
Buffered PCM samples that are not yet consumed into a full chunk.
audio_accum (np.ndarray):
Accumulated audio from the beginning of the stream up to current time (no padding).
prompt_raw (str):
Base prompt generated by chat template (with generation prompt), without appended prefix text.
context (str):
Context string.
force_language (Optional[str]):
If provided, force output to be text-only by appending "language X<asr_text>" in prompt_raw,
consistent with non-streaming transcribe().
language (str):
Latest parsed language (updated after each chunk decode). Empty if unknown/silent.
text (str):
Latest parsed transcription text (updated after each chunk decode).
_raw_decoded (str):
Internal accumulated decoded raw text (before parse_asr_output normalization).
Used for rollback/token trimming and as prefix for prompting.
"""
unfixed_chunk_num: int
unfixed_token_num: int
chunk_size_sec: float
chunk_size_samples: int
chunk_id: int
buffer: np.ndarray
audio_accum: np.ndarray
prompt_raw: str
context: str
force_language: Optional[str]
language: str
text: str
_raw_decoded: str
class Qwen3ASRModel:
"""
Unified inference wrapper for Qwen3-ASR with two backends:
- Transformers backend
- vLLM backend
It optionally supports time stamp output via Qwen3-ForcedAligner.
Notes:
- Each request uses a context text and exactly one audio.
- If language is provided, the prompt will force the output to be text-only by appending
"language {Language}<asr_text>" to the assistant prompt.
"""
def __init__(
self,
backend: str,
model: Any,
processor: Any,
sampling_params: Optional[Any] = None,
forced_aligner: Optional[Qwen3ForcedAligner] = None,
max_inference_batch_size: int = -1,
max_new_tokens: int = 512,
):
self.backend = backend # "transformers" | "vllm"
self.model = model
self.processor = processor
self.sampling_params = sampling_params
self.forced_aligner = forced_aligner
self.max_inference_batch_size = int(max_inference_batch_size)
self.max_new_tokens = max_new_tokens
if backend == "transformers":
self.device = getattr(model, "device", None)
if self.device is None:
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
self.dtype = getattr(model, "dtype", torch.float32)
else:
self.device = None
self.dtype = None
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
forced_aligner: Optional[str] = None,
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
max_inference_batch_size: int = 32,
max_new_tokens: Optional[int] = 512,
**kwargs,
) -> "Qwen3ASRModel":
"""
Initialize using Transformers backend.
Args:
pretrained_model_name_or_path:
HuggingFace repo id or local directory.
forced_aligner:
Optional forced aligner model path/repo id.
forced_aligner_kwargs:
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
max_inference_batch_size:
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
max_new_tokens:
Maximum number of tokens to generate.
**kwargs:
Forwarded to AutoModel.from_pretrained(...).
Returns:
Qwen3ASRModel
"""
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
forced_aligner_model = None
if forced_aligner is not None:
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
forced_aligner, **(forced_aligner_kwargs or {})
)
return cls(
backend="transformers",
model=model,
processor=processor,
sampling_params=None,
forced_aligner=forced_aligner_model,
max_inference_batch_size=max_inference_batch_size,
max_new_tokens=max_new_tokens,
)
@classmethod
def LLM(
cls,
model: str,
forced_aligner: Optional[str] = None,
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
max_inference_batch_size: int = -1,
max_new_tokens: Optional[int] = 4096,
**kwargs,
) -> "Qwen3ASRModel":
"""
Initialize using vLLM backend.
Import is isolated to keep vLLM optional.
Args:
model:
Model path/repo for vLLM.
forced_aligner:
Optional forced aligner model path/repo id.
forced_aligner_kwargs:
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
max_inference_batch_size:
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
max_new_tokens:
Maximum number of tokens to generate.
**kwargs:
Forwarded to vllm.LLM(...).
Returns:
Qwen3ASRModel
Raises:
ImportError: If vLLM is not installed.
"""
try:
from vllm import LLM as vLLM
from vllm import SamplingParams
except Exception as e:
raise ImportError(
"vLLM is not available. Install with: pip install qwen-asr[vllm]"
) from e
llm = vLLM(model=model, **kwargs)
processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True)
sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens}))
forced_aligner_model = None
if forced_aligner is not None:
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
forced_aligner, **(forced_aligner_kwargs or {})
)
return cls(
backend="vllm",
model=llm,
processor=processor,
sampling_params=sampling_params,
forced_aligner=forced_aligner_model,
max_inference_batch_size=max_inference_batch_size,
max_new_tokens=None,
)
def get_supported_languages(self) -> List[str]:
"""
Returns the supported language list.
Returns:
List[str]: Canonical language names.
"""
return list(SUPPORTED_LANGUAGES)
@torch.no_grad()
def transcribe(
self,
audio: Union[AudioLike, List[AudioLike]],
context: Union[str, List[str]] = "",
language: Optional[Union[str, List[Optional[str]]]] = None,
return_time_stamps: bool = False,
) -> List[ASRTranscription]:
"""
Transcribe audio with optional context and optional forced alignment timestamps.
Args:
audio:
Audio input(s). Supported:
- str: local path / URL / base64 data url
- (np.ndarray, sr)
- list of above
context:
Context string(s). If scalar, it will be broadcast to batch size.
language:
Optional language(s). If provided, it must be in supported languages.
If scalar, it will be broadcast to batch size.
If provided, the prompt will force output to be transcription text only.
return_time_stamps:
If True, timestamps are produced via forced aligner and merged across chunks.
This requires forced_aligner initialized.
Returns:
List[ASRTranscription]: One result per input audio.
Raises:
ValueError:
- If return_time_stamps=True but forced_aligner is not provided.
- If language is unsupported.
- If batch sizes mismatch for context/language.
"""
if return_time_stamps and self.forced_aligner is None:
raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.")
wavs = normalize_audios(audio)
n = len(wavs)
ctxs = context if isinstance(context, list) else [context]
if len(ctxs) == 1 and n > 1:
ctxs = ctxs * n
if len(ctxs) != n:
raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}")
langs_in: List[Optional[str]]
if language is None:
langs_in = [None] * n
else:
langs_in = language if isinstance(language, list) else [language]
if len(langs_in) == 1 and n > 1:
langs_in = langs_in * n
if len(langs_in) != n:
raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}")
langs_norm: List[Optional[str]] = []
for l in langs_in:
if l is None or str(l).strip() == "":
langs_norm.append(None)
else:
ln = normalize_language_name(str(l))
validate_language(ln)
langs_norm.append(ln)
max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS
# chunk audios and record mapping
chunks: List[AudioChunk] = []
for i, wav in enumerate(wavs):
parts = split_audio_into_chunks(
wav=wav,
sr=SAMPLE_RATE,
max_chunk_sec=max_chunk_sec,
)
for j, (cwav, offset_sec) in enumerate(parts):
chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec))
# run ASR on chunks
chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks]
chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks]
chunk_wavs: List[np.ndarray] = [c.wav for c in chunks]
raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang)
# parse outputs, prepare for optional alignment
per_chunk_lang: List[str] = []
per_chunk_text: List[str] = []
for out, forced_lang in zip(raw_outputs, chunk_lang):
lang, txt = parse_asr_output(out, user_language=forced_lang)
per_chunk_lang.append(lang)
per_chunk_text.append(txt)
# forced alignment (optional)
per_chunk_align: List[Optional[Any]] = [None] * len(chunks)
if return_time_stamps:
to_align_audio = []
to_align_text = []
to_align_lang = []
to_align_idx = []
for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)):
if txt.strip() == "":
continue
to_align_audio.append((c.wav, c.sr))
to_align_text.append(txt)
to_align_lang.append(lang_pred)
to_align_idx.append(idx)
# batch align with max_inference_batch_size
aligned_results: List[Any] = []
for a_chunk, t_chunk, l_chunk in zip(
chunk_list(to_align_audio, self.max_inference_batch_size),
chunk_list(to_align_text, self.max_inference_batch_size),
chunk_list(to_align_lang, self.max_inference_batch_size),
):
aligned_results.extend(
self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk)
)
# offset fix
for k, idx in enumerate(to_align_idx):
c = chunks[idx]
r = aligned_results[k]
per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec)
# merge chunks back to original samples
out_langs: List[List[str]] = [[] for _ in range(n)]
out_texts: List[List[str]] = [[] for _ in range(n)]
out_aligns: List[List[Any]] = [[] for _ in range(n)]
for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align):
out_langs[c.orig_index].append(lang)
out_texts[c.orig_index].append(txt)
if return_time_stamps and al is not None:
out_aligns[c.orig_index].append(al)
results: List[ASRTranscription] = []
for i in range(n):
merged_text = "".join([t for t in out_texts[i] if t is not None])
merged_language = merge_languages(out_langs[i])
merged_align = None
if return_time_stamps:
merged_align = self._merge_align_results(out_aligns[i])
results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align))
return results
def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]:
return [
{"role": "system", "content": context or ""},
{"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
]
def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str:
"""
Build the string prompt for one request.
If force_language is provided, "language X<asr_text>" is appended after the generation prompt
to request text-only output.
"""
msgs = self._build_messages(context=context, audio_payload="")
base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
if force_language:
base = base + f"language {force_language}{'<asr_text>'}"
return base
def _infer_asr(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
"""
Run backend inference for chunk-level items.
Args:
contexts: List of context strings.
wavs: List of mono waveforms (np.ndarray).
languages: List of forced languages or None.
Returns:
List[str]: Raw decoded strings (one per chunk).
"""
if self.backend == "transformers":
return self._infer_asr_transformers(contexts, wavs, languages)
if self.backend == "vllm":
return self._infer_asr_vllm(contexts, wavs, languages)
raise RuntimeError(f"Unknown backend: {self.backend}")
def _infer_asr_transformers(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
outs: List[str] = []
texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)]
batch_size = self.max_inference_batch_size
if batch_size is None or batch_size < 0:
batch_size = len(texts)
for i in range(0, len(texts), batch_size):
sub_text = texts[i : i + batch_size]
sub_wavs = wavs[i : i + batch_size]
inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True)
inputs = inputs.to(self.model.device).to(self.model.dtype)
text_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
decoded = self.processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
outs.extend(list(decoded))
return outs
def _infer_asr_vllm(
self,
contexts: List[str],
wavs: List[np.ndarray],
languages: List[Optional[str]],
) -> List[str]:
inputs: List[Dict[str, Any]] = []
for c, w, fl in zip(contexts, wavs, languages):
prompt = self._build_text_prompt(context=c, force_language=fl)
inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}})
outs: List[str] = []
for batch in chunk_list(inputs, self.max_inference_batch_size):
outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False)
for o in outputs:
outs.append(o.outputs[0].text)
return outs
def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
"""
Apply time offset to a ForcedAlignResult-like object.
This function assumes:
- result has attribute `.items` which is a list of items with start_time/end_time in seconds.
- dataclasses are frozen in upstream implementation, so we reconstruct by type.
Args:
result: ForcedAlignResult
offset_sec: Offset in seconds
Returns:
ForcedAlignResult: New object with shifted timestamps.
"""
if result is None:
return None
items = []
for it in result.items:
items.append(type(it)(text=it.text,
start_time=round(it.start_time + offset_sec, 3),
end_time=round(it.end_time + offset_sec, 3)))
return type(result)(items=items)
def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
"""
Merge multiple ForcedAlignResult objects into a single one by concatenating items.
Args:
results: List of ForcedAlignResult
Returns:
ForcedAlignResult or None
"""
if not results:
return None
all_items = []
for r in results:
if r is None:
continue
all_items.extend(list(r.items))
if not all_items:
return None
return type(results[0])(items=all_items)
def init_streaming_state(
self,
context: str = "",
language: Optional[str] = None,
unfixed_chunk_num: int = 2,
unfixed_token_num: int = 5,
chunk_size_sec: float = 2.0,
) -> ASRStreamingState:
"""
Initialize streaming ASR state for a single stream.
Notes:
- Streaming ASR is supported ONLY for vLLM backend.
- Streaming ASR does NOT support timestamps (forced aligner is not used).
- Batch inference is NOT supported.
Args:
context:
Context string.
language:
Optional forced language. If provided, it must be in supported languages.
Same behavior as transcribe(): forces text-only output via prompt suffix.
unfixed_chunk_num:
For the first N chunks, do not use previous output as prefix prompt (reset prefix to "").
unfixed_token_num:
Roll back the last K tokens from accumulated output when using it as prefix prompt
after unfixed_chunk_num.
chunk_size_sec:
Chunk size in seconds (audio is 16k PCM). The function will internally convert it
to sample count at 16kHz.
Returns:
ASRStreamingState: Mutable state object to be passed to streaming_transcribe() and
finish_streaming_transcribe().
Raises:
ValueError:
- If backend is not "vllm".
- If chunk_size_sec <= 0.
- If forced language is invalid (same validation rules as transcribe()).
"""
if self.backend != "vllm":
raise ValueError("Streaming ASR is supported only for vLLM backend (backend='vllm').")
if chunk_size_sec is None or float(chunk_size_sec) <= 0:
raise ValueError(f"chunk_size_sec must be > 0, got: {chunk_size_sec}")
force_language = None
if language is not None and str(language).strip() != "":
ln = normalize_language_name(str(language))
validate_language(ln)
force_language = ln
chunk_size_samples = int(round(float(chunk_size_sec) * SAMPLE_RATE))
chunk_size_samples = max(1, chunk_size_samples)
prompt_raw = self._build_text_prompt(context=context, force_language=force_language)
return ASRStreamingState(
unfixed_chunk_num=int(unfixed_chunk_num),
unfixed_token_num=int(unfixed_token_num),
chunk_size_sec=float(chunk_size_sec),
chunk_size_samples=int(chunk_size_samples),
chunk_id=0,
buffer=np.zeros((0,), dtype=np.float32),
audio_accum=np.zeros((0,), dtype=np.float32),
prompt_raw=prompt_raw,
context=context or "",
force_language=force_language,
language="",
text="",
_raw_decoded="",
)
def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) -> ASRStreamingState:
"""
Streaming ASR decode step.
This function accepts an arbitrary-length 16k PCM float numpy array (mono).
It buffers incoming samples, and whenever enough samples are accumulated to form one
full chunk (chunk_size_sec), it runs one incremental decode step and updates:
- state.language
- state.text
The caller only needs to keep passing audio to this function and read state.language/state.text.
Implementation details:
- Each time a new chunk is ready, we append it to audio_accum and re-feed *all* audio seen
so far to the model (no padding).
- We update the prompt as: state.prompt_raw + prefix_text
- Prefix rollback strategy:
* If chunk_id < unfixed_chunk_num: prefix_text = ""
* Else: rollback last unfixed_token_num tokens from previously accumulated decoded text.
Notes:
- vLLM backend only.
- No timestamps.
- Single stream only (no batching).
Args:
pcm16k:
16kHz mono PCM waveform (np.ndarray). Length can be any non-negative integer.
dtype can be float32/float64/int16; it will be converted to float32.
state:
Streaming state returned by init_streaming_state().
Returns:
ASRStreamingState: The same state object (mutated) for convenience.
Raises:
ValueError:
If backend is not "vllm" or state is invalid.
"""
if self.backend != "vllm":
raise ValueError("streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
if state is None:
raise ValueError("state must not be None. Call init_streaming_state() first.")
if pcm16k is None:
raise ValueError("pcm16k must not be None.")
# Ensure 1D mono
x = np.asarray(pcm16k)
if x.ndim != 1:
x = x.reshape(-1)
# Convert to float32 PCM in [-1, 1] if int16 provided
if x.dtype == np.int16:
x = (x.astype(np.float32) / 32768.0)
else:
x = x.astype(np.float32, copy=False)
# Append to buffer
if x.shape[0] > 0:
state.buffer = np.concatenate([state.buffer, x], axis=0)
# Consume full chunks
while state.buffer.shape[0] >= state.chunk_size_samples:
chunk = state.buffer[: state.chunk_size_samples]
state.buffer = state.buffer[state.chunk_size_samples :]
# Accumulate audio (re-feed from start, no padding)
if state.audio_accum.shape[0] == 0:
state.audio_accum = chunk
else:
state.audio_accum = np.concatenate([state.audio_accum, chunk], axis=0)
# Build prefix with rollback strategy
prefix = ""
if state.chunk_id < state.unfixed_chunk_num:
prefix = ""
else:
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
prompt = state.prompt_raw + prefix
# vLLM input: single item
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
gen_text = outputs[0].outputs[0].text
# Accumulate raw decoded (then parse to lang/text)
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
state.language = lang
state.text = txt
state.chunk_id += 1
return state
def finish_streaming_transcribe(self, state: ASRStreamingState) -> ASRStreamingState:
"""
Finish streaming ASR.
This function flushes the remaining buffered audio in state.buffer (tail audio).
It sends the remaining samples to the model even if shorter than chunk_size_sec,
without padding. Then it updates state.language/state.text one last time.
Notes:
- vLLM backend only.
- No timestamps.
- Single stream only.
Args:
state:
Streaming state.
Returns:
ASRStreamingState: Updated state (mutated).
Raises:
ValueError:
If backend is not "vllm" or state is invalid.
"""
if self.backend != "vllm":
raise ValueError("finish_streaming_transcribe() is supported only for vLLM backend (backend='vllm').")
if state is None:
raise ValueError("state must not be None.")
# If no remaining buffer, still return state as-is.
if state.buffer is None or state.buffer.shape[0] == 0:
return state
tail = state.buffer
state.buffer = np.zeros((0,), dtype=np.float32)
# Append tail to accumulated audio
if state.audio_accum.shape[0] == 0:
state.audio_accum = tail
else:
state.audio_accum = np.concatenate([state.audio_accum, tail], axis=0)
# Prefix rollback strategy (same as per-chunk)
prefix = ""
if state.chunk_id < state.unfixed_chunk_num:
prefix = ""
else:
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
prompt = state.prompt_raw + prefix
inp = {"prompt": prompt, "multi_modal_data": {"audio": [state.audio_accum]}}
outputs = self.model.generate([inp], sampling_params=self.sampling_params, use_tqdm=False)
gen_text = outputs[0].outputs[0].text
state._raw_decoded = (prefix + gen_text) if prefix is not None else gen_text
lang, txt = parse_asr_output(state._raw_decoded, user_language=state.force_language)
state.language = lang
state.text = txt
state.chunk_id += 1
return state

View File

@ -0,0 +1,483 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unicodedata
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import nagisa
import torch
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
from .utils import (
AudioLike,
ensure_list,
normalize_audios,
)
class Qwen3ForceAlignProcessor():
def __init__(self):
ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict")
ko_scores = {}
with open(ko_dict_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
word = line.split()[0]
ko_scores[word] = 1.0
self.ko_score = ko_scores
self.ko_tokenizer = None
def is_kept_char(self, ch: str) -> bool:
if ch == "'":
return True
cat = unicodedata.category(ch)
if cat.startswith("L") or cat.startswith("N"):
return True
return False
def clean_token(self, token: str) -> str:
return "".join(ch for ch in token if self.is_kept_char(ch))
def is_cjk_char(self, ch: str) -> bool:
code = ord(ch)
return (
0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs
or 0x3400 <= code <= 0x4DBF # Extension A
or 0x20000 <= code <= 0x2A6DF # Extension B
or 0x2A700 <= code <= 0x2B73F # Extension C
or 0x2B740 <= code <= 0x2B81F # Extension D
or 0x2B820 <= code <= 0x2CEAF # Extension E
or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs
)
def tokenize_chinese_mixed(self, text: str) -> List[str]:
tokens: List[str] = []
current_latin: List[str] = []
def flush_latin():
nonlocal current_latin
if current_latin:
token = "".join(current_latin)
cleaned = self.clean_token(token)
if cleaned:
tokens.append(cleaned)
current_latin = []
for ch in text:
if self.is_cjk_char(ch):
flush_latin()
tokens.append(ch)
else:
if self.is_kept_char(ch):
current_latin.append(ch)
else:
flush_latin()
flush_latin()
return tokens
def tokenize_japanese(self, text: str) -> List[str]:
words = nagisa.tagging(text).words
tokens: List[str] = []
for w in words:
cleaned = self.clean_token(w)
if cleaned:
tokens.append(cleaned)
return tokens
def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]:
raw_tokens = ko_tokenizer.tokenize(text)
tokens: List[str] = []
for w in raw_tokens:
w_clean = self.clean_token(w)
if w_clean:
tokens.append(w_clean)
return tokens
def split_segment_with_chinese(self, seg: str) -> List[str]:
tokens: List[str] = []
buf: List[str] = []
def flush_buf():
nonlocal buf
if buf:
tokens.append("".join(buf))
buf = []
for ch in seg:
if self.is_cjk_char(ch):
flush_buf()
tokens.append(ch)
else:
buf.append(ch)
flush_buf()
return tokens
def tokenize_space_lang(self, text: str) -> List[str]:
tokens: List[str] = []
for seg in text.split():
cleaned = self.clean_token(seg)
if cleaned:
tokens.extend(self.split_segment_with_chinese(cleaned))
return tokens
def fix_timestamp(self, data) -> List[int]:
data = data.tolist()
n = len(data)
dp = [1] * n
parent = [-1] * n
for i in range(1, n):
for j in range(i):
if data[j] <= data[i] and dp[j] + 1 > dp[i]:
dp[i] = dp[j] + 1
parent[i] = j
max_length = max(dp)
max_idx = dp.index(max_length)
lis_indices = []
idx = max_idx
while idx != -1:
lis_indices.append(idx)
idx = parent[idx]
lis_indices.reverse()
is_normal = [False] * n
for idx in lis_indices:
is_normal[idx] = True
result = data.copy()
i = 0
while i < n:
if not is_normal[i]:
j = i
while j < n and not is_normal[j]:
j += 1
anomaly_count = j - i
if anomaly_count <= 2:
left_val = None
for k in range(i - 1, -1, -1):
if is_normal[k]:
left_val = result[k]
break
right_val = None
for k in range(j, n):
if is_normal[k]:
right_val = result[k]
break
for k in range(i, j):
if left_val is None:
result[k] = right_val
elif right_val is None:
result[k] = left_val
else:
result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val
else:
left_val = None
for k in range(i - 1, -1, -1):
if is_normal[k]:
left_val = result[k]
break
right_val = None
for k in range(j, n):
if is_normal[k]:
right_val = result[k]
break
if left_val is not None and right_val is not None:
step = (right_val - left_val) / (anomaly_count + 1)
for k in range(i, j):
result[k] = left_val + step * (k - i + 1)
elif left_val is not None:
for k in range(i, j):
result[k] = left_val
elif right_val is not None:
for k in range(i, j):
result[k] = right_val
i = j
else:
i += 1
return [int(res) for res in result]
def encode_timestamp(self, text: str, language: str) -> List[str]:
language = language.lower()
if language.lower() == "japanese":
word_list = self.tokenize_japanese(text)
elif language.lower() == "korean":
if self.ko_tokenizer is None:
from soynlp.tokenizer import LTokenizer
self.ko_tokenizer = LTokenizer(scores=self.ko_score)
word_list = self.tokenize_korean(self.ko_tokenizer, text)
else:
word_list = self.tokenize_space_lang(text)
input_text = "<timestamp><timestamp>".join(word_list) + "<timestamp><timestamp>"
input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text
return word_list, input_text
def parse_timestamp(self, word_list, timestamp):
timestamp_output = []
timestamp_fixed = self.fix_timestamp(timestamp)
for i, word in enumerate(word_list):
start_time = timestamp_fixed[i * 2]
end_time = timestamp_fixed[i * 2 + 1]
timestamp_output.append({
"text": word,
"start_time": start_time,
"end_time": end_time
})
return timestamp_output
@dataclass(frozen=True)
class ForcedAlignItem:
"""
One aligned item span.
Attributes:
text (str):
The aligned unit (cjk character or word) produced by the forced aligner processor.
start_time (float):
Start time in seconds.
end_time (float):
End time in seconds.
"""
text: str
start_time: int
end_time: int
@dataclass(frozen=True)
class ForcedAlignResult:
"""
Forced alignment output for one sample.
Attributes:
items (List[ForcedAlignItem]):
Aligned token spans.
"""
items: List[ForcedAlignItem]
def __iter__(self):
return iter(self.items)
def __len__(self):
return len(self.items)
def __getitem__(self, idx: int) -> ForcedAlignItem:
return self.items[idx]
class Qwen3ForcedAligner:
"""
A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.
This wrapper provides:
- `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
- audio input normalization (path/URL/base64/(np.ndarray, sr))
- batch and single-sample forced alignment
- structured output with attribute access (`.text`, `.start_time`, `.end_time`)
"""
def __init__(
self,
model: Qwen3ASRForConditionalGeneration,
processor: Qwen3ASRProcessor,
aligner_processor: Qwen3ForceAlignProcessor,
):
self.model = model
self.processor = processor
self.aligner_processor = aligner_processor
self.device = getattr(model, "device", None)
if self.device is None:
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
self.timestamp_token_id = int(model.config.timestamp_token_id)
self.timestamp_segment_time = float(model.config.timestamp_segment_time)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs,
) -> "Qwen3ForcedAligner":
"""
Load Qwen3-ForcedAligner model and initialize processors.
This method:
1) Registers config/model/processor for HF auto classes.
2) Loads the model using `AutoModel.from_pretrained(...)`.
3) Initializes:
- HF processor (`AutoProcessor.from_pretrained(...)`)
- forced alignment text processor (`Qwen3ForceAlignProcessor()`)
Args:
pretrained_model_name_or_path (str):
HuggingFace repo id or local directory.
**kwargs:
Forwarded to `AutoModel.from_pretrained(...)`.
Typical examples: device_map="cuda:0", dtype=torch.bfloat16.
Returns:
Qwen3ForcedAligner:
Initialized wrapper instance.
"""
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
if not isinstance(model, Qwen3ASRForConditionalGeneration):
raise TypeError(
f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration."
)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
aligner_processor = Qwen3ForceAlignProcessor()
return cls(model=model, processor=processor, aligner_processor=aligner_processor)
def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
items: List[ForcedAlignItem] = []
for it in timestamp_output:
items.append(
ForcedAlignItem(
text=str(it.get("text", "")),
start_time=float(it.get("start_time", 0)),
end_time=float(it.get("end_time", 0)),
)
)
return ForcedAlignResult(items=items)
@torch.inference_mode()
def align(
self,
audio: Union[AudioLike, List[AudioLike]],
text: Union[str, List[str]],
language: Union[str, List[str]],
) -> List[ForcedAlignResult]:
"""
Run forced alignment for batch or single sample.
Args:
audio:
Audio input(s). Each item supports:
- local path / https URL / base64 string
- (np.ndarray, sr)
All audios will be converted into mono 16k float32 arrays in [-1, 1].
text:
Transcript(s) for alignment.
language:
Language(s) for each sample (e.g., "Chinese", "English").
Returns:
List[ForcedAlignResult]:
One result per sample. Each result contains `items`, and each token can be accessed via
`.text`, `.start_time`, `.end_time`.
"""
texts = ensure_list(text)
languages = ensure_list(language)
audios = normalize_audios(audio)
if len(languages) == 1 and len(audios) > 1:
languages = languages * len(audios)
if not (len(audios) == len(texts) == len(languages)):
raise ValueError(
f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
)
word_lists = []
aligner_input_texts = []
for t, lang in zip(texts, languages):
word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang)
word_lists.append(word_list)
aligner_input_texts.append(aligner_input_text)
inputs = self.processor(
text=aligner_input_texts,
audio=audios,
return_tensors="pt",
padding=True,
)
inputs = inputs.to(self.model.device).to(self.model.dtype)
logits = self.model.thinker(**inputs).logits
output_ids = logits.argmax(dim=-1)
results: List[ForcedAlignResult] = []
for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists):
masked_output_id = output_id[input_id == self.timestamp_token_id]
timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
for it in timestamp_output:
it['start_time'] = round(it['start_time'] / 1000.0, 3)
it['end_time'] = round(it['end_time'] / 1000.0, 3)
results.append(self._to_structured_items(timestamp_output))
return results
def get_supported_languages(self) -> Optional[List[str]]:
"""
List supported language names for the current model.
This is a thin wrapper around `self.model.get_support_languages()`.
If the underlying model does not expose language constraints (returns None),
this method also returns None.
Returns:
Optional[List[str]]:
- A sorted list of supported language names (lowercased), if available.
- None if the model does not provide supported languages.
"""
fn = getattr(self.model, "get_support_languages", None)
if not callable(fn):
return None
langs = fn()
if langs is None:
return None
return sorted({str(x).lower() for x in langs})

497
qwen_asr/inference/utils.py Normal file
View File

@ -0,0 +1,497 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import urllib.request
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple, Union
from urllib.parse import urlparse
import librosa
import numpy as np
import soundfile as sf
AudioLike = Union[
str, # wav path / URL / base64
Tuple[np.ndarray, int], # (waveform, sr)
]
MaybeList = Union[Any, List[Any]]
SAMPLE_RATE = 16000
MAX_ASR_INPUT_SECONDS = 1200
MAX_FORCE_ALIGN_INPUT_SECONDS = 180
MIN_ASR_INPUT_SECONDS = 0.5
SUPPORTED_LANGUAGES: List[str] = [
"Chinese",
"English",
"Cantonese",
"Arabic",
"German",
"French",
"Spanish",
"Portuguese",
"Indonesian",
"Italian",
"Korean",
"Russian",
"Thai",
"Vietnamese",
"Japanese",
"Turkish",
"Hindi",
"Malay",
"Dutch",
"Swedish",
"Danish",
"Finnish",
"Polish",
"Czech",
"Filipino",
"Persian",
"Greek",
"Romanian",
"Hungarian",
"Macedonian"
]
_ASR_TEXT_TAG = "<asr_text>"
_LANG_PREFIX = "language "
def normalize_language_name(language: str) -> str:
"""
Normalize language name to the canonical format used by Qwen3-ASR:
first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese').
Args:
language (str): Input language name.
Returns:
str: Normalized language name.
Raises:
ValueError: If language is empty.
"""
if language is None:
raise ValueError("language is None")
s = str(language).strip()
if not s:
raise ValueError("language is empty")
return s[:1].upper() + s[1:].lower()
def validate_language(language: str) -> None:
"""
Validate the language is supported.
Args:
language (str): Canonical language name.
Raises:
ValueError: If unsupported.
"""
if language not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}")
def ensure_list(x: MaybeList) -> List[Any]:
return x if isinstance(x, list) else [x]
def is_url(s: str) -> bool:
try:
u = urlparse(s)
return u.scheme in ("http", "https") and bool(u.netloc)
except Exception:
return False
def is_probably_base64(s: str) -> bool:
if s.startswith("data:audio"):
return True
if ("/" not in s and "\\" not in s) and len(s) > 256:
return True
return False
def decode_base64_bytes(b64: str) -> bytes:
if "," in b64 and b64.strip().startswith("data:"):
b64 = b64.split(",", 1)[1]
return base64.b64decode(b64)
def load_audio_any(x: str) -> Tuple[np.ndarray, int]:
if is_url(x):
with urllib.request.urlopen(x) as resp:
audio_bytes = resp.read()
with io.BytesIO(audio_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
elif is_probably_base64(x):
audio_bytes = decode_base64_bytes(x)
with io.BytesIO(audio_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
audio, sr = librosa.load(x, sr=None, mono=False)
audio = np.asarray(audio, dtype=np.float32)
sr = int(sr)
return audio, sr
def to_mono(audio: np.ndarray) -> np.ndarray:
if audio.ndim == 1:
return audio
# soundfile can return shape (T, C); some pipelines use (C, T)
if audio.ndim == 2:
if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]:
audio = audio.T
return np.mean(audio, axis=-1).astype(np.float32)
raise ValueError(f"Unsupported audio ndim={audio.ndim}")
def float_range_normalize(audio: np.ndarray) -> np.ndarray:
audio = audio.astype(np.float32)
if audio.size == 0:
return audio
peak = float(np.max(np.abs(audio)))
if peak == 0.0:
return audio
# If decoded audio is int-like scaled or out-of-range, normalize conservatively.
if peak > 1.0:
audio = audio / peak
audio = np.clip(audio, -1.0, 1.0)
return audio
def normalize_audio_input(a: AudioLike) -> np.ndarray:
"""
Normalize one audio input to mono 16k float32 waveform in [-1, 1].
Supported inputs:
- str: local file path / https URL / base64 audio string
- (np.ndarray, sr): waveform and sampling rate
Returns:
np.ndarray:
Mono 16k float32 waveform in [-1, 1].
"""
if isinstance(a, str):
audio, sr = load_audio_any(a)
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
audio, sr = a[0], int(a[1])
else:
raise TypeError(f"Unsupported audio input type: {type(a)}")
audio = to_mono(np.asarray(audio))
if sr != SAMPLE_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
audio = float_range_normalize(audio)
return audio
def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]:
items = ensure_list(audios)
return [normalize_audio_input(a) for a in items]
def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]:
"""
Yield chunks of a list.
Args:
xs (List[Any]): Input list.
chunk_size (int): Chunk size.
Yields:
List[Any]: Slices of xs.
"""
if chunk_size <= 0:
yield xs
return
for i in range(0, len(xs), chunk_size):
yield xs[i : i + chunk_size]
@dataclass(frozen=True)
class AudioChunk:
"""
One chunk cut from an original audio.
Attributes:
orig_index: Index of the original sample in the input batch.
chunk_index: Index of this chunk within the original sample.
wav: Mono float32 waveform.
sr: Sampling rate.
offset_sec: Start offset of this chunk in the original audio, in seconds.
"""
orig_index: int
chunk_index: int
wav: np.ndarray
sr: int
offset_sec: float
def split_audio_into_chunks(
wav: np.ndarray,
sr: int,
max_chunk_sec: float,
search_expand_sec: float = 5.0,
min_window_ms: float = 100.0,
) -> List[Tuple[np.ndarray, float]]:
"""
Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary.
This implementation guarantees:
- Concatenating all returned chunks reproduces the original audio exactly
(total number of samples is identical, no overlaps, no gaps).
Args:
wav: Mono waveform float32.
sr: Sampling rate.
max_chunk_sec: Target max chunk duration in seconds.
search_expand_sec: Boundary search half-window in seconds.
min_window_ms: Sliding window in milliseconds for energy estimation.
Returns:
List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec).
"""
wav = np.asarray(wav, dtype=np.float32)
if wav.ndim > 1:
wav = np.mean(wav, axis=-1).astype(np.float32)
total_len = int(wav.shape[0])
total_sec = total_len / float(sr)
if total_sec <= max_chunk_sec:
return [(wav, 0.0)]
max_len = int(max_chunk_sec * sr)
expand = int(search_expand_sec * sr)
win = max(4, int((min_window_ms / 1000.0) * sr))
chunks: List[Tuple[np.ndarray, float]] = []
start = 0
offset_sec = 0.0
while (total_len - start) > max_len:
cut = start + max_len
left = max(start, cut - expand)
right = min(total_len, cut + expand)
if right - left <= win:
boundary = cut
else:
seg = wav[left:right]
seg_abs = np.abs(seg)
window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid")
min_pos = int(np.argmin(window_sums))
wstart = min_pos
wend = min_pos + win
local = seg_abs[wstart:wend]
inner = int(np.argmin(local))
boundary = left + wstart + inner
boundary = int(max(boundary, start + 1))
boundary = int(min(boundary, total_len))
chunk = wav[start:boundary]
chunks.append((chunk, offset_sec))
offset_sec += (boundary - start) / float(sr)
start = boundary
tail = wav[start:total_len]
chunks.append((tail, offset_sec))
# Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail)
min_len = int(MIN_ASR_INPUT_SECONDS * sr)
padded: List[Tuple[np.ndarray, float]] = []
for c, off in chunks:
if c.shape[0] < min_len:
pad = min_len - int(c.shape[0])
c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32)
padded.append((c, off))
chunks = padded
return chunks
def detect_and_fix_repetitions(text, threshold=20):
def fix_char_repeats(s, thresh):
res = []
i = 0
n = len(s)
while i < n:
count = 1
while i + count < n and s[i + count] == s[i]:
count += 1
if count > thresh:
res.append(s[i])
i += count
else:
res.append(s[i:i+count])
i += count
return ''.join(res)
def fix_pattern_repeats(s, thresh, max_len=20):
n = len(s)
min_repeat_chars = thresh * 2
if n < min_repeat_chars:
return s
i = 0
result = []
while i <= n - min_repeat_chars:
found = False
for k in range(1, max_len + 1):
if i + k * thresh > n:
break
pattern = s[i:i+k]
valid = True
for rep in range(1, thresh):
start_idx = i + rep * k
if s[start_idx:start_idx+k] != pattern:
valid = False
break
if valid:
total_rep = thresh
end_index = i + thresh * k
while end_index + k <= n and s[end_index:end_index+k] == pattern:
total_rep += 1
end_index += k
result.append(pattern)
result.append(fix_pattern_repeats(s[end_index:], thresh, max_len))
i = n
found = True
break
if found:
break
else:
result.append(s[i])
i += 1
if not found:
result.append(s[i:])
return ''.join(result)
text_raw = text
text = fix_char_repeats(text_raw, threshold)
text = fix_pattern_repeats(text, threshold)
return text
def parse_asr_output(
raw: str,
user_language: Optional[str] = None,
) -> Tuple[str, str]:
"""
Parse Qwen3-ASR raw output into (language, text).
Cases:
- With tag: "language Chinese<asr_text>...."
- With newlines: "language Chinese\\n...\\n<asr_text>...."
- No tag: treat whole string as text.
- "language None<asr_text>": treat as empty audio -> ("", "")
If user_language is provided, language is forced to user_language and raw is treated as text-only
(the model is expected to output plain transcription without metadata).
Args:
raw: Raw decoded string.
user_language: Canonical language name if user forced language.
Returns:
Tuple[str, str]: (language, text)
"""
if raw is None:
return "", ""
s = str(raw).strip()
if not s:
return "", ""
s = detect_and_fix_repetitions(s)
if user_language:
# user explicitly forced language => model output is treated as pure text
return user_language, s
meta_part = s
text_part = ""
has_tag = _ASR_TEXT_TAG in s
if has_tag:
meta_part, text_part = s.split(_ASR_TEXT_TAG, 1)
else:
# no tag => pure text
return "", s.strip()
meta_lower = meta_part.lower()
# empty audio heuristic
if "language none" in meta_lower:
t = text_part.strip()
if not t:
return "", ""
# if model still returned something, keep it but language unknown
return "", t
# extract "language xxx" from meta
lang = ""
for line in meta_part.splitlines():
line = line.strip()
if not line:
continue
low = line.lower()
if low.startswith(_LANG_PREFIX):
val = line[len(_LANG_PREFIX):].strip()
if val:
lang = normalize_language_name(val)
break
return lang, text_part.strip()
def merge_languages(langs: List[str]) -> str:
"""
Merge per-chunk languages into a compact comma-separated string,
keeping order and removing consecutive duplicates and empty entries.
Example:
["Chinese", "English", "English"] -> "Chinese,English"
Args:
langs: List of canonical language names.
Returns:
str: Merged language string.
"""
out: List[str] = []
prev = None
for x in langs:
x = (x or "").strip()
if not x:
continue
if x == prev:
continue
out.append(x)
prev = x
return ",".join(out)