combine models
This commit is contained in:
12
model.py
12
model.py
@ -15,6 +15,7 @@ from funasr.register import tables
|
||||
from funasr.train_utils.device_funcs import force_gatherable, to_device
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
|
||||
@ -69,16 +70,9 @@ class FunASRNano(nn.Module):
|
||||
init_param_path = llm_conf.get("init_param_path", None)
|
||||
llm_dim = None
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
llm_load_kwargs = llm_conf.get("load_kwargs", {})
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
init_param_path,
|
||||
load_in_8bit=None,
|
||||
device_map=None,
|
||||
use_cache=None,
|
||||
**llm_load_kwargs,
|
||||
)
|
||||
config = AutoConfig.from_pretrained(init_param_path)
|
||||
model = AutoModelForCausalLM.from_config(config, **llm_load_kwargs)
|
||||
|
||||
freeze = llm_conf.get("freeze", True)
|
||||
if freeze:
|
||||
|
||||
Reference in New Issue
Block a user