fix warning
This commit is contained in:
89
model.py
89
model.py
@ -556,6 +556,36 @@ class FunASRNano(nn.Module):
|
||||
speech_idx += 1
|
||||
return inputs_embeds, contents, batch, source_ids, meta_data
|
||||
|
||||
def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
|
||||
if len(hotwords) > 0:
|
||||
hotwords = ", ".join(hotwords)
|
||||
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
|
||||
prompt += f"热词列表:[{hotwords}]\n"
|
||||
else:
|
||||
prompt = ""
|
||||
if language is None:
|
||||
prompt += "语音转写"
|
||||
else:
|
||||
prompt += f"语音转写成{language}"
|
||||
if not itn:
|
||||
prompt += ",不进行文本规整"
|
||||
return prompt + ":"
|
||||
|
||||
def generate_chatml(self, prompt: str, data: str | torch.Tensor):
|
||||
if isinstance(data, str):
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
|
||||
{"role": "assistant", "content": "null"},
|
||||
]
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>", "audio": data},
|
||||
{"role": "assistant", "content": "null"},
|
||||
]
|
||||
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
@ -565,57 +595,14 @@ class FunASRNano(nn.Module):
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
hotwords = kwargs.get("hotwords", [])
|
||||
if len(hotwords) > 0:
|
||||
hotwords = ", ".join(hotwords)
|
||||
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
|
||||
prompt += f"热词列表:[{hotwords}]\n"
|
||||
else:
|
||||
prompt = ""
|
||||
language = kwargs.get("language", None)
|
||||
if language is None:
|
||||
prompt += "语音转写"
|
||||
else:
|
||||
prompt += f"语音转写成{language}"
|
||||
itn = kwargs.get("itn", True)
|
||||
if not itn:
|
||||
prompt += ",不进行文本规整"
|
||||
prompt += ":"
|
||||
|
||||
new_data_in = []
|
||||
for data in data_in:
|
||||
if isinstance(data, str):
|
||||
new_data_in.append(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>",
|
||||
},
|
||||
{"role": "assistant", "content": "null"},
|
||||
]
|
||||
)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
new_data_in.append(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
|
||||
"audio": data,
|
||||
},
|
||||
{"role": "assistant", "content": "null"},
|
||||
]
|
||||
)
|
||||
data_in = new_data_in
|
||||
prompt = self.get_prompt(kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True))
|
||||
data_in = [self.generate_chatml(prompt, data) for data in data_in]
|
||||
|
||||
if key is None:
|
||||
key = []
|
||||
for _ in data_in:
|
||||
chars = string.ascii_letters + string.digits
|
||||
key.append(
|
||||
"rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
||||
)
|
||||
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
|
||||
|
||||
return self.inference_llm(
|
||||
data_in,
|
||||
@ -676,10 +663,13 @@ class FunASRNano(nn.Module):
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
llm_kwargs = kwargs.get("llm_kwargs", {})
|
||||
if not kwargs.get("teachforing", False):
|
||||
if not kwargs.get("teacherforcing", False):
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=kwargs.get("max_length", 512),
|
||||
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
@ -697,6 +687,7 @@ class FunASRNano(nn.Module):
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels_ids,
|
||||
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
@ -727,8 +718,8 @@ class FunASRNano(nn.Module):
|
||||
results.append(result_i)
|
||||
|
||||
for ctc_result, result in zip(ctc_results, results):
|
||||
result["ctc_text"] = ctc_result["text"]
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(ctc_result["text"]), dtype=torch.int64)
|
||||
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64)
|
||||
result["ctc_timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
|
||||
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
|
||||
|
||||
Reference in New Issue
Block a user