From 30412af1565553f964234741835c76bdc9ecc696 Mon Sep 17 00:00:00 2001 From: verachen <511201264@qq.com> Date: Mon, 13 Jan 2025 11:20:15 +0800 Subject: [PATCH] feat: chatpipeline chat2tts flow --- server.py | 3 +- src/blackbox/chatpipeline.py | 222 +++++++++++++++-------------------- src/blackbox/tts.py | 8 +- 3 files changed, 101 insertions(+), 132 deletions(-) diff --git a/server.py b/server.py index 60cce03..c5c7990 100644 --- a/server.py +++ b/server.py @@ -47,10 +47,11 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No # else: # return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) -@app.get("/audio/{filename}") +@app.get("/audio/audio_files/{filename}") async def serve_audio(filename: str): import os import aiofiles + filename = os.path.join("audio_files", filename) # 确保文件存在 if os.path.exists(filename): try: diff --git a/src/blackbox/chatpipeline.py b/src/blackbox/chatpipeline.py index 55f6d83..ff3c979 100644 --- a/src/blackbox/chatpipeline.py +++ b/src/blackbox/chatpipeline.py @@ -36,6 +36,10 @@ class ChatPipeline(Blackbox): self.audio_part_counter = 0 # 音频段计数器 self.text_part_counter = 0 self.audio_dir = "audio_files" # 存储音频文件的目录 + self.is_last = False + + self.settings = {} # 外部传入的 settings + self.lock = threading.Lock() # 创建锁 if not os.path.exists(self.audio_dir): os.makedirs(self.audio_dir) @@ -68,140 +72,116 @@ class ChatPipeline(Blackbox): f.write(audio_data) return file_name - def chat_stream(self, prompt: str): + def chat_stream(self, prompt: str, settings: dict): """从 chat.py 获取实时生成的文本,并放入队列""" url = 'http://10.6.44.141:8000/?blackbox_name=chat' - headers = {'Content-Type': 'text/plain'} + headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存} data = { "prompt": prompt, "context": [], - "settings": { - "stream": True - } + "settings": settings } + print(f"data_chat: {data}") # 每次执行时清空原有音频文件 self.clear_audio_files() self.audio_part_counter = 0 self.text_part_counter = 0 + self.is_last = False + with self.lock: # 确保对 settings 的访问是线程安全的 + llm_stream = settings.get("stream") + if llm_stream: + with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: + print(f"data_chat1: {data}") + complete_message = "" # 用于累积完整的文本 + lines = list(response.iter_lines()) # 先将所有行读取到一个列表中 + total_lines = len(lines) + for i, line in enumerate(lines): + if line: + message = line.decode('utf-8') + + if message.strip().lower() == "data:": + continue # 跳过"data:"行 + + complete_message += message - with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: - complete_message = "" # 用于累积完整的文本 - last_message = False # 标记是否是最后一句 - for line in response.iter_lines(): - if line: - message = line.decode('utf-8') - - if message.strip().lower() == "data:": - continue # 跳过"data:"行 - - complete_message += message + # 如果包含标点符号,拆分成句子 + if re.search(self.PUNCTUATION, complete_message): + sentences = re.split(self.PUNCTUATION, complete_message) + for sentence in sentences[:-1]: + cleaned_sentence = self.filter_invalid_chars(sentence.strip()) + if cleaned_sentence: + print(f"Sending complete sentence: {cleaned_sentence}") + self.text_queue.put(cleaned_sentence) # 放入文本队列 + complete_message = sentences[-1] + self.text_part_counter += 1 + # 判断是否是最后一句 + if i == total_lines - 2: # 如果是最后一行 + self.is_last = True + print(f'2.is_last: {self.is_last}') - # 如果包含标点符号,拆分成句子 - if re.search(self.PUNCTUATION, complete_message): - sentences = re.split(self.PUNCTUATION, complete_message) - for sentence in sentences[:-1]: - cleaned_sentence = self.filter_invalid_chars(sentence.strip()) - if cleaned_sentence: - print(f"Sending complete sentence: {cleaned_sentence}") - self.text_queue.put({ - 'text': cleaned_sentence, - 'is_last': False # 当前不是最后一句 - }) # 放入文本队列 - complete_message = sentences[-1] - self.text_part_counter += 1 - print(f"***text_part_counter: {self.text_part_counter}") - # 判断是否是最后一句 - print(f'1.last_message: {last_message}') - if not response.iter_lines(): - last_message = True - print(f'2.last_message: {last_message}') + time.sleep(0.2) + else: + with requests.post(url, headers=headers, data=json.dumps(data)) as response: + print(f"data_chat1: {data}") + if response.status_code == 200: + response_json = response.json() + response_content = response_json.get("response") + self.text_queue.put(response_content) - time.sleep(0.2) - # 放入最后一句消息 - print(f'---a---') - if complete_message.strip(): - print('---b---') - cleaned_sentence = self.filter_invalid_chars(complete_message.strip()) - if cleaned_sentence: - print(f"Sending last complete sentence: {cleaned_sentence}") - self.text_queue.put({ - 'text': cleaned_sentence, - 'is_last': True # 最后一条消息 - }) - else: - self.text_queue.put({ - 'text': "结束", - 'is_last': True # 最后一条消息 - }) def send_to_tts(self, settings: dict): """从队列中获取文本并发送给 tts.py 进行语音合成""" url = 'http://10.6.44.141:8000/?blackbox_name=tts' - headers = {'Content-Type': 'text/plain'} - user_stream = settings.get("tts_stream") + headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存} + with self.lock: + user_stream = settings.get("tts_stream") + tts_model_name = settings.get("tts_model_name") + print(f"data_tts0: {settings}") while True: try: # 获取队列中的一个完整句子 - item = self.text_queue.get(timeout=5) + text = self.text_queue.get(timeout=5) - if item is None: + if text is None: break - - text = item['text'] - is_last = item['is_last'] # 判断是否是最后一句 if not text.strip(): continue - text = self.filter_invalid_chars(text) - + if tts_model_name == 'sovitstts': + text = self.filter_invalid_chars(text) + print(f"data_tts0.1: {settings}") data = { "settings": settings, "text": text } - - # if user_stream: - # # 发送请求到 TTS 服务 - # response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) + print(f"data_tts1: {data}") + if user_stream: + # 发送请求到 TTS 服务 + response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) - # if response.status_code == 200: - # audio_data = response.content - # if isinstance(audio_data, bytes): - # self.audio_part_counter += 1 # 增加音频段计数器 - # file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 - # print(f"Audio part saved as {file_name}") + if response.status_code == 200: + audio_data = response.content + if isinstance(audio_data, bytes): + self.audio_part_counter += 1 # 增加音频段计数器 + file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 + print(f"Audio part saved as {file_name}") - # # 将文件名和是否是最后一条消息放入音频队列 - # self.audio_queue.put({ - # 'file_name': file_name, - # 'is_last': is_last - # }) # 放入音频队列 - - # # 如果是最后一句,执行额外的处理 - # if is_last: - # print("This is the last sentence in the stream.") - # # 你可以在这里添加处理最后一句的额外逻辑,例如: - # # - 通知系统音频合成已完成 - # # - 结束音频流等 + # 将文件名和是否是最后一条消息放入音频队列 + self.audio_queue.put(file_name) # 放入音频队列 - # else: - # print(f"Error: Received non-binary data.") - # else: - # print(f"Failed to send to TTS: {response.status_code}, Text: {text}") + else: + print(f"Error: Received non-binary data.") + else: + print(f"Failed to send to TTS: {response.status_code}, Text: {text}") - # else: - # response = requests.post(url, headers=headers, data=json.dumps(data)) - # if response.status_code == 200: - # print("1"*90) - # audio_data = response.content - # print(audio_data) - response = requests.post(url, headers=headers, data=json.dumps(data)) - if response.status_code == 200: - print("1"*90) - audio_data = response.content - print(audio_data) - + else: + print(f"data_tts2: {data}") + response = requests.post(url, headers=headers, data=json.dumps(data)) + if response.status_code == 200: + self.audio_queue.put(response.content) + # 通知下一个 TTS 可以执行了 self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 time.sleep(0.2) @@ -229,7 +209,7 @@ class ChatPipeline(Blackbox): def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: # 启动聊天流线程 - threading.Thread(target=self.chat_stream, args=(text,), daemon=True).start() + threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start() # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() @@ -247,58 +227,44 @@ class ChatPipeline(Blackbox): text = data.get("text") setting = data.get("settings") user_stream = setting.get("tts_stream") - is_last = False + self.is_last = False self.audio_part_counter = 0 # 音频段计数器 self.text_part_counter = 0 self.reset_queues() - + self.clear_audio_files() + print(f"data0: {data}") if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) # 调用 processing 方法,并传递动态的 text 参数 response_data = self.processing(text, settings=setting) - print('---1---') # 根据是否启用流式传输进行处理 if user_stream: - print('---2---') # 等待至少一个音频片段生成完成 await self.wait_for_audio() - print('---3---') def audio_stream(): - print('---4---') - is_last = False # 从上游服务器流式读取数据并逐块发送 - print(f'111self.audio_part_counter: {self.audio_part_counter}') - print(f'111self.text_part_counter: {self.text_part_counter}') - print(f"111.self.audio_queue.qsize: {self.audio_queue.qsize()}") - print(f'111is_last: {is_last}') - while self.audio_part_counter != 0 and not is_last: - print('---5---') - print(f'1.is_last: {is_last}') - print(f"222.self.audio_queue.qsize: {self.audio_queue.qsize()}") + while self.audio_part_counter != 0 and not self.is_last: audio = self.audio_queue.get() - audio_file = audio['file_name'] - is_last = audio['is_last'] - print(f'2.is_last: {is_last}') + audio_file = audio if audio_file: - print('---6---') with open(audio_file, "rb") as f: - print('---7---') print(f"Sending audio file: {audio_file}") yield f.read() # 分段发送音频文件内容 - print('---8---') - print('---9---') return StreamingResponse(audio_stream(), media_type="audio/wav") else: - print('---10---') # 如果没有启用流式传输,可以返回一个完整的响应或音频文件 - audio_files = [] - while not self.audio_queue.empty(): - audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 - + await self.wait_for_audio() + file_name = self.audio_queue.get() + if file_name: + file_name_json = json.loads(file_name.decode('utf-8')) + # audio_files = [] + # while not self.audio_queue.empty(): + # print("9") + # audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 # 返回多个音频文件 - return JSONResponse(content={"audio_files": audio_files}) + return JSONResponse(content=file_name_json) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) \ No newline at end of file diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index ac604f9..7446190 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -154,7 +154,7 @@ class TTS(Blackbox): self.melo_model_init(melo_config) self.cosyvoice_model_init(cosyvoice_config) self.sovits_model_init(sovits_config) - + self.audio_dir = "audio_files" # 存储音频文件的目录 def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -290,9 +290,11 @@ class TTS(Blackbox): } if user_stream: response = requests.get(self.sovits_url, params=message, stream=True) + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) return response else: response = requests.get(self.sovits_url, params=message) + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) return response.content print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) @@ -375,7 +377,7 @@ class TTS(Blackbox): # if user_stream and tts_model_name == 'sovitstts': # if by.status_code == 200: # # 保存 WAV 文件 - # wav_filename = 'audio.wav' + # wav_filename = os.path.join(self.audio_dir, 'audio.wav') # with open(wav_filename, 'wb') as f: # for chunk in by.iter_content(chunk_size=1024): # if chunk: @@ -402,7 +404,7 @@ class TTS(Blackbox): else: - wav_filename = 'audio.wav' + wav_filename = os.path.join(self.audio_dir, 'audio.wav') with open(wav_filename, 'wb') as f: f.write(by)