add decode script

This commit is contained in:
pengzhendong
2026-01-04 15:12:17 +08:00
parent c35ce7601a
commit 403d305df9
5 changed files with 85 additions and 19 deletions

View File

@ -17,8 +17,12 @@ def parse_args():
parser.add_argument("--scp-file", type=str, required=True)
parser.add_argument("--transcript-file", type=str, required=True)
parser.add_argument("--jsonl-file", type=str, required=True)
parser.add_argument("--max-workers", type=int, default=8,
help="Number of concurrent workers (default: 8)")
parser.add_argument(
"--max-workers",
type=int,
default=8,
help="Number of concurrent workers (default: 8)",
)
return parser.parse_args()
@ -59,11 +63,14 @@ class LineProcessor:
data = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>"},
{"role": "assistant", "content": text}
{
"role": "user",
"content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>",
},
{"role": "assistant", "content": text},
],
"speech_length": int((duration * 1000 - 25) // 10 + 1),
"text_length": len(self.tokenizer.tokenize(text))
"text_length": len(self.tokenizer.tokenize(text)),
}
return {"success": data, "utt": utt1}
@ -79,7 +86,9 @@ def main():
transcript_lines = f2.readlines()
if len(scp_lines) != len(transcript_lines):
print(f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}")
print(
f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
processor = LineProcessor(tokenizer)
@ -93,8 +102,10 @@ def main():
with tqdm(total=len(data_pairs), desc="Processing") as pbar:
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
with open(args.jsonl_file, "w") as f_out:
futures = {executor.submit(processor.process_line, pair): i
for i, pair in enumerate(data_pairs)}
futures = {
executor.submit(processor.process_line, pair): i
for i, pair in enumerate(data_pairs)
}
for future in as_completed(futures):
result = future.result()
@ -109,10 +120,9 @@ def main():
error_messages.append(result["error"])
pbar.update(1)
pbar.set_postfix({
"processed": processed_count,
"failed": failed_count
})
pbar.set_postfix(
{"processed": processed_count, "failed": failed_count}
)
print(f"\nProcessing completed:")
print(f" Total lines: {len(data_pairs)}")