update
This commit is contained in:
102
model.py
102
model.py
@ -5,10 +5,11 @@ import re
|
|||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from funasr import AutoModel
|
|
||||||
from funasr.metrics.compute_acc import compute_accuracy
|
from funasr.metrics.compute_acc import compute_accuracy
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
from funasr.train_utils.device_funcs import force_gatherable, to_device
|
from funasr.train_utils.device_funcs import force_gatherable, to_device
|
||||||
@ -44,6 +45,8 @@ class FunASRNano(nn.Module):
|
|||||||
"activation_checkpoint", False
|
"activation_checkpoint", False
|
||||||
)
|
)
|
||||||
if hub == "ms":
|
if hub == "ms":
|
||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
model = AutoModel(model=audio_encoder, model_revision="master")
|
model = AutoModel(model=audio_encoder, model_revision="master")
|
||||||
audio_encoder_output_size = (
|
audio_encoder_output_size = (
|
||||||
model.model.encoder_output_size
|
model.model.encoder_output_size
|
||||||
@ -51,9 +54,7 @@ class FunASRNano(nn.Module):
|
|||||||
else -1
|
else -1
|
||||||
)
|
)
|
||||||
audio_encoder = (
|
audio_encoder = (
|
||||||
model.model.model.encoder
|
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
|
||||||
if hasattr(model.model, "model")
|
|
||||||
else model.model.encoder
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_class = tables.encoder_classes.get(audio_encoder)
|
encoder_class = tables.encoder_classes.get(audio_encoder)
|
||||||
@ -109,8 +110,16 @@ class FunASRNano(nn.Module):
|
|||||||
# TODO: fix table name
|
# TODO: fix table name
|
||||||
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
|
||||||
if ctc_decoder_class is not None:
|
if ctc_decoder_class is not None:
|
||||||
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
|
ctc_tokenizer = (
|
||||||
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
kwargs.get("ctc_tokenizer", None)
|
||||||
|
if "ctc_tokenizer" in kwargs
|
||||||
|
else kwargs["dataset_conf"]["ctc_tokenizer"]
|
||||||
|
)
|
||||||
|
ctc_tokenizer_conf = (
|
||||||
|
kwargs.get("ctc_tokenizer_conf", None)
|
||||||
|
if "ctc_tokenizer_conf" in kwargs
|
||||||
|
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
|
||||||
|
)
|
||||||
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
|
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
|
||||||
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
|
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
|
||||||
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
|
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
|
||||||
@ -126,9 +135,7 @@ class FunASRNano(nn.Module):
|
|||||||
if init_param_path is not None:
|
if init_param_path is not None:
|
||||||
src_state = torch.load(init_param_path, map_location="cpu")
|
src_state = torch.load(init_param_path, map_location="cpu")
|
||||||
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
||||||
logging.info(
|
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
|
||||||
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
|
|
||||||
)
|
|
||||||
freeze = ctc_decoder_conf.get("freeze", False)
|
freeze = ctc_decoder_conf.get("freeze", False)
|
||||||
if freeze:
|
if freeze:
|
||||||
for _, param in self.ctc_decoder.named_parameters():
|
for _, param in self.ctc_decoder.named_parameters():
|
||||||
@ -182,9 +189,7 @@ class FunASRNano(nn.Module):
|
|||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
encoder_out, encoder_out_lens = self.audio_adaptor(
|
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||||
encoder_out, encoder_out_lens
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, token_num, dims = inputs_embeds.shape
|
batch_size, token_num, dims = inputs_embeds.shape
|
||||||
fake_token_len = kwargs.get("fake_token_len")
|
fake_token_len = kwargs.get("fake_token_len")
|
||||||
@ -223,9 +228,7 @@ class FunASRNano(nn.Module):
|
|||||||
stats["batch_size_speech"] = batch_size_speech
|
stats["batch_size_speech"] = batch_size_speech
|
||||||
stats["batch_size_x_frames"] = frames * batch_size_speech
|
stats["batch_size_x_frames"] = frames * batch_size_speech
|
||||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||||
stats["padding_frames"] = (
|
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||||
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
|
||||||
)
|
|
||||||
|
|
||||||
device_type = next(self.parameters()).device.type
|
device_type = next(self.parameters()).device.type
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
@ -244,9 +247,7 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = torch.argmax(model_outputs.logits, -1)
|
preds = torch.argmax(model_outputs.logits, -1)
|
||||||
acc_att = compute_accuracy(
|
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||||
preds[:, :-1], labels_ids[:, 1:], ignore_label=-100
|
|
||||||
)
|
|
||||||
stats["acc"] = acc_att
|
stats["acc"] = acc_att
|
||||||
|
|
||||||
stats["loss"] = torch.clone(loss.detach())
|
stats["loss"] = torch.clone(loss.detach())
|
||||||
@ -254,9 +255,7 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
stats["batch_size_x_tokens"] = token_num * batch_size
|
stats["batch_size_x_tokens"] = token_num * batch_size
|
||||||
stats["batch_size_real_tokens"] = attention_mask.sum().item()
|
stats["batch_size_real_tokens"] = attention_mask.sum().item()
|
||||||
stats["padding_tokens"] = (
|
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||||
stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
|
||||||
)
|
|
||||||
|
|
||||||
dialog_turns = (fbank_beg > 0).sum(-1)
|
dialog_turns = (fbank_beg > 0).sum(-1)
|
||||||
dialog_turns_max = torch.max(dialog_turns).int().item()
|
dialog_turns_max = torch.max(dialog_turns).int().item()
|
||||||
@ -306,9 +305,7 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
def data_load_speech(
|
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
|
||||||
self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs
|
|
||||||
):
|
|
||||||
system = contents["system"]
|
system = contents["system"]
|
||||||
user = contents["user"]
|
user = contents["user"]
|
||||||
assistant = contents["assistant"]
|
assistant = contents["assistant"]
|
||||||
@ -329,9 +326,7 @@ class FunASRNano(nn.Module):
|
|||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
input_source_ids = []
|
input_source_ids = []
|
||||||
for i, (system_prompt, user_prompt, target_out) in enumerate(
|
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
|
||||||
zip(system, user, assistant)
|
|
||||||
):
|
|
||||||
if i >= kwargs.get("multiturn_num_max", 5):
|
if i >= kwargs.get("multiturn_num_max", 5):
|
||||||
break
|
break
|
||||||
if len(input_ids) > kwargs.get("max_token_length", 1500):
|
if len(input_ids) > kwargs.get("max_token_length", 1500):
|
||||||
@ -346,12 +341,16 @@ class FunASRNano(nn.Module):
|
|||||||
else:
|
else:
|
||||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
if not sys_prompt:
|
if not sys_prompt:
|
||||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
source_input = (
|
||||||
|
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if kwargs.get("infer_with_assistant_input", False):
|
if kwargs.get("infer_with_assistant_input", False):
|
||||||
source_input = f"<|im_start|>user\n{user_prompt}"
|
source_input = f"<|im_start|>user\n{user_prompt}"
|
||||||
else:
|
else:
|
||||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
source_input = (
|
||||||
|
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
if not do_think:
|
if not do_think:
|
||||||
source_input += "<think>\n\n</think>\n\n"
|
source_input += "<think>\n\n</think>\n\n"
|
||||||
if kwargs.get("prev_text", None) is not None:
|
if kwargs.get("prev_text", None) is not None:
|
||||||
@ -384,9 +383,7 @@ class FunASRNano(nn.Module):
|
|||||||
time2 = time.perf_counter()
|
time2 = time.perf_counter()
|
||||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
|
||||||
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
speech, speech_lengths = extract_fbank(
|
speech, speech_lengths = extract_fbank(
|
||||||
data_src,
|
data_src,
|
||||||
@ -428,9 +425,7 @@ class FunASRNano(nn.Module):
|
|||||||
fbank.append(speech[0, :, :])
|
fbank.append(speech[0, :, :])
|
||||||
fbank_lens.append(speech_lengths)
|
fbank_lens.append(speech_lengths)
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||||
input_ids, dtype=torch.int64
|
|
||||||
) # [: self.max_token_length]
|
|
||||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||||
|
|
||||||
@ -441,9 +436,7 @@ class FunASRNano(nn.Module):
|
|||||||
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
||||||
|
|
||||||
if len(fbank) > 0:
|
if len(fbank) > 0:
|
||||||
speech = torch.nn.utils.rnn.pad_sequence(
|
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
|
||||||
fbank, batch_first=True, padding_value=0.0
|
|
||||||
)
|
|
||||||
speech_lengths = torch.nn.utils.rnn.pad_sequence(
|
speech_lengths = torch.nn.utils.rnn.pad_sequence(
|
||||||
fbank_lens, batch_first=True, padding_value=-1
|
fbank_lens, batch_first=True, padding_value=-1
|
||||||
)
|
)
|
||||||
@ -480,9 +473,7 @@ class FunASRNano(nn.Module):
|
|||||||
raise NotImplementedError("batch decoding is not implemented")
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
|
|
||||||
contents = self.data_template(data_in[0])
|
contents = self.data_template(data_in[0])
|
||||||
output = self.data_load_speech(
|
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
|
||||||
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
|
|
||||||
)
|
|
||||||
batch = to_device(output, kwargs["device"])
|
batch = to_device(output, kwargs["device"])
|
||||||
|
|
||||||
# audio encoder
|
# audio encoder
|
||||||
@ -503,9 +494,7 @@ class FunASRNano(nn.Module):
|
|||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
adaptor_out, adaptor_out_lens = self.audio_adaptor(
|
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||||
encoder_out, encoder_out_lens
|
|
||||||
)
|
|
||||||
meta_data["encoder_out"] = encoder_out
|
meta_data["encoder_out"] = encoder_out
|
||||||
meta_data["encoder_out_lens"] = encoder_out_lens
|
meta_data["encoder_out_lens"] = encoder_out_lens
|
||||||
meta_data["audio_adaptor_out"] = adaptor_out
|
meta_data["audio_adaptor_out"] = adaptor_out
|
||||||
@ -573,7 +562,7 @@ class FunASRNano(nn.Module):
|
|||||||
prompt += ",不进行文本规整"
|
prompt += ",不进行文本规整"
|
||||||
return prompt + ":"
|
return prompt + ":"
|
||||||
|
|
||||||
def generate_chatml(self, prompt: str, data: str | torch.Tensor):
|
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
return [
|
return [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@ -583,11 +572,14 @@ class FunASRNano(nn.Module):
|
|||||||
elif isinstance(data, torch.Tensor):
|
elif isinstance(data, torch.Tensor):
|
||||||
return [
|
return [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
|
||||||
|
"audio": data,
|
||||||
|
},
|
||||||
{"role": "assistant", "content": "null"},
|
{"role": "assistant", "content": "null"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
data_in,
|
data_in,
|
||||||
@ -597,7 +589,9 @@ class FunASRNano(nn.Module):
|
|||||||
frontend=None,
|
frontend=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True))
|
prompt = self.get_prompt(
|
||||||
|
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
|
||||||
|
)
|
||||||
data_in = [self.generate_chatml(prompt, data) for data in data_in]
|
data_in = [self.generate_chatml(prompt, data) for data in data_in]
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
@ -722,8 +716,12 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
for ctc_result, result in zip(ctc_results, results):
|
for ctc_result, result in zip(ctc_results, results):
|
||||||
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
|
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
|
||||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64)
|
target_ids = torch.tensor(
|
||||||
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
|
||||||
|
)
|
||||||
|
result["ctc_timestamps"] = forced_align(
|
||||||
|
ctc_result["ctc_logits"], target_ids, self.blank_id
|
||||||
|
)
|
||||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
||||||
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||||
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
|
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
|
||||||
@ -743,8 +741,6 @@ class FunASRNano(nn.Module):
|
|||||||
def from_pretrained(model: str = None, **kwargs):
|
def from_pretrained(model: str = None, **kwargs):
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
|
|
||||||
model, kwargs = AutoModel.build_model(
|
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
|
||||||
model=model, trust_remote_code=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, kwargs
|
return model, kwargs
|
||||||
|
|||||||
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
Reference in New Issue
Block a user