Files
jarvis-models/src/tts/tts_service.py
superobk 2a0c0e0477 feat
2024-03-27 16:20:12 +08:00

81 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import sys
sys.path.append('src/tts/vits')
import soundfile
import os
os.environ["PYTORCH_JIT"] = "0"
import torch
import src.tts.vits.commons as commons
import src.tts.vits.utils as utils
from src.tts.vits.models import SynthesizerTrn
from src.tts.vits.text.symbols import symbols
from src.tts.vits.text import text_to_sequence
import logging
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "C://" + "/".join(dirbaspath)
config = {
'paimon': {
'cfg': dirbaspath + '/models/paimon6k.json',
'model': dirbaspath + '/models/paimon6k_390k.pth',
'char': 'character_paimon',
'speed': 1
},
'yunfei': {
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
'char': 'character_yunfei',
'speed': 1.1
},
'catmaid': {
'cfg': dirbaspath + '/models/catmix.json',
'model': dirbaspath + '/models/catmix_107k.pth',
'char': 'character_catmaid',
'speed': 1.2
},
}
class TTService():
def __init__(self, model_name="catmaid"):
cfg = config[model_name]
logging.info('Initializing TTS Service for %s...' % cfg["char"])
self.hps = utils.get_hparams_from_file(cfg["cfg"])
self.speed = cfg["speed"]
self.net_g = SynthesizerTrn(
len(symbols),
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu()
_ = self.net_g.eval()
_ = utils.load_checkpoint(cfg["model"], self.net_g, None)
def get_text(self, text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def read(self, text, format="wav") -> io.BytesIO:
text = text.replace('~', '')
stn_tst = self.get_text(text, self.hps)
with torch.no_grad():
x_tst = stn_tst.cpu().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu()
# tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][
0, 0].data.cpu().float().numpy()
f = io.BytesIO()
soundfile.write(f, audio, self.hps.data.sampling_rate, format=format)
f.seek(0)
return f