add decode script

This commit is contained in:
pengzhendong
2026-01-04 15:12:17 +08:00
parent c35ce7601a
commit 403d305df9
5 changed files with 85 additions and 19 deletions

View File

@ -117,7 +117,9 @@ class FunASRNano(nn.Module):
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
logging.info(
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
)
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
@ -127,7 +129,12 @@ class FunASRNano(nn.Module):
ctc_conf = kwargs.get("ctc_conf", {})
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
self.ctc = CTC(odim=ctc_vocab_size, encoder_output_size=audio_encoder_output_size, blank_id=self.blank_id, **ctc_conf)
self.ctc = CTC(
odim=ctc_vocab_size,
encoder_output_size=audio_encoder_output_size,
blank_id=self.blank_id,
**ctc_conf,
)
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
self.error_calculator = None
@ -629,7 +636,7 @@ class FunASRNano(nn.Module):
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype]
dtype=dtype_map[llm_dtype],
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
@ -677,7 +684,7 @@ class FunASRNano(nn.Module):
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {
"key": key[0],
"text": re.sub(r'\s+', ' ', response.replace("/sil", " ")),
"text": re.sub(r"\s+", " ", response.replace("/sil", " ")),
"text_tn": response_clean,
"label": label,
}