feat: chat2tts stream

This commit is contained in:
verachen
2025-01-09 11:29:34 +08:00
parent ec3b4b143a
commit 37174413fe
12 changed files with 643 additions and 67 deletions

View File

@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
# 加载embedding模型和chroma server
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
id = "g2e"
#client.delete_collection(id)
@ -107,7 +107,7 @@ print("collection_number",collection_number)
# # chroma 召回
# from chromadb.utils import embedding_functions
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
# collection = client.get_collection("g2e", embedding_function=embedding_model)
# print(collection.count())
@ -152,7 +152,7 @@ print("collection_number",collection_number)
# 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key
# }
# url = "http://192.168.0.200:23333/v1/chat/completions"
# url = "http://10.6.44.141:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json())

View File

@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
# 加载embedding模型和chroma server
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
id = "g2e_english"
client.delete_collection(id)
@ -107,7 +107,7 @@ print("collection_number",collection_number)
# # chroma 召回
# from chromadb.utils import embedding_functions
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
# collection = client.get_collection("g2e", embedding_function=embedding_model)
# print(collection.count())
@ -152,7 +152,7 @@ print("collection_number",collection_number)
# 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key
# }
# url = "http://192.168.0.200:23333/v1/chat/completions"
# url = "http://10.6.44.141:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json())

View File

@ -81,7 +81,7 @@ def get_all_files(folder_path):
# # 加载embedding模型和chroma server
# embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
# id = "g2e"
# client.delete_collection(id)
@ -103,7 +103,7 @@ def get_all_files(folder_path):
# chroma 召回
from chromadb.utils import embedding_functions
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
collection = client.get_collection("g2e", embedding_function=embedding_model)
print(collection.count())
@ -148,7 +148,7 @@ print("time: ", time.time() - start_time)
# 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key
# }
# url = "http://192.168.0.200:23333/v1/chat/completions"
# url = "http://10.6.44.141:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json())

View File

@ -10,64 +10,64 @@ import time
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
id = "boss"
client.delete_collection(id)
id = "kiki"
# client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
while True:
usr_question = input("\n 请输入问题: ")
# query it
time1 = time.time()
results = collection.query(
query_texts=[usr_question],
n_results=10,
)
time2 = time.time()
print("query time: ", time2 - time1)
# while True:
# usr_question = input("\n 请输入问题: ")
# # query it
# time1 = time.time()
# results = collection.query(
# query_texts=[usr_question],
# n_results=10,
# )
# time2 = time.time()
# print("query time: ", time2 - time1)
# print("query: ",usr_question)
# print("results: ",print(results["documents"][0]))
# # print("query: ",usr_question)
# # print("results: ",print(results["documents"][0]))
pairs = [[usr_question, doc] for doc in results["documents"][0]]
# print('\n',pairs)
scores = reranker_model.predict(pairs)
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
# # print('\n',pairs)
# scores = reranker_model.predict(pairs)
#重新排列文件顺序:
print("New Ordering:")
i = 0
final_result = ''
for o in np.argsort(scores)[::-1]:
if i == 3 or scores[o] < 0.5:
break
i += 1
print(o+1)
print("Scores:", scores[o])
print(results["documents"][0][o],'\n')
final_result += results["documents"][0][o] + '\n'
# #重新排列文件顺序:
# print("New Ordering:")
# i = 0
# final_result = ''
# for o in np.argsort(scores)[::-1]:
# if i == 3 or scores[o] < 0.5:
# break
# i += 1
# print(o+1)
# print("Scores:", scores[o])
# print(results["documents"][0][o],'\n')
# final_result += results["documents"][0][o] + '\n'
print("\n final_result: ", final_result)
time3 = time.time()
print("rerank time: ", time3 - time2)
# print("\n final_result: ", final_result)
# time3 = time.time()
# print("rerank time: ", time3 - time2)

View File

