use hydra instead of argparse
This commit is contained in:
@ -3,7 +3,7 @@ import os
|
|||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
from omegaconf import DictConfig, OmegaConf, ListConfig
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_name=None, version_base=None)
|
@hydra.main(config_name=None, version_base=None)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import argparse
|
import hydra
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -10,20 +10,7 @@ from urllib.request import urlopen
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from modelscope import AutoTokenizer
|
from modelscope import AutoTokenizer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from omegaconf import DictConfig, OmegaConf, ListConfig
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--scp-file", type=str, required=True)
|
|
||||||
parser.add_argument("--transcript-file", type=str, required=True)
|
|
||||||
parser.add_argument("--jsonl-file", type=str, required=True)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Number of concurrent workers (default: 8)",
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
class LineProcessor:
|
class LineProcessor:
|
||||||
@ -78,10 +65,22 @@ class LineProcessor:
|
|||||||
return {"error": f"Error processing {wav_path}: {str(e)}"}
|
return {"error": f"Error processing {wav_path}: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
@hydra.main(config_name=None, version_base=None)
|
||||||
args = parse_args()
|
def main_hydra(cfg: DictConfig):
|
||||||
|
def to_plain_list(cfg_item):
|
||||||
|
if isinstance(cfg_item, ListConfig):
|
||||||
|
return OmegaConf.to_container(cfg_item, resolve=True)
|
||||||
|
elif isinstance(cfg_item, DictConfig):
|
||||||
|
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
||||||
|
else:
|
||||||
|
return cfg_item
|
||||||
|
kwargs = to_plain_list(cfg)
|
||||||
|
scp_file = kwargs["scp_file"]
|
||||||
|
transcript_file = kwargs["transcript_file"]
|
||||||
|
max_workers = kwargs.get("max_workers", os.cpu_count())
|
||||||
|
jsonl_file = kwargs["jsonl_file"]
|
||||||
|
|
||||||
with open(args.scp_file, "r") as f1, open(args.transcript_file, "r") as f2:
|
with open(scp_file, "r") as f1, open(transcript_file, "r") as f2:
|
||||||
scp_lines = f1.readlines()
|
scp_lines = f1.readlines()
|
||||||
transcript_lines = f2.readlines()
|
transcript_lines = f2.readlines()
|
||||||
|
|
||||||
@ -100,8 +99,8 @@ def main():
|
|||||||
error_messages = []
|
error_messages = []
|
||||||
|
|
||||||
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
|
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
|
||||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
with open(args.jsonl_file, "w") as f_out:
|
with open(jsonl_file, "w") as f_out:
|
||||||
futures = {
|
futures = {
|
||||||
executor.submit(processor.process_line, pair): i
|
executor.submit(processor.process_line, pair): i
|
||||||
for i, pair in enumerate(data_pairs)
|
for i, pair in enumerate(data_pairs)
|
||||||
@ -141,4 +140,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main_hydra()
|
||||||
|
|||||||
Reference in New Issue
Block a user