add forced align
This commit is contained in:
60
model.py
60
model.py
@ -17,6 +17,7 @@ from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
from ctc import CTC
|
from ctc import CTC
|
||||||
|
from tools.utils import forced_align
|
||||||
|
|
||||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||||
|
|
||||||
@ -108,6 +109,14 @@ class FunASRNano(nn.Module):
|
|||||||
# TODO: fix table name
|
# TODO: fix table name
|
||||||
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
||||||
if ctc_decoder_class is not None:
|
if ctc_decoder_class is not None:
|
||||||
|
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
|
||||||
|
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
||||||
|
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
|
||||||
|
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
|
||||||
|
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
|
||||||
|
self.ctc_tokenizer = ctc_tokenizer
|
||||||
|
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
|
||||||
|
|
||||||
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
|
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
|
||||||
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
|
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
|
||||||
if audio_encoder_output_size > 0:
|
if audio_encoder_output_size > 0:
|
||||||
@ -492,11 +501,13 @@ class FunASRNano(nn.Module):
|
|||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
encoder_out, encoder_out_lens = self.audio_adaptor(
|
adaptor_out, adaptor_out_lens = self.audio_adaptor(
|
||||||
encoder_out, encoder_out_lens
|
encoder_out, encoder_out_lens
|
||||||
)
|
)
|
||||||
meta_data["audio_adaptor_out"] = encoder_out
|
meta_data["encoder_out"] = encoder_out
|
||||||
meta_data["audio_adaptor_out_lens"] = encoder_out_lens
|
meta_data["encoder_out_lens"] = encoder_out_lens
|
||||||
|
meta_data["audio_adaptor_out"] = adaptor_out
|
||||||
|
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
|
||||||
|
|
||||||
input_ids = batch["input_ids"]
|
input_ids = batch["input_ids"]
|
||||||
source_ids = batch["source_ids"]
|
source_ids = batch["source_ids"]
|
||||||
@ -520,7 +531,7 @@ class FunASRNano(nn.Module):
|
|||||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||||
if fbank_beg_idx > 0:
|
if fbank_beg_idx > 0:
|
||||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inputs_embeds[
|
inputs_embeds[
|
||||||
@ -532,10 +543,10 @@ class FunASRNano(nn.Module):
|
|||||||
#
|
#
|
||||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||||
)
|
)
|
||||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
speech_token_len = adaptor_out_lens[speech_idx].item()
|
||||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
|
||||||
inputs_embeds[
|
inputs_embeds[
|
||||||
batch_idx,
|
batch_idx,
|
||||||
fbank_beg_idx : fbank_beg_idx + speech_token_len,
|
fbank_beg_idx : fbank_beg_idx + speech_token_len,
|
||||||
@ -627,6 +638,29 @@ class FunASRNano(nn.Module):
|
|||||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ctc_results = []
|
||||||
|
if self.ctc_decoder is not None:
|
||||||
|
encoder_out = meta_data["encoder_out"]
|
||||||
|
encoder_out_lens = meta_data["encoder_out_lens"]
|
||||||
|
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
|
||||||
|
ctc_logits = self.ctc.log_softmax(decoder_out)
|
||||||
|
|
||||||
|
b, n, d = encoder_out.size()
|
||||||
|
if isinstance(key[0], (list, tuple)):
|
||||||
|
key = key[0]
|
||||||
|
if len(key) < b:
|
||||||
|
key = key * b
|
||||||
|
for i in range(b):
|
||||||
|
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
|
||||||
|
yseq = x.argmax(dim=-1)
|
||||||
|
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||||
|
mask = yseq != self.blank_id
|
||||||
|
token_int = yseq[mask].tolist()
|
||||||
|
# Change integer-ids to tokens
|
||||||
|
text = self.ctc_tokenizer.decode(token_int)
|
||||||
|
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
|
||||||
|
|
||||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||||
if llm_dtype == "fp32":
|
if llm_dtype == "fp32":
|
||||||
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
||||||
@ -692,6 +726,18 @@ class FunASRNano(nn.Module):
|
|||||||
result_i["loss"] = loss
|
result_i["loss"] = loss
|
||||||
results.append(result_i)
|
results.append(result_i)
|
||||||
|
|
||||||
|
for ctc_result, result in zip(ctc_results, results):
|
||||||
|
result["ctc_text"] = ctc_result["text"]
|
||||||
|
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
|
||||||
|
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||||
|
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
||||||
|
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||||
|
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
|
||||||
|
for timestamp in timestamps:
|
||||||
|
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
|
||||||
|
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
|
||||||
|
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
|
||||||
|
|
||||||
if ibest_writer is not None:
|
if ibest_writer is not None:
|
||||||
ibest_writer["text"][key[0]] = response.replace("\n", " ")
|
ibest_writer["text"][key[0]] = response.replace("\n", " ")
|
||||||
ibest_writer["label"][key[0]] = label.replace("\n", " ")
|
ibest_writer["label"][key[0]] = label.replace("\n", " ")
|
||||||
|
|||||||
@ -6,3 +6,4 @@ zhconv
|
|||||||
whisper_normalizer
|
whisper_normalizer
|
||||||
pyopenjtalk-plus
|
pyopenjtalk-plus
|
||||||
compute-wer
|
compute-wer
|
||||||
|
openai-whisper
|
||||||
|
|||||||
33
tools/utils.py
Normal file
33
tools/utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from itertools import groupby
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0):
|
||||||
|
items = []
|
||||||
|
try:
|
||||||
|
# The current version only supports batch_size==1.
|
||||||
|
log_probs, targets = log_probs.unsqueeze(0).cpu(), targets.unsqueeze(0).cpu()
|
||||||
|
assert log_probs.shape[1] >= targets.shape[1]
|
||||||
|
alignments, scores = F.forced_align(log_probs, targets, blank=blank)
|
||||||
|
alignments, scores = alignments[0], torch.exp(scores[0]).tolist()
|
||||||
|
# use enumerate to keep track of the original indices, then group by token value
|
||||||
|
for token, group in groupby(enumerate(alignments), key=lambda item: item[1]):
|
||||||
|
if token == blank:
|
||||||
|
continue
|
||||||
|
group = list(group)
|
||||||
|
start = group[0][0]
|
||||||
|
end = start + len(group)
|
||||||
|
score = max(scores[start:end])
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
"token": token.item(),
|
||||||
|
"start_time": start,
|
||||||
|
"end_time": end,
|
||||||
|
"score": round(score, 3),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return items
|
||||||
Reference in New Issue
Block a user