diff --git a/sample/chroma_client1.py b/sample/chroma_client1.py index 4f798fe..edce182 100644 --- a/sample/chroma_client1.py +++ b/sample/chroma_client1.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +client = chromadb.HttpClient(host='192.168.0.200', port=7000) id = "g2e" #client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='10.6.44.141', port=7000) +# client = chromadb.HttpClient(host='192.168.0.200', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://10.6.44.141:23333/v1/chat/completions" +# url = "http://192.168.0.200:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_en.py b/sample/chroma_client_en.py index 8ca506c..d3e3de9 100644 --- a/sample/chroma_client_en.py +++ b/sample/chroma_client_en.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +client = chromadb.HttpClient(host='192.168.0.200', port=7000) id = "g2e_english" client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='10.6.44.141', port=7000) +# client = chromadb.HttpClient(host='192.168.0.200', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://10.6.44.141:23333/v1/chat/completions" +# url = "http://192.168.0.200:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_query.py b/sample/chroma_client_query.py index bcc3331..cafbfcc 100644 --- a/sample/chroma_client_query.py +++ b/sample/chroma_client_query.py @@ -81,7 +81,7 @@ def get_all_files(folder_path): # # 加载embedding模型和chroma server # embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -# client = chromadb.HttpClient(host='10.6.44.141', port=7000) +# client = chromadb.HttpClient(host='192.168.0.200', port=7000) # id = "g2e" # client.delete_collection(id) @@ -103,7 +103,7 @@ def get_all_files(folder_path): # chroma 召回 from chromadb.utils import embedding_functions embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +client = chromadb.HttpClient(host='192.168.0.200', port=7000) collection = client.get_collection("g2e", embedding_function=embedding_model) print(collection.count()) @@ -148,7 +148,7 @@ print("time: ", time.time() - start_time) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://10.6.44.141:23333/v1/chat/completions" +# url = "http://192.168.0.200:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_rerank.py b/sample/chroma_rerank.py index 09be954..3093c85 100644 --- a/sample/chroma_rerank.py +++ b/sample/chroma_rerank.py @@ -10,64 +10,64 @@ import time # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") -loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") +loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] -embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"}) +client = chromadb.HttpClient(host='192.168.0.200', port=7000) -id = "kiki" -# client.delete_collection(id) +id = "boss" +client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) -embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") +embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1") -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +client = chromadb.HttpClient(host='192.168.0.200', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) -reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") +reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1") -# while True: -# usr_question = input("\n 请输入问题: ") -# # query it -# time1 = time.time() -# results = collection.query( -# query_texts=[usr_question], -# n_results=10, -# ) -# time2 = time.time() -# print("query time: ", time2 - time1) +while True: + usr_question = input("\n 请输入问题: ") + # query it + time1 = time.time() + results = collection.query( + query_texts=[usr_question], + n_results=10, + ) + time2 = time.time() + print("query time: ", time2 - time1) -# # print("query: ",usr_question) -# # print("results: ",print(results["documents"][0])) + # print("query: ",usr_question) + # print("results: ",print(results["documents"][0])) -# pairs = [[usr_question, doc] for doc in results["documents"][0]] -# # print('\n',pairs) -# scores = reranker_model.predict(pairs) + pairs = [[usr_question, doc] for doc in results["documents"][0]] + # print('\n',pairs) + scores = reranker_model.predict(pairs) -# #重新排列文件顺序: -# print("New Ordering:") -# i = 0 -# final_result = '' -# for o in np.argsort(scores)[::-1]: -# if i == 3 or scores[o] < 0.5: -# break -# i += 1 -# print(o+1) -# print("Scores:", scores[o]) -# print(results["documents"][0][o],'\n') -# final_result += results["documents"][0][o] + '\n' + #重新排列文件顺序: + print("New Ordering:") + i = 0 + final_result = '' + for o in np.argsort(scores)[::-1]: + if i == 3 or scores[o] < 0.5: + break + i += 1 + print(o+1) + print("Scores:", scores[o]) + print(results["documents"][0][o],'\n') + final_result += results["documents"][0][o] + '\n' -# print("\n final_result: ", final_result) -# time3 = time.time() -# print("rerank time: ", time3 - time2) \ No newline at end of file + print("\n final_result: ", final_result) + time3 = time.time() + print("rerank time: ", time3 - time2) \ No newline at end of file diff --git a/server.py b/server.py index c5c7990..d32b58d 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ from typing import Union from fastapi import FastAPI, Request, status -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import JSONResponse from src.blackbox.blackbox_factory import BlackboxFactory from fastapi.middleware.cors import CORSMiddleware @@ -21,7 +21,6 @@ injector = Injector() blackbox_factory = injector.get(BlackboxFactory) @app.post("/") -@app.get("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -30,58 +29,3 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No except ValueError: return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await box.fast_api_handler(request) - -# @app.get("/audio/{filename}") -# async def serve_audio(filename: str): -# import os -# # 确保文件存在 -# if os.path.exists(filename): -# with open(filename, 'rb') as f: -# audio_data = f.read() -# if filename.endswith(".mp3"): -# # 对于 MP3 文件,设置为 inline 以在浏览器中播放 -# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) -# elif filename.endswith(".wav"): -# # 对于 WAV 文件,设置为 inline 以在浏览器中播放 -# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) -# else: -# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) - -@app.get("/audio/audio_files/{filename}") -async def serve_audio(filename: str): - import os - import aiofiles - filename = os.path.join("audio_files", filename) - # 确保文件存在 - if os.path.exists(filename): - try: - # 使用 aiofiles 异步读取文件 - async with aiofiles.open(filename, 'rb') as f: - audio_data = await f.read() - - # 根据文件类型返回正确的音频响应 - if filename.endswith(".mp3"): - return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) - elif filename.endswith(".wav"): - return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) - else: - return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE) - except asyncio.CancelledError: - # 处理任务被取消的情况 - return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST) - except Exception as e: - # 捕获其他异常 - return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - else: - return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) - - -# @app.get("/audio/{filename}") -# async def serve_audio(filename: str): -# import os -# file_path = f"audio/{filename}" -# # 检查文件是否存在 -# if os.path.exists(file_path): -# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav") -# else: -# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404) \ No newline at end of file diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index d7ba14e..0bdf5a9 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -62,11 +62,6 @@ def tts_loader(): from .tts import TTS return Injector().get(TTS) -@model_loader(lazy=blackboxConf.lazyloading) -def chatpipeline_loader(): - from .chatpipeline import ChatPipeline - return Injector().get(ChatPipeline) - @model_loader(lazy=blackboxConf.lazyloading) def emotion_loader(): from .emotion import Emotion @@ -137,7 +132,6 @@ class BlackboxFactory: self.models["audio_to_text"] = audio_to_text_loader self.models["asr"] = asr_loader self.models["tts"] = tts_loader - self.models["chatpipeline"] = chatpipeline_loader self.models["sentiment_engine"] = sentiment_loader self.models["emotion"] = emotion_loader self.models["fastchat"] = fastchat_loader diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 4973e0d..b7d2527 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -60,8 +60,6 @@ class Chat(Blackbox): system_prompt = settings.get('system_prompt') user_prompt_template = settings.get('user_prompt_template') user_stream = settings.get('stream') - - llm_model = "vllm" if user_context == None: user_context = [] @@ -102,16 +100,10 @@ class Chat(Blackbox): #user_presence_penalty = 0.8 if user_model_url is None or user_model_url.isspace() or user_model_url == "": - if llm_model != "vllm": - user_model_url = "http://10.6.80.75:23333/v1/chat/completions" - else: - user_model_url = "http://10.6.80.94:8000/v1/completions" + user_model_url = "http://10.6.80.75:23333/v1/chat/completions" if user_model_key is None or user_model_key.isspace() or user_model_key == "": - if llm_model != "vllm": - user_model_key = "YOUR_API_KEY" - else: - user_model_key = "vllm" + user_model_key = "YOUR_API_KEY" if chroma_embedding_model: chroma_response = self.chroma_query(user_question, settings) @@ -125,10 +117,7 @@ class Chat(Blackbox): print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}") user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "。" else: - if llm_model != "vllm": - user_question = user_prompt_template + "问题: " + user_question + "。" - else: - user_question = user_question + user_question = user_prompt_template + "问题: " + user_question + "。" print(f"1.user_question: {user_question}") @@ -183,17 +172,10 @@ class Chat(Blackbox): else: url = user_model_url key = user_model_key - if llm_model != "vllm": - header = { - 'Content-Type': 'application/json', - "Cache-Control": "no-cache", # 禁用缓存 - } - else: - header = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + key, - "Cache-Control": "no-cache", - } + header = { + 'Content-Type': 'application/json', + "Cache-Control": "no-cache", # 禁用缓存 + } # system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程**: 严格遵循中设定的工作流程。\n- **无内置知识库** :根据中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**:\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关,直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问:“中国的首都是哪个城市?” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n" @@ -201,37 +183,23 @@ class Chat(Blackbox): {"role": "system", "content": system_prompt} ] - if llm_model != "vllm": - chat_inputs={ - "model": user_model_name, - "messages": prompt_template + user_context + [ - { - "role": "user", - "content": user_question - } - ], - "temperature": str(user_temperature), - "top_p": str(user_top_p), - "n": str(user_n), - "max_tokens": str(user_max_tokens), - "frequency_penalty": str(user_frequency_penalty), - "presence_penalty": str(user_presence_penalty), - "stop": str(user_stop), - "stream": user_stream, - } - else: - chat_inputs={ - "model": user_model_name, - "prompt": user_question, - "temperature": float(user_temperature), - "top_p": float(user_top_p), - "n": float(user_n), - "max_tokens": float(user_max_tokens), - "frequency_penalty": float(user_frequency_penalty), - "presence_penalty":float( user_presence_penalty), - # "stop": user_stop, - "stream": user_stream, - } + chat_inputs={ + "model": user_model_name, + "messages": prompt_template + user_context + [ + { + "role": "user", + "content": user_question + } + ], + "temperature": str(user_temperature), + "top_p": str(user_top_p), + "n": str(user_n), + "max_tokens": str(user_max_tokens), + "frequency_penalty": str(user_frequency_penalty), + "presence_penalty": str(user_presence_penalty), + "stop": str(user_stop), + "stream": user_stream, + } # # 获取当前时间戳 # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -284,14 +252,9 @@ class Chat(Blackbox): if response_result.get("choices") is None: yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST) else: - if llm_model != "vllm": - print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n") - yield fastchat_response.json()["choices"][0]["message"]["content"] - else: - print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["text"],"\n\n") - yield fastchat_response.json()["choices"][0]["text"] - - + print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n") + yield fastchat_response.json()["choices"][0]["message"]["content"] + async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() diff --git a/src/blackbox/chatpipeline.py b/src/blackbox/chatpipeline.py deleted file mode 100644 index ff3c979..0000000 --- a/src/blackbox/chatpipeline.py +++ /dev/null @@ -1,270 +0,0 @@ -import re -import requests -import json -import queue -import threading -from fastapi import FastAPI -from fastapi.responses import StreamingResponse - -import os -import io -import time -import requests -from fastapi import Request, Response, status, HTTPException -from fastapi.responses import JSONResponse, StreamingResponse -from .blackbox import Blackbox - -from injector import inject -from injector import singleton - -from ..log.logging_time import logging_time -import logging -logger = logging.getLogger(__name__) - -import asyncio -from uuid import uuid4 # 用于生成唯一的会话 ID - -@singleton -class ChatPipeline(Blackbox): - - @inject - def __init__(self, ) -> None: - self.text_queue = queue.Queue() # 文本队列 - self.audio_queue = queue.Queue() # 音频队列 - self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号 - self.tts_event = threading.Event() # TTS 事件 - self.audio_part_counter = 0 # 音频段计数器 - self.text_part_counter = 0 - self.audio_dir = "audio_files" # 存储音频文件的目录 - self.is_last = False - - self.settings = {} # 外部传入的 settings - self.lock = threading.Lock() # 创建锁 - - if not os.path.exists(self.audio_dir): - os.makedirs(self.audio_dir) - - def __call__(self, *args, **kwargs): - return self.processing(*args, **kwargs) - - def valid(self, data: any) -> bool: - if isinstance(data, bytes): - return True - return False - - def reset_queues(self): - """清空之前的队列数据""" - self.text_queue.queue.clear() # 清空文本队列 - self.audio_queue.queue.clear() # 清空音频队列 - - def clear_audio_files(self): - """清空音频文件夹中的所有音频文件""" - for file_name in os.listdir(self.audio_dir): - file_path = os.path.join(self.audio_dir, file_name) - if os.path.isfile(file_path): - os.remove(file_path) - print(f"Removed old audio file: {file_path}") - - def save_audio(self, audio_data, part_number: int): - """保存音频数据为文件""" - file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav") - with open(file_name, 'wb') as f: - f.write(audio_data) - return file_name - - def chat_stream(self, prompt: str, settings: dict): - """从 chat.py 获取实时生成的文本,并放入队列""" - url = 'http://10.6.44.141:8000/?blackbox_name=chat' - headers = {'Content-Type': 'text/plain',"Cache-Control": "no-cache",}# 禁用缓存} - data = { - "prompt": prompt, - "context": [], - "settings": settings - } - print(f"data_chat: {data}") - - # 每次执行时清空原有音频文件 - self.clear_audio_files() - self.audio_part_counter = 0 - self.text_part_counter = 0 - self.is_last = False - with self.lock: # 确保对 settings 的访问是线程安全的 - llm_stream = settings.get("stream") - if llm_stream: - with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: - print(f"data_chat1: {data}") - complete_message = "" # 用于累积完整的文本 - lines = list(response.iter_lines()) # 先将所有行读取到一个列表中 - total_lines = len(lines) - for i, line in enumerate(lines): - if line: - message = line.decode('utf-8') - - if message.strip().lower() == "data:": - continue # 跳过"data:"行 - - complete_message += message - - # 如果包含标点符号,拆分成句子 - if re.search(self.PUNCTUATION, complete_message): - sentences = re.split(self.PUNCTUATION, complete_message) - for sentence in sentences[:-1]: - cleaned_sentence = self.filter_invalid_chars(sentence.strip()) - if cleaned_sentence: - print(f"Sending complete sentence: {cleaned_sentence}") - self.text_queue.put(cleaned_sentence) # 放入文本队列 - complete_message = sentences[-1] - self.text_part_counter += 1 - # 判断是否是最后一句 - if i == total_lines - 2: # 如果是最后一行 - self.is_last = True - print(f'2.is_last: {self.is_last}') - - time.sleep(0.2) - else: - with requests.post(url, headers=headers, data=json.dumps(data)) as response: - print(f"data_chat1: {data}") - if response.status_code == 200: - response_json = response.json() - response_content = response_json.get("response") - self.text_queue.put(response_content) - - - def send_to_tts(self, settings: dict): - """从队列中获取文本并发送给 tts.py 进行语音合成""" - url = 'http://10.6.44.141:8000/?blackbox_name=tts' - headers = {'Content-Type': 'text/plain', "Cache-Control": "no-cache",} # 禁用缓存} - with self.lock: - user_stream = settings.get("tts_stream") - tts_model_name = settings.get("tts_model_name") - print(f"data_tts0: {settings}") - while True: - try: - # 获取队列中的一个完整句子 - text = self.text_queue.get(timeout=5) - - if text is None: - break - - if not text.strip(): - continue - - if tts_model_name == 'sovitstts': - text = self.filter_invalid_chars(text) - print(f"data_tts0.1: {settings}") - data = { - "settings": settings, - "text": text - } - print(f"data_tts1: {data}") - if user_stream: - # 发送请求到 TTS 服务 - response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) - - if response.status_code == 200: - audio_data = response.content - if isinstance(audio_data, bytes): - self.audio_part_counter += 1 # 增加音频段计数器 - file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 - print(f"Audio part saved as {file_name}") - - # 将文件名和是否是最后一条消息放入音频队列 - self.audio_queue.put(file_name) # 放入音频队列 - - else: - print(f"Error: Received non-binary data.") - else: - print(f"Failed to send to TTS: {response.status_code}, Text: {text}") - - else: - print(f"data_tts2: {data}") - response = requests.post(url, headers=headers, data=json.dumps(data)) - if response.status_code == 200: - self.audio_queue.put(response.content) - - # 通知下一个 TTS 可以执行了 - self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 - time.sleep(0.2) - - except queue.Empty: - time.sleep(1) - - def filter_invalid_chars(self,text): - """过滤无效字符(包括字节流)""" - invalid_keywords = ["data:", "\n", "\r", "\t", " "] - - if isinstance(text, bytes): - text = text.decode('utf-8', errors='ignore') - - for keyword in invalid_keywords: - text = text.replace(keyword, "") - - # 移除所有英文字母和符号(保留中文、标点等) - text = re.sub(r'[a-zA-Z]', '', text) - - return text.strip() - - - @logging_time(logger=logger) - def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: - - # 启动聊天流线程 - threading.Thread(target=self.chat_stream, args=(text, settings,), daemon=True).start() - - # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 - threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() - - return {"message": "Chat and TTS processing started"} - - # 新增异步方法,等待音频片段生成 - async def wait_for_audio(self): - while self.audio_queue.empty(): - await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环 - - async def fast_api_handler(self, request: Request) -> Response: - try: - data = await request.json() - text = data.get("text") - setting = data.get("settings") - user_stream = setting.get("tts_stream") - self.is_last = False - self.audio_part_counter = 0 # 音频段计数器 - self.text_part_counter = 0 - self.reset_queues() - self.clear_audio_files() - print(f"data0: {data}") - if text is None: - return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) - - # 调用 processing 方法,并传递动态的 text 参数 - response_data = self.processing(text, settings=setting) - # 根据是否启用流式传输进行处理 - if user_stream: - # 等待至少一个音频片段生成完成 - await self.wait_for_audio() - def audio_stream(): - # 从上游服务器流式读取数据并逐块发送 - while self.audio_part_counter != 0 and not self.is_last: - audio = self.audio_queue.get() - audio_file = audio - if audio_file: - with open(audio_file, "rb") as f: - print(f"Sending audio file: {audio_file}") - yield f.read() # 分段发送音频文件内容 - return StreamingResponse(audio_stream(), media_type="audio/wav") - - else: - # 如果没有启用流式传输,可以返回一个完整的响应或音频文件 - await self.wait_for_audio() - file_name = self.audio_queue.get() - if file_name: - file_name_json = json.loads(file_name.decode('utf-8')) - # audio_files = [] - # while not self.audio_queue.empty(): - # print("9") - # audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 - # 返回多个音频文件 - return JSONResponse(content=file_name_json) - - except Exception as e: - return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index dff10ce..ba4767a 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -25,7 +25,7 @@ class ChromaQuery(Blackbox): self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") - self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) + self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") @@ -60,7 +60,7 @@ class ChromaQuery(Blackbox): chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": - chroma_host = "10.6.44.141" + chroma_host = "192.168.0.200" if chroma_port is None or chroma_port.isspace() or chroma_port == "": chroma_port = "7000" @@ -72,7 +72,7 @@ class ChromaQuery(Blackbox): chroma_n_results = 10 # load client and embedding model from init - if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): + if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: try: diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index f266e0a..c07ba7a 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox): # load embedding model self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) + self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/g2e.py b/src/blackbox/g2e.py index b4ec980..d43bcc0 100755 --- a/src/blackbox/g2e.py +++ b/src/blackbox/g2e.py @@ -23,7 +23,7 @@ class G2E(Blackbox): if context == None: context = [] #url = 'http://120.196.116.194:48890/v1' - url = 'http://10.6.44.141:23333/v1' + url = 'http://192.168.0.200:23333/v1' background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯 diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 7446190..247fd02 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -3,20 +3,18 @@ import time from ntpath import join import requests -from fastapi import Request, Response, status, HTTPException -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse from .blackbox import Blackbox from ..tts.tts_service import TTService from ..configuration import MeloConf from ..configuration import CosyVoiceConf -from ..configuration import SovitsConf from injector import inject from injector import singleton import sys,os sys.path.append('/Workspace/CosyVoice') -sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.cli.cosyvoice import CosyVoice # from cosyvoice.utils.file_utils import load_wav, speed_change import soundfile as sf @@ -31,38 +29,13 @@ import random import torch import numpy as np -from pydub import AudioSegment -import subprocess - def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) -def convert_wav_to_mp3(wav_filename: str): - # 检查文件是否是 .wav 格式 - if not wav_filename.lower().endswith(".wav"): - raise ValueError("The input file must be a .wav file") - # 构建对应的 .mp3 文件路径 - mp3_filename = wav_filename.replace(".wav", ".mp3") - - # 如果原 MP3 文件已存在,先删除 - if os.path.exists(mp3_filename): - os.remove(mp3_filename) - - # 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件 - command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件 - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - # 检查转换是否成功 - if result.returncode != 0: - raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}") - - return mp3_filename - - @singleton class TTS(Blackbox): melo_mode: str @@ -88,11 +61,10 @@ class TTS(Blackbox): self.melo_url = '' self.melo_mode = melo_config.mode self.melotts = None - # self.speaker_ids = None - self.melo_speaker = None + self.speaker_ids = None if self.melo_mode == 'local': self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device) - self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language] + self.speaker_ids = self.melotts.hps.data.spk2id else: self.melo_url = melo_config.url logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') @@ -109,52 +81,19 @@ class TTS(Blackbox): self.cosyvoicetts = None # os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device) if self.cosyvoice_mode == 'local': - # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') - self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') - # self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) - + self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') else: self.cosyvoice_url = cosyvoice_config.url logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') - @logging_time(logger=logger) - def sovits_model_init(self, sovits_config: SovitsConf) -> None: - self.sovits_speed = sovits_config.speed - self.sovits_device = sovits_config.device - self.sovits_language = sovits_config.language - self.sovits_speaker = sovits_config.speaker - self.sovits_url = sovits_config.url - self.sovits_mode = sovits_config.mode - self.sovitstts = None - self.speaker_ids = None - - self.sovits_text_lang = sovits_config.text_lang - self.sovits_ref_audio_path = sovits_config.ref_audio_path - self.sovits_prompt_lang = sovits_config.prompt_lang - self.sovits_prompt_text = sovits_config.prompt_text - self.sovits_text_split_method = sovits_config.text_split_method - self.sovits_batch_size = sovits_config.batch_size - self.sovits_media_type = sovits_config.media_type - self.sovits_streaming_mode = sovits_config.streaming_mode - - if self.sovits_mode == 'local': - # self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device) - self.sovits_speaker = self.sovitstts.hps.data.spk2id - else: - self.sovits_url = sovits_config.url - logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') - print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') - - @inject - def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: + def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None: self.tts_service = TTService("yunfeineo") self.melo_model_init(melo_config) self.cosyvoice_model_init(cosyvoice_config) - self.sovits_model_init(sovits_config) - self.audio_dir = "audio_files" # 存储音频文件的目录 + def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -167,19 +106,14 @@ class TTS(Blackbox): settings = {} user_model_name = settings.get("tts_model_name") chroma_collection_id = settings.get("chroma_collection_id") - user_stream = settings.get('tts_stream') - print(f"chroma_collection_id: {chroma_collection_id}") print(f"tts_model_name: {user_model_name}") - if user_stream in [None, ""]: - user_stream = True - text = args[0] current_time = time.time() if user_model_name == 'melotts': if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.melo_mode == 'local': - audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed) + audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed) f = io.BytesIO() sf.write(f, audio, 44100, format='wav') f.seek(0) @@ -213,7 +147,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -232,8 +166,7 @@ class TTS(Blackbox): if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.cosyvoice_mode == 'local': set_all_random_seed(56056558) - print("*"*90) - audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) + audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -250,7 +183,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -263,65 +196,11 @@ class TTS(Blackbox): } response = requests.post(self.cosyvoice_url, json=message) print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) - return response.content - - elif user_model_name == 'sovitstts': - if chroma_collection_id == 'kiki' or chroma_collection_id is None: - if self.sovits_mode == 'local': - set_all_random_seed(56056558) - audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) - for i, j in enumerate(audio): - f = io.BytesIO() - sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') - f.seek(0) - print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) - return f.read() - else: - message = { - "text": text, - "text_lang": self.sovits_text_lang, - "ref_audio_path": self.sovits_ref_audio_path, - "prompt_lang": self.sovits_prompt_lang, - "prompt_text": self.sovits_prompt_text, - "text_split_method": self.sovits_text_split_method, - "batch_size": self.sovits_batch_size, - "media_type": self.sovits_media_type, - "streaming_mode": self.sovits_streaming_mode - } - if user_stream: - response = requests.get(self.sovits_url, params=message, stream=True) - print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) - return response - else: - response = requests.get(self.sovits_url, params=message) - print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) - return response.content - print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) - - - # elif chroma_collection_id == 'boss': - # if self.cosyvoice_mode == 'local': - # set_all_random_seed(35616313) - # audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) - # for i, j in enumerate(audio): - # f = io.BytesIO() - # sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') - # f.seek(0) - # print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) - # return f.read() - # else: - # message = { - # "text": text - # } - # response = requests.post(self.cosyvoice_url, json=message) - # print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) - # return response.content - - + return response.content elif user_model_name == 'man': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -353,73 +232,7 @@ class TTS(Blackbox): return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) text = data.get("text") setting = data.get("settings") - tts_model_name = setting.get("tts_model_name") - user_stream = setting.get("tts_stream") - if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) by = self.processing(text, settings=setting) - # return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) - - if user_stream: - if by.status_code == 200: - print("*"*90) - def audio_stream(): - # 从上游服务器流式读取数据并逐块发送 - for chunk in by.iter_content(chunk_size=1024): - if chunk: - yield chunk - - return StreamingResponse(audio_stream(), media_type="audio/wav") - else: - raise HTTPException(status_code=by.status_code, detail="请求失败") - - # if user_stream and tts_model_name == 'sovitstts': - # if by.status_code == 200: - # # 保存 WAV 文件 - # wav_filename = os.path.join(self.audio_dir, 'audio.wav') - # with open(wav_filename, 'wb') as f: - # for chunk in by.iter_content(chunk_size=1024): - # if chunk: - # f.write(chunk) - - # try: - # # 检查文件是否存在 - # if not os.path.exists(wav_filename): - # return JSONResponse(content={"error": "File not found"}, status_code=404) - - # # 转换 WAV 为 MP3 - # mp3_filename = convert_wav_to_mp3(wav_filename) - - # # 返回 WAV 和 MP3 文件的下载链接 - # wav_url = f"/audio/{wav_filename}" - # mp3_url = f"/audio/{mp3_filename}" - # return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) - - # except Exception as e: - # return JSONResponse(content={"error": str(e)}, status_code=500) - - # else: - # raise HTTPException(status_code=by.status_code, detail="请求失败") - - - else: - wav_filename = os.path.join(self.audio_dir, 'audio.wav') - with open(wav_filename, 'wb') as f: - f.write(by) - - try: - # 先检查文件是否存在 - if not os.path.exists(wav_filename): - return JSONResponse(content={"error": "File not found"}, status_code=404) - - # 转换 WAV 为 MP3 - mp3_filename = convert_wav_to_mp3(wav_filename) - - # 返回 MP3 文件的下载链接 - wav_url = f"/audio/{wav_filename}" - mp3_url = f"/audio/{mp3_filename}" - return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) - - except Exception as e: - return JSONResponse(content={"error": str(e)}, status_code=500) \ No newline at end of file + return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file diff --git a/src/configuration.py b/src/configuration.py index bac66ea..0c6e0cb 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -82,32 +82,6 @@ class CosyVoiceConf(): self.language = config.get("cosyvoicetts.language") self.speaker = config.get("cosyvoicetts.speaker") -class SovitsConf(): - mode: str - url: str - speed: int - device: str - language: str - speaker: str - - @inject - def __init__(self, config: Configuration) -> None: - self.mode = config.get("sovitstts.mode") - self.url = config.get("sovitstts.url") - self.speed = config.get("sovitstts.speed") - self.device = config.get("sovitstts.device") - self.language = config.get("sovitstts.language") - self.speaker = config.get("sovitstts.speaker") - - self.text_lang = config.get("sovitstts.text_lang") - self.ref_audio_path = config.get("sovitstts.ref_audio_path") - self.prompt_lang = config.get("sovitstts.prompt_lang") - self.prompt_text = config.get("sovitstts.prompt_text") - self.text_split_method = config.get("sovitstts.text_split_method") - self.batch_size = config.get("sovitstts.batch_size") - self.media_type = config.get("sovitstts.media_type") - self.streaming_mode = config.get("sovitstts.streaming_mode") - class SenseVoiceConf(): mode: str url: str