Optimize logging validation set results to TensorBoard

This commit is contained in:
刘鑫
2025-12-27 11:49:04 +08:00
parent d57ac634f8
commit 6499215204
3 changed files with 197 additions and 16 deletions

View File

@ -228,12 +228,13 @@ def load_model(pretrained_path, lora_path=None):
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed):
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
base_model_path = default_pretrained_path # 默认路径
# 优先使用用户指定的预训练模型路径
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
# 如果选择了 LoRA尝试从其 config 读取 base_model
if lora_selection and lora_selection != "None":
@ -1133,7 +1134,7 @@ with gr.Blocks(
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed],
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
outputs=[audio_out, status_out]
)

View File

@ -41,6 +41,8 @@ dependencies = [
"simplejson",
"sortedcontainers",
"soundfile",
"librosa",
"matplotlib",
"funasr",
"spaces",
"argbind",

View File

@ -17,6 +17,8 @@ from transformers import get_cosine_schedule_with_warmup
import signal
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
@ -97,7 +99,10 @@ def train(
return {"text_ids": text_ids}
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
# Save original validation texts for audio generation display
val_texts = None
if val_ds is not None:
val_texts = list(val_ds["text"]) # Save original texts
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
@ -154,6 +159,8 @@ def train(
dataset_cnt=dataset_cnt,
device=accelerator.device,
)
# Save audio_vae for audio generation
audio_vae_for_gen = base_model.audio_vae
del base_model.audio_vae
model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model)
@ -287,7 +294,7 @@ def train(
accelerator.update()
scheduler.step()
if step % log_interval == 0:
if step % log_interval == 0 or step == num_iters - 1:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
@ -296,10 +303,13 @@ def train(
loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and step % valid_interval == 0 and step != 0:
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
valid_interval=valid_interval)
if step % save_interval == 0 and accelerator.rank == 0:
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
if accelerator.rank == 0:
@ -308,9 +318,16 @@ def train(
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
val_texts=None, tokenizer=None, valid_interval=1000):
"""Validate and generate sample audio"""
import numpy as np
from collections import defaultdict
model.eval()
losses = []
total_losses = []
sub_losses = defaultdict(list) # Track individual sub-losses
num_batches = 0
max_val_batches = 10
@ -334,18 +351,179 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
total = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
total += lambdas.get(key, 1.0) * value
losses.append(total.detach())
weighted_loss = lambdas.get(key, 1.0) * value
total += weighted_loss
sub_losses[key].append(value.detach())
total_losses.append(total.detach())
num_batches += 1
if losses:
mean_loss = torch.stack(losses).mean()
# All-reduce validation loss across processes for global average
accelerator.all_reduce(mean_loss)
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
if total_losses:
# Compute mean total loss
mean_total_loss = torch.stack(total_losses).mean()
accelerator.all_reduce(mean_total_loss)
# Compute mean of each sub-loss
val_metrics = {"loss/total": mean_total_loss.item()}
for key, values in sub_losses.items():
mean_sub_loss = torch.stack(values).mean()
accelerator.all_reduce(mean_sub_loss)
val_metrics[key] = mean_sub_loss.item()
tracker.log_metrics(val_metrics, split="val")
# Generate sample audio for TensorBoard display
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
try:
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
tracker=tracker)
except Exception as e:
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
import traceback
import io
buf = io.StringIO()
traceback.print_exc(file=buf)
tracker.print(buf.getvalue())
else:
# Log why audio generation was skipped
missing = []
if writer is None:
missing.append("writer")
if val_ds is None:
missing.append("val_ds")
if audio_vae is None:
missing.append("audio_vae")
if missing and accelerator.rank == 0:
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
model.train()
def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
"""Compute Mel Spectrogram (dB) using librosa"""
import numpy as np
import librosa
audio_np = audio_np.flatten().astype(np.float32)
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
return librosa.power_to_db(mel, ref=np.max)
def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_np=None, ref_mel=None):
"""
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import librosa.display
fmax = sample_rate // 2
step_str = f" @ Step {step}" if step is not None else ""
if ref_audio_np is not None and ref_mel is not None:
# Comparison mode: reference vs generated
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
else:
# Single figure mode: show generated only
fig, ax = plt.subplots(figsize=(12, 4))
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
plt.tight_layout()
return fig
def normalize_audio(audio_np):
"""Normalize audio to [-0.9, 0.9]"""
import numpy as np
max_val = np.abs(audio_np).max()
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
tracker=None):
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
import numpy as np
log = tracker.print if tracker else print
num_samples = min(2, len(val_ds))
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
unwrapped_model = accelerator.unwrap(model)
for i in range(num_samples):
sample = val_ds[i]
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
# Load reference audio
ref_audio_np = None
try:
if "audio" in sample and isinstance(sample["audio"], dict) and "array" in sample["audio"]:
ref_audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
if ref_sr != sample_rate:
import torchaudio.functional as F
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}")
try:
# Inference setup
unwrapped_model.eval()
unwrapped_model.to(torch.bfloat16)
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
with torch.no_grad():
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
# Restore training setup
unwrapped_model.to(torch.float32)
unwrapped_model.audio_vae = None
if generated is None or len(generated) == 0:
log(f"[Warning] Generated audio is empty for sample {i}")
continue
# Process generated audio
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}"
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
# Log reference audio
if ref_audio_np is not None:
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
# Generate mel spectrogram figure
try:
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
log(f"[Audio] Created mel spectrogram figure for sample {i}")
except Exception as e:
log(f"[Warning] Failed to create mel spectrogram: {e}")
except Exception as e:
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
import traceback
traceback.print_exc()
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
"""
Load the latest checkpoint if it exists.