Files
jarvis-models/tts/tts_service.py
2024-03-19 17:33:09 +08:00

64 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import sys
import time
sys.path.append('tts/vits')
import numpy as np
import soundfile
import os
os.environ["PYTORCH_JIT"] = "0"
import torch
import tts.vits.commons as commons
import tts.vits.utils as utils
from tts.vits.models import SynthesizerTrn
from tts.vits.text.symbols import symbols
from tts.vits.text import text_to_sequence
import logging
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
from pydub import AudioSegment
class TTService():
def __init__(self, cfg, model, char, speed):
logging.info('Initializing TTS Service for %s...' % char)
self.hps = utils.get_hparams_from_file(cfg)
self.speed = speed
self.net_g = SynthesizerTrn(
len(symbols),
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu()
_ = self.net_g.eval()
_ = utils.load_checkpoint(model, self.net_g, None)
def get_text(self, text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def read(self, text, format="wav") -> io.BytesIO:
text = text.replace('~', '')
stn_tst = self.get_text(text, self.hps)
with torch.no_grad():
x_tst = stn_tst.cpu().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu()
# tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][
0, 0].data.cpu().float().numpy()
f = io.BytesIO()
soundfile.write(f, audio, self.hps.data.sampling_rate, format=format)
f.seek(0)
return f