add streaming usage
This commit is contained in:
18
demo2.py
18
demo2.py
@ -1,6 +1,9 @@
|
|||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model import FunASRNano
|
from model import FunASRNano
|
||||||
|
from tools.utils import load_audio
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -13,13 +16,26 @@ def main():
|
|||||||
else "cpu"
|
else "cpu"
|
||||||
)
|
)
|
||||||
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
|
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
|
||||||
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
m.eval()
|
m.eval()
|
||||||
|
|
||||||
wav_path = f"{kwargs['model_path']}/example/zh.mp3"
|
wav_path = f"{kwargs['model_path']}/example/zh.mp3"
|
||||||
res = m.inference(data_in=[wav_path], **kwargs)
|
res = m.inference(data_in=[wav_path], **kwargs)
|
||||||
text = res[0][0]["text"]
|
text = res[0][0]
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
|
chunk_size = 0.72
|
||||||
|
duration = sf.info(wav_path).duration
|
||||||
|
cum_durations = np.arange(chunk_size, duration + chunk_size, chunk_size)
|
||||||
|
prev_text = ""
|
||||||
|
for idx, cum_duration in enumerate(cum_durations):
|
||||||
|
audio, rate = load_audio(wav_path, 16000, duration=round(cum_duration, 3))
|
||||||
|
prev_text = m.inference([torch.tensor(audio)], prev_text=prev_text, **kwargs)[0][0]["text"]
|
||||||
|
if idx != len(cum_durations) - 1:
|
||||||
|
prev_text = tokenizer.decode(tokenizer.encode(prev_text)[:-5]).replace("<EFBFBD>", "")
|
||||||
|
if prev_text:
|
||||||
|
print(prev_text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
5
model.py
5
model.py
@ -354,6 +354,8 @@ class FunASRNano(nn.Module):
|
|||||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
if not do_think:
|
if not do_think:
|
||||||
source_input += "<think>\n\n</think>\n\n"
|
source_input += "<think>\n\n</think>\n\n"
|
||||||
|
if kwargs.get("prev_text", None) is not None:
|
||||||
|
source_input += kwargs["prev_text"]
|
||||||
|
|
||||||
splits = pattern.split(source_input)
|
splits = pattern.split(source_input)
|
||||||
source_ids = []
|
source_ids = []
|
||||||
@ -514,7 +516,7 @@ class FunASRNano(nn.Module):
|
|||||||
fbank_beg = batch["fbank_beg"]
|
fbank_beg = batch["fbank_beg"]
|
||||||
fake_token_len = batch["fake_token_len"]
|
fake_token_len = batch["fake_token_len"]
|
||||||
|
|
||||||
if not kwargs.get("tearchforing", False):
|
if not kwargs.get("teacherforcing", False):
|
||||||
input_ids = source_ids
|
input_ids = source_ids
|
||||||
|
|
||||||
input_ids[input_ids < 0] = 0
|
input_ids[input_ids < 0] = 0
|
||||||
@ -698,6 +700,7 @@ class FunASRNano(nn.Module):
|
|||||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||||
)[0]
|
)[0]
|
||||||
loss = model_outputs.loss.item()
|
loss = model_outputs.loss.item()
|
||||||
|
response = kwargs.get("prev_text", "") + response
|
||||||
|
|
||||||
ibest_writer = None
|
ibest_writer = None
|
||||||
if kwargs.get("output_dir") is not None:
|
if kwargs.get("output_dir") is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user