OPTIMIZE: Improve sample length computation by using batch column access
This commit is contained in:
@ -70,25 +70,28 @@ def compute_sample_lengths(
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
lengths: List[int] = []
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
text_len = len(item["text_ids"])
|
||||
|
||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
||||
if has_duration:
|
||||
duration = float(item["duration"])
|
||||
durations = ds["duration"]
|
||||
else:
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||
durations = []
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
t_vae = math.ceil(duration * audio_vae_fps)
|
||||
# Vectorized length computation
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
|
||||
total_len = text_len + t_seq + 2
|
||||
lengths.append(total_len)
|
||||
|
||||
@ -211,4 +214,3 @@ def build_dataloader(
|
||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user