Optimize logging validation set results to TensorBoard

This commit is contained in:
刘鑫
2025-12-27 11:49:04 +08:00
parent d57ac634f8
commit 6499215204
3 changed files with 197 additions and 16 deletions

View File

@ -228,12 +228,13 @@ def load_model(pretrained_path, lora_path=None):
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed):
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
base_model_path = default_pretrained_path # 默认路径
# 优先使用用户指定的预训练模型路径
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
# 如果选择了 LoRA尝试从其 config 读取 base_model
if lora_selection and lora_selection != "None":
@ -1133,7 +1134,7 @@ with gr.Blocks(
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed],
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
outputs=[audio_out, status_out]
)