@ -1,7 +1,7 @@
from typing import Union
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
from src.blackbox.blackbox_factory import BlackboxFactory
from fastapi.middleware.cors import CORSMiddleware
@ -21,6 +21,7 @@ injector = Injector()
blackbox_factory = injector.get(BlackboxFactory)
@app.post("/")
@app.get("/")
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -29,3 +30,57 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request)
# @app.get("/audio/{filename}")
# async def serve_audio(filename: str):
# import os
# # 确保文件存在
# if os.path.exists(filename):
# with open(filename, 'rb') as f:
# audio_data = f.read()
# if filename.endswith(".mp3"):
# # 对于 MP3 文件,设置为 inline 以在浏览器中播放
# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
# elif filename.endswith(".wav"):
# # 对于 WAV 文件,设置为 inline 以在浏览器中播放
# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
# else:
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
@app.get("/audio/{filename}")
async def serve_audio(filename: str):
import os
import aiofiles
# 确保文件存在
if os.path.exists(filename):
try:
# 使用 aiofiles 异步读取文件
async with aiofiles.open(filename, 'rb') as f:
audio_data = await f.read()
# 根据文件类型返回正确的音频响应
if filename.endswith(".mp3"):
return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
elif filename.endswith(".wav"):
return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
else:
return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE)
except asyncio.CancelledError:
# 处理任务被取消的情况
return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST)
except Exception as e:
# 捕获其他异常
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
# @app.get("/audio/{filename}")
# async def serve_audio(filename: str):
# import os
# file_path = f"audio/{filename}"
# # 检查文件是否存在
# if os.path.exists(file_path):
# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav")
# else:
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404)

View File

@ -62,6 +62,11 @@ def tts_loader():
from .tts import TTS
return Injector().get(TTS)
@model_loader(lazy=blackboxConf.lazyloading)
def chatpipeline_loader():
from .chatpipeline import ChatPipeline
return Injector().get(ChatPipeline)
@model_loader(lazy=blackboxConf.lazyloading)
def emotion_loader():
from .emotion import Emotion
@ -132,6 +137,7 @@ class BlackboxFactory:
self.models["audio_to_text"] = audio_to_text_loader
self.models["asr"] = asr_loader
self.models["tts"] = tts_loader
self.models["chatpipeline"] = chatpipeline_loader
self.models["sentiment_engine"] = sentiment_loader
self.models["emotion"] = emotion_loader
self.models["fastchat"] = fastchat_loader

View File

@ -0,0 +1,304 @@
import re
import requests
import json
import queue
import threading
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import os
import io
import time
import requests
from fastapi import Request, Response, status, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from .blackbox import Blackbox
from injector import inject
from injector import singleton
from ..log.logging_time import logging_time
import logging
logger = logging.getLogger(__name__)
import asyncio
from uuid import uuid4 # 用于生成唯一的会话 ID
@singleton
class ChatPipeline(Blackbox):
@inject
def __init__(self, ) -> None:
self.text_queue = queue.Queue() # 文本队列
self.audio_queue = queue.Queue() # 音频队列
self.PUNCTUATION = r'[。!?、,.?]' # 标点符号
self.tts_event = threading.Event() # TTS 事件
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.audio_dir = "audio_files" # 存储音频文件的目录
if not os.path.exists(self.audio_dir):
os.makedirs(self.audio_dir)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, data: any) -> bool:
if isinstance(data, bytes):
return True
return False
def reset_queues(self):
"""清空之前的队列数据"""
self.text_queue.queue.clear() # 清空文本队列
self.audio_queue.queue.clear() # 清空音频队列
def clear_audio_files(self):
"""清空音频文件夹中的所有音频文件"""
for file_name in os.listdir(self.audio_dir):
file_path = os.path.join(self.audio_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
print(f"Removed old audio file: {file_path}")
def save_audio(self, audio_data, part_number: int):
"""保存音频数据为文件"""
file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav")
with open(file_name, 'wb') as f:
f.write(audio_data)
return file_name
def chat_stream(self, prompt: str):
"""从 chat.py 获取实时生成的文本,并放入队列"""
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
headers = {'Content-Type': 'text/plain'}
data = {
"prompt": prompt,
"context": [],
"settings": {
"stream": True
}
}
# 每次执行时清空原有音频文件
self.clear_audio_files()
self.audio_part_counter = 0
self.text_part_counter = 0
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
complete_message = "" # 用于累积完整的文本
last_message = False # 标记是否是最后一句
for line in response.iter_lines():
if line:
message = line.decode('utf-8')
if message.strip().lower() == "data:":
continue # 跳过"data:"行
complete_message += message
# 如果包含标点符号,拆分成句子
if re.search(self.PUNCTUATION, complete_message):
sentences = re.split(self.PUNCTUATION, complete_message)
for sentence in sentences[:-1]:
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
if cleaned_sentence:
print(f"Sending complete sentence: {cleaned_sentence}")
self.text_queue.put({
'text': cleaned_sentence,
'is_last': False # 当前不是最后一句
}) # 放入文本队列
complete_message = sentences[-1]
self.text_part_counter += 1
print(f"***text_part_counter: {self.text_part_counter}")
# 判断是否是最后一句
print(f'1.last_message: {last_message}')
if not response.iter_lines():
last_message = True
print(f'2.last_message: {last_message}')
time.sleep(0.2)
# 放入最后一句消息
print(f'---a---')
if complete_message.strip():
print('---b---')
cleaned_sentence = self.filter_invalid_chars(complete_message.strip())
if cleaned_sentence:
print(f"Sending last complete sentence: {cleaned_sentence}")
self.text_queue.put({
'text': cleaned_sentence,
'is_last': True # 最后一条消息
})
else:
self.text_queue.put({
'text': "结束",
'is_last': True # 最后一条消息
})
def send_to_tts(self, settings: dict):
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
headers = {'Content-Type': 'text/plain'}
user_stream = settings.get("tts_stream")
while True:
try:
# 获取队列中的一个完整句子
item = self.text_queue.get(timeout=5)
if item is None:
break
text = item['text']
is_last = item['is_last'] # 判断是否是最后一句
if not text.strip():
continue
text = self.filter_invalid_chars(text)
data = {
"settings": settings,
"text": text
}
# if user_stream:
# # 发送请求到 TTS 服务
# response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
# if response.status_code == 200:
# audio_data = response.content
# if isinstance(audio_data, bytes):
# self.audio_part_counter += 1 # 增加音频段计数器
# file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
# print(f"Audio part saved as {file_name}")
# # 将文件名和是否是最后一条消息放入音频队列
# self.audio_queue.put({
# 'file_name': file_name,
# 'is_last': is_last
# }) # 放入音频队列
# # 如果是最后一句,执行额外的处理
# if is_last:
# print("This is the last sentence in the stream.")
# # 你可以在这里添加处理最后一句的额外逻辑,例如:
# # - 通知系统音频合成已完成
# # - 结束音频流等
# else:
# print(f"Error: Received non-binary data.")
# else:
# print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
# else:
# response = requests.post(url, headers=headers, data=json.dumps(data))
# if response.status_code == 200:
# print("1"*90)
# audio_data = response.content
# print(audio_data)
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
print("1"*90)
audio_data = response.content
print(audio_data)
# 通知下一个 TTS 可以执行了
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
time.sleep(0.2)
except queue.Empty:
time.sleep(1)
def filter_invalid_chars(self,text):
"""过滤无效字符(包括字节流)"""
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
if isinstance(text, bytes):
text = text.decode('utf-8', errors='ignore')
for keyword in invalid_keywords:
text = text.replace(keyword, "")
# 移除所有英文字母和符号(保留中文、标点等)
text = re.sub(r'[a-zA-Z]', '', text)
return text.strip()
@logging_time(logger=logger)
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
# 启动聊天流线程
threading.Thread(target=self.chat_stream, args=(text,), daemon=True).start()
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
return {"message": "Chat and TTS processing started"}
# 新增异步方法,等待音频片段生成
async def wait_for_audio(self):
while self.audio_queue.empty():
await asyncio.sleep(0.2) # 使用异步 sleep避免阻塞事件循环
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
text = data.get("text")
setting = data.get("settings")
user_stream = setting.get("tts_stream")
is_last = False
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.reset_queues()
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
# 调用 processing 方法,并传递动态的 text 参数
response_data = self.processing(text, settings=setting)
print('---1---')
# 根据是否启用流式传输进行处理
if user_stream:
print('---2---')
# 等待至少一个音频片段生成完成
await self.wait_for_audio()
print('---3---')
def audio_stream():
print('---4---')
is_last = False
# 从上游服务器流式读取数据并逐块发送
print(f'111self.audio_part_counter: {self.audio_part_counter}')
print(f'111self.text_part_counter: {self.text_part_counter}')
print(f"111.self.audio_queue.qsize: {self.audio_queue.qsize()}")
print(f'111is_last: {is_last}')
while self.audio_part_counter != 0 and not is_last:
print('---5---')
print(f'1.is_last: {is_last}')
print(f"222.self.audio_queue.qsize: {self.audio_queue.qsize()}")
audio = self.audio_queue.get()
audio_file = audio['file_name']
is_last = audio['is_last']
print(f'2.is_last: {is_last}')
if audio_file:
print('---6---')
with open(audio_file, "rb") as f:
print('---7---')
print(f"Sending audio file: {audio_file}")
yield f.read() # 分段发送音频文件内容
print('---8---')
print('---9---')
return StreamingResponse(audio_stream(), media_type="audio/wav")
else:
print('---10---')
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
audio_files = []
while not self.audio_queue.empty():
audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
# 返回多个音频文件
return JSONResponse(content={"audio_files": audio_files})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -25,7 +25,7 @@ class ChromaQuery(Blackbox):
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0")
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda")
@ -60,7 +60,7 @@ class ChromaQuery(Blackbox):
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "192.168.0.200"
chroma_host = "10.6.44.141"
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "7000"
@ -72,7 +72,7 @@ class ChromaQuery(Blackbox):
chroma_n_results = 10
# load client and embedding model from init
if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port):
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1
else:
try:

