Revert "Vera gdi"

This commit is contained in:
headbigsile
2025-01-20 17:52:31 +08:00
committed by GitHub
parent 2c450a5ffe
commit 83d0a86dae
13 changed files with 95 additions and 677 deletions

View File

@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
# 加载embedding模型和chroma server # 加载embedding模型和chroma server
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='192.168.0.200', port=7000)
id = "g2e" id = "g2e"
#client.delete_collection(id) #client.delete_collection(id)
@ -107,7 +107,7 @@ print("collection_number",collection_number)
# # chroma 召回 # # chroma 召回
# from chromadb.utils import embedding_functions # from chromadb.utils import embedding_functions
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# collection = client.get_collection("g2e", embedding_function=embedding_model) # collection = client.get_collection("g2e", embedding_function=embedding_model)
# print(collection.count()) # print(collection.count())
@ -152,7 +152,7 @@ print("collection_number",collection_number)
# 'Content-Type': 'application/json', # 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key # 'Authorization': "Bearer " + key
# } # }
# url = "http://10.6.44.141:23333/v1/chat/completions" # url = "http://192.168.0.200:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header) # fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json()) # # print(fastchat_response.json())

View File

@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
# 加载embedding模型和chroma server # 加载embedding模型和chroma server
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='192.168.0.200', port=7000)
id = "g2e_english" id = "g2e_english"
client.delete_collection(id) client.delete_collection(id)
@ -107,7 +107,7 @@ print("collection_number",collection_number)
# # chroma 召回 # # chroma 召回
# from chromadb.utils import embedding_functions # from chromadb.utils import embedding_functions
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# collection = client.get_collection("g2e", embedding_function=embedding_model) # collection = client.get_collection("g2e", embedding_function=embedding_model)
# print(collection.count()) # print(collection.count())
@ -152,7 +152,7 @@ print("collection_number",collection_number)
# 'Content-Type': 'application/json', # 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key # 'Authorization': "Bearer " + key
# } # }
# url = "http://10.6.44.141:23333/v1/chat/completions" # url = "http://192.168.0.200:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header) # fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json()) # # print(fastchat_response.json())

View File

@ -81,7 +81,7 @@ def get_all_files(folder_path):
# # 加载embedding模型和chroma server # # 加载embedding模型和chroma server
# embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) # embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # client = chromadb.HttpClient(host='192.168.0.200', port=7000)
# id = "g2e" # id = "g2e"
# client.delete_collection(id) # client.delete_collection(id)
@ -103,7 +103,7 @@ def get_all_files(folder_path):
# chroma 召回 # chroma 召回
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='192.168.0.200', port=7000)
collection = client.get_collection("g2e", embedding_function=embedding_model) collection = client.get_collection("g2e", embedding_function=embedding_model)
print(collection.count()) print(collection.count())
@ -148,7 +148,7 @@ print("time: ", time.time() - start_time)
# 'Content-Type': 'application/json', # 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key # 'Authorization': "Bearer " + key
# } # }
# url = "http://10.6.44.141:23333/v1/chat/completions" # url = "http://192.168.0.200:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header) # fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json()) # # print(fastchat_response.json())

View File

@ -10,64 +10,64 @@ import time
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt")
documents = loader.load() documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents) docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs)) print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))] ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='192.168.0.200', port=7000)
id = "kiki" id = "boss"
# client.delete_collection(id) client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量) # 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='192.168.0.200', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model) collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
# while True: while True:
# usr_question = input("\n 请输入问题: ") usr_question = input("\n 请输入问题: ")
# # query it # query it
# time1 = time.time() time1 = time.time()
# results = collection.query( results = collection.query(
# query_texts=[usr_question], query_texts=[usr_question],
# n_results=10, n_results=10,
# ) )
# time2 = time.time() time2 = time.time()
# print("query time: ", time2 - time1) print("query time: ", time2 - time1)
# # print("query: ",usr_question) # print("query: ",usr_question)
# # print("results: ",print(results["documents"][0])) # print("results: ",print(results["documents"][0]))
# pairs = [[usr_question, doc] for doc in results["documents"][0]] pairs = [[usr_question, doc] for doc in results["documents"][0]]
# # print('\n',pairs) # print('\n',pairs)
# scores = reranker_model.predict(pairs) scores = reranker_model.predict(pairs)
# #重新排列文件顺序: #重新排列文件顺序:
# print("New Ordering:") print("New Ordering:")
# i = 0 i = 0
# final_result = '' final_result = ''
# for o in np.argsort(scores)[::-1]: for o in np.argsort(scores)[::-1]:
# if i == 3 or scores[o] < 0.5: if i == 3 or scores[o] < 0.5:
# break break
# i += 1 i += 1
# print(o+1) print(o+1)
# print("Scores:", scores[o]) print("Scores:", scores[o])
# print(results["documents"][0][o],'\n') print(results["documents"][0][o],'\n')
# final_result += results["documents"][0][o] + '\n' final_result += results["documents"][0][o] + '\n'
# print("\n final_result: ", final_result) print("\n final_result: ", final_result)
# time3 = time.time() time3 = time.time()
# print("rerank time: ", time3 - time2) print("rerank time: ", time3 - time2)

