feat: chatpipeline chat2tts flow

This commit is contained in:
verachen
2025-01-13 11:20:15 +08:00
parent 37174413fe
commit 30412af156
3 changed files with 101 additions and 132 deletions

View File

@ -47,10 +47,11 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
# else: # else:
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) # return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
@app.get("/audio/{filename}") @app.get("/audio/audio_files/{filename}")
async def serve_audio(filename: str): async def serve_audio(filename: str):
import os import os
import aiofiles import aiofiles
filename = os.path.join("audio_files", filename)
# 确保文件存在 # 确保文件存在
if os.path.exists(filename): if os.path.exists(filename):
try: try:

View File

@ -36,6 +36,10 @@ class ChatPipeline(Blackbox):
self.audio_part_counter = 0 # 音频段计数器 self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0 self.text_part_counter = 0
self.audio_dir = "audio_files" # 存储音频文件的目录 self.audio_dir = "audio_files" # 存储音频文件的目录
self.is_last = False
self.settings = {} # 外部传入的 settings
self.lock = threading.Lock() # 创建锁
if not os.path.exists(self.audio_dir): if not os.path.exists(self.audio_dir):
os.makedirs(self.audio_dir) os.makedirs(self.audio_dir)
@ -68,27 +72,31 @@ class ChatPipeline(Blackbox):
f.write(audio_data) f.write(audio_data)
return file_name return file_name
def chat_stream(self, prompt: str): def chat_stream(self, prompt: str, settings: dict):
"""从 chat.py 获取实时生成的文本,并放入队列""" """从 chat.py 获取实时生成的文本,并放入队列"""
url = 'http://10.6.44.141:8000/?blackbox_name=chat' url = 'http://10.6.44.141:8000/?blackbox_name=chat'
headers = {'Content-Type': 'text/plain'} headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
data = { data = {
"prompt": prompt, "prompt": prompt,
"context": [], "context": [],
"settings": { "settings": settings
"stream": True
}
} }
print(f"data_chat: {data}")
# 每次执行时清空原有音频文件 # 每次执行时清空原有音频文件
self.clear_audio_files() self.clear_audio_files()
self.audio_part_counter = 0 self.audio_part_counter = 0
self.text_part_counter = 0 self.text_part_counter = 0
self.is_last = False
with self.lock: # 确保对 settings 的访问是线程安全的
llm_stream = settings.get("stream")
if llm_stream:
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
print(f"data_chat1: {data}")
complete_message = "" # 用于累积完整的文本 complete_message = "" # 用于累积完整的文本
last_message = False # 标记是否是最后一句 lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
for line in response.iter_lines(): total_lines = len(lines)
for i, line in enumerate(lines):
if line: if line:
message = line.decode('utf-8') message = line.decode('utf-8')
@ -104,103 +112,75 @@ class ChatPipeline(Blackbox):
cleaned_sentence = self.filter_invalid_chars(sentence.strip()) cleaned_sentence = self.filter_invalid_chars(sentence.strip())
if cleaned_sentence: if cleaned_sentence:
print(f"Sending complete sentence: {cleaned_sentence}") print(f"Sending complete sentence: {cleaned_sentence}")
self.text_queue.put({ self.text_queue.put(cleaned_sentence) # 放入文本队列
'text': cleaned_sentence,
'is_last': False # 当前不是最后一句
}) # 放入文本队列
complete_message = sentences[-1] complete_message = sentences[-1]
self.text_part_counter += 1 self.text_part_counter += 1
print(f"***text_part_counter: {self.text_part_counter}")
# 判断是否是最后一句 # 判断是否是最后一句
print(f'1.last_message: {last_message}') if i == total_lines - 2: # 如果是最后一行
if not response.iter_lines(): self.is_last = True
last_message = True print(f'2.is_last: {self.is_last}')
print(f'2.last_message: {last_message}')
time.sleep(0.2) time.sleep(0.2)
# 放入最后一句消息
print(f'---a---')
if complete_message.strip():
print('---b---')
cleaned_sentence = self.filter_invalid_chars(complete_message.strip())
if cleaned_sentence:
print(f"Sending last complete sentence: {cleaned_sentence}")
self.text_queue.put({
'text': cleaned_sentence,
'is_last': True # 最后一条消息
})
else: else:
self.text_queue.put({ with requests.post(url, headers=headers, data=json.dumps(data)) as response:
'text': "结束", print(f"data_chat1: {data}")
'is_last': True # 最后一条消息 if response.status_code == 200:
}) response_json = response.json()
response_content = response_json.get("response")
self.text_queue.put(response_content)
def send_to_tts(self, settings: dict): def send_to_tts(self, settings: dict):
"""从队列中获取文本并发送给 tts.py 进行语音合成""" """从队列中获取文本并发送给 tts.py 进行语音合成"""
url = 'http://10.6.44.141:8000/?blackbox_name=tts' url = 'http://10.6.44.141:8000/?blackbox_name=tts'
headers = {'Content-Type': 'text/plain'} headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
with self.lock:
user_stream = settings.get("tts_stream") user_stream = settings.get("tts_stream")
tts_model_name = settings.get("tts_model_name")
print(f"data_tts0: {settings}")
while True: while True:
try: try:
# 获取队列中的一个完整句子 # 获取队列中的一个完整句子
item = self.text_queue.get(timeout=5) text = self.text_queue.get(timeout=5)
if item is None: if text is None:
break break
text = item['text']
is_last = item['is_last'] # 判断是否是最后一句
if not text.strip(): if not text.strip():
continue continue
if tts_model_name == 'sovitstts':
text = self.filter_invalid_chars(text) text = self.filter_invalid_chars(text)
print(f"data_tts0.1: {settings}")
data = { data = {
"settings": settings, "settings": settings,
"text": text "text": text
} }
print(f"data_tts1: {data}")
if user_stream:
# 发送请求到 TTS 服务
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
# if user_stream: if response.status_code == 200:
# # 发送请求到 TTS 服务 audio_data = response.content
# response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) if isinstance(audio_data, bytes):
self.audio_part_counter += 1 # 增加音频段计数器
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
print(f"Audio part saved as {file_name}")
# if response.status_code == 200: # 将文件名和是否是最后一条消息放入音频队列
# audio_data = response.content self.audio_queue.put(file_name) # 放入音频队列
# if isinstance(audio_data, bytes):
# self.audio_part_counter += 1 # 增加音频段计数器
# file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
# print(f"Audio part saved as {file_name}")
# # 将文件名和是否是最后一条消息放入音频队列 else:
# self.audio_queue.put({ print(f"Error: Received non-binary data.")
# 'file_name': file_name, else:
# 'is_last': is_last print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
# }) # 放入音频队列
# # 如果是最后一句,执行额外的处理 else:
# if is_last: print(f"data_tts2: {data}")
# print("This is the last sentence in the stream.")
# # 你可以在这里添加处理最后一句的额外逻辑,例如:
# # - 通知系统音频合成已完成
# # - 结束音频流等
# else:
# print(f"Error: Received non-binary data.")
# else:
# print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
# else:
# response = requests.post(url, headers=headers, data=json.dumps(data))
# if response.status_code == 200:
# print("1"*90)
# audio_data = response.content
# print(audio_data)
response = requests.post(url, headers=headers, data=json.dumps(data)) response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200: if response.status_code == 200:
print("1"*90) self.audio_queue.put(response.content)
audio_data = response.content
print(audio_data)
# 通知下一个 TTS 可以执行了 # 通知下一个 TTS 可以执行了
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
@ -229,7 +209,7 @@ class ChatPipeline(Blackbox):
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
# 启动聊天流线程 # 启动聊天流线程
threading.Thread(target=self.chat_stream, args=(text,), daemon=True).start() threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
@ -247,58 +227,44 @@ class ChatPipeline(Blackbox):
text = data.get("text") text = data.get("text")
setting = data.get("settings") setting = data.get("settings")
user_stream = setting.get("tts_stream") user_stream = setting.get("tts_stream")
is_last = False self.is_last = False
self.audio_part_counter = 0 # 音频段计数器 self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0 self.text_part_counter = 0
self.reset_queues() self.reset_queues()
self.clear_audio_files()
print(f"data0: {data}")
if text is None: if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
# 调用 processing 方法,并传递动态的 text 参数 # 调用 processing 方法,并传递动态的 text 参数
response_data = self.processing(text, settings=setting) response_data = self.processing(text, settings=setting)
print('---1---')
# 根据是否启用流式传输进行处理 # 根据是否启用流式传输进行处理
if user_stream: if user_stream:
print('---2---')
# 等待至少一个音频片段生成完成 # 等待至少一个音频片段生成完成
await self.wait_for_audio() await self.wait_for_audio()
print('---3---')
def audio_stream(): def audio_stream():
print('---4---')
is_last = False
# 从上游服务器流式读取数据并逐块发送 # 从上游服务器流式读取数据并逐块发送
print(f'111self.audio_part_counter: {self.audio_part_counter}') while self.audio_part_counter != 0 and not self.is_last:
print(f'111self.text_part_counter: {self.text_part_counter}')
print(f"111.self.audio_queue.qsize: {self.audio_queue.qsize()}")
print(f'111is_last: {is_last}')
while self.audio_part_counter != 0 and not is_last:
print('---5---')
print(f'1.is_last: {is_last}')
print(f"222.self.audio_queue.qsize: {self.audio_queue.qsize()}")
audio = self.audio_queue.get() audio = self.audio_queue.get()
audio_file = audio['file_name'] audio_file = audio
is_last = audio['is_last']
print(f'2.is_last: {is_last}')
if audio_file: if audio_file:
print('---6---')
with open(audio_file, "rb") as f: with open(audio_file, "rb") as f:
print('---7---')
print(f"Sending audio file: {audio_file}") print(f"Sending audio file: {audio_file}")
yield f.read() # 分段发送音频文件内容 yield f.read() # 分段发送音频文件内容
print('---8---')
print('---9---')
return StreamingResponse(audio_stream(), media_type="audio/wav") return StreamingResponse(audio_stream(), media_type="audio/wav")
else: else:
print('---10---')
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件 # 如果没有启用流式传输,可以返回一个完整的响应或音频文件
audio_files = [] await self.wait_for_audio()
while not self.audio_queue.empty(): file_name = self.audio_queue.get()
audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 if file_name:
file_name_json = json.loads(file_name.decode('utf-8'))
# audio_files = []
# while not self.audio_queue.empty():
# print("9")
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
# 返回多个音频文件 # 返回多个音频文件
return JSONResponse(content={"audio_files": audio_files}) return JSONResponse(content=file_name_json)
except Exception as e: except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -154,7 +154,7 @@ class TTS(Blackbox):
self.melo_model_init(melo_config) self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config) self.cosyvoice_model_init(cosyvoice_config)
self.sovits_model_init(sovits_config) self.sovits_model_init(sovits_config)
self.audio_dir = "audio_files" # 存储音频文件的目录
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -290,9 +290,11 @@ class TTS(Blackbox):
} }
if user_stream: if user_stream:
response = requests.get(self.sovits_url, params=message, stream=True) response = requests.get(self.sovits_url, params=message, stream=True)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response return response
else: else:
response = requests.get(self.sovits_url, params=message) response = requests.get(self.sovits_url, params=message)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response.content return response.content
print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
@ -375,7 +377,7 @@ class TTS(Blackbox):
# if user_stream and tts_model_name == 'sovitstts': # if user_stream and tts_model_name == 'sovitstts':
# if by.status_code == 200: # if by.status_code == 200:
# # 保存 WAV 文件 # # 保存 WAV 文件
# wav_filename = 'audio.wav' # wav_filename = os.path.join(self.audio_dir, 'audio.wav')
# with open(wav_filename, 'wb') as f: # with open(wav_filename, 'wb') as f:
# for chunk in by.iter_content(chunk_size=1024): # for chunk in by.iter_content(chunk_size=1024):
# if chunk: # if chunk:
@ -402,7 +404,7 @@ class TTS(Blackbox):
else: else:
wav_filename = 'audio.wav' wav_filename = os.path.join(self.audio_dir, 'audio.wav')
with open(wav_filename, 'wb') as f: with open(wav_filename, 'wb') as f:
f.write(by) f.write(by)