Initial commit
This commit is contained in:
327
finetuning/qwen3_asr_sft.py
Normal file
327
finetuning/qwen3_asr_sft.py
Normal file
@ -0,0 +1,327 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
from transformers import (GenerationConfig, Trainer, TrainerCallback,
|
||||
TrainingArguments)
|
||||
|
||||
|
||||
def patch_outer_forward(model):
|
||||
cls = model.__class__
|
||||
if getattr(cls, "_forward_patched", False):
|
||||
return
|
||||
|
||||
if not hasattr(model, "thinker") or not hasattr(model.thinker, "forward"):
|
||||
raise RuntimeError(
|
||||
"Cannot patch forward: model has no `.thinker.forward`. "
|
||||
"Your qwen3_asr model may be incompatible."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
input_features=None,
|
||||
feature_attention_mask=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
return self.thinker.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_features=input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
labels=labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
cls.forward = forward
|
||||
cls._forward_patched = True
|
||||
|
||||
|
||||
_CKPT_RE = re.compile(r"^checkpoint-(\d+)$")
|
||||
|
||||
|
||||
def find_latest_checkpoint(output_dir: str) -> Optional[str]:
|
||||
if not output_dir or not os.path.isdir(output_dir):
|
||||
return None
|
||||
best_step = None
|
||||
best_path = None
|
||||
for name in os.listdir(output_dir):
|
||||
m = _CKPT_RE.match(name)
|
||||
if not m:
|
||||
continue
|
||||
step = int(m.group(1))
|
||||
path = os.path.join(output_dir, name)
|
||||
if os.path.isdir(path) and (best_step is None or step > best_step):
|
||||
best_step = step
|
||||
best_path = path
|
||||
return best_path
|
||||
|
||||
|
||||
def load_audio(path: str, sr: int = 16000):
|
||||
wav, _ = librosa.load(path, sr=sr, mono=True)
|
||||
return wav
|
||||
|
||||
|
||||
def build_prefix_messages(prompt: str, audio_array):
|
||||
return [
|
||||
{"role": "system", "content": prompt or ""},
|
||||
{"role": "user", "content": [{"type": "audio", "audio": audio_array}]},
|
||||
]
|
||||
|
||||
|
||||
def make_preprocess_fn_prefix_only(processor):
|
||||
def _preprocess(ex: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = ex.get("prompt", "")
|
||||
dummy_audio = None
|
||||
prefix_msgs = build_prefix_messages(prompt, dummy_audio)
|
||||
prefix_text = processor.apply_chat_template(
|
||||
[prefix_msgs], add_generation_prompt=True, tokenize=False
|
||||
)[0]
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"audio": ex["audio"],
|
||||
"target": ex["text"],
|
||||
"prefix_text": prefix_text,
|
||||
}
|
||||
|
||||
return _preprocess
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForQwen3ASRFinetuning:
|
||||
processor: Any
|
||||
sampling_rate: int = 16000
|
||||
|
||||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
audio_paths = [f["audio"] for f in features]
|
||||
prefix_texts = [f["prefix_text"] for f in features]
|
||||
targets = [f["target"] for f in features]
|
||||
|
||||
eos = self.processor.tokenizer.eos_token or ""
|
||||
full_texts = [pfx + tgt + eos for pfx, tgt in zip(prefix_texts, targets)]
|
||||
audios = [load_audio(p, sr=self.sampling_rate) for p in audio_paths]
|
||||
|
||||
full_inputs = self.processor(
|
||||
text=full_texts,
|
||||
audio=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=False,
|
||||
)
|
||||
prefix_inputs = self.processor(
|
||||
text=prefix_texts,
|
||||
audio=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=False,
|
||||
)
|
||||
|
||||
prefix_lens = prefix_inputs["attention_mask"].sum(dim=1).tolist()
|
||||
labels = full_inputs["input_ids"].clone()
|
||||
for i, pl in enumerate(prefix_lens):
|
||||
labels[i, :pl] = -100
|
||||
|
||||
pad_id = self.processor.tokenizer.pad_token_id
|
||||
if pad_id is not None:
|
||||
labels[labels == pad_id] = -100
|
||||
|
||||
full_inputs["labels"] = labels
|
||||
return full_inputs
|
||||
|
||||
|
||||
class CastFloatInputsTrainer(Trainer):
|
||||
def _prepare_inputs(self, inputs):
|
||||
inputs = super()._prepare_inputs(inputs)
|
||||
model_dtype = getattr(self.model, "dtype", None)
|
||||
if model_dtype is not None:
|
||||
for k, v in list(inputs.items()):
|
||||
if torch.is_tensor(v) and v.is_floating_point():
|
||||
inputs[k] = v.to(dtype=model_dtype)
|
||||
return inputs
|
||||
|
||||
|
||||
def copy_required_hf_files_for_qwen_asr(src_dir: str, dst_dir: str):
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
required = [
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
"tokenizer.json",
|
||||
"special_tokens_map.json",
|
||||
"chat_template.json",
|
||||
"merges.txt",
|
||||
"vocab.json",
|
||||
]
|
||||
for fn in required:
|
||||
src = os.path.join(src_dir, fn)
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, os.path.join(dst_dir, fn))
|
||||
|
||||
|
||||
class MakeEveryCheckpointInferableCallback(TrainerCallback):
|
||||
def __init__(self, base_model_path: str):
|
||||
self.base_model_path = base_model_path
|
||||
|
||||
def on_save(self, args: TrainingArguments, state, control, **kwargs):
|
||||
if args.process_index != 0:
|
||||
return control
|
||||
|
||||
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
ckpt_dir = kwargs.get("checkpoint", ckpt_dir)
|
||||
|
||||
copy_required_hf_files_for_qwen_asr(self.base_model_path, ckpt_dir)
|
||||
return control
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser("Qwen3-ASR Finetuning")
|
||||
|
||||
# Paths
|
||||
p.add_argument("--model_path", type=str, default="Qwen/Qwen3-ASR-1.7B")
|
||||
p.add_argument("--train_file", type=str, default="train.jsonl")
|
||||
p.add_argument("--eval_file", type=str, default="")
|
||||
p.add_argument("--output_dir", type=str, default="./qwen3-asr-finetuning-out")
|
||||
|
||||
# Audio
|
||||
p.add_argument("--sr", type=int, default=16000)
|
||||
|
||||
# Train hyper-params
|
||||
p.add_argument("--batch_size", type=int, default=32)
|
||||
p.add_argument("--grad_acc", type=int, default=4)
|
||||
p.add_argument("--lr", type=float, default=2e-5)
|
||||
p.add_argument("--epochs", type=float, default=1)
|
||||
p.add_argument("--log_steps", type=int, default=10)
|
||||
p.add_argument("--lr_scheduler_type", type=str, default="linear")
|
||||
p.add_argument("--warmup_ratio", type=float, default=0.02)
|
||||
|
||||
# DataLoader
|
||||
p.add_argument("--num_workers", type=int, default=4)
|
||||
p.add_argument("--pin_memory", type=int, default=1)
|
||||
p.add_argument("--persistent_workers", type=int, default=1)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
|
||||
# Save
|
||||
p.add_argument("--save_strategy", type=str, default="steps")
|
||||
p.add_argument("--save_steps", type=int, default=200)
|
||||
p.add_argument("--save_total_limit", type=int, default=5)
|
||||
|
||||
# Resume
|
||||
p.add_argument("--resume_from", type=str, default="")
|
||||
p.add_argument("--resume", type=int, default=0)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args_cli = parse_args()
|
||||
|
||||
if not args_cli.train_file:
|
||||
raise ValueError("TRAIN_FILE is required (json/jsonl). Needs fields: audio, text, optional prompt")
|
||||
|
||||
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
|
||||
asr_wrapper = Qwen3ASRModel.from_pretrained(
|
||||
args_cli.model_path,
|
||||
dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
device_map=None,
|
||||
)
|
||||
model = asr_wrapper.model
|
||||
processor = asr_wrapper.processor
|
||||
|
||||
patch_outer_forward(model)
|
||||
model.generation_config = GenerationConfig.from_model_config(model.config)
|
||||
|
||||
raw_ds = load_dataset(
|
||||
"json",
|
||||
data_files={
|
||||
"train": args_cli.train_file,
|
||||
**({"validation": args_cli.eval_file} if args_cli.eval_file else {}),
|
||||
},
|
||||
)
|
||||
ds = raw_ds.map(make_preprocess_fn_prefix_only(processor), num_proc=1)
|
||||
|
||||
keep = {"prompt", "audio", "target", "prefix_text"}
|
||||
for split in ds.keys():
|
||||
drop = [c for c in ds[split].column_names if c not in keep]
|
||||
if drop:
|
||||
ds[split] = ds[split].remove_columns(drop)
|
||||
|
||||
collator = DataCollatorForQwen3ASRFinetuning(processor=processor, sampling_rate=args_cli.sr)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args_cli.output_dir,
|
||||
per_device_train_batch_size=args_cli.batch_size,
|
||||
gradient_accumulation_steps=args_cli.grad_acc,
|
||||
learning_rate=args_cli.lr,
|
||||
num_train_epochs=args_cli.epochs,
|
||||
logging_steps=args_cli.log_steps,
|
||||
lr_scheduler_type=args_cli.lr_scheduler_type,
|
||||
warmup_ratio=args_cli.warmup_ratio,
|
||||
dataloader_num_workers=args_cli.num_workers,
|
||||
dataloader_pin_memory=(args_cli.pin_memory == 1),
|
||||
dataloader_persistent_workers=(args_cli.persistent_workers == 1),
|
||||
dataloader_prefetch_factor=args_cli.prefetch_factor if args_cli.num_workers > 0 else None,
|
||||
save_strategy=args_cli.save_strategy,
|
||||
save_steps=args_cli.save_steps,
|
||||
save_total_limit=args_cli.save_total_limit,
|
||||
save_safetensors=True,
|
||||
eval_strategy="steps",
|
||||
eval_steps=args_cli.save_steps,
|
||||
do_eval=bool(args_cli.eval_file),
|
||||
bf16=use_bf16,
|
||||
fp16=not use_bf16,
|
||||
ddp_find_unused_parameters=False,
|
||||
remove_unused_columns=False,
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
trainer = CastFloatInputsTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=ds["train"],
|
||||
eval_dataset=ds.get("validation", None),
|
||||
data_collator=collator,
|
||||
tokenizer=processor.tokenizer,
|
||||
callbacks=[MakeEveryCheckpointInferableCallback(base_model_path=args_cli.model_path)],
|
||||
)
|
||||
|
||||
resume_from = (args_cli.resume_from or "").strip()
|
||||
if not resume_from and args_cli.resume == 1:
|
||||
resume_from = find_latest_checkpoint(training_args.output_dir) or ""
|
||||
|
||||
if resume_from:
|
||||
if trainer.args.process_index == 0:
|
||||
print(f"[resume] resume_from_checkpoint = {resume_from}")
|
||||
trainer.train(resume_from_checkpoint=resume_from)
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user