feat: use cosyvoice2

This commit is contained in:
verachen
2025-01-25 10:40:44 +08:00
parent c2d6fca633
commit 05a48ad022

View File

@ -17,7 +17,7 @@ import sys,os
sys.path.append('/Workspace/CosyVoice')
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
# from cosyvoice.utils.file_utils import load_wav, speed_change
from cosyvoice.utils.file_utils import load_wav#, speed_change
import soundfile as sf
import pyloudnorm as pyln
@ -33,6 +33,7 @@ import numpy as np
from pydub import AudioSegment
import subprocess
import re
def set_all_random_seed(seed):
random.seed(seed)
@ -107,12 +108,13 @@ class TTS(Blackbox):
self.cosyvoice_url = ''
self.cosyvoice_mode = cosyvoice_config.mode
self.cosyvoicetts = None
self.prompt_speech_16k = None
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.cosyvoice_mode == 'local':
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
# self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
self.prompt_speech_16k = load_wav('/Workspace/jarvis-models/Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav', 16000)
else:
self.cosyvoice_url = cosyvoice_config.url
@ -158,7 +160,22 @@ class TTS(Blackbox):
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def filter_invalid_chars(self,text):
"""过滤无效字符(包括字节流)"""
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
if isinstance(text, bytes):
text = text.decode('utf-8', errors='ignore')
for keyword in invalid_keywords:
text = text.replace(keyword, "")
# 移除所有英文字母和符号(保留中文、标点等)
text = re.sub(r'[a-zA-Z]', '', text)
return text.strip()
@logging_time(logger=logger)
def processing(self, *args, settings: dict) -> io.BytesIO:
@ -233,13 +250,45 @@ class TTS(Blackbox):
if self.cosyvoice_mode == 'local':
set_all_random_seed(56056558)
print("*"*90)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
# audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
# for i, j in enumerate(audio):
# f = io.BytesIO()
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
# f.seek(0)
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
# return f.read()
# 打印 audio 的长度和内容结构
# print(f"Total audio segments: {len(audio)}")
# print(f"Audio data structure: {audio}")
# 创建一个空的列表来存储所有音频段的 NumPy 数组
all_audio_data = []
# 遍历每一段音频并将它们存储到 all_audio_data 列表
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
# print(f"Processing segment {i + 1}...")
# 打印每段音频的信息,确保其正确
# print(f"Segment {i + 1} shape: {j['tts_speech'].shape}")
# 直接将音频数据转换成 NumPy 数组
audio_data = j['tts_speech'].cpu().numpy()
# 将每个段的音频数据添加到 all_audio_data 列表
all_audio_data.append(audio_data[0]) # 取音频的第一个通道(假设为单声道)
# 将所有音频段的 NumPy 数组合并成一个完整的音频数组
combined_audio_data = np.concatenate(all_audio_data, axis=0)
# 将合并后的音频数据写入到 BytesIO 中
f = io.BytesIO()
sf.write(f, combined_audio_data, 22050, format='wav') # 22050 为采样率,可能需要根据实际情况调整
f.seek(0)
# 返回合并后的音频
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
return f.read() # 返回最终合并后的音频数据
else:
message = {
"text": text
@ -266,6 +315,7 @@ class TTS(Blackbox):
return response.content
elif user_model_name == 'sovitstts':
# text = self.filter_invalid_chars(text)
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.sovits_mode == 'local':
set_all_random_seed(56056558)
@ -288,7 +338,7 @@ class TTS(Blackbox):
"media_type": self.sovits_media_type,
"streaming_mode": self.sovits_streaming_mode
}
if user_stream:
if user_stream == True or str(user_stream).lower() == "true":
response = requests.get(self.sovits_url, params=message, stream=True)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response
@ -360,8 +410,10 @@ class TTS(Blackbox):
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting)
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
print(f"tts user_stream: {type(user_stream)}")
if user_stream:
if user_stream == True or str(user_stream).lower() == "true":
print(f"tts user_stream22: {user_stream}")
if by.status_code == 200:
print("*"*90)
def audio_stream():
@ -405,6 +457,7 @@ class TTS(Blackbox):
else:
wav_filename = os.path.join(self.audio_dir, 'audio.wav')
print("8"*90)
with open(wav_filename, 'wb') as f:
f.write(by)