Merge pull request #128 from jayll1303/feat/optimize-data-loader

perf: optimize dataset length calculation via batch column access
This commit is contained in:
xliucs
2025-12-20 14:19:35 +08:00
committed by GitHub

View File

@ -70,25 +70,28 @@ def compute_sample_lengths(
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size) t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2 - 序列总长约为: text_len + t_seq + 2
Optimized: Use batch column access instead of iterating item by item.
""" """
lengths: List[int] = [] # Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
has_duration = "duration" in ds.column_names has_duration = "duration" in ds.column_names
for i in range(len(ds)):
item = ds[i]
text_len = len(item["text_ids"])
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
if has_duration: if has_duration:
duration = float(item["duration"]) durations = ds["duration"]
else: else:
audio = item[DEFAULT_AUDIO_COLUMN] # Fallback: need to compute from audio (slow, but unavoidable without duration column)
duration = len(audio["array"]) / float(audio["sampling_rate"]) durations = []
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
t_vae = math.ceil(duration * audio_vae_fps) # Vectorized length computation
lengths = []
for text_len, duration in zip(text_lens, durations):
t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size) t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2 total_len = text_len + t_seq + 2
lengths.append(total_len) lengths.append(total_len)
@ -211,4 +214,3 @@ def build_dataloader(
collate_fn=HFVoxCPMDataset.collate_fn, collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last, drop_last=drop_last,
) )