Optimize logging validation set results to TensorBoard
This commit is contained in:
@ -17,6 +17,8 @@ from transformers import get_cosine_schedule_with_warmup
|
||||
import signal
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
@ -97,7 +99,10 @@ def train(
|
||||
return {"text_ids": text_ids}
|
||||
|
||||
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
# Save original validation texts for audio generation display
|
||||
val_texts = None
|
||||
if val_ds is not None:
|
||||
val_texts = list(val_ds["text"]) # Save original texts
|
||||
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
|
||||
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
|
||||
@ -154,6 +159,8 @@ def train(
|
||||
dataset_cnt=dataset_cnt,
|
||||
device=accelerator.device,
|
||||
)
|
||||
# Save audio_vae for audio generation
|
||||
audio_vae_for_gen = base_model.audio_vae
|
||||
del base_model.audio_vae
|
||||
model = accelerator.prepare_model(base_model)
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
@ -287,7 +294,7 @@ def train(
|
||||
accelerator.update()
|
||||
scheduler.step()
|
||||
|
||||
if step % log_interval == 0:
|
||||
if step % log_interval == 0 or step == num_iters - 1:
|
||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
||||
@ -296,10 +303,13 @@ def train(
|
||||
loss_values["grad_norm"] = float(grad_norm)
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and step % valid_interval == 0 and step != 0:
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
|
||||
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
||||
valid_interval=valid_interval)
|
||||
|
||||
if step % save_interval == 0 and accelerator.rank == 0:
|
||||
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||
|
||||
if accelerator.rank == 0:
|
||||
@ -308,9 +318,16 @@ def train(
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, valid_interval=1000):
|
||||
"""Validate and generate sample audio"""
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
model.eval()
|
||||
losses = []
|
||||
total_losses = []
|
||||
sub_losses = defaultdict(list) # Track individual sub-losses
|
||||
num_batches = 0
|
||||
max_val_batches = 10
|
||||
|
||||
@ -334,18 +351,179 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
||||
total = 0.0
|
||||
for key, value in outputs.items():
|
||||
if key.startswith("loss/"):
|
||||
total += lambdas.get(key, 1.0) * value
|
||||
losses.append(total.detach())
|
||||
weighted_loss = lambdas.get(key, 1.0) * value
|
||||
total += weighted_loss
|
||||
sub_losses[key].append(value.detach())
|
||||
total_losses.append(total.detach())
|
||||
num_batches += 1
|
||||
|
||||
if losses:
|
||||
mean_loss = torch.stack(losses).mean()
|
||||
# All-reduce validation loss across processes for global average
|
||||
accelerator.all_reduce(mean_loss)
|
||||
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
|
||||
if total_losses:
|
||||
# Compute mean total loss
|
||||
mean_total_loss = torch.stack(total_losses).mean()
|
||||
accelerator.all_reduce(mean_total_loss)
|
||||
|
||||
# Compute mean of each sub-loss
|
||||
val_metrics = {"loss/total": mean_total_loss.item()}
|
||||
for key, values in sub_losses.items():
|
||||
mean_sub_loss = torch.stack(values).mean()
|
||||
accelerator.all_reduce(mean_sub_loss)
|
||||
val_metrics[key] = mean_sub_loss.item()
|
||||
|
||||
tracker.log_metrics(val_metrics, split="val")
|
||||
|
||||
# Generate sample audio for TensorBoard display
|
||||
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||
try:
|
||||
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
||||
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
||||
tracker=tracker)
|
||||
except Exception as e:
|
||||
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||
import traceback
|
||||
import io
|
||||
buf = io.StringIO()
|
||||
traceback.print_exc(file=buf)
|
||||
tracker.print(buf.getvalue())
|
||||
else:
|
||||
# Log why audio generation was skipped
|
||||
missing = []
|
||||
if writer is None:
|
||||
missing.append("writer")
|
||||
if val_ds is None:
|
||||
missing.append("val_ds")
|
||||
if audio_vae is None:
|
||||
missing.append("audio_vae")
|
||||
if missing and accelerator.rank == 0:
|
||||
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
||||
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||
import numpy as np
|
||||
import librosa
|
||||
audio_np = audio_np.flatten().astype(np.float32)
|
||||
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||
return librosa.power_to_db(mel, ref=np.max)
|
||||
|
||||
|
||||
def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_np=None, ref_mel=None):
|
||||
"""
|
||||
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa.display
|
||||
|
||||
fmax = sample_rate // 2
|
||||
step_str = f" @ Step {step}" if step is not None else ""
|
||||
|
||||
if ref_audio_np is not None and ref_mel is not None:
|
||||
# Comparison mode: reference vs generated
|
||||
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
||||
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
||||
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
||||
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
||||
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
||||
else:
|
||||
# Single figure mode: show generated only
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
||||
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
||||
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def normalize_audio(audio_np):
|
||||
"""Normalize audio to [-0.9, 0.9]"""
|
||||
import numpy as np
|
||||
max_val = np.abs(audio_np).max()
|
||||
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||
|
||||
|
||||
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
||||
tracker=None):
|
||||
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||
import numpy as np
|
||||
|
||||
log = tracker.print if tracker else print
|
||||
num_samples = min(2, len(val_ds))
|
||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
|
||||
for i in range(num_samples):
|
||||
sample = val_ds[i]
|
||||
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
|
||||
|
||||
# Load reference audio
|
||||
ref_audio_np = None
|
||||
try:
|
||||
if "audio" in sample and isinstance(sample["audio"], dict) and "array" in sample["audio"]:
|
||||
ref_audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
||||
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||
if ref_sr != sample_rate:
|
||||
import torchaudio.functional as F
|
||||
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to load reference audio: {e}")
|
||||
|
||||
try:
|
||||
# Inference setup
|
||||
unwrapped_model.eval()
|
||||
unwrapped_model.to(torch.bfloat16)
|
||||
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
||||
|
||||
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
||||
with torch.no_grad():
|
||||
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
||||
|
||||
# Restore training setup
|
||||
unwrapped_model.to(torch.float32)
|
||||
unwrapped_model.audio_vae = None
|
||||
|
||||
if generated is None or len(generated) == 0:
|
||||
log(f"[Warning] Generated audio is empty for sample {i}")
|
||||
continue
|
||||
|
||||
# Process generated audio
|
||||
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||
|
||||
# Log reference audio
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
||||
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to create mel spectrogram: {e}")
|
||||
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
"""
|
||||
Load the latest checkpoint if it exists.
|
||||
|
||||
Reference in New Issue
Block a user