init ctc decoder
This commit is contained in:
38
model.py
38
model.py
@ -16,6 +16,8 @@ from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from ctc import CTC
|
||||
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
|
||||
|
||||
@ -57,13 +59,13 @@ class FunASRNano(nn.Module):
|
||||
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
|
||||
audio_encoder_output_size = audio_encoder.output_size()
|
||||
freeze = audio_encoder_conf.get("freeze", True)
|
||||
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
|
||||
|
||||
if freeze:
|
||||
for name, param in audio_encoder.named_parameters():
|
||||
for _, param in audio_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
audio_encoder.eval()
|
||||
self.audio_encoder = audio_encoder
|
||||
|
||||
# llm
|
||||
self.llm = None
|
||||
init_param_path = llm_conf.get("init_param_path", None)
|
||||
@ -75,7 +77,7 @@ class FunASRNano(nn.Module):
|
||||
|
||||
freeze = llm_conf.get("freeze", True)
|
||||
if freeze:
|
||||
for name, param in model.named_parameters():
|
||||
for _, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
model.eval()
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
@ -95,12 +97,40 @@ class FunASRNano(nn.Module):
|
||||
audio_adaptor = adaptor_class(**audio_adaptor_conf)
|
||||
freeze = audio_adaptor_conf.get("freeze", False)
|
||||
if freeze:
|
||||
for name, param in audio_adaptor.named_parameters():
|
||||
for _, param in audio_adaptor.named_parameters():
|
||||
param.requires_grad = False
|
||||
audio_adaptor.eval()
|
||||
self.audio_adaptor = audio_adaptor
|
||||
self.use_low_frame_rate = audio_adaptor_conf.get("use_low_frame_rate", False)
|
||||
|
||||
# ctc decoder
|
||||
self.ctc_decoder = None
|
||||
# TODO: fix table name
|
||||
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
||||
if ctc_decoder_class is not None:
|
||||
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
|
||||
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
|
||||
if audio_encoder_output_size > 0:
|
||||
ctc_decoder_conf["encoder_dim"] = audio_encoder_output_size
|
||||
self.ctc_decoder = ctc_decoder_class(**ctc_decoder_conf)
|
||||
init_param_path = ctc_decoder_conf.get("init_param_path", None)
|
||||
if init_param_path is not None:
|
||||
src_state = torch.load(init_param_path, map_location="cpu")
|
||||
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
||||
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
|
||||
freeze = ctc_decoder_conf.get("freeze", False)
|
||||
if freeze:
|
||||
for _, param in self.ctc_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ctc_decoder.eval()
|
||||
|
||||
ctc_conf = kwargs.get("ctc_conf", {})
|
||||
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
|
||||
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
|
||||
self.ctc = CTC(odim=ctc_vocab_size, encoder_output_size=audio_encoder_output_size, blank_id=self.blank_id, **ctc_conf)
|
||||
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
|
||||
self.error_calculator = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
logging.info(f"rank: {rank}, model is builded.")
|
||||
|
||||
Reference in New Issue
Block a user