add decode script
This commit is contained in:
45
decode.py
Normal file
45
decode.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model import FunASRNano
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="FunASR-Nano")
|
||||||
|
parser.add_argument("--scp-file", type=str, required=True)
|
||||||
|
parser.add_argument("--output-file", type=str, required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-dir", type=str, default="FunAudioLLM/Fun-ASR-Nano-2512"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = (
|
||||||
|
"cuda:0"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
m, kwargs = FunASRNano.from_pretrained(model=args.model_dir, device=device)
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
output_dir = os.path.dirname(args.output_file)
|
||||||
|
if output_dir and not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(args.scp_file, "r", encoding="utf-8") as f1:
|
||||||
|
with open(args.output_file, "w", encoding="utf-8") as f2:
|
||||||
|
for line in f1:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(maxsplit=1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
text = m.inference(data_in=[parts[1]], **kwargs)[0][0]["text"]
|
||||||
|
f2.write(f"{parts[0]}\t{text}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
4
demo1.py
4
demo1.py
@ -7,7 +7,9 @@ def main():
|
|||||||
device = (
|
device = (
|
||||||
"cuda:0"
|
"cuda:0"
|
||||||
if torch.cuda.is_available()
|
if torch.cuda.is_available()
|
||||||
else "mps" if torch.backends.mps.is_available() else "cpu"
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
)
|
)
|
||||||
model = AutoModel(
|
model = AutoModel(
|
||||||
model=model_dir,
|
model=model_dir,
|
||||||
|
|||||||
4
demo2.py
4
demo2.py
@ -8,7 +8,9 @@ def main():
|
|||||||
device = (
|
device = (
|
||||||
"cuda:0"
|
"cuda:0"
|
||||||
if torch.cuda.is_available()
|
if torch.cuda.is_available()
|
||||||
else "mps" if torch.backends.mps.is_available() else "cpu"
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
)
|
)
|
||||||
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
|
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
|
||||||
m.eval()
|
m.eval()
|
||||||
|
|||||||
15
model.py
15
model.py
@ -117,7 +117,9 @@ class FunASRNano(nn.Module):
|
|||||||
if init_param_path is not None:
|
if init_param_path is not None:
|
||||||
src_state = torch.load(init_param_path, map_location="cpu")
|
src_state = torch.load(init_param_path, map_location="cpu")
|
||||||
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
|
||||||
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
|
logging.info(
|
||||||
|
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
|
||||||
|
)
|
||||||
freeze = ctc_decoder_conf.get("freeze", False)
|
freeze = ctc_decoder_conf.get("freeze", False)
|
||||||
if freeze:
|
if freeze:
|
||||||
for _, param in self.ctc_decoder.named_parameters():
|
for _, param in self.ctc_decoder.named_parameters():
|
||||||
@ -127,7 +129,12 @@ class FunASRNano(nn.Module):
|
|||||||
ctc_conf = kwargs.get("ctc_conf", {})
|
ctc_conf = kwargs.get("ctc_conf", {})
|
||||||
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
|
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
|
||||||
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
|
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
|
||||||
self.ctc = CTC(odim=ctc_vocab_size, encoder_output_size=audio_encoder_output_size, blank_id=self.blank_id, **ctc_conf)
|
self.ctc = CTC(
|
||||||
|
odim=ctc_vocab_size,
|
||||||
|
encoder_output_size=audio_encoder_output_size,
|
||||||
|
blank_id=self.blank_id,
|
||||||
|
**ctc_conf,
|
||||||
|
)
|
||||||
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
|
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
|
||||||
self.error_calculator = None
|
self.error_calculator = None
|
||||||
|
|
||||||
@ -629,7 +636,7 @@ class FunASRNano(nn.Module):
|
|||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
|
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
|
||||||
enabled=True if llm_dtype != "fp32" else False,
|
enabled=True if llm_dtype != "fp32" else False,
|
||||||
dtype=dtype_map[llm_dtype]
|
dtype=dtype_map[llm_dtype],
|
||||||
):
|
):
|
||||||
label = contents["assistant"][-1]
|
label = contents["assistant"][-1]
|
||||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||||
@ -677,7 +684,7 @@ class FunASRNano(nn.Module):
|
|||||||
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
||||||
result_i = {
|
result_i = {
|
||||||
"key": key[0],
|
"key": key[0],
|
||||||
"text": re.sub(r'\s+', ' ', response.replace("/sil", " ")),
|
"text": re.sub(r"\s+", " ", response.replace("/sil", " ")),
|
||||||
"text_tn": response_clean,
|
"text_tn": response_clean,
|
||||||
"label": label,
|
"label": label,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,8 +17,12 @@ def parse_args():
|
|||||||
parser.add_argument("--scp-file", type=str, required=True)
|
parser.add_argument("--scp-file", type=str, required=True)
|
||||||
parser.add_argument("--transcript-file", type=str, required=True)
|
parser.add_argument("--transcript-file", type=str, required=True)
|
||||||
parser.add_argument("--jsonl-file", type=str, required=True)
|
parser.add_argument("--jsonl-file", type=str, required=True)
|
||||||
parser.add_argument("--max-workers", type=int, default=8,
|
parser.add_argument(
|
||||||
help="Number of concurrent workers (default: 8)")
|
"--max-workers",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of concurrent workers (default: 8)",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -59,11 +63,14 @@ class LineProcessor:
|
|||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>"},
|
{
|
||||||
{"role": "assistant", "content": text}
|
"role": "user",
|
||||||
|
"content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": text},
|
||||||
],
|
],
|
||||||
"speech_length": int((duration * 1000 - 25) // 10 + 1),
|
"speech_length": int((duration * 1000 - 25) // 10 + 1),
|
||||||
"text_length": len(self.tokenizer.tokenize(text))
|
"text_length": len(self.tokenizer.tokenize(text)),
|
||||||
}
|
}
|
||||||
return {"success": data, "utt": utt1}
|
return {"success": data, "utt": utt1}
|
||||||
|
|
||||||
@ -79,7 +86,9 @@ def main():
|
|||||||
transcript_lines = f2.readlines()
|
transcript_lines = f2.readlines()
|
||||||
|
|
||||||
if len(scp_lines) != len(transcript_lines):
|
if len(scp_lines) != len(transcript_lines):
|
||||||
print(f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}")
|
print(
|
||||||
|
f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}"
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||||
processor = LineProcessor(tokenizer)
|
processor = LineProcessor(tokenizer)
|
||||||
@ -93,8 +102,10 @@ def main():
|
|||||||
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
|
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
|
||||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||||
with open(args.jsonl_file, "w") as f_out:
|
with open(args.jsonl_file, "w") as f_out:
|
||||||
futures = {executor.submit(processor.process_line, pair): i
|
futures = {
|
||||||
for i, pair in enumerate(data_pairs)}
|
executor.submit(processor.process_line, pair): i
|
||||||
|
for i, pair in enumerate(data_pairs)
|
||||||
|
}
|
||||||
|
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
result = future.result()
|
result = future.result()
|
||||||
@ -109,10 +120,9 @@ def main():
|
|||||||
error_messages.append(result["error"])
|
error_messages.append(result["error"])
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_postfix({
|
pbar.set_postfix(
|
||||||
"processed": processed_count,
|
{"processed": processed_count, "failed": failed_count}
|
||||||
"failed": failed_count
|
)
|
||||||
})
|
|
||||||
|
|
||||||
print(f"\nProcessing completed:")
|
print(f"\nProcessing completed:")
|
||||||
print(f" Total lines: {len(data_pairs)}")
|
print(f" Total lines: {len(data_pairs)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user