mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
270 lines
11 KiB
Python
270 lines
11 KiB
Python
import re
|
||
import requests
|
||
import json
|
||
import queue
|
||
import threading
|
||
from fastapi import FastAPI
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
import os
|
||
import io
|
||
import time
|
||
import requests
|
||
from fastapi import Request, Response, status, HTTPException
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from .blackbox import Blackbox
|
||
|
||
from injector import inject
|
||
from injector import singleton
|
||
|
||
from ..log.logging_time import logging_time
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
import asyncio
|
||
from uuid import uuid4 # 用于生成唯一的会话 ID
|
||
|
||
@singleton
|
||
class ChatPipeline(Blackbox):
|
||
|
||
@inject
|
||
def __init__(self, ) -> None:
|
||
self.text_queue = queue.Queue() # 文本队列
|
||
self.audio_queue = queue.Queue() # 音频队列
|
||
self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号
|
||
self.tts_event = threading.Event() # TTS 事件
|
||
self.audio_part_counter = 0 # 音频段计数器
|
||
self.text_part_counter = 0
|
||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||
self.is_last = False
|
||
|
||
self.settings = {} # 外部传入的 settings
|
||
self.lock = threading.Lock() # 创建锁
|
||
|
||
if not os.path.exists(self.audio_dir):
|
||
os.makedirs(self.audio_dir)
|
||
|
||
def __call__(self, *args, **kwargs):
|
||
return self.processing(*args, **kwargs)
|
||
|
||
def valid(self, data: any) -> bool:
|
||
if isinstance(data, bytes):
|
||
return True
|
||
return False
|
||
|
||
def reset_queues(self):
|
||
"""清空之前的队列数据"""
|
||
self.text_queue.queue.clear() # 清空文本队列
|
||
self.audio_queue.queue.clear() # 清空音频队列
|
||
|
||
def clear_audio_files(self):
|
||
"""清空音频文件夹中的所有音频文件"""
|
||
for file_name in os.listdir(self.audio_dir):
|
||
file_path = os.path.join(self.audio_dir, file_name)
|
||
if os.path.isfile(file_path):
|
||
os.remove(file_path)
|
||
print(f"Removed old audio file: {file_path}")
|
||
|
||
def save_audio(self, audio_data, part_number: int):
|
||
"""保存音频数据为文件"""
|
||
file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav")
|
||
with open(file_name, 'wb') as f:
|
||
f.write(audio_data)
|
||
return file_name
|
||
|
||
def chat_stream(self, prompt: str, settings: dict):
|
||
"""从 chat.py 获取实时生成的文本,并放入队列"""
|
||
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
|
||
headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
|
||
data = {
|
||
"prompt": prompt,
|
||
"context": [],
|
||
"settings": settings
|
||
}
|
||
print(f"data_chat: {data}")
|
||
|
||
# 每次执行时清空原有音频文件
|
||
self.clear_audio_files()
|
||
self.audio_part_counter = 0
|
||
self.text_part_counter = 0
|
||
self.is_last = False
|
||
with self.lock: # 确保对 settings 的访问是线程安全的
|
||
llm_stream = settings.get("stream")
|
||
if llm_stream:
|
||
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
|
||
print(f"data_chat1: {data}")
|
||
complete_message = "" # 用于累积完整的文本
|
||
lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
|
||
total_lines = len(lines)
|
||
for i, line in enumerate(lines):
|
||
if line:
|
||
message = line.decode('utf-8')
|
||
|
||
if message.strip().lower() == "data:":
|
||
continue # 跳过"data:"行
|
||
|
||
complete_message += message
|
||
|
||
# 如果包含标点符号,拆分成句子
|
||
if re.search(self.PUNCTUATION, complete_message):
|
||
sentences = re.split(self.PUNCTUATION, complete_message)
|
||
for sentence in sentences[:-1]:
|
||
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
|
||
if cleaned_sentence:
|
||
print(f"Sending complete sentence: {cleaned_sentence}")
|
||
self.text_queue.put(cleaned_sentence) # 放入文本队列
|
||
complete_message = sentences[-1]
|
||
self.text_part_counter += 1
|
||
# 判断是否是最后一句
|
||
if i == total_lines - 2: # 如果是最后一行
|
||
self.is_last = True
|
||
print(f'2.is_last: {self.is_last}')
|
||
|
||
time.sleep(0.2)
|
||
else:
|
||
with requests.post(url, headers=headers, data=json.dumps(data)) as response:
|
||
print(f"data_chat1: {data}")
|
||
if response.status_code == 200:
|
||
response_json = response.json()
|
||
response_content = response_json.get("response")
|
||
self.text_queue.put(response_content)
|
||
|
||
|
||
def send_to_tts(self, settings: dict):
|
||
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
|
||
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
|
||
headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
|
||
with self.lock:
|
||
user_stream = settings.get("tts_stream")
|
||
tts_model_name = settings.get("tts_model_name")
|
||
print(f"data_tts0: {settings}")
|
||
while True:
|
||
try:
|
||
# 获取队列中的一个完整句子
|
||
text = self.text_queue.get(timeout=5)
|
||
|
||
if text is None:
|
||
break
|
||
|
||
if not text.strip():
|
||
continue
|
||
|
||
if tts_model_name == 'sovitstts':
|
||
text = self.filter_invalid_chars(text)
|
||
print(f"data_tts0.1: {settings}")
|
||
data = {
|
||
"settings": settings,
|
||
"text": text
|
||
}
|
||
print(f"data_tts1: {data}")
|
||
if user_stream:
|
||
# 发送请求到 TTS 服务
|
||
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
|
||
|
||
if response.status_code == 200:
|
||
audio_data = response.content
|
||
if isinstance(audio_data, bytes):
|
||
self.audio_part_counter += 1 # 增加音频段计数器
|
||
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
|
||
print(f"Audio part saved as {file_name}")
|
||
|
||
# 将文件名和是否是最后一条消息放入音频队列
|
||
self.audio_queue.put(file_name) # 放入音频队列
|
||
|
||
else:
|
||
print(f"Error: Received non-binary data.")
|
||
else:
|
||
print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
|
||
|
||
else:
|
||
print(f"data_tts2: {data}")
|
||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||
if response.status_code == 200:
|
||
self.audio_queue.put(response.content)
|
||
|
||
# 通知下一个 TTS 可以执行了
|
||
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
|
||
time.sleep(0.2)
|
||
|
||
except queue.Empty:
|
||
time.sleep(1)
|
||
|
||
def filter_invalid_chars(self,text):
|
||
"""过滤无效字符(包括字节流)"""
|
||
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
|
||
|
||
if isinstance(text, bytes):
|
||
text = text.decode('utf-8', errors='ignore')
|
||
|
||
for keyword in invalid_keywords:
|
||
text = text.replace(keyword, "")
|
||
|
||
# 移除所有英文字母和符号(保留中文、标点等)
|
||
text = re.sub(r'[a-zA-Z]', '', text)
|
||
|
||
return text.strip()
|
||
|
||
|
||
@logging_time(logger=logger)
|
||
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
|
||
|
||
# 启动聊天流线程
|
||
threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
|
||
|
||
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
|
||
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
|
||
|
||
return {"message": "Chat and TTS processing started"}
|
||
|
||
# 新增异步方法,等待音频片段生成
|
||
async def wait_for_audio(self):
|
||
while self.audio_queue.empty():
|
||
await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环
|
||
|
||
async def fast_api_handler(self, request: Request) -> Response:
|
||
try:
|
||
data = await request.json()
|
||
text = data.get("text")
|
||
setting = data.get("settings")
|
||
user_stream = setting.get("tts_stream")
|
||
self.is_last = False
|
||
self.audio_part_counter = 0 # 音频段计数器
|
||
self.text_part_counter = 0
|
||
self.reset_queues()
|
||
self.clear_audio_files()
|
||
print(f"data0: {data}")
|
||
if text is None:
|
||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 调用 processing 方法,并传递动态的 text 参数
|
||
response_data = self.processing(text, settings=setting)
|
||
# 根据是否启用流式传输进行处理
|
||
if user_stream:
|
||
# 等待至少一个音频片段生成完成
|
||
await self.wait_for_audio()
|
||
def audio_stream():
|
||
# 从上游服务器流式读取数据并逐块发送
|
||
while self.audio_part_counter != 0 and not self.is_last:
|
||
audio = self.audio_queue.get()
|
||
audio_file = audio
|
||
if audio_file:
|
||
with open(audio_file, "rb") as f:
|
||
print(f"Sending audio file: {audio_file}")
|
||
yield f.read() # 分段发送音频文件内容
|
||
return StreamingResponse(audio_stream(), media_type="audio/wav")
|
||
|
||
else:
|
||
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
|
||
await self.wait_for_audio()
|
||
file_name = self.audio_queue.get()
|
||
if file_name:
|
||
file_name_json = json.loads(file_name.decode('utf-8'))
|
||
# audio_files = []
|
||
# while not self.audio_queue.empty():
|
||
# print("9")
|
||
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
|
||
# 返回多个音频文件
|
||
return JSONResponse(content=file_name_json)
|
||
|
||
except Exception as e:
|
||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) |