From bab055e7d60ea412545248be0ec1b842bf7e0344 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 10 Apr 2024 10:13:36 +0800 Subject: [PATCH 1/2] TTS & cuda processing updated --- cuda.py | 5 +++++ src/blackbox/tts.py | 3 +++ src/tts/tts_service.py | 6 +++--- src/tts/vits/monotonic_align/__init__.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 cuda.py diff --git a/cuda.py b/cuda.py new file mode 100644 index 0000000..159370c --- /dev/null +++ b/cuda.py @@ -0,0 +1,5 @@ +import torch + +print("Torch version:",torch.__version__) + +print("Is CUDA enabled?",torch.cuda.is_available()) \ No newline at end of file diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index aea74d6..669f273 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -1,4 +1,5 @@ import io +import time from ntpath import join from fastapi import Request, Response, status @@ -16,7 +17,9 @@ class TTS(Blackbox): def processing(self, *args, **kwargs) -> io.BytesIO: text = args[0] + current_time = time.time() audio = self.tts_service.read(text) + print("#### TTS Service consume : ", (time.time()-current_time)) return audio def valid(self, *args, **kwargs) -> bool: diff --git a/src/tts/tts_service.py b/src/tts/tts_service.py index 938df56..e140ff4 100644 --- a/src/tts/tts_service.py +++ b/src/tts/tts_service.py @@ -53,7 +53,7 @@ class TTService(): len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, - **self.hps.model).cpu() + **self.hps.model).cuda() _ = self.net_g.eval() _ = utils.load_checkpoint(cfg["model"], self.net_g, None) @@ -69,8 +69,8 @@ class TTService(): stn_tst = self.get_text(text, self.hps) with torch.no_grad(): - x_tst = stn_tst.cpu().unsqueeze(0) - x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu() + x_tst = stn_tst.cuda().unsqueeze(0) + x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() # tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed) audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][ 0, 0].data.cpu().float().numpy() diff --git a/src/tts/vits/monotonic_align/__init__.py b/src/tts/vits/monotonic_align/__init__.py index 3d7009c..9293c5a 100644 --- a/src/tts/vits/monotonic_align/__init__.py +++ b/src/tts/vits/monotonic_align/__init__.py @@ -1,6 +1,6 @@ import numpy as np import torch -from .monotonic_align.core import maximum_path_c +from .core import maximum_path_c def maximum_path(neg_cent, mask): From 005acc087405c36031045fd9f587ef060851c770 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 10 Apr 2024 10:29:41 +0800 Subject: [PATCH 2/2] feat: tts enable gpu --- src/blackbox/tesou.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index ae547fc..723b811 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -31,7 +31,7 @@ class Tesou(Blackbox): except: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_id = data.get("user_id") - user_prompt = data.get("prompt") + user_prompt = data.get("prompt") if user_prompt is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK) \ No newline at end of file