Files
jarvis-models/src/blackbox/chatpipeline.py
2025-01-13 11:20:15 +08:00

270 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import requests
import json
import queue
import threading
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import os
import io
import time
import requests
from fastapi import Request, Response, status, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from .blackbox import Blackbox
from injector import inject
from injector import singleton
from ..log.logging_time import logging_time
import logging
logger = logging.getLogger(__name__)
import asyncio
from uuid import uuid4 # 用于生成唯一的会话 ID
@singleton
class ChatPipeline(Blackbox):
@inject
def __init__(self, ) -> None:
self.text_queue = queue.Queue() # 文本队列
self.audio_queue = queue.Queue() # 音频队列
self.PUNCTUATION = r'[。!?、,.?]' # 标点符号
self.tts_event = threading.Event() # TTS 事件
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.audio_dir = "audio_files" # 存储音频文件的目录
self.is_last = False
self.settings = {} # 外部传入的 settings
self.lock = threading.Lock() # 创建锁
if not os.path.exists(self.audio_dir):
os.makedirs(self.audio_dir)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, data: any) -> bool:
if isinstance(data, bytes):
return True
return False
def reset_queues(self):
"""清空之前的队列数据"""
self.text_queue.queue.clear() # 清空文本队列
self.audio_queue.queue.clear() # 清空音频队列
def clear_audio_files(self):
"""清空音频文件夹中的所有音频文件"""
for file_name in os.listdir(self.audio_dir):
file_path = os.path.join(self.audio_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
print(f"Removed old audio file: {file_path}")
def save_audio(self, audio_data, part_number: int):
"""保存音频数据为文件"""
file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav")
with open(file_name, 'wb') as f:
f.write(audio_data)
return file_name
def chat_stream(self, prompt: str, settings: dict):
"""从 chat.py 获取实时生成的文本,并放入队列"""
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
data = {
"prompt": prompt,
"context": [],
"settings": settings
}
print(f"data_chat: {data}")
# 每次执行时清空原有音频文件
self.clear_audio_files()
self.audio_part_counter = 0
self.text_part_counter = 0
self.is_last = False
with self.lock: # 确保对 settings 的访问是线程安全的
llm_stream = settings.get("stream")
if llm_stream:
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
print(f"data_chat1: {data}")
complete_message = "" # 用于累积完整的文本
lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
total_lines = len(lines)
for i, line in enumerate(lines):
if line:
message = line.decode('utf-8')
if message.strip().lower() == "data:":
continue # 跳过"data:"行
complete_message += message
# 如果包含标点符号,拆分成句子
if re.search(self.PUNCTUATION, complete_message):
sentences = re.split(self.PUNCTUATION, complete_message)
for sentence in sentences[:-1]:
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
if cleaned_sentence:
print(f"Sending complete sentence: {cleaned_sentence}")
self.text_queue.put(cleaned_sentence) # 放入文本队列
complete_message = sentences[-1]
self.text_part_counter += 1
# 判断是否是最后一句
if i == total_lines - 2: # 如果是最后一行
self.is_last = True
print(f'2.is_last: {self.is_last}')
time.sleep(0.2)
else:
with requests.post(url, headers=headers, data=json.dumps(data)) as response:
print(f"data_chat1: {data}")
if response.status_code == 200:
response_json = response.json()
response_content = response_json.get("response")
self.text_queue.put(response_content)
def send_to_tts(self, settings: dict):
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
with self.lock:
user_stream = settings.get("tts_stream")
tts_model_name = settings.get("tts_model_name")
print(f"data_tts0: {settings}")
while True:
try:
# 获取队列中的一个完整句子
text = self.text_queue.get(timeout=5)
if text is None:
break
if not text.strip():
continue
if tts_model_name == 'sovitstts':
text = self.filter_invalid_chars(text)
print(f"data_tts0.1: {settings}")
data = {
"settings": settings,
"text": text
}
print(f"data_tts1: {data}")
if user_stream:
# 发送请求到 TTS 服务
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
if response.status_code == 200:
audio_data = response.content
if isinstance(audio_data, bytes):
self.audio_part_counter += 1 # 增加音频段计数器
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
print(f"Audio part saved as {file_name}")
# 将文件名和是否是最后一条消息放入音频队列
self.audio_queue.put(file_name) # 放入音频队列
else:
print(f"Error: Received non-binary data.")
else:
print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
else:
print(f"data_tts2: {data}")
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
self.audio_queue.put(response.content)
# 通知下一个 TTS 可以执行了
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
time.sleep(0.2)
except queue.Empty:
time.sleep(1)
def filter_invalid_chars(self,text):
"""过滤无效字符(包括字节流)"""
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
if isinstance(text, bytes):
text = text.decode('utf-8', errors='ignore')
for keyword in invalid_keywords:
text = text.replace(keyword, "")
# 移除所有英文字母和符号(保留中文、标点等)
text = re.sub(r'[a-zA-Z]', '', text)
return text.strip()
@logging_time(logger=logger)
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
# 启动聊天流线程
threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
return {"message": "Chat and TTS processing started"}
# 新增异步方法,等待音频片段生成
async def wait_for_audio(self):
while self.audio_queue.empty():
await asyncio.sleep(0.2) # 使用异步 sleep避免阻塞事件循环
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
text = data.get("text")
setting = data.get("settings")
user_stream = setting.get("tts_stream")
self.is_last = False
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.reset_queues()
self.clear_audio_files()
print(f"data0: {data}")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
# 调用 processing 方法,并传递动态的 text 参数
response_data = self.processing(text, settings=setting)
# 根据是否启用流式传输进行处理
if user_stream:
# 等待至少一个音频片段生成完成
await self.wait_for_audio()
def audio_stream():
# 从上游服务器流式读取数据并逐块发送
while self.audio_part_counter != 0 and not self.is_last:
audio = self.audio_queue.get()
audio_file = audio
if audio_file:
with open(audio_file, "rb") as f:
print(f"Sending audio file: {audio_file}")
yield f.read() # 分段发送音频文件内容
return StreamingResponse(audio_stream(), media_type="audio/wav")
else:
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
await self.wait_for_audio()
file_name = self.audio_queue.get()
if file_name:
file_name_json = json.loads(file_name.decode('utf-8'))
# audio_files = []
# while not self.audio_queue.empty():
# print("9")
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
# 返回多个音频文件
return JSONResponse(content=file_name_json)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)