# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import re import shutil from dataclasses import dataclass from typing import Any, Dict, List, Optional import librosa import torch from datasets import load_dataset from qwen_asr import Qwen3ASRModel from transformers import (GenerationConfig, Trainer, TrainerCallback, TrainingArguments) def patch_outer_forward(model): cls = model.__class__ if getattr(cls, "_forward_patched", False): return if not hasattr(model, "thinker") or not hasattr(model.thinker, "forward"): raise RuntimeError( "Cannot patch forward: model has no `.thinker.forward`. " "Your qwen3_asr model may be incompatible." ) def forward( self, input_ids=None, attention_mask=None, input_features=None, feature_attention_mask=None, labels=None, **kwargs, ): return self.thinker.forward( input_ids=input_ids, attention_mask=attention_mask, input_features=input_features, feature_attention_mask=feature_attention_mask, labels=labels, **kwargs, ) cls.forward = forward cls._forward_patched = True _CKPT_RE = re.compile(r"^checkpoint-(\d+)$") def find_latest_checkpoint(output_dir: str) -> Optional[str]: if not output_dir or not os.path.isdir(output_dir): return None best_step = None best_path = None for name in os.listdir(output_dir): m = _CKPT_RE.match(name) if not m: continue step = int(m.group(1)) path = os.path.join(output_dir, name) if os.path.isdir(path) and (best_step is None or step > best_step): best_step = step best_path = path return best_path def load_audio(path: str, sr: int = 16000): wav, _ = librosa.load(path, sr=sr, mono=True) return wav def build_prefix_messages(prompt: str, audio_array): return [ {"role": "system", "content": prompt or ""}, {"role": "user", "content": [{"type": "audio", "audio": audio_array}]}, ] def make_preprocess_fn_prefix_only(processor): def _preprocess(ex: Dict[str, Any]) -> Dict[str, Any]: prompt = ex.get("prompt", "") dummy_audio = None prefix_msgs = build_prefix_messages(prompt, dummy_audio) prefix_text = processor.apply_chat_template( [prefix_msgs], add_generation_prompt=True, tokenize=False )[0] return { "prompt": prompt, "audio": ex["audio"], "target": ex["text"], "prefix_text": prefix_text, } return _preprocess @dataclass class DataCollatorForQwen3ASRFinetuning: processor: Any sampling_rate: int = 16000 def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: audio_paths = [f["audio"] for f in features] prefix_texts = [f["prefix_text"] for f in features] targets = [f["target"] for f in features] eos = self.processor.tokenizer.eos_token or "" full_texts = [pfx + tgt + eos for pfx, tgt in zip(prefix_texts, targets)] audios = [load_audio(p, sr=self.sampling_rate) for p in audio_paths] full_inputs = self.processor( text=full_texts, audio=audios, return_tensors="pt", padding=True, truncation=False, ) prefix_inputs = self.processor( text=prefix_texts, audio=audios, return_tensors="pt", padding=True, truncation=False, ) prefix_lens = prefix_inputs["attention_mask"].sum(dim=1).tolist() labels = full_inputs["input_ids"].clone() for i, pl in enumerate(prefix_lens): labels[i, :pl] = -100 pad_id = self.processor.tokenizer.pad_token_id if pad_id is not None: labels[labels == pad_id] = -100 full_inputs["labels"] = labels return full_inputs class CastFloatInputsTrainer(Trainer): def _prepare_inputs(self, inputs): inputs = super()._prepare_inputs(inputs) model_dtype = getattr(self.model, "dtype", None) if model_dtype is not None: for k, v in list(inputs.items()): if torch.is_tensor(v) and v.is_floating_point(): inputs[k] = v.to(dtype=model_dtype) return inputs def copy_required_hf_files_for_qwen_asr(src_dir: str, dst_dir: str): os.makedirs(dst_dir, exist_ok=True) required = [ "config.json", "generation_config.json", "preprocessor_config.json", "processor_config.json", "tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "chat_template.json", "merges.txt", "vocab.json", ] for fn in required: src = os.path.join(src_dir, fn) if os.path.exists(src): shutil.copy2(src, os.path.join(dst_dir, fn)) class MakeEveryCheckpointInferableCallback(TrainerCallback): def __init__(self, base_model_path: str): self.base_model_path = base_model_path def on_save(self, args: TrainingArguments, state, control, **kwargs): if args.process_index != 0: return control ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") if not os.path.isdir(ckpt_dir): ckpt_dir = kwargs.get("checkpoint", ckpt_dir) copy_required_hf_files_for_qwen_asr(self.base_model_path, ckpt_dir) return control def parse_args(): p = argparse.ArgumentParser("Qwen3-ASR Finetuning") # Paths p.add_argument("--model_path", type=str, default="Qwen/Qwen3-ASR-1.7B") p.add_argument("--train_file", type=str, default="train.jsonl") p.add_argument("--eval_file", type=str, default="") p.add_argument("--output_dir", type=str, default="./qwen3-asr-finetuning-out") # Audio p.add_argument("--sr", type=int, default=16000) # Train hyper-params p.add_argument("--batch_size", type=int, default=32) p.add_argument("--grad_acc", type=int, default=4) p.add_argument("--lr", type=float, default=2e-5) p.add_argument("--epochs", type=float, default=1) p.add_argument("--log_steps", type=int, default=10) p.add_argument("--lr_scheduler_type", type=str, default="linear") p.add_argument("--warmup_ratio", type=float, default=0.02) # DataLoader p.add_argument("--num_workers", type=int, default=4) p.add_argument("--pin_memory", type=int, default=1) p.add_argument("--persistent_workers", type=int, default=1) p.add_argument("--prefetch_factor", type=int, default=2) # Save p.add_argument("--save_strategy", type=str, default="steps") p.add_argument("--save_steps", type=int, default=200) p.add_argument("--save_total_limit", type=int, default=5) # Resume p.add_argument("--resume_from", type=str, default="") p.add_argument("--resume", type=int, default=0) return p.parse_args() def main(): args_cli = parse_args() if not args_cli.train_file: raise ValueError("TRAIN_FILE is required (json/jsonl). Needs fields: audio, text, optional prompt") use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 asr_wrapper = Qwen3ASRModel.from_pretrained( args_cli.model_path, dtype=torch.bfloat16 if use_bf16 else torch.float16, device_map=None, ) model = asr_wrapper.model processor = asr_wrapper.processor patch_outer_forward(model) model.generation_config = GenerationConfig.from_model_config(model.config) raw_ds = load_dataset( "json", data_files={ "train": args_cli.train_file, **({"validation": args_cli.eval_file} if args_cli.eval_file else {}), }, ) ds = raw_ds.map(make_preprocess_fn_prefix_only(processor), num_proc=1) keep = {"prompt", "audio", "target", "prefix_text"} for split in ds.keys(): drop = [c for c in ds[split].column_names if c not in keep] if drop: ds[split] = ds[split].remove_columns(drop) collator = DataCollatorForQwen3ASRFinetuning(processor=processor, sampling_rate=args_cli.sr) training_args = TrainingArguments( output_dir=args_cli.output_dir, per_device_train_batch_size=args_cli.batch_size, gradient_accumulation_steps=args_cli.grad_acc, learning_rate=args_cli.lr, num_train_epochs=args_cli.epochs, logging_steps=args_cli.log_steps, lr_scheduler_type=args_cli.lr_scheduler_type, warmup_ratio=args_cli.warmup_ratio, dataloader_num_workers=args_cli.num_workers, dataloader_pin_memory=(args_cli.pin_memory == 1), dataloader_persistent_workers=(args_cli.persistent_workers == 1), dataloader_prefetch_factor=args_cli.prefetch_factor if args_cli.num_workers > 0 else None, save_strategy=args_cli.save_strategy, save_steps=args_cli.save_steps, save_total_limit=args_cli.save_total_limit, save_safetensors=True, eval_strategy="steps", eval_steps=args_cli.save_steps, do_eval=bool(args_cli.eval_file), bf16=use_bf16, fp16=not use_bf16, ddp_find_unused_parameters=False, remove_unused_columns=False, report_to="none", ) trainer = CastFloatInputsTrainer( model=model, args=training_args, train_dataset=ds["train"], eval_dataset=ds.get("validation", None), data_collator=collator, tokenizer=processor.tokenizer, callbacks=[MakeEveryCheckpointInferableCallback(base_model_path=args_cli.model_path)], ) resume_from = (args_cli.resume_from or "").strip() if not resume_from and args_cli.resume == 1: resume_from = find_latest_checkpoint(training_args.output_dir) or "" if resume_from: if trainer.args.process_index == 0: print(f"[resume] resume_from_checkpoint = {resume_from}") trainer.train(resume_from_checkpoint=resume_from) else: trainer.train() if __name__ == "__main__": main()