Merge pull request #128 from jayll1303/feat/optimize-data-loader

perf: optimize dataset length calculation via batch column access
This commit is contained in:
xliucs
2025-12-20 14:19:35 +08:00
committed by GitHub

View File

@ -70,25 +70,28 @@ def compute_sample_lengths(
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size) t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2 - 序列总长约为: text_len + t_seq + 2
Optimized: Use batch column access instead of iterating item by item.
""" """
lengths: List[int] = [] # Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
has_duration = "duration" in ds.column_names has_duration = "duration" in ds.column_names
if has_duration:
for i in range(len(ds)): durations = ds["duration"]
item = ds[i] else:
text_len = len(item["text_ids"]) # Fallback: need to compute from audio (slow, but unavoidable without duration column)
durations = []
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用) for i in range(len(ds)):
if has_duration: audio = ds[i][DEFAULT_AUDIO_COLUMN]
duration = float(item["duration"]) durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
else:
audio = item[DEFAULT_AUDIO_COLUMN] # Vectorized length computation
duration = len(audio["array"]) / float(audio["sampling_rate"]) lengths = []
for text_len, duration in zip(text_lens, durations):
t_vae = math.ceil(duration * audio_vae_fps) t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size) t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2 total_len = text_len + t_seq + 2
lengths.append(total_len) lengths.append(total_len)
@ -211,4 +214,3 @@ def build_dataloader(
collate_fn=HFVoxCPMDataset.collate_fn, collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last, drop_last=drop_last,
) )