use hydra instead of argparse

This commit is contained in:
pengzhendong
2026-01-07 17:02:10 +08:00
parent 82c1e6afc8
commit 4506e00e9c
2 changed files with 21 additions and 22 deletions

View File

@ -1,4 +1,4 @@
import argparse
import hydra
import json
import os
import threading
@ -10,20 +10,7 @@ from urllib.request import urlopen
import soundfile as sf
from modelscope import AutoTokenizer
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--scp-file", type=str, required=True)
parser.add_argument("--transcript-file", type=str, required=True)
parser.add_argument("--jsonl-file", type=str, required=True)
parser.add_argument(
"--max-workers",
type=int,
default=8,
help="Number of concurrent workers (default: 8)",
)
return parser.parse_args()
from omegaconf import DictConfig, OmegaConf, ListConfig
class LineProcessor:
@ -78,10 +65,22 @@ class LineProcessor:
return {"error": f"Error processing {wav_path}: {str(e)}"}
def main():
args = parse_args()
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
scp_file = kwargs["scp_file"]
transcript_file = kwargs["transcript_file"]
max_workers = kwargs.get("max_workers", os.cpu_count())
jsonl_file = kwargs["jsonl_file"]
with open(args.scp_file, "r") as f1, open(args.transcript_file, "r") as f2:
with open(scp_file, "r") as f1, open(transcript_file, "r") as f2:
scp_lines = f1.readlines()
transcript_lines = f2.readlines()
@ -100,8 +99,8 @@ def main():
error_messages = []
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
with open(args.jsonl_file, "w") as f_out:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
with open(jsonl_file, "w") as f_out:
futures = {
executor.submit(processor.process_line, pair): i
for i, pair in enumerate(data_pairs)
@ -141,4 +140,4 @@ def main():
if __name__ == "__main__":
main()
main_hydra()