View File

@ -1,7 +1,7 @@
from typing import Union from typing import Union
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse
from src.blackbox.blackbox_factory import BlackboxFactory from src.blackbox.blackbox_factory import BlackboxFactory
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -21,7 +21,6 @@ injector = Injector()
blackbox_factory = injector.get(BlackboxFactory) blackbox_factory = injector.get(BlackboxFactory)
@app.post("/") @app.post("/")
@app.get("/")
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
if not blackbox_name: if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -30,58 +29,3 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
except ValueError: except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request) return await box.fast_api_handler(request)
# @app.get("/audio/{filename}")
# async def serve_audio(filename: str):
# import os
# # 确保文件存在
# if os.path.exists(filename):
# with open(filename, 'rb') as f:
# audio_data = f.read()
# if filename.endswith(".mp3"):
# # 对于 MP3 文件,设置为 inline 以在浏览器中播放
# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
# elif filename.endswith(".wav"):
# # 对于 WAV 文件,设置为 inline 以在浏览器中播放
# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
# else:
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
@app.get("/audio/audio_files/{filename}")
async def serve_audio(filename: str):
import os
import aiofiles
filename = os.path.join("audio_files", filename)
# 确保文件存在
if os.path.exists(filename):
try:
# 使用 aiofiles 异步读取文件
async with aiofiles.open(filename, 'rb') as f:
audio_data = await f.read()
# 根据文件类型返回正确的音频响应
if filename.endswith(".mp3"):
return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"})
elif filename.endswith(".wav"):
return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"})
else:
return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE)
except asyncio.CancelledError:
# 处理任务被取消的情况
return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST)
except Exception as e:
# 捕获其他异常
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND)
# @app.get("/audio/{filename}")
# async def serve_audio(filename: str):
# import os
# file_path = f"audio/{filename}"
# # 检查文件是否存在
# if os.path.exists(file_path):
# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav")
# else:
# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404)

View File

@ -62,11 +62,6 @@ def tts_loader():
from .tts import TTS from .tts import TTS
return Injector().get(TTS) return Injector().get(TTS)
@model_loader(lazy=blackboxConf.lazyloading)
def chatpipeline_loader():
from .chatpipeline import ChatPipeline
return Injector().get(ChatPipeline)
@model_loader(lazy=blackboxConf.lazyloading) @model_loader(lazy=blackboxConf.lazyloading)
def emotion_loader(): def emotion_loader():
from .emotion import Emotion from .emotion import Emotion
@ -137,7 +132,6 @@ class BlackboxFactory:
self.models["audio_to_text"] = audio_to_text_loader self.models["audio_to_text"] = audio_to_text_loader
self.models["asr"] = asr_loader self.models["asr"] = asr_loader
self.models["tts"] = tts_loader self.models["tts"] = tts_loader
self.models["chatpipeline"] = chatpipeline_loader
self.models["sentiment_engine"] = sentiment_loader self.models["sentiment_engine"] = sentiment_loader
self.models["emotion"] = emotion_loader self.models["emotion"] = emotion_loader
self.models["fastchat"] = fastchat_loader self.models["fastchat"] = fastchat_loader

View File

