add forced align
This commit is contained in:
60
model.py
60
model.py
@ -17,6 +17,7 @@ from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from ctc import CTC
|
||||
from tools.utils import forced_align
|
||||
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
|
||||
@ -108,6 +109,14 @@ class FunASRNano(nn.Module):
|
||||
# TODO: fix table name
|
||||
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
||||
if ctc_decoder_class is not None:
|
||||
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
|
||||
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
||||
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
|
||||
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
|
||||
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
|
||||
self.ctc_tokenizer = ctc_tokenizer
|
||||
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
|
||||
|
||||
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
|
||||
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
|
||||
if audio_encoder_output_size > 0:
|
||||
@ -492,11 +501,13 @@ class FunASRNano(nn.Module):
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(
|
||||
adaptor_out, adaptor_out_lens = self.audio_adaptor(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
meta_data["audio_adaptor_out"] = encoder_out
|
||||
meta_data["audio_adaptor_out_lens"] = encoder_out_lens
|
||||
meta_data["encoder_out"] = encoder_out
|
||||
meta_data["encoder_out_lens"] = encoder_out_lens
|
||||
meta_data["audio_adaptor_out"] = adaptor_out
|
||||
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
source_ids = batch["source_ids"]
|
||||
@ -520,7 +531,7 @@ class FunASRNano(nn.Module):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
@ -532,10 +543,10 @@ class FunASRNano(nn.Module):
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
speech_token_len = adaptor_out_lens[speech_idx].item()
|
||||
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx,
|
||||
fbank_beg_idx : fbank_beg_idx + speech_token_len,
|
||||
@ -627,6 +638,29 @@ class FunASRNano(nn.Module):
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||
)
|
||||
|
||||
ctc_results = []
|
||||
if self.ctc_decoder is not None:
|
||||
encoder_out = meta_data["encoder_out"]
|
||||
encoder_out_lens = meta_data["encoder_out_lens"]
|
||||
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
|
||||
ctc_logits = self.ctc.log_softmax(decoder_out)
|
||||
|
||||
b, n, d = encoder_out.size()
|
||||
if isinstance(key[0], (list, tuple)):
|
||||
key = key[0]
|
||||
if len(key) < b:
|
||||
key = key * b
|
||||
for i in range(b):
|
||||
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
# Change integer-ids to tokens
|
||||
text = self.ctc_tokenizer.decode(token_int)
|
||||
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
|
||||
|
||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||
if llm_dtype == "fp32":
|
||||
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
||||
@ -692,6 +726,18 @@ class FunASRNano(nn.Module):
|
||||
result_i["loss"] = loss
|
||||
results.append(result_i)
|
||||
|
||||
for ctc_result, result in zip(ctc_results, results):
|
||||
result["ctc_text"] = ctc_result["text"]
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
|
||||
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
||||
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
|
||||
for timestamp in timestamps:
|
||||
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
|
||||
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
|
||||
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[0]] = response.replace("\n", " ")
|
||||
ibest_writer["label"][key[0]] = label.replace("\n", " ")
|
||||
|
||||
Reference in New Issue
Block a user