diff --git a/lora_ft_webui.py b/lora_ft_webui.py index 9f679e6..95da1ce 100644 --- a/lora_ft_webui.py +++ b/lora_ft_webui.py @@ -228,12 +228,13 @@ def load_model(pretrained_path, lora_path=None): ) return "Model loaded successfully!" -def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed): +def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None): global current_model # 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model if current_model is None: - base_model_path = default_pretrained_path # 默认路径 + # 优先使用用户指定的预训练模型路径 + base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path # 如果选择了 LoRA,尝试从其 config 读取 base_model if lora_selection and lora_selection != "None": @@ -1133,7 +1134,7 @@ with gr.Blocks( generate_btn.click( run_inference, - inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed], + inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path], outputs=[audio_out, status_out] ) diff --git a/pyproject.toml b/pyproject.toml index 3f5a379..cef7027 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ "simplejson", "sortedcontainers", "soundfile", + "librosa", + "matplotlib", "funasr", "spaces", "argbind", diff --git a/scripts/train_voxcpm_finetune.py b/scripts/train_voxcpm_finetune.py index e17e46f..a9c9492 100644 --- a/scripts/train_voxcpm_finetune.py +++ b/scripts/train_voxcpm_finetune.py @@ -17,6 +17,8 @@ from transformers import get_cosine_schedule_with_warmup import signal import os +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + try: from safetensors.torch import save_file SAFETENSORS_AVAILABLE = True @@ -97,7 +99,10 @@ def train( return {"text_ids": text_ids} train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"]) + # Save original validation texts for audio generation display + val_texts = None if val_ds is not None: + val_texts = list(val_ds["text"]) # Save original texts val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"]) dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1 @@ -154,6 +159,8 @@ def train( dataset_cnt=dataset_cnt, device=accelerator.device, ) + # Save audio_vae for audio generation + audio_vae_for_gen = base_model.audio_vae del base_model.audio_vae model = accelerator.prepare_model(base_model) unwrapped_model = accelerator.unwrap(model) @@ -287,7 +294,7 @@ def train( accelerator.update() scheduler.step() - if step % log_interval == 0: + if step % log_interval == 0 or step == num_iters - 1: loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()} loss_values["lr"] = float(optimizer.param_groups[0]["lr"]) # Approximate epoch: seen samples / total samples (considering grad_accum and batch_size) @@ -296,10 +303,13 @@ def train( loss_values["grad_norm"] = float(grad_norm) tracker.log_metrics(loss_values, split="train") - if val_loader is not None and step % valid_interval == 0 and step != 0: - validate(model, val_loader, batch_processor, accelerator, tracker, lambdas) + if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1): + validate(model, val_loader, batch_processor, accelerator, tracker, lambdas, + writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen, + sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer, + valid_interval=valid_interval) - if step % save_interval == 0 and accelerator.rank == 0: + if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0: save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute) if accelerator.rank == 0: @@ -308,9 +318,16 @@ def train( writer.close() -def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas): +def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas, + writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050, + val_texts=None, tokenizer=None, valid_interval=1000): + """Validate and generate sample audio""" + import numpy as np + from collections import defaultdict + model.eval() - losses = [] + total_losses = [] + sub_losses = defaultdict(list) # Track individual sub-losses num_batches = 0 max_val_batches = 10 @@ -334,18 +351,179 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas): total = 0.0 for key, value in outputs.items(): if key.startswith("loss/"): - total += lambdas.get(key, 1.0) * value - losses.append(total.detach()) + weighted_loss = lambdas.get(key, 1.0) * value + total += weighted_loss + sub_losses[key].append(value.detach()) + total_losses.append(total.detach()) num_batches += 1 - if losses: - mean_loss = torch.stack(losses).mean() - # All-reduce validation loss across processes for global average - accelerator.all_reduce(mean_loss) - tracker.log_metrics({"loss": mean_loss.item()}, split="val") + if total_losses: + # Compute mean total loss + mean_total_loss = torch.stack(total_losses).mean() + accelerator.all_reduce(mean_total_loss) + + # Compute mean of each sub-loss + val_metrics = {"loss/total": mean_total_loss.item()} + for key, values in sub_losses.items(): + mean_sub_loss = torch.stack(values).mean() + accelerator.all_reduce(mean_sub_loss) + val_metrics[key] = mean_sub_loss.item() + + tracker.log_metrics(val_metrics, split="val") + + # Generate sample audio for TensorBoard display + if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0: + try: + generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate, + val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval, + tracker=tracker) + except Exception as e: + tracker.print(f"[Warning] Failed to generate sample audio: {e}") + import traceback + import io + buf = io.StringIO() + traceback.print_exc(file=buf) + tracker.print(buf.getvalue()) + else: + # Log why audio generation was skipped + missing = [] + if writer is None: + missing.append("writer") + if val_ds is None: + missing.append("val_ds") + if audio_vae is None: + missing.append("audio_vae") + if missing and accelerator.rank == 0: + tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}") + model.train() +def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128): + """Compute Mel Spectrogram (dB) using librosa""" + import numpy as np + import librosa + audio_np = audio_np.flatten().astype(np.float32) + mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2) + return librosa.power_to_db(mel, ref=np.max) + + +def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_np=None, ref_mel=None): + """ + Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only + """ + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import librosa.display + + fmax = sample_rate // 2 + step_str = f" @ Step {step}" if step is not None else "" + + if ref_audio_np is not None and ref_mel is not None: + # Comparison mode: reference vs generated + fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8)) + + img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref) + ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745') + plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02) + + img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen) + ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545') + plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02) + else: + # Single figure mode: show generated only + fig, ax = plt.subplots(figsize=(12, 4)) + img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax) + ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold') + plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02) + + plt.tight_layout() + return fig + + +def normalize_audio(audio_np): + """Normalize audio to [-0.9, 0.9]""" + import numpy as np + max_val = np.abs(audio_np).max() + return audio_np / max_val * 0.9 if max_val > 0 else audio_np + + +def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050, + val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000, + tracker=None): + """Select 2 fixed validation samples, generate audio and log to TensorBoard""" + import numpy as np + + log = tracker.print if tracker else print + num_samples = min(2, len(val_ds)) + log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}") + + unwrapped_model = accelerator.unwrap(model) + + for i in range(num_samples): + sample = val_ds[i] + text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test." + + # Load reference audio + ref_audio_np = None + try: + if "audio" in sample and isinstance(sample["audio"], dict) and "array" in sample["audio"]: + ref_audio_np = np.array(sample["audio"]["array"], dtype=np.float32) + ref_sr = sample["audio"].get("sampling_rate", sample_rate) + if ref_sr != sample_rate: + import torchaudio.functional as F + ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy() + log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s") + except Exception as e: + log(f"[Warning] Failed to load reference audio: {e}") + + try: + # Inference setup + unwrapped_model.eval() + unwrapped_model.to(torch.bfloat16) + unwrapped_model.audio_vae = audio_vae.to(torch.float32) + + log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'") + with torch.no_grad(): + generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0) + + # Restore training setup + unwrapped_model.to(torch.float32) + unwrapped_model.audio_vae = None + + if generated is None or len(generated) == 0: + log(f"[Warning] Generated audio is empty for sample {i}") + continue + + # Process generated audio + gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten() + gen_audio_np = normalize_audio(gen_audio_np) + + tag = f"val_sample_{i}" + writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate) + log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s") + + # Log reference audio + if ref_audio_np is not None: + writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate) + + # Generate mel spectrogram figure + try: + mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate) + mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None + fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref) + writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step) + log(f"[Audio] Created mel spectrogram figure for sample {i}") + except Exception as e: + log(f"[Warning] Failed to create mel spectrogram: {e}") + + except Exception as e: + log(f"[Warning] Failed to generate audio for sample {i}: {e}") + import traceback + traceback.print_exc() + + def load_checkpoint(model, optimizer, scheduler, save_dir: Path): """ Load the latest checkpoint if it exists.