@ -61,8 +61,6 @@ class Chat(Blackbox):
user_prompt_template = settings.get('user_prompt_template') user_prompt_template = settings.get('user_prompt_template')
user_stream = settings.get('stream') user_stream = settings.get('stream')
llm_model = "vllm"
if user_context == None: if user_context == None:
user_context = [] user_context = []
@ -102,16 +100,10 @@ class Chat(Blackbox):
#user_presence_penalty = 0.8 #user_presence_penalty = 0.8
if user_model_url is None or user_model_url.isspace() or user_model_url == "": if user_model_url is None or user_model_url.isspace() or user_model_url == "":
if llm_model != "vllm":
user_model_url = "http://10.6.80.75:23333/v1/chat/completions" user_model_url = "http://10.6.80.75:23333/v1/chat/completions"
else:
user_model_url = "http://10.6.80.94:8000/v1/completions"
if user_model_key is None or user_model_key.isspace() or user_model_key == "": if user_model_key is None or user_model_key.isspace() or user_model_key == "":
if llm_model != "vllm":
user_model_key = "YOUR_API_KEY" user_model_key = "YOUR_API_KEY"
else:
user_model_key = "vllm"
if chroma_embedding_model: if chroma_embedding_model:
chroma_response = self.chroma_query(user_question, settings) chroma_response = self.chroma_query(user_question, settings)
@ -125,10 +117,7 @@ class Chat(Blackbox):
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}") print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "" user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + ""
else: else:
if llm_model != "vllm":
user_question = user_prompt_template + "问题: " + user_question + "" user_question = user_prompt_template + "问题: " + user_question + ""
else:
user_question = user_question
print(f"1.user_question: {user_question}") print(f"1.user_question: {user_question}")
@ -183,17 +172,10 @@ class Chat(Blackbox):
else: else:
url = user_model_url url = user_model_url
key = user_model_key key = user_model_key
if llm_model != "vllm":
header = { header = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
"Cache-Control": "no-cache", # 禁用缓存 "Cache-Control": "no-cache", # 禁用缓存
} }
else:
header = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + key,
"Cache-Control": "no-cache",
}
# system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程** 严格遵循<Workflow >中设定的工作流程。\n- **无内置知识库** :根据<Workflow >中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问“中国的首都是哪个城市” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n" # system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程** 严格遵循<Workflow >中设定的工作流程。\n- **无内置知识库** :根据<Workflow >中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问“中国的首都是哪个城市” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n"
@ -201,7 +183,6 @@ class Chat(Blackbox):
{"role": "system", "content": system_prompt} {"role": "system", "content": system_prompt}
] ]
if llm_model != "vllm":
chat_inputs={ chat_inputs={
"model": user_model_name, "model": user_model_name,
"messages": prompt_template + user_context + [ "messages": prompt_template + user_context + [
@ -219,19 +200,6 @@ class Chat(Blackbox):
"stop": str(user_stop), "stop": str(user_stop),
"stream": user_stream, "stream": user_stream,
} }
else:
chat_inputs={
"model": user_model_name,
"prompt": user_question,
"temperature": float(user_temperature),
"top_p": float(user_top_p),
"n": float(user_n),
"max_tokens": float(user_max_tokens),
"frequency_penalty": float(user_frequency_penalty),
"presence_penalty":float( user_presence_penalty),
# "stop": user_stop,
"stream": user_stream,
}
# # 获取当前时间戳 # # 获取当前时间戳
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -284,13 +252,8 @@ class Chat(Blackbox):
if response_result.get("choices") is None: if response_result.get("choices") is None:
yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST) yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST)
else: else:
if llm_model != "vllm":
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n") print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
yield fastchat_response.json()["choices"][0]["message"]["content"] yield fastchat_response.json()["choices"][0]["message"]["content"]
else:
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["text"],"\n\n")
yield fastchat_response.json()["choices"][0]["text"]
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
try: try:

View File

