From 64e4d92a35c158b2728c4dc2a00dc0776d2e3554 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Sun, 25 Jan 2026 22:59:46 +0800 Subject: [PATCH] fix warning --- model.py | 89 +++++++++++++++++++++++--------------------------- tools/utils.py | 24 ++++++++++++++ 2 files changed, 64 insertions(+), 49 deletions(-) diff --git a/model.py b/model.py index e8f3810..8ac2c3f 100644 --- a/model.py +++ b/model.py @@ -556,6 +556,36 @@ class FunASRNano(nn.Module): speech_idx += 1 return inputs_embeds, contents, batch, source_ids, meta_data + def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True): + if len(hotwords) > 0: + hotwords = ", ".join(hotwords) + prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n" + prompt += f"热词列表:[{hotwords}]\n" + else: + prompt = "" + if language is None: + prompt += "语音转写" + else: + prompt += f"语音转写成{language}" + if not itn: + prompt += ",不进行文本规整" + return prompt + ":" + + def generate_chatml(self, prompt: str, data: str | torch.Tensor): + if isinstance(data, str): + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"}, + {"role": "assistant", "content": "null"}, + ] + elif isinstance(data, torch.Tensor): + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data}, + {"role": "assistant", "content": "null"}, + ] + + def inference( self, data_in, @@ -565,57 +595,14 @@ class FunASRNano(nn.Module): frontend=None, **kwargs, ): - hotwords = kwargs.get("hotwords", []) - if len(hotwords) > 0: - hotwords = ", ".join(hotwords) - prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n" - prompt += f"热词列表:[{hotwords}]\n" - else: - prompt = "" - language = kwargs.get("language", None) - if language is None: - prompt += "语音转写" - else: - prompt += f"语音转写成{language}" - itn = kwargs.get("itn", True) - if not itn: - prompt += ",不进行文本规整" - prompt += ":" - - new_data_in = [] - for data in data_in: - if isinstance(data, str): - new_data_in.append( - [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>", - }, - {"role": "assistant", "content": "null"}, - ] - ) - elif isinstance(data, torch.Tensor): - new_data_in.append( - [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", - "audio": data, - }, - {"role": "assistant", "content": "null"}, - ] - ) - data_in = new_data_in + prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)) + data_in = [self.generate_chatml(prompt, data) for data in data_in] if key is None: key = [] for _ in data_in: chars = string.ascii_letters + string.digits - key.append( - "rand_key_" + "".join(random.choice(chars) for _ in range(13)) - ) + key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13))) return self.inference_llm( data_in, @@ -676,10 +663,13 @@ class FunASRNano(nn.Module): self.llm = self.llm.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) llm_kwargs = kwargs.get("llm_kwargs", {}) - if not kwargs.get("teachforing", False): + if not kwargs.get("teacherforcing", False): + attention_mask = batch.get("attention_mask", None) generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, + attention_mask=attention_mask, max_new_tokens=kwargs.get("max_length", 512), + pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, **llm_kwargs, ) @@ -697,6 +687,7 @@ class FunASRNano(nn.Module): inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids, + pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, **llm_kwargs, ) @@ -727,8 +718,8 @@ class FunASRNano(nn.Module): results.append(result_i) for ctc_result, result in zip(ctc_results, results): - result["ctc_text"] = ctc_result["text"] - target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64) + result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "") + target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64) result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) diff --git a/tools/utils.py b/tools/utils.py index 7526a68..7125bca 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -1,9 +1,33 @@ from itertools import groupby +import soundfile as sf import torch +import torchaudio import torchaudio.functional as F +def load_audio(wav_path, rate: int = None, offset: float = 0, duration: float = None): + with sf.SoundFile(wav_path) as f: + start_frame = int(offset * f.samplerate) + if duration is None: + frames_to_read = f.frames - start_frame + else: + frames_to_read = int(duration * f.samplerate) + f.seek(start_frame) + audio_data = f.read(frames_to_read, dtype="float32") + audio_tensor = torch.from_numpy(audio_data) + if rate is not None and f.samplerate != rate: + if audio_tensor.ndim == 1: + audio_tensor = audio_tensor.unsqueeze(0) + else: + audio_tensor = audio_tensor.T + resampler = torchaudio.transforms.Resample(orig_freq=f.samplerate, new_freq=rate) + audio_tensor = resampler(audio_tensor) + if audio_tensor.shape[0] == 1: + audio_tensor = audio_tensor.squeeze(0) + return audio_tensor, rate if rate is not None else f.samplerate + + def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0): items = [] try: