fix warning

This commit is contained in:
pengzhendong
2026-01-25 22:59:46 +08:00
parent ad332b018e
commit 64e4d92a35
2 changed files with 64 additions and 49 deletions

View File

@ -556,6 +556,36 @@ class FunASRNano(nn.Module):
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ""
def generate_chatml(self, prompt: str, data: str | torch.Tensor):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data},
{"role": "assistant", "content": "null"},
]
def inference(
self,
data_in,
@ -565,57 +595,14 @@ class FunASRNano(nn.Module):
frontend=None,
**kwargs,
):
hotwords = kwargs.get("hotwords", [])
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
language = kwargs.get("language", None)
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
itn = kwargs.get("itn", True)
if not itn:
prompt += ",不进行文本规整"
prompt += ""
new_data_in = []
for data in data_in:
if isinstance(data, str):
new_data_in.append(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>",
},
{"role": "assistant", "content": "null"},
]
)
elif isinstance(data, torch.Tensor):
new_data_in.append(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
)
data_in = new_data_in
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True))
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append(
"rand_key_" + "".join(random.choice(chars) for _ in range(13))
)
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
return self.inference_llm(
data_in,
@ -676,10 +663,13 @@ class FunASRNano(nn.Module):
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("teachforing", False):
if not kwargs.get("teacherforcing", False):
attention_mask = batch.get("attention_mask", None)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_length", 512),
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
@ -697,6 +687,7 @@ class FunASRNano(nn.Module):
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
@ -727,8 +718,8 @@ class FunASRNano(nn.Module):
results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"]
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64)
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)

View File

@ -1,9 +1,33 @@
from itertools import groupby
import soundfile as sf
import torch
import torchaudio
import torchaudio.functional as F
def load_audio(wav_path, rate: int = None, offset: float = 0, duration: float = None):
with sf.SoundFile(wav_path) as f:
start_frame = int(offset * f.samplerate)
if duration is None:
frames_to_read = f.frames - start_frame
else:
frames_to_read = int(duration * f.samplerate)
f.seek(start_frame)
audio_data = f.read(frames_to_read, dtype="float32")
audio_tensor = torch.from_numpy(audio_data)
if rate is not None and f.samplerate != rate:
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
else:
audio_tensor = audio_tensor.T
resampler = torchaudio.transforms.Resample(orig_freq=f.samplerate, new_freq=rate)
audio_tensor = resampler(audio_tensor)
if audio_tensor.shape[0] == 1:
audio_tensor = audio_tensor.squeeze(0)
return audio_tensor, rate if rate is not None else f.samplerate
def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0):
items = []
try: