improvement: update tts

This commit is contained in:
0Xiao0
2024-10-11 20:18:50 +08:00
parent 7a5008abb7
commit 8783f27fdb
4 changed files with 31 additions and 31 deletions

View File

@ -110,8 +110,9 @@ class Chat(Blackbox):
if chroma_response: if chroma_response:
if chroma_collection_id == 'boss': if chroma_collection_id == 'boss':
user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口澳门博维集团董事长周家俊G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”" user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口澳门博维集团董事长周家俊G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
elif chroma_collection_id == 'g2e': elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki':
user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅博维康普可可琪琪G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”" user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅博维康普可可琪琪G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "" user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + ""
else: else:
user_question = user_prompt_template + "问题: " + user_question + "" user_question = user_prompt_template + "问题: " + user_question + ""

View File

@ -22,11 +22,11 @@ class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0]) # config = read_yaml(args[0])
# load chromadb and embedding model # load chromadb and embedding model
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:1") self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000) self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/bge-reranker-v2-m3", max_length=512, device = "cuda") self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda")
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -85,7 +85,7 @@ class ChromaQuery(Blackbox):
embedding_model = self.embedding_model_2 embedding_model = self.embedding_model_2
else: else:
try: try:
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1") embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:0")
except: except:
return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -15,9 +15,9 @@ from injector import singleton
import sys,os import sys,os
sys.path.append('/home/gpu/Workspace/CosyVoice') sys.path.append('/home/gpu/Workspace/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav, speed_change # from cosyvoice.utils.file_utils import load_wav, speed_change
import soundfile import soundfile as sf
import pyloudnorm as pyln import pyloudnorm as pyln
from melo.api import TTS as MELOTTS from melo.api import TTS as MELOTTS
@ -115,11 +115,11 @@ class TTS(Blackbox):
if self.melo_mode == 'local': if self.melo_mode == 'local':
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed) audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
f = io.BytesIO() f = io.BytesIO()
soundfile.write(f, audio, 44100, format='wav') sf.write(f, audio, 44100, format='wav')
f.seek(0) f.seek(0)
# Read the audio data from the buffer # Read the audio data from the buffer
data, rate = soundfile.read(f, dtype='float32') data, rate = sf.read(f, dtype='float32')
# Peak normalization # Peak normalization
peak_normalized_audio = pyln.normalize.peak(data, -1.0) peak_normalized_audio = pyln.normalize.peak(data, -1.0)
@ -131,7 +131,7 @@ class TTS(Blackbox):
# Write the loudness normalized audio to an in-memory buffer # Write the loudness normalized audio to an in-memory buffer
normalized_audio_buffer = io.BytesIO() normalized_audio_buffer = io.BytesIO()
soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav') sf.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav')
normalized_audio_buffer.seek(0) normalized_audio_buffer.seek(0)
print("#### MeloTTS Service consume - local : ", (time.time() - current_time)) print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
@ -147,9 +147,10 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男') audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0) f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read() return f.read()
@ -166,8 +167,9 @@ class TTS(Blackbox):
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(56056558) set_all_random_seed(56056558)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language) audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0) f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read() return f.read()
@ -181,9 +183,10 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男') audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0) f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read() return f.read()
@ -197,14 +200,10 @@ class TTS(Blackbox):
elif user_model_name == 'man': elif user_model_name == 'man':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男') audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
try: for i, j in enumerate(audio):
audio, sample_rate = speed_change(audio["tts_speech"], 22050, str(1.5))
audio = audio.numpy().flatten()
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
f = io.BytesIO() f = io.BytesIO()
soundfile.write(f, audio, 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0) f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read() return f.read()

View File

@ -5,7 +5,7 @@ sys.path.append('src/tts/vits')
import soundfile import soundfile
import os import os
os.environ["PYTORCH_JIT"] = "0" # os.environ["PYTORCH_JIT"] = "0"
import torch import torch
import src.tts.vits.commons as commons import src.tts.vits.commons as commons