Merge pull request #4 from BoardWare-Genius/refactor

Refactor
This commit is contained in:
Benson Ou
2024-04-10 10:35:43 +08:00
committed by GitHub
5 changed files with 13 additions and 5 deletions

5
cuda.py Normal file
View File

@ -0,0 +1,5 @@
import torch
print("Torch version:",torch.__version__)
print("Is CUDA enabled?",torch.cuda.is_available())

View File

@ -31,7 +31,7 @@ class Tesou(Blackbox):
except: except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_id = data.get("user_id") user_id = data.get("user_id")
user_prompt = data.get("prompt") user_prompt = data.get("prompt")
if user_prompt is None: if user_prompt is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK) return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK)

View File

@ -1,4 +1,5 @@
import io import io
import time
from ntpath import join from ntpath import join
from fastapi import Request, Response, status from fastapi import Request, Response, status
@ -16,7 +17,9 @@ class TTS(Blackbox):
def processing(self, *args, **kwargs) -> io.BytesIO: def processing(self, *args, **kwargs) -> io.BytesIO:
text = args[0] text = args[0]
current_time = time.time()
audio = self.tts_service.read(text) audio = self.tts_service.read(text)
print("#### TTS Service consume : ", (time.time()-current_time))
return audio return audio
def valid(self, *args, **kwargs) -> bool: def valid(self, *args, **kwargs) -> bool:

View File

@ -53,7 +53,7 @@ class TTService():
len(symbols), len(symbols),
self.hps.data.filter_length // 2 + 1, self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length, self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu() **self.hps.model).cuda()
_ = self.net_g.eval() _ = self.net_g.eval()
_ = utils.load_checkpoint(cfg["model"], self.net_g, None) _ = utils.load_checkpoint(cfg["model"], self.net_g, None)
@ -69,8 +69,8 @@ class TTService():
stn_tst = self.get_text(text, self.hps) stn_tst = self.get_text(text, self.hps)
with torch.no_grad(): with torch.no_grad():
x_tst = stn_tst.cpu().unsqueeze(0) x_tst = stn_tst.cuda().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu() x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
# tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed) # tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][ audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][
0, 0].data.cpu().float().numpy() 0, 0].data.cpu().float().numpy()

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import torch import torch
from .monotonic_align.core import maximum_path_c from .core import maximum_path_c
def maximum_path(neg_cent, mask): def maximum_path(neg_cent, mask):