use hydra instead of argparse

This commit is contained in:
pengzhendong
2026-01-07 17:02:10 +08:00
parent 82c1e6afc8
commit 4506e00e9c
2 changed files with 21 additions and 22 deletions

View File

@ -3,7 +3,7 @@ import os
import hydra import hydra
import torch import torch
from funasr import AutoModel from funasr import AutoModel
from omegaconf import DictConfig, OmegaConf, ListConfig from omegaconf import DictConfig, ListConfig, OmegaConf
@hydra.main(config_name=None, version_base=None) @hydra.main(config_name=None, version_base=None)

View File

@ -1,4 +1,4 @@
import argparse import hydra
import json import json
import os import os
import threading import threading
@ -10,20 +10,7 @@ from urllib.request import urlopen
import soundfile as sf import soundfile as sf
from modelscope import AutoTokenizer from modelscope import AutoTokenizer
from tqdm import tqdm from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--scp-file", type=str, required=True)
parser.add_argument("--transcript-file", type=str, required=True)
parser.add_argument("--jsonl-file", type=str, required=True)
parser.add_argument(
"--max-workers",
type=int,
default=8,
help="Number of concurrent workers (default: 8)",
)
return parser.parse_args()
class LineProcessor: class LineProcessor:
@ -78,10 +65,22 @@ class LineProcessor:
return {"error": f"Error processing {wav_path}: {str(e)}"} return {"error": f"Error processing {wav_path}: {str(e)}"}
def main(): @hydra.main(config_name=None, version_base=None)
args = parse_args() def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
scp_file = kwargs["scp_file"]
transcript_file = kwargs["transcript_file"]
max_workers = kwargs.get("max_workers", os.cpu_count())
jsonl_file = kwargs["jsonl_file"]
with open(args.scp_file, "r") as f1, open(args.transcript_file, "r") as f2: with open(scp_file, "r") as f1, open(transcript_file, "r") as f2:
scp_lines = f1.readlines() scp_lines = f1.readlines()
transcript_lines = f2.readlines() transcript_lines = f2.readlines()
@ -100,8 +99,8 @@ def main():
error_messages = [] error_messages = []
with tqdm(total=len(data_pairs), desc="Processing") as pbar: with tqdm(total=len(data_pairs), desc="Processing") as pbar:
with ThreadPoolExecutor(max_workers=args.max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
with open(args.jsonl_file, "w") as f_out: with open(jsonl_file, "w") as f_out:
futures = { futures = {
executor.submit(processor.process_line, pair): i executor.submit(processor.process_line, pair): i
for i, pair in enumerate(data_pairs) for i, pair in enumerate(data_pairs)
@ -141,4 +140,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main_hydra()