View File

@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox):
# load embedding model
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
# load chroma db
self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)

View File

@ -23,7 +23,7 @@ class G2E(Blackbox):
if context == None:
context = []
#url = 'http://120.196.116.194:48890/v1'
url = 'http://192.168.0.200:23333/v1'
url = 'http://10.6.44.141:23333/v1'
background_prompt = '''KOMBUKIKI是一款茶饮料目标受众 年龄20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯

View File

@ -3,18 +3,20 @@ import time
from ntpath import join
import requests
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from fastapi import Request, Response, status, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from .blackbox import Blackbox
from ..tts.tts_service import TTService
from ..configuration import MeloConf
from ..configuration import CosyVoiceConf
from ..configuration import SovitsConf
from injector import inject
from injector import singleton
import sys,os
sys.path.append('/Workspace/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
# from cosyvoice.utils.file_utils import load_wav, speed_change
import soundfile as sf
@ -29,13 +31,38 @@ import random
import torch
import numpy as np
from pydub import AudioSegment
import subprocess
def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def convert_wav_to_mp3(wav_filename: str):
# 检查文件是否是 .wav 格式
if not wav_filename.lower().endswith(".wav"):
raise ValueError("The input file must be a .wav file")
# 构建对应的 .mp3 文件路径
mp3_filename = wav_filename.replace(".wav", ".mp3")
# 如果原 MP3 文件已存在,先删除
if os.path.exists(mp3_filename):
os.remove(mp3_filename)
# 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件
command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# 检查转换是否成功
if result.returncode != 0:
raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}")
return mp3_filename
@singleton
class TTS(Blackbox):
melo_mode: str
@ -61,10 +88,11 @@ class TTS(Blackbox):
self.melo_url = ''
self.melo_mode = melo_config.mode
self.melotts = None
self.speaker_ids = None
# self.speaker_ids = None
self.melo_speaker = None
if self.melo_mode == 'local':
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
self.speaker_ids = self.melotts.hps.data.spk2id
self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language]
else:
self.melo_url = melo_config.url
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
@ -81,18 +109,51 @@ class TTS(Blackbox):
self.cosyvoicetts = None
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.cosyvoice_mode == 'local':
self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
else:
self.cosyvoice_url = cosyvoice_config.url
logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
@logging_time(logger=logger)
def sovits_model_init(self, sovits_config: SovitsConf) -> None:
self.sovits_speed = sovits_config.speed
self.sovits_device = sovits_config.device
self.sovits_language = sovits_config.language
self.sovits_speaker = sovits_config.speaker
self.sovits_url = sovits_config.url
self.sovits_mode = sovits_config.mode
self.sovitstts = None
self.speaker_ids = None
self.sovits_text_lang = sovits_config.text_lang
self.sovits_ref_audio_path = sovits_config.ref_audio_path
self.sovits_prompt_lang = sovits_config.prompt_lang
self.sovits_prompt_text = sovits_config.prompt_text
self.sovits_text_split_method = sovits_config.text_split_method
self.sovits_batch_size = sovits_config.batch_size
self.sovits_media_type = sovits_config.media_type
self.sovits_streaming_mode = sovits_config.streaming_mode
if self.sovits_mode == 'local':
# self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device)
self.sovits_speaker = self.sovitstts.hps.data.spk2id
else:
self.sovits_url = sovits_config.url
logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
@inject
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None:
self.tts_service = TTService("yunfeineo")
self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config)
self.sovits_model_init(sovits_config)
def __call__(self, *args, **kwargs):
@ -106,14 +167,19 @@ class TTS(Blackbox):
settings = {}
user_model_name = settings.get("tts_model_name")
chroma_collection_id = settings.get("chroma_collection_id")
user_stream = settings.get('tts_stream')
print(f"chroma_collection_id: {chroma_collection_id}")
print(f"tts_model_name: {user_model_name}")
if user_stream in [None, ""]:
user_stream = True
text = args[0]
current_time = time.time()
if user_model_name == 'melotts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.melo_mode == 'local':
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed)
f = io.BytesIO()
sf.write(f, audio, 44100, format='wav')
f.seek(0)
@ -147,7 +213,7 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -166,7 +232,8 @@ class TTS(Blackbox):
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.cosyvoice_mode == 'local':
set_all_random_seed(56056558)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
print("*"*90)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -183,7 +250,7 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -196,11 +263,63 @@ class TTS(Blackbox):
}
response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
return response.content
elif user_model_name == 'sovitstts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.sovits_mode == 'local':
set_all_random_seed(56056558)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text,
"text_lang": self.sovits_text_lang,
"ref_audio_path": self.sovits_ref_audio_path,
"prompt_lang": self.sovits_prompt_lang,
"prompt_text": self.sovits_prompt_text,
"text_split_method": self.sovits_text_split_method,
"batch_size": self.sovits_batch_size,
"media_type": self.sovits_media_type,
"streaming_mode": self.sovits_streaming_mode
}
if user_stream:
response = requests.get(self.sovits_url, params=message, stream=True)
return response
else:
response = requests.get(self.sovits_url, params=message)
return response.content
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
# elif chroma_collection_id == 'boss':
# if self.cosyvoice_mode == 'local':
# set_all_random_seed(35616313)
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
# for i, j in enumerate(audio):
# f = io.BytesIO()
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
# f.seek(0)
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
# return f.read()
# else:
# message = {
# "text": text
# }
# response = requests.post(self.cosyvoice_url, json=message)
# print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
# return response.content
elif user_model_name == 'man':
if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -232,7 +351,73 @@ class TTS(Blackbox):
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
text = data.get("text")
setting = data.get("settings")
tts_model_name = setting.get("tts_model_name")
user_stream = setting.get("tts_stream")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting)
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
if user_stream:
if by.status_code == 200:
print("*"*90)
def audio_stream():
# 从上游服务器流式读取数据并逐块发送
for chunk in by.iter_content(chunk_size=1024):
if chunk:
yield chunk
return StreamingResponse(audio_stream(), media_type="audio/wav")
else:
raise HTTPException(status_code=by.status_code, detail="请求失败")
# if user_stream and tts_model_name == 'sovitstts':
# if by.status_code == 200:
# # 保存 WAV 文件
# wav_filename = 'audio.wav'
# with open(wav_filename, 'wb') as f:
# for chunk in by.iter_content(chunk_size=1024):
# if chunk:
# f.write(chunk)
# try:
# # 检查文件是否存在
# if not os.path.exists(wav_filename):
# return JSONResponse(content={"error": "File not found"}, status_code=404)
# # 转换 WAV 为 MP3
# mp3_filename = convert_wav_to_mp3(wav_filename)
# # 返回 WAV 和 MP3 文件的下载链接
# wav_url = f"/audio/{wav_filename}"
# mp3_url = f"/audio/{mp3_filename}"
# return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)
# else:
# raise HTTPException(status_code=by.status_code, detail="请求失败")
else:
wav_filename = 'audio.wav'
with open(wav_filename, 'wb') as f:
f.write(by)
try:
# 先检查文件是否存在
if not os.path.exists(wav_filename):
return JSONResponse(content={"error": "File not found"}, status_code=404)
# 转换 WAV 为 MP3
mp3_filename = convert_wav_to_mp3(wav_filename)
# 返回 MP3 文件的下载链接
wav_url = f"/audio/{wav_filename}"
mp3_url = f"/audio/{mp3_filename}"
return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

View File

@ -82,6 +82,32 @@ class CosyVoiceConf():
self.language = config.get("cosyvoicetts.language")
self.speaker = config.get("cosyvoicetts.speaker")
class SovitsConf():
mode: str
url: str
speed: int
device: str
language: str
speaker: str
@inject
def __init__(self, config: Configuration) -> None:
self.mode = config.get("sovitstts.mode")
self.url = config.get("sovitstts.url")
self.speed = config.get("sovitstts.speed")
self.device = config.get("sovitstts.device")
self.language = config.get("sovitstts.language")
self.speaker = config.get("sovitstts.speaker")
self.text_lang = config.get("sovitstts.text_lang")
self.ref_audio_path = config.get("sovitstts.ref_audio_path")
self.prompt_lang = config.get("sovitstts.prompt_lang")
self.prompt_text = config.get("sovitstts.prompt_text")
self.text_split_method = config.get("sovitstts.text_split_method")
self.batch_size = config.get("sovitstts.batch_size")
self.media_type = config.get("sovitstts.media_type")
self.streaming_mode = config.get("sovitstts.streaming_mode")
class SenseVoiceConf():
mode: str
url: str