Initial commit

This commit is contained in:
Xiong Wang
2026-01-29 20:23:50 +08:00
commit 9567667698
32 changed files with 30029 additions and 0 deletions

155
finetuning/README.md Normal file
View File

@ -0,0 +1,155 @@
## Fine-tuning Qwen3-ASR
This script fine-tunes **Qwen3-ASR** using JSONL audio-text pairs. It supports multi-GPU training via `torchrun`.
### 1) Setup
First, please install the two Python packages `qwen-asr` and `datasets` using the command below.
```bash
pip install -U qwen-asr datasets
```
Then, to reduce GPU memory usage and speed up training, it is recommended to install FlashAttention 2.
```bash
pip install -U flash-attn --no-build-isolation
```
If your machine has less than 96GB of RAM and lots of CPU cores, run:
```bash
MAX_JOBS=4 pip install -U flash-attn --no-build-isolation
```
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [FlashAttention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention 2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
### 2) Input JSONL format
Prepare your training file as JSONL (one JSON per line). Each line must contain:
- `audio`: path to a WAV file
- `text`: transcript text (you can include a language prefix)
Example:
```jsonl
{"audio":"/data/wavs/utt0001.wav","text":"language English<asr_text>This is a test sentence."}
{"audio":"/data/wavs/utt0002.wav","text":"language English<asr_text>Another example."}
{"audio":"/data/wavs/utt0003.wav","text":"language English<asr_text>Fine-tuning data line."}
```
Language prefix recommendation:
- If you **have** language info, use:
- `language English<asr_text>...`
- `language Chinese<asr_text>...`
- If you **do not have** language info, use:
- `language None<asr_text>...`
Note:
- If you set `language None`, the model will not learn language detection from that prefix.
### 3) Fine-tune (single GPU)
```bash
python qwen3_asr_sft.py \
--model_path Qwen/Qwen3-ASR-1.7B \
--train_file ./train.jsonl \
--output_dir ./qwen3-asr-finetuning-out \
--batch_size 32 \
--grad_acc 4 \
--lr 2e-5 \
--epochs 1 \
--save_steps 200 \
--save_total_limit 5
```
Checkpoints will be written to:
- `./qwen3-asr-finetuning-out/checkpoint-<global_step>`
### 4) Fine-tune (multi GPU with torchrun)
```bash
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 qwen3_asr_sft.py \
--model_path Qwen/Qwen3-ASR-1.7B \
--train_file ./train.jsonl \
--output_dir ./qwen3-asr-finetuning-out \
--batch_size 32 \
--grad_acc 4 \
--lr 2e-5 \
--epochs 1 \
--save_steps 200
```
### 5) Resume training
Option A: explicitly set a checkpoint path:
```bash
python qwen3_asr_sft.py \
--train_file ./train.jsonl \
--output_dir ./qwen3-asr-finetuning-out \
--resume_from ./qwen3-asr-finetuning-out/checkpoint-200
```
Option B: automatically resume from the latest checkpoint under `output_dir`:
```bash
python qwen3_asr_sft.py \
--train_file ./train.jsonl \
--output_dir ./qwen3-asr-finetuning-out \
--resume 1
```
### 6) Quick inference test
```python
import torch
from qwen_asr import Qwen3ASRModel
model = Qwen3ASRModel.from_pretrained(
"qwen3-asr-finetuning-out/checkpoint-200",
dtype=torch.bfloat16,
device_map="cuda:0",
)
results = model.transcribe(
audio="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav",
)
print(results[0].language)
print(results[0].text)
```
### One-click shell script example
```bash
#!/usr/bin/env bash
set -e
export CUDA_VISIBLE_DEVICES=0,1
MODEL_PATH="Qwen/Qwen3-ASR-1.7B"
TRAIN_FILE="./train.jsonl"
EVAL_FILE="./eval.jsonl"
OUTPUT_DIR="./qwen3-asr-finetuning-out"
torchrun --nproc_per_node=2 qwen3_asr_sft.py \
--model_path ${MODEL_PATH} \
--train_file ${TRAIN_FILE} \
--eval_file ${EVAL_FILE} \
--output_dir ${OUTPUT_DIR} \
--batch_size 32 \
--grad_acc 4 \
--lr 2e-5 \
--epochs 1 \
--log_steps 10 \
--save_strategy steps \
--save_steps 200 \
--save_total_limit 5 \
--num_workers 2 \
--pin_memory 1 \
--persistent_workers 1 \
--prefetch_factor 2
```

327
finetuning/qwen3_asr_sft.py Normal file
View File

@ -0,0 +1,327 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
import shutil
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import librosa
import torch
from datasets import load_dataset
from qwen_asr import Qwen3ASRModel
from transformers import (GenerationConfig, Trainer, TrainerCallback,
TrainingArguments)
def patch_outer_forward(model):
cls = model.__class__
if getattr(cls, "_forward_patched", False):
return
if not hasattr(model, "thinker") or not hasattr(model.thinker, "forward"):
raise RuntimeError(
"Cannot patch forward: model has no `.thinker.forward`. "
"Your qwen3_asr model may be incompatible."
)
def forward(
self,
input_ids=None,
attention_mask=None,
input_features=None,
feature_attention_mask=None,
labels=None,
**kwargs,
):
return self.thinker.forward(
input_ids=input_ids,
attention_mask=attention_mask,
input_features=input_features,
feature_attention_mask=feature_attention_mask,
labels=labels,
**kwargs,
)
cls.forward = forward
cls._forward_patched = True
_CKPT_RE = re.compile(r"^checkpoint-(\d+)$")
def find_latest_checkpoint(output_dir: str) -> Optional[str]:
if not output_dir or not os.path.isdir(output_dir):
return None
best_step = None
best_path = None
for name in os.listdir(output_dir):
m = _CKPT_RE.match(name)
if not m:
continue
step = int(m.group(1))
path = os.path.join(output_dir, name)
if os.path.isdir(path) and (best_step is None or step > best_step):
best_step = step
best_path = path
return best_path
def load_audio(path: str, sr: int = 16000):
wav, _ = librosa.load(path, sr=sr, mono=True)
return wav
def build_prefix_messages(prompt: str, audio_array):
return [
{"role": "system", "content": prompt or ""},
{"role": "user", "content": [{"type": "audio", "audio": audio_array}]},
]
def make_preprocess_fn_prefix_only(processor):
def _preprocess(ex: Dict[str, Any]) -> Dict[str, Any]:
prompt = ex.get("prompt", "")
dummy_audio = None
prefix_msgs = build_prefix_messages(prompt, dummy_audio)
prefix_text = processor.apply_chat_template(
[prefix_msgs], add_generation_prompt=True, tokenize=False
)[0]
return {
"prompt": prompt,
"audio": ex["audio"],
"target": ex["text"],
"prefix_text": prefix_text,
}
return _preprocess
@dataclass
class DataCollatorForQwen3ASRFinetuning:
processor: Any
sampling_rate: int = 16000
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
audio_paths = [f["audio"] for f in features]
prefix_texts = [f["prefix_text"] for f in features]
targets = [f["target"] for f in features]
eos = self.processor.tokenizer.eos_token or ""
full_texts = [pfx + tgt + eos for pfx, tgt in zip(prefix_texts, targets)]
audios = [load_audio(p, sr=self.sampling_rate) for p in audio_paths]
full_inputs = self.processor(
text=full_texts,
audio=audios,
return_tensors="pt",
padding=True,
truncation=False,
)
prefix_inputs = self.processor(
text=prefix_texts,
audio=audios,
return_tensors="pt",
padding=True,
truncation=False,
)
prefix_lens = prefix_inputs["attention_mask"].sum(dim=1).tolist()
labels = full_inputs["input_ids"].clone()
for i, pl in enumerate(prefix_lens):
labels[i, :pl] = -100
pad_id = self.processor.tokenizer.pad_token_id
if pad_id is not None:
labels[labels == pad_id] = -100
full_inputs["labels"] = labels
return full_inputs
class CastFloatInputsTrainer(Trainer):
def _prepare_inputs(self, inputs):
inputs = super()._prepare_inputs(inputs)
model_dtype = getattr(self.model, "dtype", None)
if model_dtype is not None:
for k, v in list(inputs.items()):
if torch.is_tensor(v) and v.is_floating_point():
inputs[k] = v.to(dtype=model_dtype)
return inputs
def copy_required_hf_files_for_qwen_asr(src_dir: str, dst_dir: str):
os.makedirs(dst_dir, exist_ok=True)
required = [
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
"chat_template.json",
"merges.txt",
"vocab.json",
]
for fn in required:
src = os.path.join(src_dir, fn)
if os.path.exists(src):
shutil.copy2(src, os.path.join(dst_dir, fn))
class MakeEveryCheckpointInferableCallback(TrainerCallback):
def __init__(self, base_model_path: str):
self.base_model_path = base_model_path
def on_save(self, args: TrainingArguments, state, control, **kwargs):
if args.process_index != 0:
return control
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
if not os.path.isdir(ckpt_dir):
ckpt_dir = kwargs.get("checkpoint", ckpt_dir)
copy_required_hf_files_for_qwen_asr(self.base_model_path, ckpt_dir)
return control
def parse_args():
p = argparse.ArgumentParser("Qwen3-ASR Finetuning")
# Paths
p.add_argument("--model_path", type=str, default="Qwen/Qwen3-ASR-1.7B")
p.add_argument("--train_file", type=str, default="train.jsonl")
p.add_argument("--eval_file", type=str, default="")
p.add_argument("--output_dir", type=str, default="./qwen3-asr-finetuning-out")
# Audio
p.add_argument("--sr", type=int, default=16000)
# Train hyper-params
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--grad_acc", type=int, default=4)
p.add_argument("--lr", type=float, default=2e-5)
p.add_argument("--epochs", type=float, default=1)
p.add_argument("--log_steps", type=int, default=10)
p.add_argument("--lr_scheduler_type", type=str, default="linear")
p.add_argument("--warmup_ratio", type=float, default=0.02)
# DataLoader
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--pin_memory", type=int, default=1)
p.add_argument("--persistent_workers", type=int, default=1)
p.add_argument("--prefetch_factor", type=int, default=2)
# Save
p.add_argument("--save_strategy", type=str, default="steps")
p.add_argument("--save_steps", type=int, default=200)
p.add_argument("--save_total_limit", type=int, default=5)
# Resume
p.add_argument("--resume_from", type=str, default="")
p.add_argument("--resume", type=int, default=0)
return p.parse_args()
def main():
args_cli = parse_args()
if not args_cli.train_file:
raise ValueError("TRAIN_FILE is required (json/jsonl). Needs fields: audio, text, optional prompt")
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
asr_wrapper = Qwen3ASRModel.from_pretrained(
args_cli.model_path,
dtype=torch.bfloat16 if use_bf16 else torch.float16,
device_map=None,
)
model = asr_wrapper.model
processor = asr_wrapper.processor
patch_outer_forward(model)
model.generation_config = GenerationConfig.from_model_config(model.config)
raw_ds = load_dataset(
"json",
data_files={
"train": args_cli.train_file,
**({"validation": args_cli.eval_file} if args_cli.eval_file else {}),
},
)
ds = raw_ds.map(make_preprocess_fn_prefix_only(processor), num_proc=1)
keep = {"prompt", "audio", "target", "prefix_text"}
for split in ds.keys():
drop = [c for c in ds[split].column_names if c not in keep]
if drop:
ds[split] = ds[split].remove_columns(drop)
collator = DataCollatorForQwen3ASRFinetuning(processor=processor, sampling_rate=args_cli.sr)
training_args = TrainingArguments(
output_dir=args_cli.output_dir,
per_device_train_batch_size=args_cli.batch_size,
gradient_accumulation_steps=args_cli.grad_acc,
learning_rate=args_cli.lr,
num_train_epochs=args_cli.epochs,
logging_steps=args_cli.log_steps,
lr_scheduler_type=args_cli.lr_scheduler_type,
warmup_ratio=args_cli.warmup_ratio,
dataloader_num_workers=args_cli.num_workers,
dataloader_pin_memory=(args_cli.pin_memory == 1),
dataloader_persistent_workers=(args_cli.persistent_workers == 1),
dataloader_prefetch_factor=args_cli.prefetch_factor if args_cli.num_workers > 0 else None,
save_strategy=args_cli.save_strategy,
save_steps=args_cli.save_steps,
save_total_limit=args_cli.save_total_limit,
save_safetensors=True,
eval_strategy="steps",
eval_steps=args_cli.save_steps,
do_eval=bool(args_cli.eval_file),
bf16=use_bf16,
fp16=not use_bf16,
ddp_find_unused_parameters=False,
remove_unused_columns=False,
report_to="none",
)
trainer = CastFloatInputsTrainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=ds.get("validation", None),
data_collator=collator,
tokenizer=processor.tokenizer,
callbacks=[MakeEveryCheckpointInferableCallback(base_model_path=args_cli.model_path)],
)
resume_from = (args_cli.resume_from or "").strip()
if not resume_from and args_cli.resume == 1:
resume_from = find_latest_checkpoint(training_args.output_dir) or ""
if resume_from:
if trainer.args.process_index == 0:
print(f"[resume] resume_from_checkpoint = {resume_from}")
trainer.train(resume_from_checkpoint=resume_from)
else:
trainer.train()
if __name__ == "__main__":
main()