mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
# 加载embedding模型和chroma server
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
id = "g2e"
|
||||
#client.delete_collection(id)
|
||||
@ -107,7 +107,7 @@ print("collection_number",collection_number)
|
||||
# # chroma 召回
|
||||
# from chromadb.utils import embedding_functions
|
||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
# print(collection.count())
|
||||
@ -152,7 +152,7 @@ print("collection_number",collection_number)
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://192.168.0.200:23333/v1/chat/completions"
|
||||
# url = "http://10.6.44.141:23333/v1/chat/completions"
|
||||
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
# 加载embedding模型和chroma server
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
id = "g2e_english"
|
||||
client.delete_collection(id)
|
||||
@ -107,7 +107,7 @@ print("collection_number",collection_number)
|
||||
# # chroma 召回
|
||||
# from chromadb.utils import embedding_functions
|
||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
# print(collection.count())
|
||||
@ -152,7 +152,7 @@ print("collection_number",collection_number)
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://192.168.0.200:23333/v1/chat/completions"
|
||||
# url = "http://10.6.44.141:23333/v1/chat/completions"
|
||||
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
@ -81,7 +81,7 @@ def get_all_files(folder_path):
|
||||
|
||||
# # 加载embedding模型和chroma server
|
||||
# embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
# client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
# client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
# id = "g2e"
|
||||
# client.delete_collection(id)
|
||||
@ -103,7 +103,7 @@ def get_all_files(folder_path):
|
||||
# chroma 召回
|
||||
from chromadb.utils import embedding_functions
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
print(collection.count())
|
||||
@ -148,7 +148,7 @@ print("time: ", time.time() - start_time)
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://192.168.0.200:23333/v1/chat/completions"
|
||||
# url = "http://10.6.44.141:23333/v1/chat/completions"
|
||||
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
@ -10,64 +10,64 @@ import time
|
||||
|
||||
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
|
||||
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
|
||||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt")
|
||||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
docs = text_splitter.split_documents(documents)
|
||||
print("len(docs)", len(docs))
|
||||
ids = ["粤语语料"+str(i) for i in range(len(docs))]
|
||||
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
id = "boss"
|
||||
client.delete_collection(id)
|
||||
id = "kiki"
|
||||
# client.delete_collection(id)
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
|
||||
|
||||
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||||
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
collection = client.get_collection(id, embedding_function=embedding_model)
|
||||
|
||||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
|
||||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
|
||||
|
||||
while True:
|
||||
usr_question = input("\n 请输入问题: ")
|
||||
# query it
|
||||
time1 = time.time()
|
||||
results = collection.query(
|
||||
query_texts=[usr_question],
|
||||
n_results=10,
|
||||
)
|
||||
time2 = time.time()
|
||||
print("query time: ", time2 - time1)
|
||||
# while True:
|
||||
# usr_question = input("\n 请输入问题: ")
|
||||
# # query it
|
||||
# time1 = time.time()
|
||||
# results = collection.query(
|
||||
# query_texts=[usr_question],
|
||||
# n_results=10,
|
||||
# )
|
||||
# time2 = time.time()
|
||||
# print("query time: ", time2 - time1)
|
||||
|
||||
# print("query: ",usr_question)
|
||||
# print("results: ",print(results["documents"][0]))
|
||||
# # print("query: ",usr_question)
|
||||
# # print("results: ",print(results["documents"][0]))
|
||||
|
||||
|
||||
pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||||
# print('\n',pairs)
|
||||
scores = reranker_model.predict(pairs)
|
||||
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||||
# # print('\n',pairs)
|
||||
# scores = reranker_model.predict(pairs)
|
||||
|
||||
#重新排列文件顺序:
|
||||
print("New Ordering:")
|
||||
i = 0
|
||||
final_result = ''
|
||||
for o in np.argsort(scores)[::-1]:
|
||||
if i == 3 or scores[o] < 0.5:
|
||||
break
|
||||
i += 1
|
||||
print(o+1)
|
||||
print("Scores:", scores[o])
|
||||
print(results["documents"][0][o],'\n')
|
||||
final_result += results["documents"][0][o] + '\n'
|
||||
# #重新排列文件顺序:
|
||||
# print("New Ordering:")
|
||||
# i = 0
|
||||
# final_result = ''
|
||||
# for o in np.argsort(scores)[::-1]:
|
||||
# if i == 3 or scores[o] < 0.5:
|
||||
# break
|
||||
# i += 1
|
||||
# print(o+1)
|
||||
# print("Scores:", scores[o])
|
||||
# print(results["documents"][0][o],'\n')
|
||||
# final_result += results["documents"][0][o] + '\n'
|
||||
|
||||
print("\n final_result: ", final_result)
|
||||
time3 = time.time()
|
||||
print("rerank time: ", time3 - time2)
|
||||
# print("\n final_result: ", final_result)
|
||||
# time3 = time.time()
|
||||
# print("rerank time: ", time3 - time2)
|
||||
58
server.py
58
server.py
@ -1,7 +1,7 @@
|
||||
from typing import Union
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from src.blackbox.blackbox_factory import BlackboxFactory
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -21,6 +21,7 @@ injector = Injector()
|
||||
blackbox_factory = injector.get(BlackboxFactory)
|
||||
|
||||
@app.post("/")
|
||||
@app.get("/")
|
||||
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
|
||||
if not blackbox_name:
|
||||
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
@ -29,3 +30,58 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
|
||||
except ValueError:
|
||||
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return await box.fast_api_handler(request)
|
||||
|
||||
# @app.get("/audio/{filename}")
|
||||
# async def serve_audio(filename: str):
|
||||
# import os
|
||||
# # 确保文件存在
|
||||
# if os.path.exists(filename):
|
||||
# with open(filename, 'rb') as f:
|
||||
# audio_data = f.read()
|
||||
# if filename.endswith(".mp3"):
|
||||
# # 对于 MP3 文件,设置为 inline 以在浏览器中播放
|
||||
# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
|
||||
# elif filename.endswith(".wav"):
|
||||
# # 对于 WAV 文件,设置为 inline 以在浏览器中播放
|
||||
# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
|
||||
# else:
|
||||
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
@app.get("/audio/audio_files/{filename}")
|
||||
async def serve_audio(filename: str):
|
||||
import os
|
||||
import aiofiles
|
||||
filename = os.path.join("audio_files", filename)
|
||||
# 确保文件存在
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(filename, 'rb') as f:
|
||||
audio_data = await f.read()
|
||||
|
||||
# 根据文件类型返回正确的音频响应
|
||||
if filename.endswith(".mp3"):
|
||||
return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
|
||||
elif filename.endswith(".wav"):
|
||||
return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
|
||||
else:
|
||||
return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE)
|
||||
except asyncio.CancelledError:
|
||||
# 处理任务被取消的情况
|
||||
return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as e:
|
||||
# 捕获其他异常
|
||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
else:
|
||||
return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
# @app.get("/audio/{filename}")
|
||||
# async def serve_audio(filename: str):
|
||||
# import os
|
||||
# file_path = f"audio/{filename}"
|
||||
# # 检查文件是否存在
|
||||
# if os.path.exists(file_path):
|
||||
# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav")
|
||||
# else:
|
||||
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404)
|
||||
@ -62,6 +62,11 @@ def tts_loader():
|
||||
from .tts import TTS
|
||||
return Injector().get(TTS)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chatpipeline_loader():
|
||||
from .chatpipeline import ChatPipeline
|
||||
return Injector().get(ChatPipeline)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def emotion_loader():
|
||||
from .emotion import Emotion
|
||||
@ -132,6 +137,7 @@ class BlackboxFactory:
|
||||
self.models["audio_to_text"] = audio_to_text_loader
|
||||
self.models["asr"] = asr_loader
|
||||
self.models["tts"] = tts_loader
|
||||
self.models["chatpipeline"] = chatpipeline_loader
|
||||
self.models["sentiment_engine"] = sentiment_loader
|
||||
self.models["emotion"] = emotion_loader
|
||||
self.models["fastchat"] = fastchat_loader
|
||||
|
||||
@ -60,6 +60,8 @@ class Chat(Blackbox):
|
||||
system_prompt = settings.get('system_prompt')
|
||||
user_prompt_template = settings.get('user_prompt_template')
|
||||
user_stream = settings.get('stream')
|
||||
|
||||
llm_model = "vllm"
|
||||
|
||||
if user_context == None:
|
||||
user_context = []
|
||||
@ -100,10 +102,16 @@ class Chat(Blackbox):
|
||||
#user_presence_penalty = 0.8
|
||||
|
||||
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
||||
user_model_url = "http://10.6.80.75:23333/v1/chat/completions"
|
||||
if llm_model != "vllm":
|
||||
user_model_url = "http://10.6.80.75:23333/v1/chat/completions"
|
||||
else:
|
||||
user_model_url = "http://10.6.80.94:8000/v1/completions"
|
||||
|
||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||
user_model_key = "YOUR_API_KEY"
|
||||
if llm_model != "vllm":
|
||||
user_model_key = "YOUR_API_KEY"
|
||||
else:
|
||||
user_model_key = "vllm"
|
||||
|
||||
if chroma_embedding_model:
|
||||
chroma_response = self.chroma_query(user_question, settings)
|
||||
@ -117,7 +125,10 @@ class Chat(Blackbox):
|
||||
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
|
||||
user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "。"
|
||||
else:
|
||||
user_question = user_prompt_template + "问题: " + user_question + "。"
|
||||
if llm_model != "vllm":
|
||||
user_question = user_prompt_template + "问题: " + user_question + "。"
|
||||
else:
|
||||
user_question = user_question
|
||||
|
||||
print(f"1.user_question: {user_question}")
|
||||
|
||||
@ -172,10 +183,17 @@ class Chat(Blackbox):
|
||||
else:
|
||||
url = user_model_url
|
||||
key = user_model_key
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
"Cache-Control": "no-cache", # 禁用缓存
|
||||
}
|
||||
if llm_model != "vllm":
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
"Cache-Control": "no-cache", # 禁用缓存
|
||||
}
|
||||
else:
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key,
|
||||
"Cache-Control": "no-cache",
|
||||
}
|
||||
|
||||
# system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程**: 严格遵循<Workflow >中设定的工作流程。\n- **无内置知识库** :根据<Workflow >中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**:\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关,直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问:“中国的首都是哪个城市?” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n"
|
||||
|
||||
@ -183,23 +201,37 @@ class Chat(Blackbox):
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
|
||||
chat_inputs={
|
||||
"model": user_model_name,
|
||||
"messages": prompt_template + user_context + [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_question
|
||||
}
|
||||
],
|
||||
"temperature": str(user_temperature),
|
||||
"top_p": str(user_top_p),
|
||||
"n": str(user_n),
|
||||
"max_tokens": str(user_max_tokens),
|
||||
"frequency_penalty": str(user_frequency_penalty),
|
||||
"presence_penalty": str(user_presence_penalty),
|
||||
"stop": str(user_stop),
|
||||
"stream": user_stream,
|
||||
}
|
||||
if llm_model != "vllm":
|
||||
chat_inputs={
|
||||
"model": user_model_name,
|
||||
"messages": prompt_template + user_context + [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_question
|
||||
}
|
||||
],
|
||||
"temperature": str(user_temperature),
|
||||
"top_p": str(user_top_p),
|
||||
"n": str(user_n),
|
||||
"max_tokens": str(user_max_tokens),
|
||||
"frequency_penalty": str(user_frequency_penalty),
|
||||
"presence_penalty": str(user_presence_penalty),
|
||||
"stop": str(user_stop),
|
||||
"stream": user_stream,
|
||||
}
|
||||
else:
|
||||
chat_inputs={
|
||||
"model": user_model_name,
|
||||
"prompt": user_question,
|
||||
"temperature": float(user_temperature),
|
||||
"top_p": float(user_top_p),
|
||||
"n": float(user_n),
|
||||
"max_tokens": float(user_max_tokens),
|
||||
"frequency_penalty": float(user_frequency_penalty),
|
||||
"presence_penalty":float( user_presence_penalty),
|
||||
# "stop": user_stop,
|
||||
"stream": user_stream,
|
||||
}
|
||||
|
||||
# # 获取当前时间戳
|
||||
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@ -252,9 +284,14 @@ class Chat(Blackbox):
|
||||
if response_result.get("choices") is None:
|
||||
yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
else:
|
||||
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
|
||||
yield fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
if llm_model != "vllm":
|
||||
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
|
||||
yield fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
else:
|
||||
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["text"],"\n\n")
|
||||
yield fastchat_response.json()["choices"][0]["text"]
|
||||
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
270
src/blackbox/chatpipeline.py
Normal file
270
src/blackbox/chatpipeline.py
Normal file
@ -0,0 +1,270 @@
|
||||
import re
|
||||
import requests
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import requests
|
||||
from fastapi import Request, Response, status, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from .blackbox import Blackbox
|
||||
|
||||
from injector import inject
|
||||
from injector import singleton
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4 # 用于生成唯一的会话 ID
|
||||
|
||||
@singleton
|
||||
class ChatPipeline(Blackbox):
|
||||
|
||||
@inject
|
||||
def __init__(self, ) -> None:
|
||||
self.text_queue = queue.Queue() # 文本队列
|
||||
self.audio_queue = queue.Queue() # 音频队列
|
||||
self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号
|
||||
self.tts_event = threading.Event() # TTS 事件
|
||||
self.audio_part_counter = 0 # 音频段计数器
|
||||
self.text_part_counter = 0
|
||||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||||
self.is_last = False
|
||||
|
||||
self.settings = {} # 外部传入的 settings
|
||||
self.lock = threading.Lock() # 创建锁
|
||||
|
||||
if not os.path.exists(self.audio_dir):
|
||||
os.makedirs(self.audio_dir)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def valid(self, data: any) -> bool:
|
||||
if isinstance(data, bytes):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_queues(self):
|
||||
"""清空之前的队列数据"""
|
||||
self.text_queue.queue.clear() # 清空文本队列
|
||||
self.audio_queue.queue.clear() # 清空音频队列
|
||||
|
||||
def clear_audio_files(self):
|
||||
"""清空音频文件夹中的所有音频文件"""
|
||||
for file_name in os.listdir(self.audio_dir):
|
||||
file_path = os.path.join(self.audio_dir, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
print(f"Removed old audio file: {file_path}")
|
||||
|
||||
def save_audio(self, audio_data, part_number: int):
|
||||
"""保存音频数据为文件"""
|
||||
file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav")
|
||||
with open(file_name, 'wb') as f:
|
||||
f.write(audio_data)
|
||||
return file_name
|
||||
|
||||
def chat_stream(self, prompt: str, settings: dict):
|
||||
"""从 chat.py 获取实时生成的文本,并放入队列"""
|
||||
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
|
||||
headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"context": [],
|
||||
"settings": settings
|
||||
}
|
||||
print(f"data_chat: {data}")
|
||||
|
||||
# 每次执行时清空原有音频文件
|
||||
self.clear_audio_files()
|
||||
self.audio_part_counter = 0
|
||||
self.text_part_counter = 0
|
||||
self.is_last = False
|
||||
with self.lock: # 确保对 settings 的访问是线程安全的
|
||||
llm_stream = settings.get("stream")
|
||||
if llm_stream:
|
||||
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
|
||||
print(f"data_chat1: {data}")
|
||||
complete_message = "" # 用于累积完整的文本
|
||||
lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
|
||||
total_lines = len(lines)
|
||||
for i, line in enumerate(lines):
|
||||
if line:
|
||||
message = line.decode('utf-8')
|
||||
|
||||
if message.strip().lower() == "data:":
|
||||
continue # 跳过"data:"行
|
||||
|
||||
complete_message += message
|
||||
|
||||
# 如果包含标点符号,拆分成句子
|
||||
if re.search(self.PUNCTUATION, complete_message):
|
||||
sentences = re.split(self.PUNCTUATION, complete_message)
|
||||
for sentence in sentences[:-1]:
|
||||
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
|
||||
if cleaned_sentence:
|
||||
print(f"Sending complete sentence: {cleaned_sentence}")
|
||||
self.text_queue.put(cleaned_sentence) # 放入文本队列
|
||||
complete_message = sentences[-1]
|
||||
self.text_part_counter += 1
|
||||
# 判断是否是最后一句
|
||||
if i == total_lines - 2: # 如果是最后一行
|
||||
self.is_last = True
|
||||
print(f'2.is_last: {self.is_last}')
|
||||
|
||||
time.sleep(0.2)
|
||||
else:
|
||||
with requests.post(url, headers=headers, data=json.dumps(data)) as response:
|
||||
print(f"data_chat1: {data}")
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
response_content = response_json.get("response")
|
||||
self.text_queue.put(response_content)
|
||||
|
||||
|
||||
def send_to_tts(self, settings: dict):
|
||||
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
|
||||
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
|
||||
headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
|
||||
with self.lock:
|
||||
user_stream = settings.get("tts_stream")
|
||||
tts_model_name = settings.get("tts_model_name")
|
||||
print(f"data_tts0: {settings}")
|
||||
while True:
|
||||
try:
|
||||
# 获取队列中的一个完整句子
|
||||
text = self.text_queue.get(timeout=5)
|
||||
|
||||
if text is None:
|
||||
break
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
if tts_model_name == 'sovitstts':
|
||||
text = self.filter_invalid_chars(text)
|
||||
print(f"data_tts0.1: {settings}")
|
||||
data = {
|
||||
"settings": settings,
|
||||
"text": text
|
||||
}
|
||||
print(f"data_tts1: {data}")
|
||||
if user_stream:
|
||||
# 发送请求到 TTS 服务
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
|
||||
|
||||
if response.status_code == 200:
|
||||
audio_data = response.content
|
||||
if isinstance(audio_data, bytes):
|
||||
self.audio_part_counter += 1 # 增加音频段计数器
|
||||
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
|
||||
print(f"Audio part saved as {file_name}")
|
||||
|
||||
# 将文件名和是否是最后一条消息放入音频队列
|
||||
self.audio_queue.put(file_name) # 放入音频队列
|
||||
|
||||
else:
|
||||
print(f"Error: Received non-binary data.")
|
||||
else:
|
||||
print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
|
||||
|
||||
else:
|
||||
print(f"data_tts2: {data}")
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
if response.status_code == 200:
|
||||
self.audio_queue.put(response.content)
|
||||
|
||||
# 通知下一个 TTS 可以执行了
|
||||
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
|
||||
time.sleep(0.2)
|
||||
|
||||
except queue.Empty:
|
||||
time.sleep(1)
|
||||
|
||||
def filter_invalid_chars(self,text):
|
||||
"""过滤无效字符(包括字节流)"""
|
||||
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
|
||||
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode('utf-8', errors='ignore')
|
||||
|
||||
for keyword in invalid_keywords:
|
||||
text = text.replace(keyword, "")
|
||||
|
||||
# 移除所有英文字母和符号(保留中文、标点等)
|
||||
text = re.sub(r'[a-zA-Z]', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
|
||||
|
||||
# 启动聊天流线程
|
||||
threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
|
||||
|
||||
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
|
||||
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
|
||||
|
||||
return {"message": "Chat and TTS processing started"}
|
||||
|
||||
# 新增异步方法,等待音频片段生成
|
||||
async def wait_for_audio(self):
|
||||
while self.audio_queue.empty():
|
||||
await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
text = data.get("text")
|
||||
setting = data.get("settings")
|
||||
user_stream = setting.get("tts_stream")
|
||||
self.is_last = False
|
||||
self.audio_part_counter = 0 # 音频段计数器
|
||||
self.text_part_counter = 0
|
||||
self.reset_queues()
|
||||
self.clear_audio_files()
|
||||
print(f"data0: {data}")
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# 调用 processing 方法,并传递动态的 text 参数
|
||||
response_data = self.processing(text, settings=setting)
|
||||
# 根据是否启用流式传输进行处理
|
||||
if user_stream:
|
||||
# 等待至少一个音频片段生成完成
|
||||
await self.wait_for_audio()
|
||||
def audio_stream():
|
||||
# 从上游服务器流式读取数据并逐块发送
|
||||
while self.audio_part_counter != 0 and not self.is_last:
|
||||
audio = self.audio_queue.get()
|
||||
audio_file = audio
|
||||
if audio_file:
|
||||
with open(audio_file, "rb") as f:
|
||||
print(f"Sending audio file: {audio_file}")
|
||||
yield f.read() # 分段发送音频文件内容
|
||||
return StreamingResponse(audio_stream(), media_type="audio/wav")
|
||||
|
||||
else:
|
||||
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
|
||||
await self.wait_for_audio()
|
||||
file_name = self.audio_queue.get()
|
||||
if file_name:
|
||||
file_name_json = json.loads(file_name.decode('utf-8'))
|
||||
# audio_files = []
|
||||
# while not self.audio_queue.empty():
|
||||
# print("9")
|
||||
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
|
||||
# 返回多个音频文件
|
||||
return JSONResponse(content=file_name_json)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
@ -25,7 +25,7 @@ class ChromaQuery(Blackbox):
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
|
||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0")
|
||||
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||||
self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda")
|
||||
|
||||
@ -60,7 +60,7 @@ class ChromaQuery(Blackbox):
|
||||
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
|
||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||
chroma_host = "192.168.0.200"
|
||||
chroma_host = "10.6.44.141"
|
||||
|
||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||
chroma_port = "7000"
|
||||
@ -72,7 +72,7 @@ class ChromaQuery(Blackbox):
|
||||
chroma_n_results = 10
|
||||
|
||||
# load client and embedding model from init
|
||||
if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port):
|
||||
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port):
|
||||
client = self.client_1
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox):
|
||||
# load embedding model
|
||||
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
|
||||
# load chroma db
|
||||
self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -23,7 +23,7 @@ class G2E(Blackbox):
|
||||
if context == None:
|
||||
context = []
|
||||
#url = 'http://120.196.116.194:48890/v1'
|
||||
url = 'http://192.168.0.200:23333/v1'
|
||||
url = 'http://10.6.44.141:23333/v1'
|
||||
|
||||
background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯
|
||||
|
||||
|
||||
@ -3,18 +3,20 @@ import time
|
||||
from ntpath import join
|
||||
|
||||
import requests
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import Request, Response, status, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from .blackbox import Blackbox
|
||||
from ..tts.tts_service import TTService
|
||||
from ..configuration import MeloConf
|
||||
from ..configuration import CosyVoiceConf
|
||||
from ..configuration import SovitsConf
|
||||
from injector import inject
|
||||
from injector import singleton
|
||||
|
||||
import sys,os
|
||||
sys.path.append('/Workspace/CosyVoice')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
# from cosyvoice.utils.file_utils import load_wav, speed_change
|
||||
|
||||
import soundfile as sf
|
||||
@ -29,13 +31,38 @@ import random
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pydub import AudioSegment
|
||||
import subprocess
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def convert_wav_to_mp3(wav_filename: str):
|
||||
# 检查文件是否是 .wav 格式
|
||||
if not wav_filename.lower().endswith(".wav"):
|
||||
raise ValueError("The input file must be a .wav file")
|
||||
|
||||
# 构建对应的 .mp3 文件路径
|
||||
mp3_filename = wav_filename.replace(".wav", ".mp3")
|
||||
|
||||
# 如果原 MP3 文件已存在,先删除
|
||||
if os.path.exists(mp3_filename):
|
||||
os.remove(mp3_filename)
|
||||
|
||||
# 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件
|
||||
command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
# 检查转换是否成功
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}")
|
||||
|
||||
return mp3_filename
|
||||
|
||||
|
||||
@singleton
|
||||
class TTS(Blackbox):
|
||||
melo_mode: str
|
||||
@ -61,10 +88,11 @@ class TTS(Blackbox):
|
||||
self.melo_url = ''
|
||||
self.melo_mode = melo_config.mode
|
||||
self.melotts = None
|
||||
self.speaker_ids = None
|
||||
# self.speaker_ids = None
|
||||
self.melo_speaker = None
|
||||
if self.melo_mode == 'local':
|
||||
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
|
||||
self.speaker_ids = self.melotts.hps.data.spk2id
|
||||
self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language]
|
||||
else:
|
||||
self.melo_url = melo_config.url
|
||||
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
|
||||
@ -81,19 +109,52 @@ class TTS(Blackbox):
|
||||
self.cosyvoicetts = None
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
|
||||
if self.cosyvoice_mode == 'local':
|
||||
self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
|
||||
|
||||
|
||||
else:
|
||||
self.cosyvoice_url = cosyvoice_config.url
|
||||
logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
|
||||
print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def sovits_model_init(self, sovits_config: SovitsConf) -> None:
|
||||
self.sovits_speed = sovits_config.speed
|
||||
self.sovits_device = sovits_config.device
|
||||
self.sovits_language = sovits_config.language
|
||||
self.sovits_speaker = sovits_config.speaker
|
||||
self.sovits_url = sovits_config.url
|
||||
self.sovits_mode = sovits_config.mode
|
||||
self.sovitstts = None
|
||||
self.speaker_ids = None
|
||||
|
||||
self.sovits_text_lang = sovits_config.text_lang
|
||||
self.sovits_ref_audio_path = sovits_config.ref_audio_path
|
||||
self.sovits_prompt_lang = sovits_config.prompt_lang
|
||||
self.sovits_prompt_text = sovits_config.prompt_text
|
||||
self.sovits_text_split_method = sovits_config.text_split_method
|
||||
self.sovits_batch_size = sovits_config.batch_size
|
||||
self.sovits_media_type = sovits_config.media_type
|
||||
self.sovits_streaming_mode = sovits_config.streaming_mode
|
||||
|
||||
if self.sovits_mode == 'local':
|
||||
# self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device)
|
||||
self.sovits_speaker = self.sovitstts.hps.data.spk2id
|
||||
else:
|
||||
self.sovits_url = sovits_config.url
|
||||
logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
|
||||
print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
|
||||
|
||||
|
||||
@inject
|
||||
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
|
||||
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None:
|
||||
self.tts_service = TTService("yunfeineo")
|
||||
self.melo_model_init(melo_config)
|
||||
self.cosyvoice_model_init(cosyvoice_config)
|
||||
|
||||
self.sovits_model_init(sovits_config)
|
||||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
@ -106,14 +167,19 @@ class TTS(Blackbox):
|
||||
settings = {}
|
||||
user_model_name = settings.get("tts_model_name")
|
||||
chroma_collection_id = settings.get("chroma_collection_id")
|
||||
user_stream = settings.get('tts_stream')
|
||||
print(f"chroma_collection_id: {chroma_collection_id}")
|
||||
print(f"tts_model_name: {user_model_name}")
|
||||
|
||||
if user_stream in [None, ""]:
|
||||
user_stream = True
|
||||
|
||||
text = args[0]
|
||||
current_time = time.time()
|
||||
if user_model_name == 'melotts':
|
||||
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
|
||||
if self.melo_mode == 'local':
|
||||
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
|
||||
audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed)
|
||||
f = io.BytesIO()
|
||||
sf.write(f, audio, 44100, format='wav')
|
||||
f.seek(0)
|
||||
@ -147,7 +213,7 @@ class TTS(Blackbox):
|
||||
elif chroma_collection_id == 'boss':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(35616313)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -166,7 +232,8 @@ class TTS(Blackbox):
|
||||
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(56056558)
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
|
||||
print("*"*90)
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -183,7 +250,7 @@ class TTS(Blackbox):
|
||||
elif chroma_collection_id == 'boss':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(35616313)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -196,11 +263,65 @@ class TTS(Blackbox):
|
||||
}
|
||||
response = requests.post(self.cosyvoice_url, json=message)
|
||||
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
return response.content
|
||||
|
||||
elif user_model_name == 'sovitstts':
|
||||
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
|
||||
if self.sovits_mode == 'local':
|
||||
set_all_random_seed(56056558)
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
f.seek(0)
|
||||
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
return f.read()
|
||||
else:
|
||||
message = {
|
||||
"text": text,
|
||||
"text_lang": self.sovits_text_lang,
|
||||
"ref_audio_path": self.sovits_ref_audio_path,
|
||||
"prompt_lang": self.sovits_prompt_lang,
|
||||
"prompt_text": self.sovits_prompt_text,
|
||||
"text_split_method": self.sovits_text_split_method,
|
||||
"batch_size": self.sovits_batch_size,
|
||||
"media_type": self.sovits_media_type,
|
||||
"streaming_mode": self.sovits_streaming_mode
|
||||
}
|
||||
if user_stream:
|
||||
response = requests.get(self.sovits_url, params=message, stream=True)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response
|
||||
else:
|
||||
response = requests.get(self.sovits_url, params=message)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
|
||||
|
||||
# elif chroma_collection_id == 'boss':
|
||||
# if self.cosyvoice_mode == 'local':
|
||||
# set_all_random_seed(35616313)
|
||||
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
# for i, j in enumerate(audio):
|
||||
# f = io.BytesIO()
|
||||
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
# f.seek(0)
|
||||
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
# return f.read()
|
||||
# else:
|
||||
# message = {
|
||||
# "text": text
|
||||
# }
|
||||
# response = requests.post(self.cosyvoice_url, json=message)
|
||||
# print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
|
||||
# return response.content
|
||||
|
||||
|
||||
elif user_model_name == 'man':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(35616313)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -232,7 +353,73 @@ class TTS(Blackbox):
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
text = data.get("text")
|
||||
setting = data.get("settings")
|
||||
tts_model_name = setting.get("tts_model_name")
|
||||
user_stream = setting.get("tts_stream")
|
||||
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
by = self.processing(text, settings=setting)
|
||||
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
|
||||
if user_stream:
|
||||
if by.status_code == 200:
|
||||
print("*"*90)
|
||||
def audio_stream():
|
||||
# 从上游服务器流式读取数据并逐块发送
|
||||
for chunk in by.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(audio_stream(), media_type="audio/wav")
|
||||
else:
|
||||
raise HTTPException(status_code=by.status_code, detail="请求失败")
|
||||
|
||||
# if user_stream and tts_model_name == 'sovitstts':
|
||||
# if by.status_code == 200:
|
||||
# # 保存 WAV 文件
|
||||
# wav_filename = os.path.join(self.audio_dir, 'audio.wav')
|
||||
# with open(wav_filename, 'wb') as f:
|
||||
# for chunk in by.iter_content(chunk_size=1024):
|
||||
# if chunk:
|
||||
# f.write(chunk)
|
||||
|
||||
# try:
|
||||
# # 检查文件是否存在
|
||||
# if not os.path.exists(wav_filename):
|
||||
# return JSONResponse(content={"error": "File not found"}, status_code=404)
|
||||
|
||||
# # 转换 WAV 为 MP3
|
||||
# mp3_filename = convert_wav_to_mp3(wav_filename)
|
||||
|
||||
# # 返回 WAV 和 MP3 文件的下载链接
|
||||
# wav_url = f"/audio/{wav_filename}"
|
||||
# mp3_url = f"/audio/{mp3_filename}"
|
||||
# return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
|
||||
|
||||
# except Exception as e:
|
||||
# return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
# else:
|
||||
# raise HTTPException(status_code=by.status_code, detail="请求失败")
|
||||
|
||||
|
||||
else:
|
||||
wav_filename = os.path.join(self.audio_dir, 'audio.wav')
|
||||
with open(wav_filename, 'wb') as f:
|
||||
f.write(by)
|
||||
|
||||
try:
|
||||
# 先检查文件是否存在
|
||||
if not os.path.exists(wav_filename):
|
||||
return JSONResponse(content={"error": "File not found"}, status_code=404)
|
||||
|
||||
# 转换 WAV 为 MP3
|
||||
mp3_filename = convert_wav_to_mp3(wav_filename)
|
||||
|
||||
# 返回 MP3 文件的下载链接
|
||||
wav_url = f"/audio/{wav_filename}"
|
||||
mp3_url = f"/audio/{mp3_filename}"
|
||||
return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
@ -82,6 +82,32 @@ class CosyVoiceConf():
|
||||
self.language = config.get("cosyvoicetts.language")
|
||||
self.speaker = config.get("cosyvoicetts.speaker")
|
||||
|
||||
class SovitsConf():
|
||||
mode: str
|
||||
url: str
|
||||
speed: int
|
||||
device: str
|
||||
language: str
|
||||
speaker: str
|
||||
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.mode = config.get("sovitstts.mode")
|
||||
self.url = config.get("sovitstts.url")
|
||||
self.speed = config.get("sovitstts.speed")
|
||||
self.device = config.get("sovitstts.device")
|
||||
self.language = config.get("sovitstts.language")
|
||||
self.speaker = config.get("sovitstts.speaker")
|
||||
|
||||
self.text_lang = config.get("sovitstts.text_lang")
|
||||
self.ref_audio_path = config.get("sovitstts.ref_audio_path")
|
||||
self.prompt_lang = config.get("sovitstts.prompt_lang")
|
||||
self.prompt_text = config.get("sovitstts.prompt_text")
|
||||
self.text_split_method = config.get("sovitstts.text_split_method")
|
||||
self.batch_size = config.get("sovitstts.batch_size")
|
||||
self.media_type = config.get("sovitstts.media_type")
|
||||
self.streaming_mode = config.get("sovitstts.streaming_mode")
|
||||
|
||||
class SenseVoiceConf():
|
||||
mode: str
|
||||
url: str
|
||||
|
||||
Reference in New Issue
Block a user