update
This commit is contained in:
102
model.py
102
model.py
@ -5,10 +5,11 @@ import re
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr import AutoModel
|
||||
|
||||
from funasr.metrics.compute_acc import compute_accuracy
|
||||
from funasr.register import tables
|
||||
from funasr.train_utils.device_funcs import force_gatherable, to_device
|
||||
@ -44,6 +45,8 @@ class FunASRNano(nn.Module):
|
||||
"activation_checkpoint", False
|
||||
)
|
||||
if hub == "ms":
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model=audio_encoder, model_revision="master")
|
||||
audio_encoder_output_size = (
|
||||
model.model.encoder_output_size
|
||||
@ -51,9 +54,7 @@ class FunASRNano(nn.Module):
|
||||
else -1
|
||||
)
|
||||
audio_encoder = (
|
||||
model.model.model.encoder
|
||||
if hasattr(model.model, "model")
|
||||
else model.model.encoder
|
||||
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
|
||||
)
|
||||
else:
|
||||
encoder_class = tables.encoder_classes.get(audio_encoder)
|
||||
@ -109,8 +110,16 @@ class FunASRNano(nn.Module):
|
||||
# TODO: fix table name
|
||||
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
||||
if ctc_decoder_class is not None:
|
||||
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
|
||||
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
||||
ctc_tokenizer = (
|
||||
kwargs.get("ctc_tokenizer", None)
|
||||
if "ctc_tokenizer" in kwargs
|
||||
else kwargs["dataset_conf"]["ctc_tokenizer"]
|
||||
)
|
||||
ctc_tokenizer_conf = (
|
||||
kwargs.get("ctc_tokenizer_conf", None)
|
||||
if "ctc_tokenizer_conf" in kwargs
|
||||
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
||||
)
|
||||
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
|
||||
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
|
||||
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
|
||||
@ -126,9 +135,7 @@ class FunASRNano(nn.Module):
|
||||
if init_param_path is not None:
|
||||
src_state = torch.load(init_param_path, map_location="cpu")
|
||||
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
||||
logging.info(
|
||||
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
|
||||
)
|
||||
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
|
||||
freeze = ctc_decoder_conf.get("freeze", False)
|
||||
if freeze:
|
||||
for _, param in self.ctc_decoder.named_parameters():
|
||||
@ -182,9 +189,7 @@ class FunASRNano(nn.Module):
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
@ -223,9 +228,7 @@ class FunASRNano(nn.Module):
|
||||
stats["batch_size_speech"] = batch_size_speech
|
||||
stats["batch_size_x_frames"] = frames * batch_size_speech
|
||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||
stats["padding_frames"] = (
|
||||
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
)
|
||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
with torch.autocast(
|
||||
@ -244,9 +247,7 @@ class FunASRNano(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(
|
||||
preds[:, :-1], labels_ids[:, 1:], ignore_label=-100
|
||||
)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
@ -254,9 +255,7 @@ class FunASRNano(nn.Module):
|
||||
|
||||
stats["batch_size_x_tokens"] = token_num * batch_size
|
||||
stats["batch_size_real_tokens"] = attention_mask.sum().item()
|
||||
stats["padding_tokens"] = (
|
||||
stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||
)
|
||||
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||
|
||||
dialog_turns = (fbank_beg > 0).sum(-1)
|
||||
dialog_turns_max = torch.max(dialog_turns).int().item()
|
||||
@ -306,9 +305,7 @@ class FunASRNano(nn.Module):
|
||||
|
||||
return contents
|
||||
|
||||
def data_load_speech(
|
||||
self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs
|
||||
):
|
||||
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
|
||||
system = contents["system"]
|
||||
user = contents["user"]
|
||||
assistant = contents["assistant"]
|
||||
@ -329,9 +326,7 @@ class FunASRNano(nn.Module):
|
||||
[],
|
||||
)
|
||||
input_source_ids = []
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(
|
||||
zip(system, user, assistant)
|
||||
):
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
|
||||
if i >= kwargs.get("multiturn_num_max", 5):
|
||||
break
|
||||
if len(input_ids) > kwargs.get("max_token_length", 1500):
|
||||
@ -346,12 +341,16 @@ class FunASRNano(nn.Module):
|
||||
else:
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
if not sys_prompt:
|
||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
if kwargs.get("infer_with_assistant_input", False):
|
||||
source_input = f"<|im_start|>user\n{user_prompt}"
|
||||
else:
|
||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
if not do_think:
|
||||
source_input += "<think>\n\n</think>\n\n"
|
||||
if kwargs.get("prev_text", None) is not None:
|
||||
@ -384,9 +383,7 @@ class FunASRNano(nn.Module):
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
|
||||
)
|
||||
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
|
||||
|
||||
speech, speech_lengths = extract_fbank(
|
||||
data_src,
|
||||
@ -428,9 +425,7 @@ class FunASRNano(nn.Module):
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
input_ids, dtype=torch.int64
|
||||
) # [: self.max_token_length]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
|
||||
@ -441,9 +436,7 @@ class FunASRNano(nn.Module):
|
||||
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
||||
|
||||
if len(fbank) > 0:
|
||||
speech = torch.nn.utils.rnn.pad_sequence(
|
||||
fbank, batch_first=True, padding_value=0.0
|
||||
)
|
||||
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
|
||||
speech_lengths = torch.nn.utils.rnn.pad_sequence(
|
||||
fbank_lens, batch_first=True, padding_value=-1
|
||||
)
|
||||
@ -480,9 +473,7 @@ class FunASRNano(nn.Module):
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
contents = self.data_template(data_in[0])
|
||||
output = self.data_load_speech(
|
||||
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
|
||||
)
|
||||
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
|
||||
batch = to_device(output, kwargs["device"])
|
||||
|
||||
# audio encoder
|
||||
@ -503,9 +494,7 @@ class FunASRNano(nn.Module):
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
adaptor_out, adaptor_out_lens = self.audio_adaptor(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
meta_data["encoder_out"] = encoder_out
|
||||
meta_data["encoder_out_lens"] = encoder_out_lens
|
||||
meta_data["audio_adaptor_out"] = adaptor_out
|
||||
@ -573,7 +562,7 @@ class FunASRNano(nn.Module):
|
||||
prompt += ",不进行文本规整"
|
||||
return prompt + ":"
|
||||
|
||||
def generate_chatml(self, prompt: str, data: str | torch.Tensor):
|
||||
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
|
||||
if isinstance(data, str):
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@ -583,11 +572,14 @@ class FunASRNano(nn.Module):
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
|
||||
"audio": data,
|
||||
},
|
||||
{"role": "assistant", "content": "null"},
|
||||
]
|
||||
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
@ -597,7 +589,9 @@ class FunASRNano(nn.Module):
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True))
|
||||
prompt = self.get_prompt(
|
||||
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
|
||||
)
|
||||
data_in = [self.generate_chatml(prompt, data) for data in data_in]
|
||||
|
||||
if key is None:
|
||||
@ -722,8 +716,12 @@ class FunASRNano(nn.Module):
|
||||
|
||||
for ctc_result, result in zip(ctc_results, results):
|
||||
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64)
|
||||
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
target_ids = torch.tensor(
|
||||
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
|
||||
)
|
||||
result["ctc_timestamps"] = forced_align(
|
||||
ctc_result["ctc_logits"], target_ids, self.blank_id
|
||||
)
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
||||
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
|
||||
@ -743,8 +741,6 @@ class FunASRNano(nn.Module):
|
||||
def from_pretrained(model: str = None, **kwargs):
|
||||
from funasr import AutoModel
|
||||
|
||||
model, kwargs = AutoModel.build_model(
|
||||
model=model, trust_remote_code=True, **kwargs
|
||||
)
|
||||
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
|
||||
|
||||
return model, kwargs
|
||||
|
||||
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
Reference in New Issue
Block a user