@ -1,270 +0,0 @@
import re
import requests
import json
import queue
import threading
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import os
import io
import time
import requests
from fastapi import Request, Response, status, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from .blackbox import Blackbox
from injector import inject
from injector import singleton
from ..log.logging_time import logging_time
import logging
logger = logging.getLogger(__name__)
import asyncio
from uuid import uuid4 # 用于生成唯一的会话 ID
@singleton
class ChatPipeline(Blackbox):
@inject
def __init__(self, ) -> None:
self.text_queue = queue.Queue() # 文本队列
self.audio_queue = queue.Queue() # 音频队列
self.PUNCTUATION = r'[。!?、,.?]' # 标点符号
self.tts_event = threading.Event() # TTS 事件
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.audio_dir = "audio_files" # 存储音频文件的目录
self.is_last = False
self.settings = {} # 外部传入的 settings
self.lock = threading.Lock() # 创建锁
if not os.path.exists(self.audio_dir):
os.makedirs(self.audio_dir)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, data: any) -> bool:
if isinstance(data, bytes):
return True
return False
def reset_queues(self):
"""清空之前的队列数据"""
self.text_queue.queue.clear() # 清空文本队列
self.audio_queue.queue.clear() # 清空音频队列
def clear_audio_files(self):
"""清空音频文件夹中的所有音频文件"""
for file_name in os.listdir(self.audio_dir):
file_path = os.path.join(self.audio_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
print(f"Removed old audio file: {file_path}")
def save_audio(self, audio_data, part_number: int):
"""保存音频数据为文件"""
file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav")
with open(file_name, 'wb') as f:
f.write(audio_data)
return file_name
def chat_stream(self, prompt: str, settings: dict):
"""从 chat.py 获取实时生成的文本,并放入队列"""
url = 'http://10.6.44.141:8000/?blackbox_name=chat'
headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存}
data = {
"prompt": prompt,
"context": [],
"settings": settings
}
print(f"data_chat: {data}")
# 每次执行时清空原有音频文件
self.clear_audio_files()
self.audio_part_counter = 0
self.text_part_counter = 0
self.is_last = False
with self.lock: # 确保对 settings 的访问是线程安全的
llm_stream = settings.get("stream")
if llm_stream:
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
print(f"data_chat1: {data}")
complete_message = "" # 用于累积完整的文本
lines = list(response.iter_lines()) # 先将所有行读取到一个列表中
total_lines = len(lines)
for i, line in enumerate(lines):
if line:
message = line.decode('utf-8')
if message.strip().lower() == "data:":
continue # 跳过"data:"行
complete_message += message
# 如果包含标点符号,拆分成句子
if re.search(self.PUNCTUATION, complete_message):
sentences = re.split(self.PUNCTUATION, complete_message)
for sentence in sentences[:-1]:
cleaned_sentence = self.filter_invalid_chars(sentence.strip())
if cleaned_sentence:
print(f"Sending complete sentence: {cleaned_sentence}")
self.text_queue.put(cleaned_sentence) # 放入文本队列
complete_message = sentences[-1]
self.text_part_counter += 1
# 判断是否是最后一句
if i == total_lines - 2: # 如果是最后一行
self.is_last = True
print(f'2.is_last: {self.is_last}')
time.sleep(0.2)
else:
with requests.post(url, headers=headers, data=json.dumps(data)) as response:
print(f"data_chat1: {data}")
if response.status_code == 200:
response_json = response.json()
response_content = response_json.get("response")
self.text_queue.put(response_content)
def send_to_tts(self, settings: dict):
"""从队列中获取文本并发送给 tts.py 进行语音合成"""
url = 'http://10.6.44.141:8000/?blackbox_name=tts'
headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存}
with self.lock:
user_stream = settings.get("tts_stream")
tts_model_name = settings.get("tts_model_name")
print(f"data_tts0: {settings}")
while True:
try:
# 获取队列中的一个完整句子
text = self.text_queue.get(timeout=5)
if text is None:
break
if not text.strip():
continue
if tts_model_name == 'sovitstts':
text = self.filter_invalid_chars(text)
print(f"data_tts0.1: {settings}")
data = {
"settings": settings,
"text": text
}
print(f"data_tts1: {data}")
if user_stream:
# 发送请求到 TTS 服务
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
if response.status_code == 200:
audio_data = response.content
if isinstance(audio_data, bytes):
self.audio_part_counter += 1 # 增加音频段计数器
file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件
print(f"Audio part saved as {file_name}")
# 将文件名和是否是最后一条消息放入音频队列
self.audio_queue.put(file_name) # 放入音频队列
else:
print(f"Error: Received non-binary data.")
else:
print(f"Failed to send to TTS: {response.status_code}, Text: {text}")
else:
print(f"data_tts2: {data}")
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
self.audio_queue.put(response.content)
# 通知下一个 TTS 可以执行了
self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程
time.sleep(0.2)
except queue.Empty:
time.sleep(1)
def filter_invalid_chars(self,text):
"""过滤无效字符(包括字节流)"""
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
if isinstance(text, bytes):
text = text.decode('utf-8', errors='ignore')
for keyword in invalid_keywords:
text = text.replace(keyword, "")
# 移除所有英文字母和符号(保留中文、标点等)
text = re.sub(r'[a-zA-Z]', '', text)
return text.strip()
@logging_time(logger=logger)
def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO:
# 启动聊天流线程
threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start()
# 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务
threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start()
return {"message": "Chat and TTS processing started"}
# 新增异步方法,等待音频片段生成
async def wait_for_audio(self):
while self.audio_queue.empty():
await asyncio.sleep(0.2) # 使用异步 sleep避免阻塞事件循环
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
text = data.get("text")
setting = data.get("settings")
user_stream = setting.get("tts_stream")
self.is_last = False
self.audio_part_counter = 0 # 音频段计数器
self.text_part_counter = 0
self.reset_queues()
self.clear_audio_files()
print(f"data0: {data}")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
# 调用 processing 方法,并传递动态的 text 参数
response_data = self.processing(text, settings=setting)
# 根据是否启用流式传输进行处理
if user_stream:
# 等待至少一个音频片段生成完成
await self.wait_for_audio()
def audio_stream():
# 从上游服务器流式读取数据并逐块发送
while self.audio_part_counter != 0 and not self.is_last:
audio = self.audio_queue.get()
audio_file = audio
if audio_file:
with open(audio_file, "rb") as f:
print(f"Sending audio file: {audio_file}")
yield f.read() # 分段发送音频文件内容
return StreamingResponse(audio_stream(), media_type="audio/wav")
else:
# 如果没有启用流式传输,可以返回一个完整的响应或音频文件
await self.wait_for_audio()
file_name = self.audio_queue.get()
if file_name:
file_name_json = json.loads(file_name.decode('utf-8'))
# audio_files = []
# while not self.audio_queue.empty():
# print("9")
# audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名
# 返回多个音频文件
return JSONResponse(content=file_name_json)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -25,7 +25,7 @@ class ChromaQuery(Blackbox):
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0")
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda")
@ -60,7 +60,7 @@ class ChromaQuery(Blackbox):
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
if chroma_host is None or chroma_host.isspace() or chroma_host == "": if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.44.141" chroma_host = "192.168.0.200"
if chroma_port is None or chroma_port.isspace() or chroma_port == "": if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "7000" chroma_port = "7000"
@ -72,7 +72,7 @@ class ChromaQuery(Blackbox):
chroma_n_results = 10 chroma_n_results = 10
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
try: try:

