import re import requests import json import queue import threading from fastapi import FastAPI from fastapi.responses import StreamingResponse import os import io import time import requests from fastapi import Request, Response, status, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from .blackbox import Blackbox from injector import inject from injector import singleton from ..log.logging_time import logging_time import logging logger = logging.getLogger(__name__) import asyncio from uuid import uuid4 # 用于生成唯一的会话 ID @singleton class ChatPipeline(Blackbox): @inject def __init__(self, ) -> None: self.text_queue = queue.Queue() # 文本队列 self.audio_queue = queue.Queue() # 音频队列 self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号 self.tts_event = threading.Event() # TTS 事件 self.audio_part_counter = 0 # 音频段计数器 self.text_part_counter = 0 self.audio_dir = "audio_files" # 存储音频文件的目录 self.is_last = False self.settings = {} # 外部传入的 settings self.lock = threading.Lock() # 创建锁 if not os.path.exists(self.audio_dir): os.makedirs(self.audio_dir) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) def valid(self, data: any) -> bool: if isinstance(data, bytes): return True return False def reset_queues(self): """清空之前的队列数据""" self.text_queue.queue.clear() # 清空文本队列 self.audio_queue.queue.clear() # 清空音频队列 def clear_audio_files(self): """清空音频文件夹中的所有音频文件""" for file_name in os.listdir(self.audio_dir): file_path = os.path.join(self.audio_dir, file_name) if os.path.isfile(file_path): os.remove(file_path) print(f"Removed old audio file: {file_path}") def save_audio(self, audio_data, part_number: int): """保存音频数据为文件""" file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav") with open(file_name, 'wb') as f: f.write(audio_data) return file_name def chat_stream(self, prompt: str, settings: dict): """从 chat.py 获取实时生成的文本,并放入队列""" url = 'http://10.6.44.141:8000/?blackbox_name=chat' headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存} data = { "prompt": prompt, "context": [], "settings": settings } print(f"data_chat: {data}") # 每次执行时清空原有音频文件 self.clear_audio_files() self.audio_part_counter = 0 self.text_part_counter = 0 self.is_last = False with self.lock: # 确保对 settings 的访问是线程安全的 llm_stream = settings.get("stream") if llm_stream: with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: print(f"data_chat1: {data}") complete_message = "" # 用于累积完整的文本 lines = list(response.iter_lines()) # 先将所有行读取到一个列表中 total_lines = len(lines) for i, line in enumerate(lines): if line: message = line.decode('utf-8') if message.strip().lower() == "data:": continue # 跳过"data:"行 complete_message += message # 如果包含标点符号,拆分成句子 if re.search(self.PUNCTUATION, complete_message): sentences = re.split(self.PUNCTUATION, complete_message) for sentence in sentences[:-1]: cleaned_sentence = self.filter_invalid_chars(sentence.strip()) if cleaned_sentence: print(f"Sending complete sentence: {cleaned_sentence}") self.text_queue.put(cleaned_sentence) # 放入文本队列 complete_message = sentences[-1] self.text_part_counter += 1 # 判断是否是最后一句 if i == total_lines - 2: # 如果是最后一行 self.is_last = True print(f'2.is_last: {self.is_last}') time.sleep(0.2) else: with requests.post(url, headers=headers, data=json.dumps(data)) as response: print(f"data_chat1: {data}") if response.status_code == 200: response_json = response.json() response_content = response_json.get("response") self.text_queue.put(response_content) def send_to_tts(self, settings: dict): """从队列中获取文本并发送给 tts.py 进行语音合成""" url = 'http://10.6.44.141:8000/?blackbox_name=tts' headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存} with self.lock: user_stream = settings.get("tts_stream") tts_model_name = settings.get("tts_model_name") print(f"data_tts0: {settings}") while True: try: # 获取队列中的一个完整句子 text = self.text_queue.get(timeout=5) if text is None: break if not text.strip(): continue if tts_model_name == 'sovitstts': text = self.filter_invalid_chars(text) print(f"data_tts0.1: {settings}") data = { "settings": settings, "text": text } print(f"data_tts1: {data}") if user_stream: # 发送请求到 TTS 服务 response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) if response.status_code == 200: audio_data = response.content if isinstance(audio_data, bytes): self.audio_part_counter += 1 # 增加音频段计数器 file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 print(f"Audio part saved as {file_name}") # 将文件名和是否是最后一条消息放入音频队列 self.audio_queue.put(file_name) # 放入音频队列 else: print(f"Error: Received non-binary data.") else: print(f"Failed to send to TTS: {response.status_code}, Text: {text}") else: print(f"data_tts2: {data}") response = requests.post(url, headers=headers, data=json.dumps(data)) if response.status_code == 200: self.audio_queue.put(response.content) # 通知下一个 TTS 可以执行了 self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 time.sleep(0.2) except queue.Empty: time.sleep(1) def filter_invalid_chars(self,text): """过滤无效字符(包括字节流)""" invalid_keywords = ["data:", "\n", "\r", "\t", " "] if isinstance(text, bytes): text = text.decode('utf-8', errors='ignore') for keyword in invalid_keywords: text = text.replace(keyword, "") # 移除所有英文字母和符号(保留中文、标点等) text = re.sub(r'[a-zA-Z]', '', text) return text.strip() @logging_time(logger=logger) def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: # 启动聊天流线程 threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start() # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() return {"message": "Chat and TTS processing started"} # 新增异步方法,等待音频片段生成 async def wait_for_audio(self): while self.audio_queue.empty(): await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环 async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() text = data.get("text") setting = data.get("settings") user_stream = setting.get("tts_stream") self.is_last = False self.audio_part_counter = 0 # 音频段计数器 self.text_part_counter = 0 self.reset_queues() self.clear_audio_files() print(f"data0: {data}") if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) # 调用 processing 方法,并传递动态的 text 参数 response_data = self.processing(text, settings=setting) # 根据是否启用流式传输进行处理 if user_stream: # 等待至少一个音频片段生成完成 await self.wait_for_audio() def audio_stream(): # 从上游服务器流式读取数据并逐块发送 while self.audio_part_counter != 0 and not self.is_last: audio = self.audio_queue.get() audio_file = audio if audio_file: with open(audio_file, "rb") as f: print(f"Sending audio file: {audio_file}") yield f.read() # 分段发送音频文件内容 return StreamingResponse(audio_stream(), media_type="audio/wav") else: # 如果没有启用流式传输,可以返回一个完整的响应或音频文件 await self.wait_for_audio() file_name = self.audio_queue.get() if file_name: file_name_json = json.loads(file_name.decode('utf-8')) # audio_files = [] # while not self.audio_queue.empty(): # print("9") # audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 # 返回多个音频文件 return JSONResponse(content=file_name_json) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)