Initial commit
This commit is contained in:
155
finetuning/README.md
Normal file
155
finetuning/README.md
Normal file
@ -0,0 +1,155 @@
|
||||
## Fine-tuning Qwen3-ASR
|
||||
|
||||
This script fine-tunes **Qwen3-ASR** using JSONL audio-text pairs. It supports multi-GPU training via `torchrun`.
|
||||
|
||||
### 1) Setup
|
||||
|
||||
First, please install the two Python packages `qwen-asr` and `datasets` using the command below.
|
||||
|
||||
```bash
|
||||
pip install -U qwen-asr datasets
|
||||
```
|
||||
|
||||
Then, to reduce GPU memory usage and speed up training, it is recommended to install FlashAttention 2.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
If your machine has less than 96GB of RAM and lots of CPU cores, run:
|
||||
|
||||
```bash
|
||||
MAX_JOBS=4 pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [FlashAttention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention 2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
### 2) Input JSONL format
|
||||
|
||||
Prepare your training file as JSONL (one JSON per line). Each line must contain:
|
||||
|
||||
- `audio`: path to a WAV file
|
||||
- `text`: transcript text (you can include a language prefix)
|
||||
|
||||
Example:
|
||||
```jsonl
|
||||
{"audio":"/data/wavs/utt0001.wav","text":"language English<asr_text>This is a test sentence."}
|
||||
{"audio":"/data/wavs/utt0002.wav","text":"language English<asr_text>Another example."}
|
||||
{"audio":"/data/wavs/utt0003.wav","text":"language English<asr_text>Fine-tuning data line."}
|
||||
```
|
||||
|
||||
Language prefix recommendation:
|
||||
|
||||
- If you **have** language info, use:
|
||||
- `language English<asr_text>...`
|
||||
- `language Chinese<asr_text>...`
|
||||
- If you **do not have** language info, use:
|
||||
- `language None<asr_text>...`
|
||||
|
||||
Note:
|
||||
- If you set `language None`, the model will not learn language detection from that prefix.
|
||||
|
||||
### 3) Fine-tune (single GPU)
|
||||
|
||||
```bash
|
||||
python qwen3_asr_sft.py \
|
||||
--model_path Qwen/Qwen3-ASR-1.7B \
|
||||
--train_file ./train.jsonl \
|
||||
--output_dir ./qwen3-asr-finetuning-out \
|
||||
--batch_size 32 \
|
||||
--grad_acc 4 \
|
||||
--lr 2e-5 \
|
||||
--epochs 1 \
|
||||
--save_steps 200 \
|
||||
--save_total_limit 5
|
||||
```
|
||||
|
||||
Checkpoints will be written to:
|
||||
- `./qwen3-asr-finetuning-out/checkpoint-<global_step>`
|
||||
|
||||
### 4) Fine-tune (multi GPU with torchrun)
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
torchrun --nproc_per_node=2 qwen3_asr_sft.py \
|
||||
--model_path Qwen/Qwen3-ASR-1.7B \
|
||||
--train_file ./train.jsonl \
|
||||
--output_dir ./qwen3-asr-finetuning-out \
|
||||
--batch_size 32 \
|
||||
--grad_acc 4 \
|
||||
--lr 2e-5 \
|
||||
--epochs 1 \
|
||||
--save_steps 200
|
||||
```
|
||||
|
||||
### 5) Resume training
|
||||
|
||||
Option A: explicitly set a checkpoint path:
|
||||
|
||||
```bash
|
||||
python qwen3_asr_sft.py \
|
||||
--train_file ./train.jsonl \
|
||||
--output_dir ./qwen3-asr-finetuning-out \
|
||||
--resume_from ./qwen3-asr-finetuning-out/checkpoint-200
|
||||
```
|
||||
|
||||
Option B: automatically resume from the latest checkpoint under `output_dir`:
|
||||
|
||||
```bash
|
||||
python qwen3_asr_sft.py \
|
||||
--train_file ./train.jsonl \
|
||||
--output_dir ./qwen3-asr-finetuning-out \
|
||||
--resume 1
|
||||
```
|
||||
|
||||
### 6) Quick inference test
|
||||
|
||||
```python
|
||||
import torch
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
model = Qwen3ASRModel.from_pretrained(
|
||||
"qwen3-asr-finetuning-out/checkpoint-200",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
|
||||
results = model.transcribe(
|
||||
audio="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav",
|
||||
)
|
||||
|
||||
print(results[0].language)
|
||||
print(results[0].text)
|
||||
```
|
||||
|
||||
### One-click shell script example
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
MODEL_PATH="Qwen/Qwen3-ASR-1.7B"
|
||||
TRAIN_FILE="./train.jsonl"
|
||||
EVAL_FILE="./eval.jsonl"
|
||||
OUTPUT_DIR="./qwen3-asr-finetuning-out"
|
||||
|
||||
torchrun --nproc_per_node=2 qwen3_asr_sft.py \
|
||||
--model_path ${MODEL_PATH} \
|
||||
--train_file ${TRAIN_FILE} \
|
||||
--eval_file ${EVAL_FILE} \
|
||||
--output_dir ${OUTPUT_DIR} \
|
||||
--batch_size 32 \
|
||||
--grad_acc 4 \
|
||||
--lr 2e-5 \
|
||||
--epochs 1 \
|
||||
--log_steps 10 \
|
||||
--save_strategy steps \
|
||||
--save_steps 200 \
|
||||
--save_total_limit 5 \
|
||||
--num_workers 2 \
|
||||
--pin_memory 1 \
|
||||
--persistent_workers 1 \
|
||||
--prefetch_factor 2
|
||||
```
|
||||
327
finetuning/qwen3_asr_sft.py
Normal file
327
finetuning/qwen3_asr_sft.py
Normal file
@ -0,0 +1,327 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
from transformers import (GenerationConfig, Trainer, TrainerCallback,
|
||||
TrainingArguments)
|
||||
|
||||
|
||||
def patch_outer_forward(model):
|
||||
cls = model.__class__
|
||||
if getattr(cls, "_forward_patched", False):
|
||||
return
|
||||
|
||||
if not hasattr(model, "thinker") or not hasattr(model.thinker, "forward"):
|
||||
raise RuntimeError(
|
||||
"Cannot patch forward: model has no `.thinker.forward`. "
|
||||
"Your qwen3_asr model may be incompatible."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
input_features=None,
|
||||
feature_attention_mask=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
return self.thinker.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_features=input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
labels=labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
cls.forward = forward
|
||||
cls._forward_patched = True
|
||||
|
||||
|
||||
_CKPT_RE = re.compile(r"^checkpoint-(\d+)$")
|
||||
|
||||
|
||||
def find_latest_checkpoint(output_dir: str) -> Optional[str]:
|
||||
if not output_dir or not os.path.isdir(output_dir):
|
||||
return None
|
||||
best_step = None
|
||||
best_path = None
|
||||
for name in os.listdir(output_dir):
|
||||
m = _CKPT_RE.match(name)
|
||||
if not m:
|
||||
continue
|
||||
step = int(m.group(1))
|
||||
path = os.path.join(output_dir, name)
|
||||
if os.path.isdir(path) and (best_step is None or step > best_step):
|
||||
best_step = step
|
||||
best_path = path
|
||||
return best_path
|
||||
|
||||
|
||||
def load_audio(path: str, sr: int = 16000):
|
||||
wav, _ = librosa.load(path, sr=sr, mono=True)
|
||||
return wav
|
||||
|
||||
|
||||
def build_prefix_messages(prompt: str, audio_array):
|
||||
return [
|
||||
{"role": "system", "content": prompt or ""},
|
||||
{"role": "user", "content": [{"type": "audio", "audio": audio_array}]},
|
||||
]
|
||||
|
||||
|
||||
def make_preprocess_fn_prefix_only(processor):
|
||||
def _preprocess(ex: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = ex.get("prompt", "")
|
||||
dummy_audio = None
|
||||
prefix_msgs = build_prefix_messages(prompt, dummy_audio)
|
||||
prefix_text = processor.apply_chat_template(
|
||||
[prefix_msgs], add_generation_prompt=True, tokenize=False
|
||||
)[0]
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"audio": ex["audio"],
|
||||
"target": ex["text"],
|
||||
"prefix_text": prefix_text,
|
||||
}
|
||||
|
||||
return _preprocess
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForQwen3ASRFinetuning:
|
||||
processor: Any
|
||||
sampling_rate: int = 16000
|
||||
|
||||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
audio_paths = [f["audio"] for f in features]
|
||||
prefix_texts = [f["prefix_text"] for f in features]
|
||||
targets = [f["target"] for f in features]
|
||||
|
||||
eos = self.processor.tokenizer.eos_token or ""
|
||||
full_texts = [pfx + tgt + eos for pfx, tgt in zip(prefix_texts, targets)]
|
||||
audios = [load_audio(p, sr=self.sampling_rate) for p in audio_paths]
|
||||
|
||||
full_inputs = self.processor(
|
||||
text=full_texts,
|
||||
audio=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=False,
|
||||
)
|
||||
prefix_inputs = self.processor(
|
||||
text=prefix_texts,
|
||||
audio=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=False,
|
||||
)
|
||||
|
||||
prefix_lens = prefix_inputs["attention_mask"].sum(dim=1).tolist()
|
||||
labels = full_inputs["input_ids"].clone()
|
||||
for i, pl in enumerate(prefix_lens):
|
||||
labels[i, :pl] = -100
|
||||
|
||||
pad_id = self.processor.tokenizer.pad_token_id
|
||||
if pad_id is not None:
|
||||
labels[labels == pad_id] = -100
|
||||
|
||||
full_inputs["labels"] = labels
|
||||
return full_inputs
|
||||
|
||||
|
||||
class CastFloatInputsTrainer(Trainer):
|
||||
def _prepare_inputs(self, inputs):
|
||||
inputs = super()._prepare_inputs(inputs)
|
||||
model_dtype = getattr(self.model, "dtype", None)
|
||||
if model_dtype is not None:
|
||||
for k, v in list(inputs.items()):
|
||||
if torch.is_tensor(v) and v.is_floating_point():
|
||||
inputs[k] = v.to(dtype=model_dtype)
|
||||
return inputs
|
||||
|
||||
|
||||
def copy_required_hf_files_for_qwen_asr(src_dir: str, dst_dir: str):
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
required = [
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
"tokenizer.json",
|
||||
"special_tokens_map.json",
|
||||
"chat_template.json",
|
||||
"merges.txt",
|
||||
"vocab.json",
|
||||
]
|
||||
for fn in required:
|
||||
src = os.path.join(src_dir, fn)
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, os.path.join(dst_dir, fn))
|
||||
|
||||
|
||||
class MakeEveryCheckpointInferableCallback(TrainerCallback):
|
||||
def __init__(self, base_model_path: str):
|
||||
self.base_model_path = base_model_path
|
||||
|
||||
def on_save(self, args: TrainingArguments, state, control, **kwargs):
|
||||
if args.process_index != 0:
|
||||
return control
|
||||
|
||||
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
ckpt_dir = kwargs.get("checkpoint", ckpt_dir)
|
||||
|
||||
copy_required_hf_files_for_qwen_asr(self.base_model_path, ckpt_dir)
|
||||
return control
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser("Qwen3-ASR Finetuning")
|
||||
|
||||
# Paths
|
||||
p.add_argument("--model_path", type=str, default="Qwen/Qwen3-ASR-1.7B")
|
||||
p.add_argument("--train_file", type=str, default="train.jsonl")
|
||||
p.add_argument("--eval_file", type=str, default="")
|
||||
p.add_argument("--output_dir", type=str, default="./qwen3-asr-finetuning-out")
|
||||
|
||||
# Audio
|
||||
p.add_argument("--sr", type=int, default=16000)
|
||||
|
||||
# Train hyper-params
|
||||
p.add_argument("--batch_size", type=int, default=32)
|
||||
p.add_argument("--grad_acc", type=int, default=4)
|
||||
p.add_argument("--lr", type=float, default=2e-5)
|
||||
p.add_argument("--epochs", type=float, default=1)
|
||||
p.add_argument("--log_steps", type=int, default=10)
|
||||
p.add_argument("--lr_scheduler_type", type=str, default="linear")
|
||||
p.add_argument("--warmup_ratio", type=float, default=0.02)
|
||||
|
||||
# DataLoader
|
||||
p.add_argument("--num_workers", type=int, default=4)
|
||||
p.add_argument("--pin_memory", type=int, default=1)
|
||||
p.add_argument("--persistent_workers", type=int, default=1)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
|
||||
# Save
|
||||
p.add_argument("--save_strategy", type=str, default="steps")
|
||||
p.add_argument("--save_steps", type=int, default=200)
|
||||
p.add_argument("--save_total_limit", type=int, default=5)
|
||||
|
||||
# Resume
|
||||
p.add_argument("--resume_from", type=str, default="")
|
||||
p.add_argument("--resume", type=int, default=0)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args_cli = parse_args()
|
||||
|
||||
if not args_cli.train_file:
|
||||
raise ValueError("TRAIN_FILE is required (json/jsonl). Needs fields: audio, text, optional prompt")
|
||||
|
||||
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
|
||||
asr_wrapper = Qwen3ASRModel.from_pretrained(
|
||||
args_cli.model_path,
|
||||
dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
device_map=None,
|
||||
)
|
||||
model = asr_wrapper.model
|
||||
processor = asr_wrapper.processor
|
||||
|
||||
patch_outer_forward(model)
|
||||
model.generation_config = GenerationConfig.from_model_config(model.config)
|
||||
|
||||
raw_ds = load_dataset(
|
||||
"json",
|
||||
data_files={
|
||||
"train": args_cli.train_file,
|
||||
**({"validation": args_cli.eval_file} if args_cli.eval_file else {}),
|
||||
},
|
||||
)
|
||||
ds = raw_ds.map(make_preprocess_fn_prefix_only(processor), num_proc=1)
|
||||
|
||||
keep = {"prompt", "audio", "target", "prefix_text"}
|
||||
for split in ds.keys():
|
||||
drop = [c for c in ds[split].column_names if c not in keep]
|
||||
if drop:
|
||||
ds[split] = ds[split].remove_columns(drop)
|
||||
|
||||
collator = DataCollatorForQwen3ASRFinetuning(processor=processor, sampling_rate=args_cli.sr)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args_cli.output_dir,
|
||||
per_device_train_batch_size=args_cli.batch_size,
|
||||
gradient_accumulation_steps=args_cli.grad_acc,
|
||||
learning_rate=args_cli.lr,
|
||||
num_train_epochs=args_cli.epochs,
|
||||
logging_steps=args_cli.log_steps,
|
||||
lr_scheduler_type=args_cli.lr_scheduler_type,
|
||||
warmup_ratio=args_cli.warmup_ratio,
|
||||
dataloader_num_workers=args_cli.num_workers,
|
||||
dataloader_pin_memory=(args_cli.pin_memory == 1),
|
||||
dataloader_persistent_workers=(args_cli.persistent_workers == 1),
|
||||
dataloader_prefetch_factor=args_cli.prefetch_factor if args_cli.num_workers > 0 else None,
|
||||
save_strategy=args_cli.save_strategy,
|
||||
save_steps=args_cli.save_steps,
|
||||
save_total_limit=args_cli.save_total_limit,
|
||||
save_safetensors=True,
|
||||
eval_strategy="steps",
|
||||
eval_steps=args_cli.save_steps,
|
||||
do_eval=bool(args_cli.eval_file),
|
||||
bf16=use_bf16,
|
||||
fp16=not use_bf16,
|
||||
ddp_find_unused_parameters=False,
|
||||
remove_unused_columns=False,
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
trainer = CastFloatInputsTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=ds["train"],
|
||||
eval_dataset=ds.get("validation", None),
|
||||
data_collator=collator,
|
||||
tokenizer=processor.tokenizer,
|
||||
callbacks=[MakeEveryCheckpointInferableCallback(base_model_path=args_cli.model_path)],
|
||||
)
|
||||
|
||||
resume_from = (args_cli.resume_from or "").strip()
|
||||
if not resume_from and args_cli.resume == 1:
|
||||
resume_from = find_latest_checkpoint(training_args.output_dir) or ""
|
||||
|
||||
if resume_from:
|
||||
if trainer.args.process_index == 0:
|
||||
print(f"[resume] resume_from_checkpoint = {resume_from}")
|
||||
trainer.train(resume_from_checkpoint=resume_from)
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user