add forced align

This commit is contained in:
pengzhendong
2026-01-25 21:23:21 +08:00
parent 3d49f7b0a9
commit 713aa6ff91
3 changed files with 87 additions and 7 deletions

View File

@ -17,6 +17,7 @@ from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
from ctc import CTC
from tools.utils import forced_align
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@ -108,6 +109,14 @@ class FunASRNano(nn.Module):
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
self.ctc_tokenizer = ctc_tokenizer
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0:
@ -492,11 +501,13 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(
adaptor_out, adaptor_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens
)
meta_data["audio_adaptor_out"] = encoder_out
meta_data["audio_adaptor_out_lens"] = encoder_out_lens
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
@ -520,7 +531,7 @@ class FunASRNano(nn.Module):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
@ -532,10 +543,10 @@ class FunASRNano(nn.Module):
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
speech_token_len = adaptor_out_lens[speech_idx].item()
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
@ -627,6 +638,29 @@ class FunASRNano(nn.Module):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
# Change integer-ids to tokens
text = self.ctc_tokenizer.decode(token_int)
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
@ -692,6 +726,18 @@ class FunASRNano(nn.Module):
result_i["loss"] = loss
results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"]
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ")