Files
jarvis-models/src/blackbox/tts.py
2024-09-12 15:56:00 +08:00

239 lines
10 KiB
Python

import io
import time
from ntpath import join
import requests
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
from ..tts.tts_service import TTService
from ..configuration import MeloConf
from ..configuration import CosyVoiceConf
from injector import inject
from injector import singleton
import sys,os
sys.path.append('/home/gpu/Workspace/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav, speed_change
import soundfile
import pyloudnorm as pyln
from melo.api import TTS as MELOTTS
from ..log.logging_time import logging_time
import logging
logger = logging.getLogger(__name__)
import random
import torch
import numpy as np
def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@singleton
class TTS(Blackbox):
melo_mode: str
melo_url: str
melo_speed: int
melo_device: str
melo_language: str
melo_speaker: str
cosyvoice_mode: str
cosyvoice_url: str
cosyvoice_speed: int
cosyvoice_device: str
cosyvoice_language: str
cosyvoice_speaker: str
@logging_time(logger=logger)
def melo_model_init(self, melo_config: MeloConf) -> None:
self.melo_speed = melo_config.speed
self.melo_device = melo_config.device
self.melo_language = melo_config.language
self.melo_speaker = melo_config.speaker
self.melo_url = ''
self.melo_mode = melo_config.mode
self.melotts = None
self.speaker_ids = None
if self.melo_mode == 'local':
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
self.speaker_ids = self.melotts.hps.data.spk2id
else:
self.melo_url = melo_config.url
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
@logging_time(logger=logger)
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None:
self.cosyvoice_speed = cosyvoice_config.speed
self.cosyvoice_device = cosyvoice_config.device
self.cosyvoice_language = cosyvoice_config.language
self.cosyvoice_speaker = cosyvoice_config.speaker
self.cosyvoice_url = ''
self.cosyvoice_mode = cosyvoice_config.mode
self.cosyvoicetts = None
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.cosyvoice_mode == 'local':
self.cosyvoicetts = CosyVoice('/home/gpu/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
else:
self.cosyvoice_url = cosyvoice_config.url
logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
@inject
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
self.tts_service = TTService("yunfeineo")
self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
@logging_time(logger=logger)
def processing(self, *args, settings: dict) -> io.BytesIO:
print("\nChat Settings: ", settings)
if settings is None:
settings = {}
user_model_name = settings.get("tts_model_name")
chroma_collection_id = settings.get("chroma_collection_id")
print(f"tts_model_name: {user_model_name}")
text = args[0]
current_time = time.time()
if user_model_name == 'melotts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.melo_mode == 'local':
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
f = io.BytesIO()
soundfile.write(f, audio, 44100, format='wav')
f.seek(0)
# Read the audio data from the buffer
data, rate = soundfile.read(f, dtype='float32')
# Peak normalization
peak_normalized_audio = pyln.normalize.peak(data, -1.0)
# Integrated loudness normalization
meter = pyln.Meter(rate)
loudness = meter.integrated_loudness(peak_normalized_audio)
loudness_normalized_audio = pyln.normalize.loudness(peak_normalized_audio, loudness, -12.0)
# Write the loudness normalized audio to an in-memory buffer
normalized_audio_buffer = io.BytesIO()
soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav')
normalized_audio_buffer.seek(0)
print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
return normalized_audio_buffer.read()
else:
message = {
"text": text
}
response = requests.post(self.melo_url, json=message)
print("#### MeloTTS Service consume - docker : ", (time.time()-current_time))
return response.content
elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男')
f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text
}
response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
elif user_model_name == 'cosyvoicetts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.cosyvoice_mode == 'local':
set_all_random_seed(56056558)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text
}
response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男')
f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text
}
response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
elif user_model_name == 'man':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男')
try:
audio, sample_rate = speed_change(audio["tts_speech"], 22050, str(1.5))
audio = audio.numpy().flatten()
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
f = io.BytesIO()
soundfile.write(f, audio, 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text
}
response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
else:
audio = self.tts_service.read(text)
print("#### TTS Service consume : ", (time.time()-current_time))
return audio.read()
def valid(self, *args, **kwargs) -> bool:
text = args[0]
return isinstance(text, str)
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
print(f"data: {data}")
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
text = data.get("text")
setting = data.get("settings")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting)
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})