diff --git a/decode.py b/decode.py new file mode 100644 index 0000000..815dba9 --- /dev/null +++ b/decode.py @@ -0,0 +1,45 @@ +import argparse +import os + +import torch + +from model import FunASRNano + + +def main(): + parser = argparse.ArgumentParser(description="FunASR-Nano") + parser.add_argument("--scp-file", type=str, required=True) + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument( + "--model-dir", type=str, default="FunAudioLLM/Fun-ASR-Nano-2512" + ) + args = parser.parse_args() + + device = ( + "cuda:0" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) + m, kwargs = FunASRNano.from_pretrained(model=args.model_dir, device=device) + m.eval() + + output_dir = os.path.dirname(args.output_file) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + with open(args.scp_file, "r", encoding="utf-8") as f1: + with open(args.output_file, "w", encoding="utf-8") as f2: + for line in f1: + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + if len(parts) == 2: + text = m.inference(data_in=[parts[1]], **kwargs)[0][0]["text"] + f2.write(f"{parts[0]}\t{text}\n") + + +if __name__ == "__main__": + main() diff --git a/demo1.py b/demo1.py index 5f8a67e..f99181f 100644 --- a/demo1.py +++ b/demo1.py @@ -7,7 +7,9 @@ def main(): device = ( "cuda:0" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" + else "mps" + if torch.backends.mps.is_available() + else "cpu" ) model = AutoModel( model=model_dir, @@ -28,7 +30,7 @@ def main(): # 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、 # 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512 language="中文", - itn=True, # or False + itn=True, # or False ) text = res[0]["text"] print(text) diff --git a/demo2.py b/demo2.py index cb54702..f161d4d 100644 --- a/demo2.py +++ b/demo2.py @@ -8,7 +8,9 @@ def main(): device = ( "cuda:0" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" + else "mps" + if torch.backends.mps.is_available() + else "cpu" ) m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device) m.eval() diff --git a/model.py b/model.py index 4f2275a..82d9f19 100644 --- a/model.py +++ b/model.py @@ -117,7 +117,9 @@ class FunASRNano(nn.Module): if init_param_path is not None: src_state = torch.load(init_param_path, map_location="cpu") flag = self.ctc_decoder.load_state_dict(src_state, strict=False) - logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}") + logging.info( + f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}" + ) freeze = ctc_decoder_conf.get("freeze", False) if freeze: for _, param in self.ctc_decoder.named_parameters(): @@ -127,7 +129,12 @@ class FunASRNano(nn.Module): ctc_conf = kwargs.get("ctc_conf", {}) self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1) self.ctc_weight = kwargs.get("ctc_weight", 0.3) - self.ctc = CTC(odim=ctc_vocab_size, encoder_output_size=audio_encoder_output_size, blank_id=self.blank_id, **ctc_conf) + self.ctc = CTC( + odim=ctc_vocab_size, + encoder_output_size=audio_encoder_output_size, + blank_id=self.blank_id, + **ctc_conf, + ) self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True) self.error_calculator = None @@ -629,7 +636,7 @@ class FunASRNano(nn.Module): with torch.autocast( device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu", enabled=True if llm_dtype != "fp32" else False, - dtype=dtype_map[llm_dtype] + dtype=dtype_map[llm_dtype], ): label = contents["assistant"][-1] self.llm = self.llm.to(dtype_map[llm_dtype]) @@ -677,7 +684,7 @@ class FunASRNano(nn.Module): response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response) result_i = { "key": key[0], - "text": re.sub(r'\s+', ' ', response.replace("/sil", " ")), + "text": re.sub(r"\s+", " ", response.replace("/sil", " ")), "text_tn": response_clean, "label": label, } diff --git a/tools/scp2jsonl.py b/tools/scp2jsonl.py index e7a1de3..87d8348 100644 --- a/tools/scp2jsonl.py +++ b/tools/scp2jsonl.py @@ -17,8 +17,12 @@ def parse_args(): parser.add_argument("--scp-file", type=str, required=True) parser.add_argument("--transcript-file", type=str, required=True) parser.add_argument("--jsonl-file", type=str, required=True) - parser.add_argument("--max-workers", type=int, default=8, - help="Number of concurrent workers (default: 8)") + parser.add_argument( + "--max-workers", + type=int, + default=8, + help="Number of concurrent workers (default: 8)", + ) return parser.parse_args() @@ -59,11 +63,14 @@ class LineProcessor: data = { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>"}, - {"role": "assistant", "content": text} + { + "role": "user", + "content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>", + }, + {"role": "assistant", "content": text}, ], "speech_length": int((duration * 1000 - 25) // 10 + 1), - "text_length": len(self.tokenizer.tokenize(text)) + "text_length": len(self.tokenizer.tokenize(text)), } return {"success": data, "utt": utt1} @@ -79,7 +86,9 @@ def main(): transcript_lines = f2.readlines() if len(scp_lines) != len(transcript_lines): - print(f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}") + print( + f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}" + ) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") processor = LineProcessor(tokenizer) @@ -93,8 +102,10 @@ def main(): with tqdm(total=len(data_pairs), desc="Processing") as pbar: with ThreadPoolExecutor(max_workers=args.max_workers) as executor: with open(args.jsonl_file, "w") as f_out: - futures = {executor.submit(processor.process_line, pair): i - for i, pair in enumerate(data_pairs)} + futures = { + executor.submit(processor.process_line, pair): i + for i, pair in enumerate(data_pairs) + } for future in as_completed(futures): result = future.result() @@ -109,10 +120,9 @@ def main(): error_messages.append(result["error"]) pbar.update(1) - pbar.set_postfix({ - "processed": processed_count, - "failed": failed_count - }) + pbar.set_postfix( + {"processed": processed_count, "failed": failed_count} + ) print(f"\nProcessing completed:") print(f" Total lines: {len(data_pairs)}")