use hydra instead of argparse
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
import argparse
|
||||
import hydra
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
@ -10,20 +10,7 @@ from urllib.request import urlopen
|
||||
import soundfile as sf
|
||||
from modelscope import AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--scp-file", type=str, required=True)
|
||||
parser.add_argument("--transcript-file", type=str, required=True)
|
||||
parser.add_argument("--jsonl-file", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of concurrent workers (default: 8)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
from omegaconf import DictConfig, OmegaConf, ListConfig
|
||||
|
||||
|
||||
class LineProcessor:
|
||||
@ -78,10 +65,22 @@ class LineProcessor:
|
||||
return {"error": f"Error processing {wav_path}: {str(e)}"}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(cfg: DictConfig):
|
||||
def to_plain_list(cfg_item):
|
||||
if isinstance(cfg_item, ListConfig):
|
||||
return OmegaConf.to_container(cfg_item, resolve=True)
|
||||
elif isinstance(cfg_item, DictConfig):
|
||||
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
||||
else:
|
||||
return cfg_item
|
||||
kwargs = to_plain_list(cfg)
|
||||
scp_file = kwargs["scp_file"]
|
||||
transcript_file = kwargs["transcript_file"]
|
||||
max_workers = kwargs.get("max_workers", os.cpu_count())
|
||||
jsonl_file = kwargs["jsonl_file"]
|
||||
|
||||
with open(args.scp_file, "r") as f1, open(args.transcript_file, "r") as f2:
|
||||
with open(scp_file, "r") as f1, open(transcript_file, "r") as f2:
|
||||
scp_lines = f1.readlines()
|
||||
transcript_lines = f2.readlines()
|
||||
|
||||
@ -100,8 +99,8 @@ def main():
|
||||
error_messages = []
|
||||
|
||||
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||
with open(args.jsonl_file, "w") as f_out:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
with open(jsonl_file, "w") as f_out:
|
||||
futures = {
|
||||
executor.submit(processor.process_line, pair): i
|
||||
for i, pair in enumerate(data_pairs)
|
||||
@ -141,4 +140,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main_hydra()
|
||||
|
||||
Reference in New Issue
Block a user