mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
feat: chatpipeline chat2tts flow
This commit is contained in:
@ -47,10 +47,11 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
|
||||
# else:
|
||||
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
@app.get("/audio/{filename}")
|
||||
@app.get("/audio/audio_files/{filename}")
|
||||
async def serve_audio(filename: str):
|
||||
import os
|
||||
import aiofiles
|
||||
filename = os.path.join("audio_files", filename)
|
||||
# 确保文件存在
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
|
||||
@ -36,6 +36,10 @@ class ChatPipeline(Blackbox):
|
||||
self.audio_part_counter = 0 # 音频段计数器
|
||||
self.text_part_counter = 0
|
||||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||||
self.is_last = False
|
||||
|
||||
self.settings = {} # 外部传入的 settings
|
||||
self.lock = threading.Lock() # 创建锁
|
||||
|
||||
if not os.path.exists(self.audio_dir):
|
||||
os.makedirs(self.audio_dir)
|
||||
@ -68,27 +72,31 @@ class ChatPipeline(Blackbox):
|
||||
f.write(audio_data)
|
||||
return file_name
|
||||
|
||||
def chat_stream(self, prompt: str):
|
||||
def chat_stream(self, prompt: str, settings: dict):
|
||||
"""从 chat.py 获取实时生成的文本,并放入队列"""
|
||||
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
|
||||
headers = {'Content-Type': 'text/plain'}
|
||||
headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"context": [],
|
||||
"settings": {
|
||||
"stream": True
|
||||
}
|
||||
"settings": settings
|
||||
}
|
||||
print(f"data_chat: {data}")
|
||||
|
||||
# 每次执行时清空原有音频文件
|
||||
self.clear_audio_files()
|
||||
self.audio_part_counter = 0
|
||||
self.text_part_counter = 0
|
||||
|
||||
self.is_last = False
|
||||
with self.lock: # 确保对 settings 的访问是线程安全的
|
||||
llm_stream = settings.get("stream")
|
||||
if llm_stream:
|
||||
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
|
||||
print(f"data_chat1: {data}")
|
||||
complete_message = "" # 用于累积完整的文本
|
||||
last_message = False # 标记是否是最后一句
|
||||
for line in response.iter_lines():
|
||||
lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
|
||||
total_lines = len(lines)
|
||||
for i, line in enumerate(lines):
|
||||
if line:
|
||||
message = line.decode('utf-8')
|
||||
|
||||
@ -104,103 +112,75 @@ class ChatPipeline(Blackbox):
|
||||
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
|
||||
if cleaned_sentence:
|
||||
print(f"Sending complete sentence: {cleaned_sentence}")
|
||||
self.text_queue.put({
|
||||
'text': cleaned_sentence,
|
||||
'is_last': False # 当前不是最后一句
|
||||
}) # 放入文本队列
|
||||
self.text_queue.put(cleaned_sentence) # 放入文本队列
|
||||
complete_message = sentences[-1]
|
||||
self.text_part_counter += 1
|
||||
print(f"***text_part_counter: {self.text_part_counter}")
|
||||
# 判断是否是最后一句
|
||||
print(f'1.last_message: {last_message}')
|
||||
if not response.iter_lines():
|
||||
last_message = True
|
||||
print(f'2.last_message: {last_message}')
|
||||
if i == total_lines - 2: # 如果是最后一行
|
||||
self.is_last = True
|
||||
print(f'2.is_last: {self.is_last}')
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# 放入最后一句消息
|
||||
print(f'---a---')
|
||||
if complete_message.strip():
|
||||
print('---b---')
|
||||
cleaned_sentence = self.filter_invalid_chars(complete_message.strip())
|
||||
if cleaned_sentence:
|
||||
print(f"Sending last complete sentence: {cleaned_sentence}")
|
||||
self.text_queue.put({
|
||||
'text': cleaned_sentence,
|
||||
'is_last': True # 最后一条消息
|
||||
})
|
||||
else:
|
||||
self.text_queue.put({
|
||||
'text': "结束",
|
||||
'is_last': True # 最后一条消息
|
||||
})
|
||||
with requests.post(url, headers=headers, data=json.dumps(data)) as response:
|
||||
print(f"data_chat1: {data}")
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
response_content = response_json.get("response")
|
||||
self.text_queue.put(response_content)
|
||||
|
||||
|
||||
def send_to_tts(self, settings: dict):
|
||||
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
|
||||
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
|
||||
headers = {'Content-Type': 'text/plain'}
|
||||
headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
|
||||
with self.lock:
|
||||
user_stream = settings.get("tts_stream")
|
||||
tts_model_name = settings.get("tts_model_name")
|
||||
print(f"data_tts0: {settings}")
|
||||
while True:
|
||||
try:
|
||||
# 获取队列中的一个完整句子
|
||||
item = self.text_queue.get(timeout=5)
|
||||
text = self.text_queue.get(timeout=5)
|
||||
|
||||
if item is None:
|
||||
if text is None:
|
||||
break
|
||||
|
||||
text = item['text']
|
||||
is_last = item['is_last'] # 判断是否是最后一句
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
if tts_model_name == 'sovitstts':
|
||||
text = self.filter_invalid_chars(text)
|
||||
|
||||
print(f"data_tts0.1: {settings}")
|
||||
data = {
|
||||
"settings": settings,
|
||||
"text": text
|
||||
}
|
||||
print(f"data_tts1: {data}")
|
||||
if user_stream:
|
||||
# 发送请求到 TTS 服务
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
|
||||
|
||||
# if user_stream:
|
||||
# # 发送请求到 TTS 服务
|
||||
# response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
|
||||
if response.status_code == 200:
|
||||
audio_data = response.content
|
||||
if isinstance(audio_data, bytes):
|
||||
self.audio_part_counter += 1 # 增加音频段计数器
|
||||
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
|
||||
print(f"Audio part saved as {file_name}")
|
||||
|
||||
# if response.status_code == 200:
|
||||
# audio_data = response.content
|
||||
# if isinstance(audio_data, bytes):
|
||||
# self.audio_part_counter += 1 # 增加音频段计数器
|
||||
# file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
|
||||
# print(f"Audio part saved as {file_name}")
|
||||
# 将文件名和是否是最后一条消息放入音频队列
|
||||
self.audio_queue.put(file_name) # 放入音频队列
|
||||
|
||||
# # 将文件名和是否是最后一条消息放入音频队列
|
||||
# self.audio_queue.put({
|
||||
# 'file_name': file_name,
|
||||
# 'is_last': is_last
|
||||
# }) # 放入音频队列
|
||||
else:
|
||||
print(f"Error: Received non-binary data.")
|
||||
else:
|
||||
print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
|
||||
|
||||
# # 如果是最后一句,执行额外的处理
|
||||
# if is_last:
|
||||
# print("This is the last sentence in the stream.")
|
||||
# # 你可以在这里添加处理最后一句的额外逻辑,例如:
|
||||
# # - 通知系统音频合成已完成
|
||||
# # - 结束音频流等
|
||||
|
||||
# else:
|
||||
# print(f"Error: Received non-binary data.")
|
||||
# else:
|
||||
# print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
|
||||
|
||||
# else:
|
||||
# response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
# if response.status_code == 200:
|
||||
# print("1"*90)
|
||||
# audio_data = response.content
|
||||
# print(audio_data)
|
||||
else:
|
||||
print(f"data_tts2: {data}")
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
if response.status_code == 200:
|
||||
print("1"*90)
|
||||
audio_data = response.content
|
||||
print(audio_data)
|
||||
self.audio_queue.put(response.content)
|
||||
|
||||
# 通知下一个 TTS 可以执行了
|
||||
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
|
||||
@ -229,7 +209,7 @@ class ChatPipeline(Blackbox):
|
||||
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
|
||||
|
||||
# 启动聊天流线程
|
||||
threading.Thread(target=self.chat_stream, args=(text,), daemon=True).start()
|
||||
threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
|
||||
|
||||
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
|
||||
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
|
||||
@ -247,58 +227,44 @@ class ChatPipeline(Blackbox):
|
||||
text = data.get("text")
|
||||
setting = data.get("settings")
|
||||
user_stream = setting.get("tts_stream")
|
||||
is_last = False
|
||||
self.is_last = False
|
||||
self.audio_part_counter = 0 # 音频段计数器
|
||||
self.text_part_counter = 0
|
||||
self.reset_queues()
|
||||
|
||||
self.clear_audio_files()
|
||||
print(f"data0: {data}")
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# 调用 processing 方法,并传递动态的 text 参数
|
||||
response_data = self.processing(text, settings=setting)
|
||||
print('---1---')
|
||||
# 根据是否启用流式传输进行处理
|
||||
if user_stream:
|
||||
print('---2---')
|
||||
# 等待至少一个音频片段生成完成
|
||||
await self.wait_for_audio()
|
||||
print('---3---')
|
||||
def audio_stream():
|
||||
print('---4---')
|
||||
is_last = False
|
||||
# 从上游服务器流式读取数据并逐块发送
|
||||
print(f'111self.audio_part_counter: {self.audio_part_counter}')
|
||||
print(f'111self.text_part_counter: {self.text_part_counter}')
|
||||
print(f"111.self.audio_queue.qsize: {self.audio_queue.qsize()}")
|
||||
print(f'111is_last: {is_last}')
|
||||
while self.audio_part_counter != 0 and not is_last:
|
||||
print('---5---')
|
||||
print(f'1.is_last: {is_last}')
|
||||
print(f"222.self.audio_queue.qsize: {self.audio_queue.qsize()}")
|
||||
while self.audio_part_counter != 0 and not self.is_last:
|
||||
audio = self.audio_queue.get()
|
||||
audio_file = audio['file_name']
|
||||
is_last = audio['is_last']
|
||||
print(f'2.is_last: {is_last}')
|
||||
audio_file = audio
|
||||
if audio_file:
|
||||
print('---6---')
|
||||
with open(audio_file, "rb") as f:
|
||||
print('---7---')
|
||||
print(f"Sending audio file: {audio_file}")
|
||||
yield f.read() # 分段发送音频文件内容
|
||||
print('---8---')
|
||||
print('---9---')
|
||||
return StreamingResponse(audio_stream(), media_type="audio/wav")
|
||||
|
||||
else:
|
||||
print('---10---')
|
||||
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
|
||||
audio_files = []
|
||||
while not self.audio_queue.empty():
|
||||
audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
|
||||
|
||||
await self.wait_for_audio()
|
||||
file_name = self.audio_queue.get()
|
||||
if file_name:
|
||||
file_name_json = json.loads(file_name.decode('utf-8'))
|
||||
# audio_files = []
|
||||
# while not self.audio_queue.empty():
|
||||
# print("9")
|
||||
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
|
||||
# 返回多个音频文件
|
||||
return JSONResponse(content={"audio_files": audio_files})
|
||||
return JSONResponse(content=file_name_json)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
@ -154,7 +154,7 @@ class TTS(Blackbox):
|
||||
self.melo_model_init(melo_config)
|
||||
self.cosyvoice_model_init(cosyvoice_config)
|
||||
self.sovits_model_init(sovits_config)
|
||||
|
||||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
@ -290,9 +290,11 @@ class TTS(Blackbox):
|
||||
}
|
||||
if user_stream:
|
||||
response = requests.get(self.sovits_url, params=message, stream=True)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response
|
||||
else:
|
||||
response = requests.get(self.sovits_url, params=message)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
|
||||
@ -375,7 +377,7 @@ class TTS(Blackbox):
|
||||
# if user_stream and tts_model_name == 'sovitstts':
|
||||
# if by.status_code == 200:
|
||||
# # 保存 WAV 文件
|
||||
# wav_filename = 'audio.wav'
|
||||
# wav_filename = os.path.join(self.audio_dir, 'audio.wav')
|
||||
# with open(wav_filename, 'wb') as f:
|
||||
# for chunk in by.iter_content(chunk_size=1024):
|
||||
# if chunk:
|
||||
@ -402,7 +404,7 @@ class TTS(Blackbox):
|
||||
|
||||
|
||||
else:
|
||||
wav_filename = 'audio.wav'
|
||||
wav_filename = os.path.join(self.audio_dir, 'audio.wav')
|
||||
with open(wav_filename, 'wb') as f:
|
||||
f.write(by)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user