combine models

This commit is contained in:
pengzhendong
2025-12-15 23:40:22 +08:00
parent 84f8d2a4be
commit 7f77591531

View File

@ -15,6 +15,7 @@ from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device from funasr.train_utils.device_funcs import force_gatherable, to_device
from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@ -69,16 +70,9 @@ class FunASRNano(nn.Module):
init_param_path = llm_conf.get("init_param_path", None) init_param_path = llm_conf.get("init_param_path", None)
llm_dim = None llm_dim = None
from transformers import AutoModelForCausalLM
llm_load_kwargs = llm_conf.get("load_kwargs", {}) llm_load_kwargs = llm_conf.get("load_kwargs", {})
model = AutoModelForCausalLM.from_pretrained( config = AutoConfig.from_pretrained(init_param_path)
init_param_path, model = AutoModelForCausalLM.from_config(config, **llm_load_kwargs)
load_in_8bit=None,
device_map=None,
use_cache=None,
**llm_load_kwargs,
)
freeze = llm_conf.get("freeze", True) freeze = llm_conf.get("freeze", True)
if freeze: if freeze: