diff --git a/demo2.py b/demo2.py index f161d4d..5161968 100644 --- a/demo2.py +++ b/demo2.py @@ -1,6 +1,9 @@ +import numpy as np +import soundfile as sf import torch from model import FunASRNano +from tools.utils import load_audio def main(): @@ -13,13 +16,26 @@ def main(): else "cpu" ) m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device) + tokenizer = kwargs.get("tokenizer", None) m.eval() wav_path = f"{kwargs['model_path']}/example/zh.mp3" res = m.inference(data_in=[wav_path], **kwargs) - text = res[0][0]["text"] + text = res[0][0] print(text) + chunk_size = 0.72 + duration = sf.info(wav_path).duration + cum_durations = np.arange(chunk_size, duration + chunk_size, chunk_size) + prev_text = "" + for idx, cum_duration in enumerate(cum_durations): + audio, rate = load_audio(wav_path, 16000, duration=round(cum_duration, 3)) + prev_text = m.inference([torch.tensor(audio)], prev_text=prev_text, **kwargs)[0][0]["text"] + if idx != len(cum_durations) - 1: + prev_text = tokenizer.decode(tokenizer.encode(prev_text)[:-5]).replace("�", "") + if prev_text: + print(prev_text) + if __name__ == "__main__": main() diff --git a/model.py b/model.py index 8ac2c3f..fcf7720 100644 --- a/model.py +++ b/model.py @@ -354,6 +354,8 @@ class FunASRNano(nn.Module): source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" if not do_think: source_input += "\n\n\n\n" + if kwargs.get("prev_text", None) is not None: + source_input += kwargs["prev_text"] splits = pattern.split(source_input) source_ids = [] @@ -514,7 +516,7 @@ class FunASRNano(nn.Module): fbank_beg = batch["fbank_beg"] fake_token_len = batch["fake_token_len"] - if not kwargs.get("tearchforing", False): + if not kwargs.get("teacherforcing", False): input_ids = source_ids input_ids[input_ids < 0] = 0 @@ -698,6 +700,7 @@ class FunASRNano(nn.Module): skip_special_tokens=kwargs.get("skip_special_tokens", True), )[0] loss = model_outputs.loss.item() + response = kwargs.get("prev_text", "") + response ibest_writer = None if kwargs.get("output_dir") is not None: