diff --git a/model.py b/model.py index 82d9f19..e8f3810 100644 --- a/model.py +++ b/model.py @@ -17,6 +17,7 @@ from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video from transformers import AutoConfig, AutoModelForCausalLM from ctc import CTC +from tools.utils import forced_align dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} @@ -108,6 +109,14 @@ class FunASRNano(nn.Module): # TODO: fix table name ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None)) if ctc_decoder_class is not None: + ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"] + ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"] + if ctc_tokenizer is not None and ctc_tokenizer_conf is not None: + ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer) + ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf) + self.ctc_tokenizer = ctc_tokenizer + assert ctc_tokenizer is not None, f"ctc_tokenizer must be set" + ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515) ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {}) if audio_encoder_output_size > 0: @@ -492,11 +501,13 @@ class FunASRNano(nn.Module): encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor( + adaptor_out, adaptor_out_lens = self.audio_adaptor( encoder_out, encoder_out_lens ) - meta_data["audio_adaptor_out"] = encoder_out - meta_data["audio_adaptor_out_lens"] = encoder_out_lens + meta_data["encoder_out"] = encoder_out + meta_data["encoder_out_lens"] = encoder_out_lens + meta_data["audio_adaptor_out"] = adaptor_out + meta_data["audio_adaptor_out_lens"] = adaptor_out_lens input_ids = batch["input_ids"] source_ids = batch["source_ids"] @@ -520,7 +531,7 @@ class FunASRNano(nn.Module): fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() if fbank_beg_idx > 0: speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[speech_idx, :speech_token_len, :] + speech_token = adaptor_out[speech_idx, :speech_token_len, :] try: inputs_embeds[ @@ -532,10 +543,10 @@ class FunASRNano(nn.Module): # logging.error(f"{str(e)}, {traceback.format_exc()}") logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" ) - speech_token_len = encoder_out_lens[speech_idx].item() - speech_token = encoder_out[speech_idx, :speech_token_len, :] + speech_token_len = adaptor_out_lens[speech_idx].item() + speech_token = adaptor_out[speech_idx, :speech_token_len, :] inputs_embeds[ batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, @@ -627,6 +638,29 @@ class FunASRNano(nn.Module): inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( data_in, data_lengths, key, tokenizer, frontend, **kwargs ) + + ctc_results = [] + if self.ctc_decoder is not None: + encoder_out = meta_data["encoder_out"] + encoder_out_lens = meta_data["encoder_out_lens"] + decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens) + ctc_logits = self.ctc.log_softmax(decoder_out) + + b, n, d = encoder_out.size() + if isinstance(key[0], (list, tuple)): + key = key[0] + if len(key) < b: + key = key * b + for i in range(b): + x = ctc_logits[i, : encoder_out_lens[i].item(), :] + yseq = x.argmax(dim=-1) + yseq = torch.unique_consecutive(yseq, dim=-1) + mask = yseq != self.blank_id + token_int = yseq[mask].tolist() + # Change integer-ids to tokens + text = self.ctc_tokenizer.decode(token_int) + ctc_results.append({"key": key[i], "text": text, "ctc_logits": x}) + llm_dtype = kwargs.get("llm_dtype", "fp32") if llm_dtype == "fp32": llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype @@ -692,6 +726,18 @@ class FunASRNano(nn.Module): result_i["loss"] = loss results.append(result_i) + for ctc_result, result in zip(ctc_results, results): + result["ctc_text"] = ctc_result["text"] + target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64) + result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) + target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) + result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) + for timestamps in [result["timestamps"], result["ctc_timestamps"]]: + for timestamp in timestamps: + timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]]) + timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000 + timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000 + if ibest_writer is not None: ibest_writer["text"][key[0]] = response.replace("\n", " ") ibest_writer["label"][key[0]] = label.replace("\n", " ") diff --git a/requirements.txt b/requirements.txt index d8a6816..ca7800a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ zhconv whisper_normalizer pyopenjtalk-plus compute-wer +openai-whisper diff --git a/tools/utils.py b/tools/utils.py new file mode 100644 index 0000000..7526a68 --- /dev/null +++ b/tools/utils.py @@ -0,0 +1,33 @@ +from itertools import groupby + +import torch +import torchaudio.functional as F + + +def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0): + items = [] + try: + # The current version only supports batch_size==1. + log_probs, targets = log_probs.unsqueeze(0).cpu(), targets.unsqueeze(0).cpu() + assert log_probs.shape[1] >= targets.shape[1] + alignments, scores = F.forced_align(log_probs, targets, blank=blank) + alignments, scores = alignments[0], torch.exp(scores[0]).tolist() + # use enumerate to keep track of the original indices, then group by token value + for token, group in groupby(enumerate(alignments), key=lambda item: item[1]): + if token == blank: + continue + group = list(group) + start = group[0][0] + end = start + len(group) + score = max(scores[start:end]) + items.append( + { + "token": token.item(), + "start_time": start, + "end_time": end, + "score": round(score, 3), + } + ) + except: + pass + return items