This commit is contained in:
游雁
2026-01-26 20:31:28 +08:00
parent c38d22f5f2
commit d9ba359ddf
2 changed files with 49 additions and 53 deletions

102
model.py
View File

@ -5,10 +5,11 @@ import re
import string
import time
import traceback
from typing import Union
import torch
import torch.nn as nn
from funasr import AutoModel
from funasr.metrics.compute_acc import compute_accuracy
from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device
@ -44,6 +45,8 @@ class FunASRNano(nn.Module):
"activation_checkpoint", False
)
if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master")
audio_encoder_output_size = (
model.model.encoder_output_size
@ -51,9 +54,7 @@ class FunASRNano(nn.Module):
else -1
)
audio_encoder = (
model.model.model.encoder
if hasattr(model.model, "model")
else model.model.encoder
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
)
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
@ -109,8 +110,16 @@ class FunASRNano(nn.Module):
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"]
ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
ctc_tokenizer = (
kwargs.get("ctc_tokenizer", None)
if "ctc_tokenizer" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer"]
)
ctc_tokenizer_conf = (
kwargs.get("ctc_tokenizer_conf", None)
if "ctc_tokenizer_conf" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
)
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
@ -126,9 +135,7 @@ class FunASRNano(nn.Module):
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
@ -182,9 +189,7 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens
)
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
@ -223,9 +228,7 @@ class FunASRNano(nn.Module):
stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = (
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
)
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
device_type = next(self.parameters()).device.type
with torch.autocast(
@ -244,9 +247,7 @@ class FunASRNano(nn.Module):
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(
preds[:, :-1], labels_ids[:, 1:], ignore_label=-100
)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
@ -254,9 +255,7 @@ class FunASRNano(nn.Module):
stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = (
stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
)
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item()
@ -306,9 +305,7 @@ class FunASRNano(nn.Module):
return contents
def data_load_speech(
self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs
):
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
@ -329,9 +326,7 @@ class FunASRNano(nn.Module):
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
):
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
@ -346,12 +341,16 @@ class FunASRNano(nn.Module):
else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt:
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
if not do_think:
source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None:
@ -384,9 +383,7 @@ class FunASRNano(nn.Module):
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
)
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
speech, speech_lengths = extract_fbank(
data_src,
@ -428,9 +425,7 @@ class FunASRNano(nn.Module):
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
input_ids = torch.tensor(
input_ids, dtype=torch.int64
) # [: self.max_token_length]
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
@ -441,9 +436,7 @@ class FunASRNano(nn.Module):
target_ids = torch.tensor(target_ids, dtype=torch.int64)
if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence(
fbank, batch_first=True, padding_value=0.0
)
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1
)
@ -480,9 +473,7 @@ class FunASRNano(nn.Module):
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
)
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
@ -503,9 +494,7 @@ class FunASRNano(nn.Module):
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens
)
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
@ -573,7 +562,7 @@ class FunASRNano(nn.Module):
prompt += ",不进行文本规整"
return prompt + ""
def generate_chatml(self, prompt: str, data: str | torch.Tensor):
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
@ -583,11 +572,14 @@ class FunASRNano(nn.Module):
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
def inference(
self,
data_in,
@ -597,7 +589,9 @@ class FunASRNano(nn.Module):
frontend=None,
**kwargs,
):
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True))
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
@ -722,8 +716,12 @@ class FunASRNano(nn.Module):
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64)
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
target_ids = torch.tensor(
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
)
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
@ -743,8 +741,6 @@ class FunASRNano(nn.Module):
def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel
model, kwargs = AutoModel.build_model(
model=model, trust_remote_code=True, **kwargs
)
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs

0
tools/__init__.py Normal file
View File