diff --git a/sample/chroma_client1.py b/sample/chroma_client1.py index edce182..4f798fe 100644 --- a/sample/chroma_client1.py +++ b/sample/chroma_client1.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) id = "g2e" #client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_en.py b/sample/chroma_client_en.py index d3e3de9..8ca506c 100644 --- a/sample/chroma_client_en.py +++ b/sample/chroma_client_en.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) id = "g2e_english" client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_query.py b/sample/chroma_client_query.py index cafbfcc..bcc3331 100644 --- a/sample/chroma_client_query.py +++ b/sample/chroma_client_query.py @@ -81,7 +81,7 @@ def get_all_files(folder_path): # # 加载embedding模型和chroma server # embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # id = "g2e" # client.delete_collection(id) @@ -103,7 +103,7 @@ def get_all_files(folder_path): # chroma 召回 from chromadb.utils import embedding_functions embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) collection = client.get_collection("g2e", embedding_function=embedding_model) print(collection.count()) @@ -148,7 +148,7 @@ print("time: ", time.time() - start_time) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_rerank.py b/sample/chroma_rerank.py index 3093c85..09be954 100644 --- a/sample/chroma_rerank.py +++ b/sample/chroma_rerank.py @@ -10,64 +10,64 @@ import time # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") -loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt") +loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] -embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) -id = "boss" -client.delete_collection(id) +id = "kiki" +# client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) -embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1") +embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) -reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1") +reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") -while True: - usr_question = input("\n 请输入问题: ") - # query it - time1 = time.time() - results = collection.query( - query_texts=[usr_question], - n_results=10, - ) - time2 = time.time() - print("query time: ", time2 - time1) +# while True: +# usr_question = input("\n 请输入问题: ") +# # query it +# time1 = time.time() +# results = collection.query( +# query_texts=[usr_question], +# n_results=10, +# ) +# time2 = time.time() +# print("query time: ", time2 - time1) - # print("query: ",usr_question) - # print("results: ",print(results["documents"][0])) +# # print("query: ",usr_question) +# # print("results: ",print(results["documents"][0])) - pairs = [[usr_question, doc] for doc in results["documents"][0]] - # print('\n',pairs) - scores = reranker_model.predict(pairs) +# pairs = [[usr_question, doc] for doc in results["documents"][0]] +# # print('\n',pairs) +# scores = reranker_model.predict(pairs) - #重新排列文件顺序: - print("New Ordering:") - i = 0 - final_result = '' - for o in np.argsort(scores)[::-1]: - if i == 3 or scores[o] < 0.5: - break - i += 1 - print(o+1) - print("Scores:", scores[o]) - print(results["documents"][0][o],'\n') - final_result += results["documents"][0][o] + '\n' +# #重新排列文件顺序: +# print("New Ordering:") +# i = 0 +# final_result = '' +# for o in np.argsort(scores)[::-1]: +# if i == 3 or scores[o] < 0.5: +# break +# i += 1 +# print(o+1) +# print("Scores:", scores[o]) +# print(results["documents"][0][o],'\n') +# final_result += results["documents"][0][o] + '\n' - print("\n final_result: ", final_result) - time3 = time.time() - print("rerank time: ", time3 - time2) \ No newline at end of file +# print("\n final_result: ", final_result) +# time3 = time.time() +# print("rerank time: ", time3 - time2) \ No newline at end of file diff --git a/server.py b/server.py index d32b58d..c5c7990 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ from typing import Union from fastapi import FastAPI, Request, status -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from src.blackbox.blackbox_factory import BlackboxFactory from fastapi.middleware.cors import CORSMiddleware @@ -21,6 +21,7 @@ injector = Injector() blackbox_factory = injector.get(BlackboxFactory) @app.post("/") +@app.get("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -29,3 +30,58 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No except ValueError: return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await box.fast_api_handler(request) + +# @app.get("/audio/{filename}") +# async def serve_audio(filename: str): +# import os +# # 确保文件存在 +# if os.path.exists(filename): +# with open(filename, 'rb') as f: +# audio_data = f.read() +# if filename.endswith(".mp3"): +# # 对于 MP3 文件,设置为 inline 以在浏览器中播放 +# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) +# elif filename.endswith(".wav"): +# # 对于 WAV 文件,设置为 inline 以在浏览器中播放 +# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) +# else: +# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) + +@app.get("/audio/audio_files/{filename}") +async def serve_audio(filename: str): + import os + import aiofiles + filename = os.path.join("audio_files", filename) + # 确保文件存在 + if os.path.exists(filename): + try: + # 使用 aiofiles 异步读取文件 + async with aiofiles.open(filename, 'rb') as f: + audio_data = await f.read() + + # 根据文件类型返回正确的音频响应 + if filename.endswith(".mp3"): + return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) + elif filename.endswith(".wav"): + return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) + else: + return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE) + except asyncio.CancelledError: + # 处理任务被取消的情况 + return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST) + except Exception as e: + # 捕获其他异常 + return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + else: + return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) + + +# @app.get("/audio/{filename}") +# async def serve_audio(filename: str): +# import os +# file_path = f"audio/{filename}" +# # 检查文件是否存在 +# if os.path.exists(file_path): +# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav") +# else: +# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404) \ No newline at end of file diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 0bdf5a9..d7ba14e 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -62,6 +62,11 @@ def tts_loader(): from .tts import TTS return Injector().get(TTS) +@model_loader(lazy=blackboxConf.lazyloading) +def chatpipeline_loader(): + from .chatpipeline import ChatPipeline + return Injector().get(ChatPipeline) + @model_loader(lazy=blackboxConf.lazyloading) def emotion_loader(): from .emotion import Emotion @@ -132,6 +137,7 @@ class BlackboxFactory: self.models["audio_to_text"] = audio_to_text_loader self.models["asr"] = asr_loader self.models["tts"] = tts_loader + self.models["chatpipeline"] = chatpipeline_loader self.models["sentiment_engine"] = sentiment_loader self.models["emotion"] = emotion_loader self.models["fastchat"] = fastchat_loader diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index b7d2527..4973e0d 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -60,6 +60,8 @@ class Chat(Blackbox): system_prompt = settings.get('system_prompt') user_prompt_template = settings.get('user_prompt_template') user_stream = settings.get('stream') + + llm_model = "vllm" if user_context == None: user_context = [] @@ -100,10 +102,16 @@ class Chat(Blackbox): #user_presence_penalty = 0.8 if user_model_url is None or user_model_url.isspace() or user_model_url == "": - user_model_url = "http://10.6.80.75:23333/v1/chat/completions" + if llm_model != "vllm": + user_model_url = "http://10.6.80.75:23333/v1/chat/completions" + else: + user_model_url = "http://10.6.80.94:8000/v1/completions" if user_model_key is None or user_model_key.isspace() or user_model_key == "": - user_model_key = "YOUR_API_KEY" + if llm_model != "vllm": + user_model_key = "YOUR_API_KEY" + else: + user_model_key = "vllm" if chroma_embedding_model: chroma_response = self.chroma_query(user_question, settings) @@ -117,7 +125,10 @@ class Chat(Blackbox): print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}") user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "。" else: - user_question = user_prompt_template + "问题: " + user_question + "。" + if llm_model != "vllm": + user_question = user_prompt_template + "问题: " + user_question + "。" + else: + user_question = user_question print(f"1.user_question: {user_question}") @@ -172,10 +183,17 @@ class Chat(Blackbox): else: url = user_model_url key = user_model_key - header = { - 'Content-Type': 'application/json', - "Cache-Control": "no-cache", # 禁用缓存 - } + if llm_model != "vllm": + header = { + 'Content-Type': 'application/json', + "Cache-Control": "no-cache", # 禁用缓存 + } + else: + header = { + 'Content-Type': 'application/json', + 'Authorization': "Bearer " + key, + "Cache-Control": "no-cache", + } # system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程**: 严格遵循中设定的工作流程。\n- **无内置知识库** :根据中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**:\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关,直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问:“中国的首都是哪个城市?” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n" @@ -183,23 +201,37 @@ class Chat(Blackbox): {"role": "system", "content": system_prompt} ] - chat_inputs={ - "model": user_model_name, - "messages": prompt_template + user_context + [ - { - "role": "user", - "content": user_question - } - ], - "temperature": str(user_temperature), - "top_p": str(user_top_p), - "n": str(user_n), - "max_tokens": str(user_max_tokens), - "frequency_penalty": str(user_frequency_penalty), - "presence_penalty": str(user_presence_penalty), - "stop": str(user_stop), - "stream": user_stream, - } + if llm_model != "vllm": + chat_inputs={ + "model": user_model_name, + "messages": prompt_template + user_context + [ + { + "role": "user", + "content": user_question + } + ], + "temperature": str(user_temperature), + "top_p": str(user_top_p), + "n": str(user_n), + "max_tokens": str(user_max_tokens), + "frequency_penalty": str(user_frequency_penalty), + "presence_penalty": str(user_presence_penalty), + "stop": str(user_stop), + "stream": user_stream, + } + else: + chat_inputs={ + "model": user_model_name, + "prompt": user_question, + "temperature": float(user_temperature), + "top_p": float(user_top_p), + "n": float(user_n), + "max_tokens": float(user_max_tokens), + "frequency_penalty": float(user_frequency_penalty), + "presence_penalty":float( user_presence_penalty), + # "stop": user_stop, + "stream": user_stream, + } # # 获取当前时间戳 # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -252,9 +284,14 @@ class Chat(Blackbox): if response_result.get("choices") is None: yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST) else: - print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n") - yield fastchat_response.json()["choices"][0]["message"]["content"] - + if llm_model != "vllm": + print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n") + yield fastchat_response.json()["choices"][0]["message"]["content"] + else: + print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["text"],"\n\n") + yield fastchat_response.json()["choices"][0]["text"] + + async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() diff --git a/src/blackbox/chatpipeline.py b/src/blackbox/chatpipeline.py new file mode 100644 index 0000000..ff3c979 --- /dev/null +++ b/src/blackbox/chatpipeline.py @@ -0,0 +1,270 @@ +import re +import requests +import json +import queue +import threading +from fastapi import FastAPI +from fastapi.responses import StreamingResponse + +import os +import io +import time +import requests +from fastapi import Request, Response, status, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from .blackbox import Blackbox + +from injector import inject +from injector import singleton + +from ..log.logging_time import logging_time +import logging +logger = logging.getLogger(__name__) + +import asyncio +from uuid import uuid4 # 用于生成唯一的会话 ID + +@singleton +class ChatPipeline(Blackbox): + + @inject + def __init__(self, ) -> None: + self.text_queue = queue.Queue() # 文本队列 + self.audio_queue = queue.Queue() # 音频队列 + self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号 + self.tts_event = threading.Event() # TTS 事件 + self.audio_part_counter = 0 # 音频段计数器 + self.text_part_counter = 0 + self.audio_dir = "audio_files" # 存储音频文件的目录 + self.is_last = False + + self.settings = {} # 外部传入的 settings + self.lock = threading.Lock() # 创建锁 + + if not os.path.exists(self.audio_dir): + os.makedirs(self.audio_dir) + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, data: any) -> bool: + if isinstance(data, bytes): + return True + return False + + def reset_queues(self): + """清空之前的队列数据""" + self.text_queue.queue.clear() # 清空文本队列 + self.audio_queue.queue.clear() # 清空音频队列 + + def clear_audio_files(self): + """清空音频文件夹中的所有音频文件""" + for file_name in os.listdir(self.audio_dir): + file_path = os.path.join(self.audio_dir, file_name) + if os.path.isfile(file_path): + os.remove(file_path) + print(f"Removed old audio file: {file_path}") + + def save_audio(self, audio_data, part_number: int): + """保存音频数据为文件""" + file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav") + with open(file_name, 'wb') as f: + f.write(audio_data) + return file_name + + def chat_stream(self, prompt: str, settings: dict): + """从 chat.py 获取实时生成的文本,并放入队列""" + url = 'http://10.6.44.141:8000/?blackbox_name=chat' + headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存} + data = { + "prompt": prompt, + "context": [], + "settings": settings + } + print(f"data_chat: {data}") + + # 每次执行时清空原有音频文件 + self.clear_audio_files() + self.audio_part_counter = 0 + self.text_part_counter = 0 + self.is_last = False + with self.lock: # 确保对 settings 的访问是线程安全的 + llm_stream = settings.get("stream") + if llm_stream: + with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: + print(f"data_chat1: {data}") + complete_message = "" # 用于累积完整的文本 + lines = list(response.iter_lines()) # 先将所有行读取到一个列表中 + total_lines = len(lines) + for i, line in enumerate(lines): + if line: + message = line.decode('utf-8') + + if message.strip().lower() == "data:": + continue # 跳过"data:"行 + + complete_message += message + + # 如果包含标点符号,拆分成句子 + if re.search(self.PUNCTUATION, complete_message): + sentences = re.split(self.PUNCTUATION, complete_message) + for sentence in sentences[:-1]: + cleaned_sentence = self.filter_invalid_chars(sentence.strip()) + if cleaned_sentence: + print(f"Sending complete sentence: {cleaned_sentence}") + self.text_queue.put(cleaned_sentence) # 放入文本队列 + complete_message = sentences[-1] + self.text_part_counter += 1 + # 判断是否是最后一句 + if i == total_lines - 2: # 如果是最后一行 + self.is_last = True + print(f'2.is_last: {self.is_last}') + + time.sleep(0.2) + else: + with requests.post(url, headers=headers, data=json.dumps(data)) as response: + print(f"data_chat1: {data}") + if response.status_code == 200: + response_json = response.json() + response_content = response_json.get("response") + self.text_queue.put(response_content) + + + def send_to_tts(self, settings: dict): + """从队列中获取文本并发送给 tts.py 进行语音合成""" + url = 'http://10.6.44.141:8000/?blackbox_name=tts' + headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存} + with self.lock: + user_stream = settings.get("tts_stream") + tts_model_name = settings.get("tts_model_name") + print(f"data_tts0: {settings}") + while True: + try: + # 获取队列中的一个完整句子 + text = self.text_queue.get(timeout=5) + + if text is None: + break + + if not text.strip(): + continue + + if tts_model_name == 'sovitstts': + text = self.filter_invalid_chars(text) + print(f"data_tts0.1: {settings}") + data = { + "settings": settings, + "text": text + } + print(f"data_tts1: {data}") + if user_stream: + # 发送请求到 TTS 服务 + response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) + + if response.status_code == 200: + audio_data = response.content + if isinstance(audio_data, bytes): + self.audio_part_counter += 1 # 增加音频段计数器 + file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 + print(f"Audio part saved as {file_name}") + + # 将文件名和是否是最后一条消息放入音频队列 + self.audio_queue.put(file_name) # 放入音频队列 + + else: + print(f"Error: Received non-binary data.") + else: + print(f"Failed to send to TTS: {response.status_code}, Text: {text}") + + else: + print(f"data_tts2: {data}") + response = requests.post(url, headers=headers, data=json.dumps(data)) + if response.status_code == 200: + self.audio_queue.put(response.content) + + # 通知下一个 TTS 可以执行了 + self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 + time.sleep(0.2) + + except queue.Empty: + time.sleep(1) + + def filter_invalid_chars(self,text): + """过滤无效字符(包括字节流)""" + invalid_keywords = ["data:", "\n", "\r", "\t", " "] + + if isinstance(text, bytes): + text = text.decode('utf-8', errors='ignore') + + for keyword in invalid_keywords: + text = text.replace(keyword, "") + + # 移除所有英文字母和符号(保留中文、标点等) + text = re.sub(r'[a-zA-Z]', '', text) + + return text.strip() + + + @logging_time(logger=logger) + def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: + + # 启动聊天流线程 + threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start() + + # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 + threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() + + return {"message": "Chat and TTS processing started"} + + # 新增异步方法,等待音频片段生成 + async def wait_for_audio(self): + while self.audio_queue.empty(): + await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环 + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + text = data.get("text") + setting = data.get("settings") + user_stream = setting.get("tts_stream") + self.is_last = False + self.audio_part_counter = 0 # 音频段计数器 + self.text_part_counter = 0 + self.reset_queues() + self.clear_audio_files() + print(f"data0: {data}") + if text is None: + return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + # 调用 processing 方法,并传递动态的 text 参数 + response_data = self.processing(text, settings=setting) + # 根据是否启用流式传输进行处理 + if user_stream: + # 等待至少一个音频片段生成完成 + await self.wait_for_audio() + def audio_stream(): + # 从上游服务器流式读取数据并逐块发送 + while self.audio_part_counter != 0 and not self.is_last: + audio = self.audio_queue.get() + audio_file = audio + if audio_file: + with open(audio_file, "rb") as f: + print(f"Sending audio file: {audio_file}") + yield f.read() # 分段发送音频文件内容 + return StreamingResponse(audio_stream(), media_type="audio/wav") + + else: + # 如果没有启用流式传输,可以返回一个完整的响应或音频文件 + await self.wait_for_audio() + file_name = self.audio_queue.get() + if file_name: + file_name_json = json.loads(file_name.decode('utf-8')) + # audio_files = [] + # while not self.audio_queue.empty(): + # print("9") + # audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 + # 返回多个音频文件 + return JSONResponse(content=file_name_json) + + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index ba4767a..dff10ce 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -25,7 +25,7 @@ class ChromaQuery(Blackbox): self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") - self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) + self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") @@ -60,7 +60,7 @@ class ChromaQuery(Blackbox): chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": - chroma_host = "192.168.0.200" + chroma_host = "10.6.44.141" if chroma_port is None or chroma_port.isspace() or chroma_port == "": chroma_port = "7000" @@ -72,7 +72,7 @@ class ChromaQuery(Blackbox): chroma_n_results = 10 # load client and embedding model from init - if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port): + if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: try: diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index c07ba7a..f266e0a 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox): # load embedding model self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) + self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/g2e.py b/src/blackbox/g2e.py index d43bcc0..b4ec980 100755 --- a/src/blackbox/g2e.py +++ b/src/blackbox/g2e.py @@ -23,7 +23,7 @@ class G2E(Blackbox): if context == None: context = [] #url = 'http://120.196.116.194:48890/v1' - url = 'http://192.168.0.200:23333/v1' + url = 'http://10.6.44.141:23333/v1' background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯 diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 247fd02..7446190 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -3,18 +3,20 @@ import time from ntpath import join import requests -from fastapi import Request, Response, status -from fastapi.responses import JSONResponse +from fastapi import Request, Response, status, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse from .blackbox import Blackbox from ..tts.tts_service import TTService from ..configuration import MeloConf from ..configuration import CosyVoiceConf +from ..configuration import SovitsConf from injector import inject from injector import singleton import sys,os sys.path.append('/Workspace/CosyVoice') -from cosyvoice.cli.cosyvoice import CosyVoice +sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 # from cosyvoice.utils.file_utils import load_wav, speed_change import soundfile as sf @@ -29,13 +31,38 @@ import random import torch import numpy as np +from pydub import AudioSegment +import subprocess + def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) +def convert_wav_to_mp3(wav_filename: str): + # 检查文件是否是 .wav 格式 + if not wav_filename.lower().endswith(".wav"): + raise ValueError("The input file must be a .wav file") + # 构建对应的 .mp3 文件路径 + mp3_filename = wav_filename.replace(".wav", ".mp3") + + # 如果原 MP3 文件已存在,先删除 + if os.path.exists(mp3_filename): + os.remove(mp3_filename) + + # 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件 + command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件 + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # 检查转换是否成功 + if result.returncode != 0: + raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}") + + return mp3_filename + + @singleton class TTS(Blackbox): melo_mode: str @@ -61,10 +88,11 @@ class TTS(Blackbox): self.melo_url = '' self.melo_mode = melo_config.mode self.melotts = None - self.speaker_ids = None + # self.speaker_ids = None + self.melo_speaker = None if self.melo_mode == 'local': self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device) - self.speaker_ids = self.melotts.hps.data.spk2id + self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language] else: self.melo_url = melo_config.url logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') @@ -81,19 +109,52 @@ class TTS(Blackbox): self.cosyvoicetts = None # os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device) if self.cosyvoice_mode == 'local': - self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') + # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') + self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') + # self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) + else: self.cosyvoice_url = cosyvoice_config.url logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') + @logging_time(logger=logger) + def sovits_model_init(self, sovits_config: SovitsConf) -> None: + self.sovits_speed = sovits_config.speed + self.sovits_device = sovits_config.device + self.sovits_language = sovits_config.language + self.sovits_speaker = sovits_config.speaker + self.sovits_url = sovits_config.url + self.sovits_mode = sovits_config.mode + self.sovitstts = None + self.speaker_ids = None + + self.sovits_text_lang = sovits_config.text_lang + self.sovits_ref_audio_path = sovits_config.ref_audio_path + self.sovits_prompt_lang = sovits_config.prompt_lang + self.sovits_prompt_text = sovits_config.prompt_text + self.sovits_text_split_method = sovits_config.text_split_method + self.sovits_batch_size = sovits_config.batch_size + self.sovits_media_type = sovits_config.media_type + self.sovits_streaming_mode = sovits_config.streaming_mode + + if self.sovits_mode == 'local': + # self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device) + self.sovits_speaker = self.sovitstts.hps.data.spk2id + else: + self.sovits_url = sovits_config.url + logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') + print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') + + @inject - def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None: + def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: self.tts_service = TTService("yunfeineo") self.melo_model_init(melo_config) self.cosyvoice_model_init(cosyvoice_config) - + self.sovits_model_init(sovits_config) + self.audio_dir = "audio_files" # 存储音频文件的目录 def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -106,14 +167,19 @@ class TTS(Blackbox): settings = {} user_model_name = settings.get("tts_model_name") chroma_collection_id = settings.get("chroma_collection_id") + user_stream = settings.get('tts_stream') + print(f"chroma_collection_id: {chroma_collection_id}") print(f"tts_model_name: {user_model_name}") + if user_stream in [None, ""]: + user_stream = True + text = args[0] current_time = time.time() if user_model_name == 'melotts': if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.melo_mode == 'local': - audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed) + audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed) f = io.BytesIO() sf.write(f, audio, 44100, format='wav') f.seek(0) @@ -147,7 +213,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -166,7 +232,8 @@ class TTS(Blackbox): if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.cosyvoice_mode == 'local': set_all_random_seed(56056558) - audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language) + print("*"*90) + audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -183,7 +250,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -196,11 +263,65 @@ class TTS(Blackbox): } response = requests.post(self.cosyvoice_url, json=message) print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) - return response.content + return response.content + + elif user_model_name == 'sovitstts': + if chroma_collection_id == 'kiki' or chroma_collection_id is None: + if self.sovits_mode == 'local': + set_all_random_seed(56056558) + audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) + print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) + return f.read() + else: + message = { + "text": text, + "text_lang": self.sovits_text_lang, + "ref_audio_path": self.sovits_ref_audio_path, + "prompt_lang": self.sovits_prompt_lang, + "prompt_text": self.sovits_prompt_text, + "text_split_method": self.sovits_text_split_method, + "batch_size": self.sovits_batch_size, + "media_type": self.sovits_media_type, + "streaming_mode": self.sovits_streaming_mode + } + if user_stream: + response = requests.get(self.sovits_url, params=message, stream=True) + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) + return response + else: + response = requests.get(self.sovits_url, params=message) + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) + return response.content + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) + + + # elif chroma_collection_id == 'boss': + # if self.cosyvoice_mode == 'local': + # set_all_random_seed(35616313) + # audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) + # for i, j in enumerate(audio): + # f = io.BytesIO() + # sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + # f.seek(0) + # print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) + # return f.read() + # else: + # message = { + # "text": text + # } + # response = requests.post(self.cosyvoice_url, json=message) + # print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) + # return response.content + + elif user_model_name == 'man': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -232,7 +353,73 @@ class TTS(Blackbox): return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) text = data.get("text") setting = data.get("settings") + tts_model_name = setting.get("tts_model_name") + user_stream = setting.get("tts_stream") + if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) by = self.processing(text, settings=setting) - return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file + # return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) + + if user_stream: + if by.status_code == 200: + print("*"*90) + def audio_stream(): + # 从上游服务器流式读取数据并逐块发送 + for chunk in by.iter_content(chunk_size=1024): + if chunk: + yield chunk + + return StreamingResponse(audio_stream(), media_type="audio/wav") + else: + raise HTTPException(status_code=by.status_code, detail="请求失败") + + # if user_stream and tts_model_name == 'sovitstts': + # if by.status_code == 200: + # # 保存 WAV 文件 + # wav_filename = os.path.join(self.audio_dir, 'audio.wav') + # with open(wav_filename, 'wb') as f: + # for chunk in by.iter_content(chunk_size=1024): + # if chunk: + # f.write(chunk) + + # try: + # # 检查文件是否存在 + # if not os.path.exists(wav_filename): + # return JSONResponse(content={"error": "File not found"}, status_code=404) + + # # 转换 WAV 为 MP3 + # mp3_filename = convert_wav_to_mp3(wav_filename) + + # # 返回 WAV 和 MP3 文件的下载链接 + # wav_url = f"/audio/{wav_filename}" + # mp3_url = f"/audio/{mp3_filename}" + # return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) + + # except Exception as e: + # return JSONResponse(content={"error": str(e)}, status_code=500) + + # else: + # raise HTTPException(status_code=by.status_code, detail="请求失败") + + + else: + wav_filename = os.path.join(self.audio_dir, 'audio.wav') + with open(wav_filename, 'wb') as f: + f.write(by) + + try: + # 先检查文件是否存在 + if not os.path.exists(wav_filename): + return JSONResponse(content={"error": "File not found"}, status_code=404) + + # 转换 WAV 为 MP3 + mp3_filename = convert_wav_to_mp3(wav_filename) + + # 返回 MP3 文件的下载链接 + wav_url = f"/audio/{wav_filename}" + mp3_url = f"/audio/{mp3_filename}" + return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) + + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) \ No newline at end of file diff --git a/src/configuration.py b/src/configuration.py index 0c6e0cb..bac66ea 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -82,6 +82,32 @@ class CosyVoiceConf(): self.language = config.get("cosyvoicetts.language") self.speaker = config.get("cosyvoicetts.speaker") +class SovitsConf(): + mode: str + url: str + speed: int + device: str + language: str + speaker: str + + @inject + def __init__(self, config: Configuration) -> None: + self.mode = config.get("sovitstts.mode") + self.url = config.get("sovitstts.url") + self.speed = config.get("sovitstts.speed") + self.device = config.get("sovitstts.device") + self.language = config.get("sovitstts.language") + self.speaker = config.get("sovitstts.speaker") + + self.text_lang = config.get("sovitstts.text_lang") + self.ref_audio_path = config.get("sovitstts.ref_audio_path") + self.prompt_lang = config.get("sovitstts.prompt_lang") + self.prompt_text = config.get("sovitstts.prompt_text") + self.text_split_method = config.get("sovitstts.text_split_method") + self.batch_size = config.get("sovitstts.batch_size") + self.media_type = config.get("sovitstts.media_type") + self.streaming_mode = config.get("sovitstts.streaming_mode") + class SenseVoiceConf(): mode: str url: str