Optimize logging validation set results to TensorBoard
This commit is contained in:
@ -228,12 +228,13 @@ def load_model(pretrained_path, lora_path=None):
|
||||
)
|
||||
return "Model loaded successfully!"
|
||||
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed):
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||
global current_model
|
||||
|
||||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||||
if current_model is None:
|
||||
base_model_path = default_pretrained_path # 默认路径
|
||||
# 优先使用用户指定的预训练模型路径
|
||||
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
|
||||
|
||||
# 如果选择了 LoRA,尝试从其 config 读取 base_model
|
||||
if lora_selection and lora_selection != "None":
|
||||
@ -1133,7 +1134,7 @@ with gr.Blocks(
|
||||
|
||||
generate_btn.click(
|
||||
run_inference,
|
||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed],
|
||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
|
||||
outputs=[audio_out, status_out]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user