From d9ba359ddf9eeaaa7ad4db8ae209abfe0596cee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 26 Jan 2026 20:31:28 +0800 Subject: [PATCH] update --- model.py | 102 ++++++++++++++++++++++------------------------ tools/__init__.py | 0 2 files changed, 49 insertions(+), 53 deletions(-) create mode 100644 tools/__init__.py diff --git a/model.py b/model.py index fcf7720..3087400 100644 --- a/model.py +++ b/model.py @@ -5,10 +5,11 @@ import re import string import time import traceback +from typing import Union import torch import torch.nn as nn -from funasr import AutoModel + from funasr.metrics.compute_acc import compute_accuracy from funasr.register import tables from funasr.train_utils.device_funcs import force_gatherable, to_device @@ -44,6 +45,8 @@ class FunASRNano(nn.Module): "activation_checkpoint", False ) if hub == "ms": + from funasr import AutoModel + model = AutoModel(model=audio_encoder, model_revision="master") audio_encoder_output_size = ( model.model.encoder_output_size @@ -51,9 +54,7 @@ class FunASRNano(nn.Module): else -1 ) audio_encoder = ( - model.model.model.encoder - if hasattr(model.model, "model") - else model.model.encoder + model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder ) else: encoder_class = tables.encoder_classes.get(audio_encoder) @@ -109,8 +110,16 @@ class FunASRNano(nn.Module): # TODO: fix table name ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None)) if ctc_decoder_class is not None: - ctc_tokenizer = kwargs.get("ctc_tokenizer", None) if "ctc_tokenizer" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer"] - ctc_tokenizer_conf = kwargs.get("ctc_tokenizer_conf", None) if "ctc_tokenizer_conf" in kwargs else kwargs["dataset_conf"]["ctc_tokenizer_conf"] + ctc_tokenizer = ( + kwargs.get("ctc_tokenizer", None) + if "ctc_tokenizer" in kwargs + else kwargs["dataset_conf"]["ctc_tokenizer"] + ) + ctc_tokenizer_conf = ( + kwargs.get("ctc_tokenizer_conf", None) + if "ctc_tokenizer_conf" in kwargs + else kwargs["dataset_conf"]["ctc_tokenizer_conf"] + ) if ctc_tokenizer is not None and ctc_tokenizer_conf is not None: ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer) ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf) @@ -126,9 +135,7 @@ class FunASRNano(nn.Module): if init_param_path is not None: src_state = torch.load(init_param_path, map_location="cpu") flag = self.ctc_decoder.load_state_dict(src_state, strict=False) - logging.info( - f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}" - ) + logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}") freeze = ctc_decoder_conf.get("freeze", False) if freeze: for _, param in self.ctc_decoder.named_parameters(): @@ -182,9 +189,7 @@ class FunASRNano(nn.Module): encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor( - encoder_out, encoder_out_lens - ) + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) batch_size, token_num, dims = inputs_embeds.shape fake_token_len = kwargs.get("fake_token_len") @@ -223,9 +228,7 @@ class FunASRNano(nn.Module): stats["batch_size_speech"] = batch_size_speech stats["batch_size_x_frames"] = frames * batch_size_speech stats["batch_size_real_frames"] = speech_lengths.sum().item() - stats["padding_frames"] = ( - stats["batch_size_x_frames"] - stats["batch_size_real_frames"] - ) + stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] device_type = next(self.parameters()).device.type with torch.autocast( @@ -244,9 +247,7 @@ class FunASRNano(nn.Module): with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) - acc_att = compute_accuracy( - preds[:, :-1], labels_ids[:, 1:], ignore_label=-100 - ) + acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) stats["acc"] = acc_att stats["loss"] = torch.clone(loss.detach()) @@ -254,9 +255,7 @@ class FunASRNano(nn.Module): stats["batch_size_x_tokens"] = token_num * batch_size stats["batch_size_real_tokens"] = attention_mask.sum().item() - stats["padding_tokens"] = ( - stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] - ) + stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] dialog_turns = (fbank_beg > 0).sum(-1) dialog_turns_max = torch.max(dialog_turns).int().item() @@ -306,9 +305,7 @@ class FunASRNano(nn.Module): return contents - def data_load_speech( - self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs - ): + def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): system = contents["system"] user = contents["user"] assistant = contents["assistant"] @@ -329,9 +326,7 @@ class FunASRNano(nn.Module): [], ) input_source_ids = [] - for i, (system_prompt, user_prompt, target_out) in enumerate( - zip(system, user, assistant) - ): + for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): if i >= kwargs.get("multiturn_num_max", 5): break if len(input_ids) > kwargs.get("max_token_length", 1500): @@ -346,12 +341,16 @@ class FunASRNano(nn.Module): else: source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" if not sys_prompt: - source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + source_input = ( + f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + ) else: if kwargs.get("infer_with_assistant_input", False): source_input = f"<|im_start|>user\n{user_prompt}" else: - source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + source_input = ( + f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + ) if not do_think: source_input += "\n\n\n\n" if kwargs.get("prev_text", None) is not None: @@ -384,9 +383,7 @@ class FunASRNano(nn.Module): time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" except Exception as e: - logging.error( - f"Loading wav failed! {str(e)}, {traceback.format_exc()}" - ) + logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") speech, speech_lengths = extract_fbank( data_src, @@ -428,9 +425,7 @@ class FunASRNano(nn.Module): fbank.append(speech[0, :, :]) fbank_lens.append(speech_lengths) - input_ids = torch.tensor( - input_ids, dtype=torch.int64 - ) # [: self.max_token_length] + input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] @@ -441,9 +436,7 @@ class FunASRNano(nn.Module): target_ids = torch.tensor(target_ids, dtype=torch.int64) if len(fbank) > 0: - speech = torch.nn.utils.rnn.pad_sequence( - fbank, batch_first=True, padding_value=0.0 - ) + speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) speech_lengths = torch.nn.utils.rnn.pad_sequence( fbank_lens, batch_first=True, padding_value=-1 ) @@ -480,9 +473,7 @@ class FunASRNano(nn.Module): raise NotImplementedError("batch decoding is not implemented") contents = self.data_template(data_in[0]) - output = self.data_load_speech( - contents, tokenizer, frontend, meta_data=meta_data, **kwargs - ) + output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) batch = to_device(output, kwargs["device"]) # audio encoder @@ -503,9 +494,7 @@ class FunASRNano(nn.Module): encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # audio_adaptor - adaptor_out, adaptor_out_lens = self.audio_adaptor( - encoder_out, encoder_out_lens - ) + adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) meta_data["encoder_out"] = encoder_out meta_data["encoder_out_lens"] = encoder_out_lens meta_data["audio_adaptor_out"] = adaptor_out @@ -573,7 +562,7 @@ class FunASRNano(nn.Module): prompt += ",不进行文本规整" return prompt + ":" - def generate_chatml(self, prompt: str, data: str | torch.Tensor): + def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]): if isinstance(data, str): return [ {"role": "system", "content": "You are a helpful assistant."}, @@ -583,11 +572,14 @@ class FunASRNano(nn.Module): elif isinstance(data, torch.Tensor): return [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data}, + { + "role": "user", + "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", + "audio": data, + }, {"role": "assistant", "content": "null"}, ] - def inference( self, data_in, @@ -597,7 +589,9 @@ class FunASRNano(nn.Module): frontend=None, **kwargs, ): - prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)) + prompt = self.get_prompt( + kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True) + ) data_in = [self.generate_chatml(prompt, data) for data in data_in] if key is None: @@ -722,8 +716,12 @@ class FunASRNano(nn.Module): for ctc_result, result in zip(ctc_results, results): result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "") - target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64) - result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) + target_ids = torch.tensor( + self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64 + ) + result["ctc_timestamps"] = forced_align( + ctc_result["ctc_logits"], target_ids, self.blank_id + ) target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) for timestamps in [result["timestamps"], result["ctc_timestamps"]]: @@ -743,8 +741,6 @@ class FunASRNano(nn.Module): def from_pretrained(model: str = None, **kwargs): from funasr import AutoModel - model, kwargs = AutoModel.build_model( - model=model, trust_remote_code=True, **kwargs - ) + model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) return model, kwargs diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29