From 82c1e6afc8bab5ec0d35201a1ad6404e076d3959 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Tue, 6 Jan 2026 15:19:42 +0800 Subject: [PATCH] use hydra instead of argparse --- decode.py | 48 ++++++++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/decode.py b/decode.py index 815dba9..9f62fef 100644 --- a/decode.py +++ b/decode.py @@ -1,19 +1,25 @@ -import argparse import os +import hydra import torch - -from model import FunASRNano +from funasr import AutoModel +from omegaconf import DictConfig, OmegaConf, ListConfig -def main(): - parser = argparse.ArgumentParser(description="FunASR-Nano") - parser.add_argument("--scp-file", type=str, required=True) - parser.add_argument("--output-file", type=str, required=True) - parser.add_argument( - "--model-dir", type=str, default="FunAudioLLM/Fun-ASR-Nano-2512" - ) - args = parser.parse_args() +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + kwargs = to_plain_list(cfg) + + model_dir = kwargs.get("model_dir", "FunAudioLLM/Fun-ASR-Nano-2512") + scp_file = kwargs["scp_file"] + output_file = kwargs["output_file"] device = ( "cuda:0" @@ -22,24 +28,30 @@ def main(): if torch.backends.mps.is_available() else "cpu" ) - m, kwargs = FunASRNano.from_pretrained(model=args.model_dir, device=device) - m.eval() + model = AutoModel( + model=model_dir, + trust_remote_code=True, + vad_model="fsmn-vad", + vad_kwargs={"max_single_segment_time": 30000}, + remote_code="./model.py", + device=device, + ) - output_dir = os.path.dirname(args.output_file) + output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - with open(args.scp_file, "r", encoding="utf-8") as f1: - with open(args.output_file, "w", encoding="utf-8") as f2: + with open(scp_file, "r", encoding="utf-8") as f1: + with open(output_file, "w", encoding="utf-8") as f2: for line in f1: line = line.strip() if not line: continue parts = line.split(maxsplit=1) if len(parts) == 2: - text = m.inference(data_in=[parts[1]], **kwargs)[0][0]["text"] + text = model.generate(input=[parts[1]], cache={}, batch_size=1)[0]["text"] f2.write(f"{parts[0]}\t{text}\n") if __name__ == "__main__": - main() + main_hydra()