diff --git a/src/voxcpm/training/data.py b/src/voxcpm/training/data.py index b2c2631..cd919a6 100644 --- a/src/voxcpm/training/data.py +++ b/src/voxcpm/training/data.py @@ -70,25 +70,28 @@ def compute_sample_lengths( duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae t_seq = ceil(t_vae / patch_size) - 序列总长约为: text_len + t_seq + 2 + + Optimized: Use batch column access instead of iterating item by item. """ - lengths: List[int] = [] - + # Batch access columns - much faster than per-item access + text_ids_list = ds["text_ids"] + text_lens = [len(t) for t in text_ids_list] + has_duration = "duration" in ds.column_names - - for i in range(len(ds)): - item = ds[i] - text_len = len(item["text_ids"]) - - # 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用) - if has_duration: - duration = float(item["duration"]) - else: - audio = item[DEFAULT_AUDIO_COLUMN] - duration = len(audio["array"]) / float(audio["sampling_rate"]) - - t_vae = math.ceil(duration * audio_vae_fps) + if has_duration: + durations = ds["duration"] + else: + # Fallback: need to compute from audio (slow, but unavoidable without duration column) + durations = [] + for i in range(len(ds)): + audio = ds[i][DEFAULT_AUDIO_COLUMN] + durations.append(len(audio["array"]) / float(audio["sampling_rate"])) + + # Vectorized length computation + lengths = [] + for text_len, duration in zip(text_lens, durations): + t_vae = math.ceil(float(duration) * audio_vae_fps) t_seq = math.ceil(t_vae / patch_size) - total_len = text_len + t_seq + 2 lengths.append(total_len) @@ -211,4 +214,3 @@ def build_dataloader( collate_fn=HFVoxCPMDataset.collate_fn, drop_last=drop_last, ) -