Merge pull request #141 from vytskalt/main
Print debug messages to stderr instead of stdout
This commit is contained in:
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,7 +31,7 @@ class VoxCPM:
|
|||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||||
|
|
||||||
# If lora_weights_path is provided but no lora_config, create a default one
|
# If lora_weights_path is provided but no lora_config, create a default one
|
||||||
if lora_weights_path is not None and lora_config is None:
|
if lora_weights_path is not None and lora_config is None:
|
||||||
@ -39,15 +40,15 @@ class VoxCPM:
|
|||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
enable_proj=False,
|
enable_proj=False,
|
||||||
)
|
)
|
||||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
|
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
|
||||||
# Load LoRA weights if path is provided
|
# Load LoRA weights if path is provided
|
||||||
if lora_weights_path is not None:
|
if lora_weights_path is not None:
|
||||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||||
|
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
@ -56,7 +57,7 @@ class VoxCPM:
|
|||||||
else:
|
else:
|
||||||
self.denoiser = None
|
self.denoiser = None
|
||||||
if optimize:
|
if optimize:
|
||||||
print("Warm up VoxCPMModel...")
|
print("Warm up VoxCPMModel...", file=sys.stderr)
|
||||||
self.tts_model.generate(
|
self.tts_model.generate(
|
||||||
target_text="Hello, this is the first test sentence.",
|
target_text="Hello, this is the first test sentence.",
|
||||||
max_len=10,
|
max_len=10,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from typing import Tuple, Union, Generator, List, Optional
|
from typing import Tuple, Union, Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.device = "mps"
|
self.device = "mps"
|
||||||
else:
|
else:
|
||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
self.base_lm = MiniCPMModel(config.lm_config)
|
self.base_lm = MiniCPMModel(config.lm_config)
|
||||||
@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: torch.compile disabled - {e}")
|
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -683,7 +684,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -868,10 +869,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||||
|
|
||||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||||
print(f"Loading model from safetensors: {safetensors_path}")
|
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
||||||
model_state_dict = load_file(safetensors_path)
|
model_state_dict = load_file(safetensors_path)
|
||||||
elif os.path.exists(pytorch_model_path):
|
elif os.path.exists(pytorch_model_path):
|
||||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
pytorch_model_path,
|
pytorch_model_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
|||||||
Reference in New Issue
Block a user