From 8783f27fdb7e8ad2bfaba909e02bea8e0996f2cb Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Fri, 11 Oct 2024 20:18:50 +0800 Subject: [PATCH] improvement: update tts --- src/blackbox/chat.py | 3 ++- src/blackbox/chroma_query.py | 8 +++--- src/blackbox/tts.py | 49 ++++++++++++++++++------------------ src/tts/tts_service.py | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 932724d..a3ae892 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -110,8 +110,9 @@ class Chat(Blackbox): if chroma_response: if chroma_collection_id == 'boss': user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE(目标) #\n回答游客的提问。\n# STYLE(风格)#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口,澳门博维集团董事长,周家俊,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”" - elif chroma_collection_id == 'g2e': + elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki': user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE(目标) #\n回答游客的提问。\n# STYLE(风格)#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅,博维,康普可可,琪琪,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”" + print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}") user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "。" else: user_question = user_prompt_template + "问题: " + user_question + "。" diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index ae851e4..c936bf3 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -22,11 +22,11 @@ class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load chromadb and embedding model - self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:1") - self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1") + self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") + self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) - self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/bge-reranker-v2-m3", max_length=512, device = "cuda") + self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -85,7 +85,7 @@ class ChromaQuery(Blackbox): embedding_model = self.embedding_model_2 else: try: - embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1") + embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:0") except: return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST) diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 4fc90f4..64f6d86 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -15,9 +15,9 @@ from injector import singleton import sys,os sys.path.append('/home/gpu/Workspace/CosyVoice') from cosyvoice.cli.cosyvoice import CosyVoice -from cosyvoice.utils.file_utils import load_wav, speed_change +# from cosyvoice.utils.file_utils import load_wav, speed_change -import soundfile +import soundfile as sf import pyloudnorm as pyln from melo.api import TTS as MELOTTS @@ -115,11 +115,11 @@ class TTS(Blackbox): if self.melo_mode == 'local': audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed) f = io.BytesIO() - soundfile.write(f, audio, 44100, format='wav') + sf.write(f, audio, 44100, format='wav') f.seek(0) # Read the audio data from the buffer - data, rate = soundfile.read(f, dtype='float32') + data, rate = sf.read(f, dtype='float32') # Peak normalization peak_normalized_audio = pyln.normalize.peak(data, -1.0) @@ -131,7 +131,7 @@ class TTS(Blackbox): # Write the loudness normalized audio to an in-memory buffer normalized_audio_buffer = io.BytesIO() - soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav') + sf.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav') normalized_audio_buffer.seek(0) print("#### MeloTTS Service consume - local : ", (time.time() - current_time)) @@ -147,10 +147,11 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男') - f = io.BytesIO() - soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') - f.seek(0) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) return f.read() else: @@ -166,9 +167,10 @@ class TTS(Blackbox): if self.cosyvoice_mode == 'local': set_all_random_seed(56056558) audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language) - f = io.BytesIO() - soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') - f.seek(0) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) return f.read() else: @@ -181,10 +183,11 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男') - f = io.BytesIO() - soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') - f.seek(0) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) return f.read() else: @@ -197,15 +200,11 @@ class TTS(Blackbox): elif user_model_name == 'man': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男') - try: - audio, sample_rate = speed_change(audio["tts_speech"], 22050, str(1.5)) - audio = audio.numpy().flatten() - except Exception as e: - print(f"Failed to change speed of audio: \n{e}") - f = io.BytesIO() - soundfile.write(f, audio, 22050, format='wav') - f.seek(0) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) return f.read() else: diff --git a/src/tts/tts_service.py b/src/tts/tts_service.py index f768f1a..0770442 100644 --- a/src/tts/tts_service.py +++ b/src/tts/tts_service.py @@ -5,7 +5,7 @@ sys.path.append('src/tts/vits') import soundfile import os -os.environ["PYTORCH_JIT"] = "0" +# os.environ["PYTORCH_JIT"] = "0" import torch import src.tts.vits.commons as commons