add forced align

This commit is contained in:
pengzhendong
2026-01-25 21:23:21 +08:00
parent 3d49f7b0a9
commit 713aa6ff91
3 changed files with 87 additions and 7 deletions

View File

@ -17,6 +17,7 @@ from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
from ctc import CTC from ctc import CTC
from tools.utils import forced_align
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@ -108,6 +109,14 @@ class FunASRNano(nn.Module):
# TODO: fix table name # TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None)) ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None: if ctc_decoder_class is not None:
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
self.ctc_tokenizer = ctc_tokenizer
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515) ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {}) ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0: if audio_encoder_output_size > 0:
@ -492,11 +501,13 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor # audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor( adaptor_out, adaptor_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens encoder_out, encoder_out_lens
) )
meta_data["audio_adaptor_out"] = encoder_out meta_data["encoder_out"] = encoder_out
meta_data["audio_adaptor_out_lens"] = encoder_out_lens meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
input_ids = batch["input_ids"] input_ids = batch["input_ids"]
source_ids = batch["source_ids"] source_ids = batch["source_ids"]
@ -520,7 +531,7 @@ class FunASRNano(nn.Module):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0: if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id] speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :] speech_token = adaptor_out[speech_idx, :speech_token_len, :]
try: try:
inputs_embeds[ inputs_embeds[
@ -532,10 +543,10 @@ class FunASRNano(nn.Module):
# #
logging.error(f"{str(e)}, {traceback.format_exc()}") logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info( logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
) )
speech_token_len = encoder_out_lens[speech_idx].item() speech_token_len = adaptor_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :] speech_token = adaptor_out[speech_idx, :speech_token_len, :]
inputs_embeds[ inputs_embeds[
batch_idx, batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len, fbank_beg_idx : fbank_beg_idx + speech_token_len,
@ -627,6 +638,29 @@ class FunASRNano(nn.Module):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs data_in, data_lengths, key, tokenizer, frontend, **kwargs
) )
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
# Change integer-ids to tokens
text = self.ctc_tokenizer.decode(token_int)
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
llm_dtype = kwargs.get("llm_dtype", "fp32") llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32": if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
@ -692,6 +726,18 @@ class FunASRNano(nn.Module):
result_i["loss"] = loss result_i["loss"] = loss
results.append(result_i) results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"]
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
if ibest_writer is not None: if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ") ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ") ibest_writer["label"][key[0]] = label.replace("\n", " ")

View File

@ -6,3 +6,4 @@ zhconv
whisper_normalizer whisper_normalizer
pyopenjtalk-plus pyopenjtalk-plus
compute-wer compute-wer
openai-whisper

33
tools/utils.py Normal file
View File

@ -0,0 +1,33 @@
from itertools import groupby
import torch
import torchaudio.functional as F
def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0):
items = []
try:
# The current version only supports batch_size==1.
log_probs, targets = log_probs.unsqueeze(0).cpu(), targets.unsqueeze(0).cpu()
assert log_probs.shape[1] >= targets.shape[1]
alignments, scores = F.forced_align(log_probs, targets, blank=blank)
alignments, scores = alignments[0], torch.exp(scores[0]).tolist()
# use enumerate to keep track of the original indices, then group by token value
for token, group in groupby(enumerate(alignments), key=lambda item: item[1]):
if token == blank:
continue
group = list(group)
start = group[0][0]
end = start + len(group)
score = max(scores[start:end])
items.append(
{
"token": token.item(),
"start_time": start,
"end_time": end,
"score": round(score, 3),
}
)
except:
pass
return items