use hydra instead of argparse

This commit is contained in:
pengzhendong
2026-01-06 15:19:42 +08:00
parent e068b29c23
commit 82c1e6afc8

View File

@ -1,19 +1,25 @@
import argparse
import os import os
import hydra
import torch import torch
from funasr import AutoModel
from model import FunASRNano from omegaconf import DictConfig, OmegaConf, ListConfig
def main(): @hydra.main(config_name=None, version_base=None)
parser = argparse.ArgumentParser(description="FunASR-Nano") def main_hydra(cfg: DictConfig):
parser.add_argument("--scp-file", type=str, required=True) def to_plain_list(cfg_item):
parser.add_argument("--output-file", type=str, required=True) if isinstance(cfg_item, ListConfig):
parser.add_argument( return OmegaConf.to_container(cfg_item, resolve=True)
"--model-dir", type=str, default="FunAudioLLM/Fun-ASR-Nano-2512" elif isinstance(cfg_item, DictConfig):
) return {k: to_plain_list(v) for k, v in cfg_item.items()}
args = parser.parse_args() else:
return cfg_item
kwargs = to_plain_list(cfg)
model_dir = kwargs.get("model_dir", "FunAudioLLM/Fun-ASR-Nano-2512")
scp_file = kwargs["scp_file"]
output_file = kwargs["output_file"]
device = ( device = (
"cuda:0" "cuda:0"
@ -22,24 +28,30 @@ def main():
if torch.backends.mps.is_available() if torch.backends.mps.is_available()
else "cpu" else "cpu"
) )
m, kwargs = FunASRNano.from_pretrained(model=args.model_dir, device=device) model = AutoModel(
m.eval() model=model_dir,
trust_remote_code=True,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
remote_code="./model.py",
device=device,
)
output_dir = os.path.dirname(args.output_file) output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir): if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
with open(args.scp_file, "r", encoding="utf-8") as f1: with open(scp_file, "r", encoding="utf-8") as f1:
with open(args.output_file, "w", encoding="utf-8") as f2: with open(output_file, "w", encoding="utf-8") as f2:
for line in f1: for line in f1:
line = line.strip() line = line.strip()
if not line: if not line:
continue continue
parts = line.split(maxsplit=1) parts = line.split(maxsplit=1)
if len(parts) == 2: if len(parts) == 2:
text = m.inference(data_in=[parts[1]], **kwargs)[0][0]["text"] text = model.generate(input=[parts[1]], cache={}, batch_size=1)[0]["text"]
f2.write(f"{parts[0]}\t{text}\n") f2.write(f"{parts[0]}\t{text}\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main_hydra()