TTS & cuda processing updated

This commit is contained in:
superobk
2024-04-10 10:13:36 +08:00
parent c104ea52b5
commit bab055e7d6
4 changed files with 12 additions and 4 deletions

View File

@ -53,7 +53,7 @@ class TTService():
len(symbols),
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu()
**self.hps.model).cuda()
_ = self.net_g.eval()
_ = utils.load_checkpoint(cfg["model"], self.net_g, None)
@ -69,8 +69,8 @@ class TTService():
stn_tst = self.get_text(text, self.hps)
with torch.no_grad():
x_tst = stn_tst.cpu().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu()
x_tst = stn_tst.cuda().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
# tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][
0, 0].data.cpu().float().numpy()