Merge pull request #128 from jayll1303/feat/optimize-data-loader
perf: optimize dataset length calculation via batch column access
This commit is contained in:
@ -70,25 +70,28 @@ def compute_sample_lengths(
|
|||||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||||
t_seq = ceil(t_vae / patch_size)
|
t_seq = ceil(t_vae / patch_size)
|
||||||
- 序列总长约为: text_len + t_seq + 2
|
- 序列总长约为: text_len + t_seq + 2
|
||||||
|
|
||||||
|
Optimized: Use batch column access instead of iterating item by item.
|
||||||
"""
|
"""
|
||||||
lengths: List[int] = []
|
# Batch access columns - much faster than per-item access
|
||||||
|
text_ids_list = ds["text_ids"]
|
||||||
|
text_lens = [len(t) for t in text_ids_list]
|
||||||
|
|
||||||
has_duration = "duration" in ds.column_names
|
has_duration = "duration" in ds.column_names
|
||||||
|
if has_duration:
|
||||||
|
durations = ds["duration"]
|
||||||
|
else:
|
||||||
|
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||||
|
durations = []
|
||||||
|
for i in range(len(ds)):
|
||||||
|
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||||
|
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||||
|
|
||||||
for i in range(len(ds)):
|
# Vectorized length computation
|
||||||
item = ds[i]
|
lengths = []
|
||||||
text_len = len(item["text_ids"])
|
for text_len, duration in zip(text_lens, durations):
|
||||||
|
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
|
||||||
if has_duration:
|
|
||||||
duration = float(item["duration"])
|
|
||||||
else:
|
|
||||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
|
||||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
|
||||||
|
|
||||||
t_vae = math.ceil(duration * audio_vae_fps)
|
|
||||||
t_seq = math.ceil(t_vae / patch_size)
|
t_seq = math.ceil(t_vae / patch_size)
|
||||||
|
|
||||||
total_len = text_len + t_seq + 2
|
total_len = text_len + t_seq + 2
|
||||||
lengths.append(total_len)
|
lengths.append(total_len)
|
||||||
|
|
||||||
@ -211,4 +214,3 @@ def build_dataloader(
|
|||||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||||
drop_last=drop_last,
|
drop_last=drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user