print debug messages to stderr instead of stdout

This commit is contained in:
vytskalt
2026-01-09 20:05:52 +02:00
parent 6ecc00a5d3
commit f2e203d5e2
2 changed files with 14 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import os import os
import sys
import re import re
import tempfile import tempfile
import numpy as np import numpy as np
@ -30,7 +31,7 @@ class VoxCPM:
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded. containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
""" """
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
# If lora_weights_path is provided but no lora_config, create a default one # If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None: if lora_weights_path is not None and lora_config is None:
@ -39,15 +40,15 @@ class VoxCPM:
enable_dit=True, enable_dit=True,
enable_proj=False, enable_proj=False,
) )
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}") print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided # Load LoRA weights if path is provided
if lora_weights_path is not None: if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}") print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path) loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}") print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None: if enable_denoiser and zipenhancer_model_path is not None:
@ -56,7 +57,7 @@ class VoxCPM:
else: else:
self.denoiser = None self.denoiser = None
if optimize: if optimize:
print("Warm up VoxCPMModel...") print("Warm up VoxCPMModel...", file=sys.stderr)
self.tts_model.generate( self.tts_model.generate(
target_text="Hello, this is the first test sentence.", target_text="Hello, this is the first test sentence.",
max_len=10, max_len=10,

View File

@ -19,6 +19,7 @@ limitations under the License.
""" """
import os import os
import sys
from typing import Tuple, Union, Generator, List, Optional from typing import Tuple, Union, Generator, List, Optional
import torch import torch
@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
self.device = "mps" self.device = "mps"
else: else:
self.device = "cpu" self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}") print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM # Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config) self.base_lm = MiniCPMModel(config.lm_config)
@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e: except Exception as e:
print(f"Warning: torch.compile disabled - {e}") print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self return self
def forward( def forward(
@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1 retry_badcase_times += 1
continue continue
else: else:
@ -683,7 +684,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1 retry_badcase_times += 1
continue continue
else: else:
@ -868,10 +869,10 @@ class VoxCPMModel(nn.Module):
pytorch_model_path = os.path.join(path, "pytorch_model.bin") pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE: if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}") print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path) model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path): elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}") print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
checkpoint = torch.load( checkpoint = torch.load(
pytorch_model_path, pytorch_model_path,
map_location="cpu", map_location="cpu",