remove lora
This commit is contained in:
@ -70,5 +70,7 @@ ${train_tool} \
|
|||||||
++optim_conf.lr=0.0002 \
|
++optim_conf.lr=0.0002 \
|
||||||
++audio_encoder_conf.freeze=true \
|
++audio_encoder_conf.freeze=true \
|
||||||
++audio_adaptor_conf.freeze=true \
|
++audio_adaptor_conf.freeze=true \
|
||||||
++llm_conf.freeze=false \
|
++llm_conf.freeze=true \
|
||||||
|
++llm_conf.use_lora=true \
|
||||||
|
++llm_conf.lora_conf.freeze_lora=false \
|
||||||
++output_dir="${output_dir}" &> ${log_file}
|
++output_dir="${output_dir}" &> ${log_file}
|
||||||
|
|||||||
23
model.py
23
model.py
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -79,28 +78,6 @@ class FunASRNano(nn.Module):
|
|||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
model.eval()
|
model.eval()
|
||||||
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}")
|
|
||||||
if llm_conf.get("use_lora", False):
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
|
|
||||||
lora_conf = llm_conf.get("lora_conf", {})
|
|
||||||
if isinstance(lora_conf, (OmegaConf, DictConfig)):
|
|
||||||
lora_conf = OmegaConf.to_container(lora_conf, resolve=True)
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
|
||||||
|
|
||||||
lora_init_param_path = lora_conf.get("init_param_path", None)
|
|
||||||
if lora_init_param_path is not None:
|
|
||||||
logging.info(f"lora_init_param_path: {lora_init_param_path}")
|
|
||||||
model = PeftModel.from_pretrained(model, lora_init_param_path)
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not lora_conf.get("freeze_lora", False):
|
|
||||||
if "lora_" in name:
|
|
||||||
param.requires_grad = True
|
|
||||||
else:
|
|
||||||
peft_config = LoraConfig(**lora_conf)
|
|
||||||
model = get_peft_model(model, peft_config)
|
|
||||||
model.print_trainable_parameters()
|
|
||||||
|
|
||||||
if llm_conf.get("activation_checkpoint", False):
|
if llm_conf.get("activation_checkpoint", False):
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user