View File

@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox):
# load embedding model # load embedding model
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
# load chroma db # load chroma db
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -23,7 +23,7 @@ class G2E(Blackbox):
if context == None: if context == None:
context = [] context = []
#url = 'http://120.196.116.194:48890/v1' #url = 'http://120.196.116.194:48890/v1'
url = 'http://10.6.44.141:23333/v1' url = 'http://192.168.0.200:23333/v1'
background_prompt = '''KOMBUKIKI是一款茶饮料目标受众 年龄20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯 background_prompt = '''KOMBUKIKI是一款茶饮料目标受众 年龄20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯

View File

@ -3,20 +3,18 @@ import time
from ntpath import join from ntpath import join
import requests import requests
from fastapi import Request, Response, status, HTTPException from fastapi import Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse
from .blackbox import Blackbox from .blackbox import Blackbox
from ..tts.tts_service import TTService from ..tts.tts_service import TTService
from ..configuration import MeloConf from ..configuration import MeloConf
from ..configuration import CosyVoiceConf from ..configuration import CosyVoiceConf
from ..configuration import SovitsConf
from injector import inject from injector import inject
from injector import singleton from injector import singleton
import sys,os import sys,os
sys.path.append('/Workspace/CosyVoice') sys.path.append('/Workspace/CosyVoice')
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
# from cosyvoice.utils.file_utils import load_wav, speed_change # from cosyvoice.utils.file_utils import load_wav, speed_change
import soundfile as sf import soundfile as sf
@ -31,37 +29,12 @@ import random
import torch import torch
import numpy as np import numpy as np
from pydub import AudioSegment
import subprocess
def set_all_random_seed(seed): def set_all_random_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def convert_wav_to_mp3(wav_filename: str):
# 检查文件是否是 .wav 格式
if not wav_filename.lower().endswith(".wav"):
raise ValueError("The input file must be a .wav file")
# 构建对应的 .mp3 文件路径
mp3_filename = wav_filename.replace(".wav", ".mp3")
# 如果原 MP3 文件已存在,先删除
if os.path.exists(mp3_filename):
os.remove(mp3_filename)
# 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件
command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# 检查转换是否成功
if result.returncode != 0:
raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}")
return mp3_filename
@singleton @singleton
class TTS(Blackbox): class TTS(Blackbox):
@ -88,11 +61,10 @@ class TTS(Blackbox):
self.melo_url = '' self.melo_url = ''
self.melo_mode = melo_config.mode self.melo_mode = melo_config.mode
self.melotts = None self.melotts = None
# self.speaker_ids = None self.speaker_ids = None
self.melo_speaker = None
if self.melo_mode == 'local': if self.melo_mode == 'local':
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device) self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language] self.speaker_ids = self.melotts.hps.data.spk2id
else: else:
self.melo_url = melo_config.url self.melo_url = melo_config.url
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
@ -109,52 +81,19 @@ class TTS(Blackbox):
self.cosyvoicetts = None self.cosyvoicetts = None
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device) # os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
else: else:
self.cosyvoice_url = cosyvoice_config.url self.cosyvoice_url = cosyvoice_config.url
logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...')
@logging_time(logger=logger)
def sovits_model_init(self, sovits_config: SovitsConf) -> None:
self.sovits_speed = sovits_config.speed
self.sovits_device = sovits_config.device
self.sovits_language = sovits_config.language
self.sovits_speaker = sovits_config.speaker
self.sovits_url = sovits_config.url
self.sovits_mode = sovits_config.mode
self.sovitstts = None
self.speaker_ids = None
self.sovits_text_lang = sovits_config.text_lang
self.sovits_ref_audio_path = sovits_config.ref_audio_path
self.sovits_prompt_lang = sovits_config.prompt_lang
self.sovits_prompt_text = sovits_config.prompt_text
self.sovits_text_split_method = sovits_config.text_split_method
self.sovits_batch_size = sovits_config.batch_size
self.sovits_media_type = sovits_config.media_type
self.sovits_streaming_mode = sovits_config.streaming_mode
if self.sovits_mode == 'local':
# self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device)
self.sovits_speaker = self.sovitstts.hps.data.spk2id
else:
self.sovits_url = sovits_config.url
logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...')
@inject @inject
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
self.tts_service = TTService("yunfeineo") self.tts_service = TTService("yunfeineo")
self.melo_model_init(melo_config) self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config) self.cosyvoice_model_init(cosyvoice_config)
self.sovits_model_init(sovits_config)
self.audio_dir = "audio_files" # 存储音频文件的目录
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -167,19 +106,14 @@ class TTS(Blackbox):
settings = {} settings = {}
user_model_name = settings.get("tts_model_name") user_model_name = settings.get("tts_model_name")
chroma_collection_id = settings.get("chroma_collection_id") chroma_collection_id = settings.get("chroma_collection_id")
user_stream = settings.get('tts_stream')
print(f"chroma_collection_id: {chroma_collection_id}")
print(f"tts_model_name: {user_model_name}") print(f"tts_model_name: {user_model_name}")
if user_stream in [None, ""]:
user_stream = True
text = args[0] text = args[0]
current_time = time.time() current_time = time.time()
if user_model_name == 'melotts': if user_model_name == 'melotts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None: if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.melo_mode == 'local': if self.melo_mode == 'local':
audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed) audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
f = io.BytesIO() f = io.BytesIO()
sf.write(f, audio, 44100, format='wav') sf.write(f, audio, 44100, format='wav')
f.seek(0) f.seek(0)
@ -213,7 +147,7 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -232,8 +166,7 @@ class TTS(Blackbox):
if chroma_collection_id == 'kiki' or chroma_collection_id is None: if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(56056558) set_all_random_seed(56056558)
print("*"*90) audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -250,7 +183,7 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -264,64 +197,10 @@ class TTS(Blackbox):
response = requests.post(self.cosyvoice_url, json=message) response = requests.post(self.cosyvoice_url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content return response.content
elif user_model_name == 'sovitstts':
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.sovits_mode == 'local':
set_all_random_seed(56056558)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
for i, j in enumerate(audio):
f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text,
"text_lang": self.sovits_text_lang,
"ref_audio_path": self.sovits_ref_audio_path,
"prompt_lang": self.sovits_prompt_lang,
"prompt_text": self.sovits_prompt_text,
"text_split_method": self.sovits_text_split_method,
"batch_size": self.sovits_batch_size,
"media_type": self.sovits_media_type,
"streaming_mode": self.sovits_streaming_mode
}
if user_stream:
response = requests.get(self.sovits_url, params=message, stream=True)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response
else:
response = requests.get(self.sovits_url, params=message)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response.content
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
# elif chroma_collection_id == 'boss':
# if self.cosyvoice_mode == 'local':
# set_all_random_seed(35616313)
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
# for i, j in enumerate(audio):
# f = io.BytesIO()
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
# f.seek(0)
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
# return f.read()
# else:
# message = {
# "text": text
# }
# response = requests.post(self.cosyvoice_url, json=message)
# print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
# return response.content
elif user_model_name == 'man': elif user_model_name == 'man':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -353,73 +232,7 @@ class TTS(Blackbox):
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
text = data.get("text") text = data.get("text")
setting = data.get("settings") setting = data.get("settings")
tts_model_name = setting.get("tts_model_name")
user_stream = setting.get("tts_stream")
if text is None: if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting) by = self.processing(text, settings=setting)
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
if user_stream:
if by.status_code == 200:
print("*"*90)
def audio_stream():
# 从上游服务器流式读取数据并逐块发送
for chunk in by.iter_content(chunk_size=1024):
if chunk:
yield chunk
return StreamingResponse(audio_stream(), media_type="audio/wav")
else:
raise HTTPException(status_code=by.status_code, detail="请求失败")
# if user_stream and tts_model_name == 'sovitstts':
# if by.status_code == 200:
# # 保存 WAV 文件
# wav_filename = os.path.join(self.audio_dir, 'audio.wav')
# with open(wav_filename, 'wb') as f:
# for chunk in by.iter_content(chunk_size=1024):
# if chunk:
# f.write(chunk)
# try:
# # 检查文件是否存在
# if not os.path.exists(wav_filename):
# return JSONResponse(content={"error": "File not found"}, status_code=404)
# # 转换 WAV 为 MP3
# mp3_filename = convert_wav_to_mp3(wav_filename)
# # 返回 WAV 和 MP3 文件的下载链接
# wav_url = f"/audio/{wav_filename}"
# mp3_url = f"/audio/{mp3_filename}"
# return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)
# else:
# raise HTTPException(status_code=by.status_code, detail="请求失败")
else:
wav_filename = os.path.join(self.audio_dir, 'audio.wav')
with open(wav_filename, 'wb') as f:
f.write(by)
try:
# 先检查文件是否存在
if not os.path.exists(wav_filename):
return JSONResponse(content={"error": "File not found"}, status_code=404)
# 转换 WAV 为 MP3
mp3_filename = convert_wav_to_mp3(wav_filename)
# 返回 MP3 文件的下载链接
wav_url = f"/audio/{wav_filename}"
mp3_url = f"/audio/{mp3_filename}"
return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)

