init ctc decoder

This commit is contained in:
pengzhendong
2026-01-01 00:27:35 +08:00
parent 9ec8afa472
commit 2e31abe02d
2 changed files with 94 additions and 4 deletions

View File

@ -16,6 +16,8 @@ from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
from ctc import CTC
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@ -57,13 +59,13 @@ class FunASRNano(nn.Module):
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
if freeze:
for name, param in audio_encoder.named_parameters():
for _, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
self.llm = None
init_param_path = llm_conf.get("init_param_path", None)
@ -75,7 +77,7 @@ class FunASRNano(nn.Module):
freeze = llm_conf.get("freeze", True)
if freeze:
for name, param in model.named_parameters():
for _, param in model.named_parameters():
param.requires_grad = False
model.eval()
if llm_conf.get("activation_checkpoint", False):
@ -95,12 +97,40 @@ class FunASRNano(nn.Module):
audio_adaptor = adaptor_class(**audio_adaptor_conf)
freeze = audio_adaptor_conf.get("freeze", False)
if freeze:
for name, param in audio_adaptor.named_parameters():
for _, param in audio_adaptor.named_parameters():
param.requires_grad = False
audio_adaptor.eval()
self.audio_adaptor = audio_adaptor
self.use_low_frame_rate = audio_adaptor_conf.get("use_low_frame_rate", False)
# ctc decoder
self.ctc_decoder = None
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0:
ctc_decoder_conf["encoder_dim"] = audio_encoder_output_size
self.ctc_decoder = ctc_decoder_class(**ctc_decoder_conf)
init_param_path = ctc_decoder_conf.get("init_param_path", None)
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
param.requires_grad = False
self.ctc_decoder.eval()
ctc_conf = kwargs.get("ctc_conf", {})
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
self.ctc = CTC(odim=ctc_vocab_size, encoder_output_size=audio_encoder_output_size, blank_id=self.blank_id, **ctc_conf)
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")