remove lora

This commit is contained in:
pengzhendong
2025-12-29 17:12:44 +08:00
parent 023d22efd8
commit 95de8cfe20
2 changed files with 3 additions and 24 deletions

View File

@ -70,5 +70,7 @@ ${train_tool} \
++optim_conf.lr=0.0002 \ ++optim_conf.lr=0.0002 \
++audio_encoder_conf.freeze=true \ ++audio_encoder_conf.freeze=true \
++audio_adaptor_conf.freeze=true \ ++audio_adaptor_conf.freeze=true \
++llm_conf.freeze=false \ ++llm_conf.freeze=true \
++llm_conf.use_lora=true \
++llm_conf.lora_conf.freeze_lora=false \
++output_dir="${output_dir}" &> ${log_file} ++output_dir="${output_dir}" &> ${log_file}

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import os import os
import random import random
@ -79,28 +78,6 @@ class FunASRNano(nn.Module):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
param.requires_grad = False param.requires_grad = False
model.eval() model.eval()
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}")
if llm_conf.get("use_lora", False):
from omegaconf import DictConfig, OmegaConf
lora_conf = llm_conf.get("lora_conf", {})
if isinstance(lora_conf, (OmegaConf, DictConfig)):
lora_conf = OmegaConf.to_container(lora_conf, resolve=True)
from peft import LoraConfig, PeftModel, get_peft_model
lora_init_param_path = lora_conf.get("init_param_path", None)
if lora_init_param_path is not None:
logging.info(f"lora_init_param_path: {lora_init_param_path}")
model = PeftModel.from_pretrained(model, lora_init_param_path)
for name, param in model.named_parameters():
if not lora_conf.get("freeze_lora", False):
if "lora_" in name:
param.requires_grad = True
else:
peft_config = LoraConfig(**lora_conf)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
if llm_conf.get("activation_checkpoint", False): if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()