View File

@ -82,32 +82,6 @@ class CosyVoiceConf():
self.language = config.get("cosyvoicetts.language") self.language = config.get("cosyvoicetts.language")
self.speaker = config.get("cosyvoicetts.speaker") self.speaker = config.get("cosyvoicetts.speaker")
class SovitsConf():
mode: str
url: str
speed: int
device: str
language: str
speaker: str
@inject
def __init__(self, config: Configuration) -> None:
self.mode = config.get("sovitstts.mode")
self.url = config.get("sovitstts.url")
self.speed = config.get("sovitstts.speed")
self.device = config.get("sovitstts.device")
self.language = config.get("sovitstts.language")
self.speaker = config.get("sovitstts.speaker")
self.text_lang = config.get("sovitstts.text_lang")
self.ref_audio_path = config.get("sovitstts.ref_audio_path")
self.prompt_lang = config.get("sovitstts.prompt_lang")
self.prompt_text = config.get("sovitstts.prompt_text")
self.text_split_method = config.get("sovitstts.text_split_method")
self.batch_size = config.get("sovitstts.batch_size")
self.media_type = config.get("sovitstts.media_type")
self.streaming_mode = config.get("sovitstts.streaming_mode")
class SenseVoiceConf(): class SenseVoiceConf():
mode: str mode: str
url: str url: str