remove lora
This commit is contained in:
23
model.py
23
model.py
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@ -79,28 +78,6 @@ class FunASRNano(nn.Module):
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
model.eval()
|
||||
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}")
|
||||
if llm_conf.get("use_lora", False):
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
lora_conf = llm_conf.get("lora_conf", {})
|
||||
if isinstance(lora_conf, (OmegaConf, DictConfig)):
|
||||
lora_conf = OmegaConf.to_container(lora_conf, resolve=True)
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_init_param_path = lora_conf.get("init_param_path", None)
|
||||
if lora_init_param_path is not None:
|
||||
logging.info(f"lora_init_param_path: {lora_init_param_path}")
|
||||
model = PeftModel.from_pretrained(model, lora_init_param_path)
|
||||
for name, param in model.named_parameters():
|
||||
if not lora_conf.get("freeze_lora", False):
|
||||
if "lora_" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
peft_config = LoraConfig(**lora_conf)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user