This commit is contained in:
游雁
2026-01-26 20:31:28 +08:00
parent c38d22f5f2
commit d9ba359ddf
2 changed files with 49 additions and 53 deletions

102
model.py
View File

@ -5,10 +5,11 @@ import re
import string import string
import time import time
import traceback import traceback
from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from funasr import AutoModel
from funasr.metrics.compute_acc import compute_accuracy from funasr.metrics.compute_acc import compute_accuracy
from funasr.register import tables from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device from funasr.train_utils.device_funcs import force_gatherable, to_device
@ -44,6 +45,8 @@ class FunASRNano(nn.Module):
"activation_checkpoint", False "activation_checkpoint", False
) )
if hub == "ms": if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master") model = AutoModel(model=audio_encoder, model_revision="master")
audio_encoder_output_size = ( audio_encoder_output_size = (
model.model.encoder_output_size model.model.encoder_output_size
@ -51,9 +54,7 @@ class FunASRNano(nn.Module):
else -1 else -1
) )
audio_encoder = ( audio_encoder = (
model.model.model.encoder model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
if hasattr(model.model, "model")
else model.model.encoder
) )
else: else:
encoder_class = tables.encoder_classes.get(audio_encoder) encoder_class = tables.encoder_classes.get(audio_encoder)
@ -109,8 +110,16 @@ class FunASRNano(nn.Module):
# TODO: fix table name # TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None)) ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None: if ctc_decoder_class is not None:
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"] ctc_tokenizer = (
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"] kwargs.get("ctc_tokenizer", None)
if "ctc_tokenizer" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer"]
)
ctc_tokenizer_conf = (
kwargs.get("ctc_tokenizer_conf", None)
if "ctc_tokenizer_conf" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
)
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None: if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer) ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf) ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
@ -126,9 +135,7 @@ class FunASRNano(nn.Module):
if init_param_path is not None: if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu") src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False) flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info( logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
)
freeze = ctc_decoder_conf.get("freeze", False) freeze = ctc_decoder_conf.get("freeze", False)
if freeze: if freeze:
for _, param in self.ctc_decoder.named_parameters(): for _, param in self.ctc_decoder.named_parameters():
@ -182,9 +189,7 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor # audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor( encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
encoder_out, encoder_out_lens
)
batch_size, token_num, dims = inputs_embeds.shape batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len") fake_token_len = kwargs.get("fake_token_len")
@ -223,9 +228,7 @@ class FunASRNano(nn.Module):
stats["batch_size_speech"] = batch_size_speech stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item() stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = ( stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
)
device_type = next(self.parameters()).device.type device_type = next(self.parameters()).device.type
with torch.autocast( with torch.autocast(
@ -244,9 +247,7 @@ class FunASRNano(nn.Module):
with torch.no_grad(): with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1) preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy( acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
preds[:, :-1], labels_ids[:, 1:], ignore_label=-100
)
stats["acc"] = acc_att stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach()) stats["loss"] = torch.clone(loss.detach())
@ -254,9 +255,7 @@ class FunASRNano(nn.Module):
stats["batch_size_x_tokens"] = token_num * batch_size stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item() stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = ( stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
)
dialog_turns = (fbank_beg > 0).sum(-1) dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item() dialog_turns_max = torch.max(dialog_turns).int().item()
@ -306,9 +305,7 @@ class FunASRNano(nn.Module):
return contents return contents
def data_load_speech( def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs
):
system = contents["system"] system = contents["system"]
user = contents["user"] user = contents["user"]
assistant = contents["assistant"] assistant = contents["assistant"]
@ -329,9 +326,7 @@ class FunASRNano(nn.Module):
[], [],
) )
input_source_ids = [] input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate( for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
zip(system, user, assistant)
):
if i >= kwargs.get("multiturn_num_max", 5): if i >= kwargs.get("multiturn_num_max", 5):
break break
if len(input_ids) > kwargs.get("max_token_length", 1500): if len(input_ids) > kwargs.get("max_token_length", 1500):
@ -346,12 +341,16 @@ class FunASRNano(nn.Module):
else: else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt: if not sys_prompt:
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else: else:
if kwargs.get("infer_with_assistant_input", False): if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}" source_input = f"<|im_start|>user\n{user_prompt}"
else: else:
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
if not do_think: if not do_think:
source_input += "<think>\n\n</think>\n\n" source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None: if kwargs.get("prev_text", None) is not None:
@ -384,9 +383,7 @@ class FunASRNano(nn.Module):
time2 = time.perf_counter() time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e: except Exception as e:
logging.error( logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
)
speech, speech_lengths = extract_fbank( speech, speech_lengths = extract_fbank(
data_src, data_src,
@ -428,9 +425,7 @@ class FunASRNano(nn.Module):
fbank.append(speech[0, :, :]) fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths) fbank_lens.append(speech_lengths)
input_ids = torch.tensor( input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
input_ids, dtype=torch.int64
) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
@ -441,9 +436,7 @@ class FunASRNano(nn.Module):
target_ids = torch.tensor(target_ids, dtype=torch.int64) target_ids = torch.tensor(target_ids, dtype=torch.int64)
if len(fbank) > 0: if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence( speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
fbank, batch_first=True, padding_value=0.0
)
speech_lengths = torch.nn.utils.rnn.pad_sequence( speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1 fbank_lens, batch_first=True, padding_value=-1
) )
@ -480,9 +473,7 @@ class FunASRNano(nn.Module):
raise NotImplementedError("batch decoding is not implemented") raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0]) contents = self.data_template(data_in[0])
output = self.data_load_speech( output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
)
batch = to_device(output, kwargs["device"]) batch = to_device(output, kwargs["device"])
# audio encoder # audio encoder
@ -503,9 +494,7 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor # audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor( adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
encoder_out, encoder_out_lens
)
meta_data["encoder_out"] = encoder_out meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out meta_data["audio_adaptor_out"] = adaptor_out
@ -573,7 +562,7 @@ class FunASRNano(nn.Module):
prompt += ",不进行文本规整" prompt += ",不进行文本规整"
return prompt + "" return prompt + ""
def generate_chatml(self, prompt: str, data: str | torch.Tensor): def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str): if isinstance(data, str):
return [ return [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -583,11 +572,14 @@ class FunASRNano(nn.Module):
elif isinstance(data, torch.Tensor): elif isinstance(data, torch.Tensor):
return [ return [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data}, {
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"}, {"role": "assistant", "content": "null"},
] ]
def inference( def inference(
self, self,
data_in, data_in,
@ -597,7 +589,9 @@ class FunASRNano(nn.Module):
frontend=None, frontend=None,
**kwargs, **kwargs,
): ):
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)) prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in] data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None: if key is None:
@ -722,8 +716,12 @@ class FunASRNano(nn.Module):
for ctc_result, result in zip(ctc_results, results): for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "") result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64) target_ids = torch.tensor(
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
)
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]: for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
@ -743,8 +741,6 @@ class FunASRNano(nn.Module):
def from_pretrained(model: str = None, **kwargs): def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel from funasr import AutoModel
model, kwargs = AutoModel.build_model( model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
model=model, trust_remote_code=True, **kwargs
)
return model, kwargs return model, kwargs

0
tools/__init__.py Normal file
View File