From 37174413fe3cb0e7233b0483ba431e628cb451af Mon Sep 17 00:00:00 2001 From: verachen <511201264@qq.com> Date: Thu, 9 Jan 2025 11:29:34 +0800 Subject: [PATCH] feat: chat2tts stream --- sample/chroma_client1.py | 6 +- sample/chroma_client_en.py | 6 +- sample/chroma_client_query.py | 6 +- sample/chroma_rerank.py | 76 ++++---- server.py | 57 +++++- src/blackbox/blackbox_factory.py | 6 + src/blackbox/chatpipeline.py | 304 +++++++++++++++++++++++++++++++ src/blackbox/chroma_query.py | 6 +- src/blackbox/chroma_upsert.py | 2 +- src/blackbox/g2e.py | 2 +- src/blackbox/tts.py | 213 ++++++++++++++++++++-- src/configuration.py | 26 +++ 12 files changed, 643 insertions(+), 67 deletions(-) create mode 100644 src/blackbox/chatpipeline.py diff --git a/sample/chroma_client1.py b/sample/chroma_client1.py index edce182..4f798fe 100644 --- a/sample/chroma_client1.py +++ b/sample/chroma_client1.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) id = "g2e" #client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_en.py b/sample/chroma_client_en.py index d3e3de9..8ca506c 100644 --- a/sample/chroma_client_en.py +++ b/sample/chroma_client_en.py @@ -85,7 +85,7 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))] # 加载embedding模型和chroma server embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) id = "g2e_english" client.delete_collection(id) @@ -107,7 +107,7 @@ print("collection_number",collection_number) # # chroma 召回 # from chromadb.utils import embedding_functions # embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # collection = client.get_collection("g2e", embedding_function=embedding_model) # print(collection.count()) @@ -152,7 +152,7 @@ print("collection_number",collection_number) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_client_query.py b/sample/chroma_client_query.py index cafbfcc..bcc3331 100644 --- a/sample/chroma_client_query.py +++ b/sample/chroma_client_query.py @@ -81,7 +81,7 @@ def get_all_files(folder_path): # # 加载embedding模型和chroma server # embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"}) -# client = chromadb.HttpClient(host='192.168.0.200', port=7000) +# client = chromadb.HttpClient(host='10.6.44.141', port=7000) # id = "g2e" # client.delete_collection(id) @@ -103,7 +103,7 @@ def get_all_files(folder_path): # chroma 召回 from chromadb.utils import embedding_functions embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) collection = client.get_collection("g2e", embedding_function=embedding_model) print(collection.count()) @@ -148,7 +148,7 @@ print("time: ", time.time() - start_time) # 'Content-Type': 'application/json', # 'Authorization': "Bearer " + key # } -# url = "http://192.168.0.200:23333/v1/chat/completions" +# url = "http://10.6.44.141:23333/v1/chat/completions" # fastchat_response = requests.post(url, json=chat_inputs, headers=header) # # print(fastchat_response.json()) diff --git a/sample/chroma_rerank.py b/sample/chroma_rerank.py index 3093c85..09be954 100644 --- a/sample/chroma_rerank.py +++ b/sample/chroma_rerank.py @@ -10,64 +10,64 @@ import time # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") -loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt") +loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] -embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"}) -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) -id = "boss" -client.delete_collection(id) +id = "kiki" +# client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) -embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1") +embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") -client = chromadb.HttpClient(host='192.168.0.200', port=7000) +client = chromadb.HttpClient(host='10.6.44.141', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) -reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1") +reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") -while True: - usr_question = input("\n 请输入问题: ") - # query it - time1 = time.time() - results = collection.query( - query_texts=[usr_question], - n_results=10, - ) - time2 = time.time() - print("query time: ", time2 - time1) +# while True: +# usr_question = input("\n 请输入问题: ") +# # query it +# time1 = time.time() +# results = collection.query( +# query_texts=[usr_question], +# n_results=10, +# ) +# time2 = time.time() +# print("query time: ", time2 - time1) - # print("query: ",usr_question) - # print("results: ",print(results["documents"][0])) +# # print("query: ",usr_question) +# # print("results: ",print(results["documents"][0])) - pairs = [[usr_question, doc] for doc in results["documents"][0]] - # print('\n',pairs) - scores = reranker_model.predict(pairs) +# pairs = [[usr_question, doc] for doc in results["documents"][0]] +# # print('\n',pairs) +# scores = reranker_model.predict(pairs) - #重新排列文件顺序: - print("New Ordering:") - i = 0 - final_result = '' - for o in np.argsort(scores)[::-1]: - if i == 3 or scores[o] < 0.5: - break - i += 1 - print(o+1) - print("Scores:", scores[o]) - print(results["documents"][0][o],'\n') - final_result += results["documents"][0][o] + '\n' +# #重新排列文件顺序: +# print("New Ordering:") +# i = 0 +# final_result = '' +# for o in np.argsort(scores)[::-1]: +# if i == 3 or scores[o] < 0.5: +# break +# i += 1 +# print(o+1) +# print("Scores:", scores[o]) +# print(results["documents"][0][o],'\n') +# final_result += results["documents"][0][o] + '\n' - print("\n final_result: ", final_result) - time3 = time.time() - print("rerank time: ", time3 - time2) \ No newline at end of file +# print("\n final_result: ", final_result) +# time3 = time.time() +# print("rerank time: ", time3 - time2) \ No newline at end of file diff --git a/server.py b/server.py index d32b58d..60cce03 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ from typing import Union from fastapi import FastAPI, Request, status -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from src.blackbox.blackbox_factory import BlackboxFactory from fastapi.middleware.cors import CORSMiddleware @@ -21,6 +21,7 @@ injector = Injector() blackbox_factory = injector.get(BlackboxFactory) @app.post("/") +@app.get("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -29,3 +30,57 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No except ValueError: return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await box.fast_api_handler(request) + +# @app.get("/audio/{filename}") +# async def serve_audio(filename: str): +# import os +# # 确保文件存在 +# if os.path.exists(filename): +# with open(filename, 'rb') as f: +# audio_data = f.read() +# if filename.endswith(".mp3"): +# # 对于 MP3 文件,设置为 inline 以在浏览器中播放 +# return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) +# elif filename.endswith(".wav"): +# # 对于 WAV 文件,设置为 inline 以在浏览器中播放 +# return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) +# else: +# return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) + +@app.get("/audio/{filename}") +async def serve_audio(filename: str): + import os + import aiofiles + # 确保文件存在 + if os.path.exists(filename): + try: + # 使用 aiofiles 异步读取文件 + async with aiofiles.open(filename, 'rb') as f: + audio_data = await f.read() + + # 根据文件类型返回正确的音频响应 + if filename.endswith(".mp3"): + return Response(content=audio_data, media_type="audio/mpeg", headers={"Content-Disposition": f"inline; filename={filename}"}) + elif filename.endswith(".wav"): + return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": f"inline; filename={filename}"}) + else: + return JSONResponse(content={"error": "Unsupported audio format"}, status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE) + except asyncio.CancelledError: + # 处理任务被取消的情况 + return JSONResponse(content={"error": "Request was cancelled"}, status_code=status.HTTP_400_BAD_REQUEST) + except Exception as e: + # 捕获其他异常 + return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + else: + return JSONResponse(content={"error": f"{filename} not found"}, status_code=status.HTTP_404_NOT_FOUND) + + +# @app.get("/audio/{filename}") +# async def serve_audio(filename: str): +# import os +# file_path = f"audio/{filename}" +# # 检查文件是否存在 +# if os.path.exists(file_path): +# return StreamingResponse(open(file_path, "rb"), media_type="audio/mpeg" if filename.endswith(".mp3") else "audio/wav") +# else: +# return JSONResponse(content={"error": f"{filename} not found"}, status_code=404) \ No newline at end of file diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 0bdf5a9..d7ba14e 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -62,6 +62,11 @@ def tts_loader(): from .tts import TTS return Injector().get(TTS) +@model_loader(lazy=blackboxConf.lazyloading) +def chatpipeline_loader(): + from .chatpipeline import ChatPipeline + return Injector().get(ChatPipeline) + @model_loader(lazy=blackboxConf.lazyloading) def emotion_loader(): from .emotion import Emotion @@ -132,6 +137,7 @@ class BlackboxFactory: self.models["audio_to_text"] = audio_to_text_loader self.models["asr"] = asr_loader self.models["tts"] = tts_loader + self.models["chatpipeline"] = chatpipeline_loader self.models["sentiment_engine"] = sentiment_loader self.models["emotion"] = emotion_loader self.models["fastchat"] = fastchat_loader diff --git a/src/blackbox/chatpipeline.py b/src/blackbox/chatpipeline.py new file mode 100644 index 0000000..55f6d83 --- /dev/null +++ b/src/blackbox/chatpipeline.py @@ -0,0 +1,304 @@ +import re +import requests +import json +import queue +import threading +from fastapi import FastAPI +from fastapi.responses import StreamingResponse + +import os +import io +import time +import requests +from fastapi import Request, Response, status, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from .blackbox import Blackbox + +from injector import inject +from injector import singleton + +from ..log.logging_time import logging_time +import logging +logger = logging.getLogger(__name__) + +import asyncio +from uuid import uuid4 # 用于生成唯一的会话 ID + +@singleton +class ChatPipeline(Blackbox): + + @inject + def __init__(self, ) -> None: + self.text_queue = queue.Queue() # 文本队列 + self.audio_queue = queue.Queue() # 音频队列 + self.PUNCTUATION = r'[。!?、,.,?]' # 标点符号 + self.tts_event = threading.Event() # TTS 事件 + self.audio_part_counter = 0 # 音频段计数器 + self.text_part_counter = 0 + self.audio_dir = "audio_files" # 存储音频文件的目录 + + if not os.path.exists(self.audio_dir): + os.makedirs(self.audio_dir) + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, data: any) -> bool: + if isinstance(data, bytes): + return True + return False + + def reset_queues(self): + """清空之前的队列数据""" + self.text_queue.queue.clear() # 清空文本队列 + self.audio_queue.queue.clear() # 清空音频队列 + + def clear_audio_files(self): + """清空音频文件夹中的所有音频文件""" + for file_name in os.listdir(self.audio_dir): + file_path = os.path.join(self.audio_dir, file_name) + if os.path.isfile(file_path): + os.remove(file_path) + print(f"Removed old audio file: {file_path}") + + def save_audio(self, audio_data, part_number: int): + """保存音频数据为文件""" + file_name = os.path.join(self.audio_dir, f"audio_part_{part_number}.wav") + with open(file_name, 'wb') as f: + f.write(audio_data) + return file_name + + def chat_stream(self, prompt: str): + """从 chat.py 获取实时生成的文本,并放入队列""" + url = 'http://10.6.44.141:8000/?blackbox_name=chat' + headers = {'Content-Type': 'text/plain'} + data = { + "prompt": prompt, + "context": [], + "settings": { + "stream": True + } + } + + # 每次执行时清空原有音频文件 + self.clear_audio_files() + self.audio_part_counter = 0 + self.text_part_counter = 0 + + with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: + complete_message = "" # 用于累积完整的文本 + last_message = False # 标记是否是最后一句 + for line in response.iter_lines(): + if line: + message = line.decode('utf-8') + + if message.strip().lower() == "data:": + continue # 跳过"data:"行 + + complete_message += message + + # 如果包含标点符号,拆分成句子 + if re.search(self.PUNCTUATION, complete_message): + sentences = re.split(self.PUNCTUATION, complete_message) + for sentence in sentences[:-1]: + cleaned_sentence = self.filter_invalid_chars(sentence.strip()) + if cleaned_sentence: + print(f"Sending complete sentence: {cleaned_sentence}") + self.text_queue.put({ + 'text': cleaned_sentence, + 'is_last': False # 当前不是最后一句 + }) # 放入文本队列 + complete_message = sentences[-1] + self.text_part_counter += 1 + print(f"***text_part_counter: {self.text_part_counter}") + # 判断是否是最后一句 + print(f'1.last_message: {last_message}') + if not response.iter_lines(): + last_message = True + print(f'2.last_message: {last_message}') + + time.sleep(0.2) + + # 放入最后一句消息 + print(f'---a---') + if complete_message.strip(): + print('---b---') + cleaned_sentence = self.filter_invalid_chars(complete_message.strip()) + if cleaned_sentence: + print(f"Sending last complete sentence: {cleaned_sentence}") + self.text_queue.put({ + 'text': cleaned_sentence, + 'is_last': True # 最后一条消息 + }) + else: + self.text_queue.put({ + 'text': "结束", + 'is_last': True # 最后一条消息 + }) + def send_to_tts(self, settings: dict): + """从队列中获取文本并发送给 tts.py 进行语音合成""" + url = 'http://10.6.44.141:8000/?blackbox_name=tts' + headers = {'Content-Type': 'text/plain'} + user_stream = settings.get("tts_stream") + while True: + try: + # 获取队列中的一个完整句子 + item = self.text_queue.get(timeout=5) + + if item is None: + break + + text = item['text'] + is_last = item['is_last'] # 判断是否是最后一句 + + if not text.strip(): + continue + + text = self.filter_invalid_chars(text) + + data = { + "settings": settings, + "text": text + } + + # if user_stream: + # # 发送请求到 TTS 服务 + # response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) + + # if response.status_code == 200: + # audio_data = response.content + # if isinstance(audio_data, bytes): + # self.audio_part_counter += 1 # 增加音频段计数器 + # file_name = self.save_audio(audio_data, self.audio_part_counter) # 保存为文件 + # print(f"Audio part saved as {file_name}") + + # # 将文件名和是否是最后一条消息放入音频队列 + # self.audio_queue.put({ + # 'file_name': file_name, + # 'is_last': is_last + # }) # 放入音频队列 + + # # 如果是最后一句,执行额外的处理 + # if is_last: + # print("This is the last sentence in the stream.") + # # 你可以在这里添加处理最后一句的额外逻辑,例如: + # # - 通知系统音频合成已完成 + # # - 结束音频流等 + + # else: + # print(f"Error: Received non-binary data.") + # else: + # print(f"Failed to send to TTS: {response.status_code}, Text: {text}") + + # else: + # response = requests.post(url, headers=headers, data=json.dumps(data)) + # if response.status_code == 200: + # print("1"*90) + # audio_data = response.content + # print(audio_data) + response = requests.post(url, headers=headers, data=json.dumps(data)) + if response.status_code == 200: + print("1"*90) + audio_data = response.content + print(audio_data) + + # 通知下一个 TTS 可以执行了 + self.tts_event.set() # 如果是 threading.Event(),就通知等待的线程 + time.sleep(0.2) + + except queue.Empty: + time.sleep(1) + + def filter_invalid_chars(self,text): + """过滤无效字符(包括字节流)""" + invalid_keywords = ["data:", "\n", "\r", "\t", " "] + + if isinstance(text, bytes): + text = text.decode('utf-8', errors='ignore') + + for keyword in invalid_keywords: + text = text.replace(keyword, "") + + # 移除所有英文字母和符号(保留中文、标点等) + text = re.sub(r'[a-zA-Z]', '', text) + + return text.strip() + + + @logging_time(logger=logger) + def processing(self, text: str, settings: dict) ->str:#-> io.BytesIO: + + # 启动聊天流线程 + threading.Thread(target=self.chat_stream, args=(text,), daemon=True).start() + + # 启动 TTS 线程并保证它在执行下一个 TTS 前完成当前任务 + threading.Thread(target=self.send_to_tts, args=(settings,), daemon=True).start() + + return {"message": "Chat and TTS processing started"} + + # 新增异步方法,等待音频片段生成 + async def wait_for_audio(self): + while self.audio_queue.empty(): + await asyncio.sleep(0.2) # 使用异步 sleep,避免阻塞事件循环 + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + text = data.get("text") + setting = data.get("settings") + user_stream = setting.get("tts_stream") + is_last = False + self.audio_part_counter = 0 # 音频段计数器 + self.text_part_counter = 0 + self.reset_queues() + + if text is None: + return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + # 调用 processing 方法,并传递动态的 text 参数 + response_data = self.processing(text, settings=setting) + print('---1---') + # 根据是否启用流式传输进行处理 + if user_stream: + print('---2---') + # 等待至少一个音频片段生成完成 + await self.wait_for_audio() + print('---3---') + def audio_stream(): + print('---4---') + is_last = False + # 从上游服务器流式读取数据并逐块发送 + print(f'111self.audio_part_counter: {self.audio_part_counter}') + print(f'111self.text_part_counter: {self.text_part_counter}') + print(f"111.self.audio_queue.qsize: {self.audio_queue.qsize()}") + print(f'111is_last: {is_last}') + while self.audio_part_counter != 0 and not is_last: + print('---5---') + print(f'1.is_last: {is_last}') + print(f"222.self.audio_queue.qsize: {self.audio_queue.qsize()}") + audio = self.audio_queue.get() + audio_file = audio['file_name'] + is_last = audio['is_last'] + print(f'2.is_last: {is_last}') + if audio_file: + print('---6---') + with open(audio_file, "rb") as f: + print('---7---') + print(f"Sending audio file: {audio_file}") + yield f.read() # 分段发送音频文件内容 + print('---8---') + print('---9---') + return StreamingResponse(audio_stream(), media_type="audio/wav") + + else: + print('---10---') + # 如果没有启用流式传输,可以返回一个完整的响应或音频文件 + audio_files = [] + while not self.audio_queue.empty(): + audio_files.append(self.audio_queue.get()) # 获取生成的音频文件名 + + # 返回多个音频文件 + return JSONResponse(content={"audio_files": audio_files}) + + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index ba4767a..dff10ce 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -25,7 +25,7 @@ class ChromaQuery(Blackbox): self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") - self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) + self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") @@ -60,7 +60,7 @@ class ChromaQuery(Blackbox): chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": - chroma_host = "192.168.0.200" + chroma_host = "10.6.44.141" if chroma_port is None or chroma_port.isspace() or chroma_port == "": chroma_port = "7000" @@ -72,7 +72,7 @@ class ChromaQuery(Blackbox): chroma_n_results = 10 # load client and embedding model from init - if re.search(r"192.168.0.200", chroma_host) and re.search(r"7000", chroma_port): + if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: try: diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index c07ba7a..f266e0a 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -33,7 +33,7 @@ class ChromaUpsert(Blackbox): # load embedding model self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000) + self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/g2e.py b/src/blackbox/g2e.py index d43bcc0..b4ec980 100755 --- a/src/blackbox/g2e.py +++ b/src/blackbox/g2e.py @@ -23,7 +23,7 @@ class G2E(Blackbox): if context == None: context = [] #url = 'http://120.196.116.194:48890/v1' - url = 'http://192.168.0.200:23333/v1' + url = 'http://10.6.44.141:23333/v1' background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯 diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 247fd02..ac604f9 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -3,18 +3,20 @@ import time from ntpath import join import requests -from fastapi import Request, Response, status -from fastapi.responses import JSONResponse +from fastapi import Request, Response, status, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse from .blackbox import Blackbox from ..tts.tts_service import TTService from ..configuration import MeloConf from ..configuration import CosyVoiceConf +from ..configuration import SovitsConf from injector import inject from injector import singleton import sys,os sys.path.append('/Workspace/CosyVoice') -from cosyvoice.cli.cosyvoice import CosyVoice +sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 # from cosyvoice.utils.file_utils import load_wav, speed_change import soundfile as sf @@ -29,13 +31,38 @@ import random import torch import numpy as np +from pydub import AudioSegment +import subprocess + def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) +def convert_wav_to_mp3(wav_filename: str): + # 检查文件是否是 .wav 格式 + if not wav_filename.lower().endswith(".wav"): + raise ValueError("The input file must be a .wav file") + # 构建对应的 .mp3 文件路径 + mp3_filename = wav_filename.replace(".wav", ".mp3") + + # 如果原 MP3 文件已存在,先删除 + if os.path.exists(mp3_filename): + os.remove(mp3_filename) + + # 使用 FFmpeg 进行转换,直接覆盖原有的 MP3 文件 + command = ['ffmpeg', '-i', wav_filename, '-y', mp3_filename] # `-y` 参数会自动覆盖目标文件 + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # 检查转换是否成功 + if result.returncode != 0: + raise Exception(f"Error converting {wav_filename} to MP3: {result.stderr.decode()}") + + return mp3_filename + + @singleton class TTS(Blackbox): melo_mode: str @@ -61,10 +88,11 @@ class TTS(Blackbox): self.melo_url = '' self.melo_mode = melo_config.mode self.melotts = None - self.speaker_ids = None + # self.speaker_ids = None + self.melo_speaker = None if self.melo_mode == 'local': self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device) - self.speaker_ids = self.melotts.hps.data.spk2id + self.melo_speaker = self.melotts.hps.data.spk2id[self.melo_language] else: self.melo_url = melo_config.url logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') @@ -81,18 +109,51 @@ class TTS(Blackbox): self.cosyvoicetts = None # os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device) if self.cosyvoice_mode == 'local': - self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') + # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') + self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') + # self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) + else: self.cosyvoice_url = cosyvoice_config.url logging.info('#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') print('1.#### Initializing CosyVoiceTTS Service in ' + self.cosyvoice_device + ' mode...') + @logging_time(logger=logger) + def sovits_model_init(self, sovits_config: SovitsConf) -> None: + self.sovits_speed = sovits_config.speed + self.sovits_device = sovits_config.device + self.sovits_language = sovits_config.language + self.sovits_speaker = sovits_config.speaker + self.sovits_url = sovits_config.url + self.sovits_mode = sovits_config.mode + self.sovitstts = None + self.speaker_ids = None + + self.sovits_text_lang = sovits_config.text_lang + self.sovits_ref_audio_path = sovits_config.ref_audio_path + self.sovits_prompt_lang = sovits_config.prompt_lang + self.sovits_prompt_text = sovits_config.prompt_text + self.sovits_text_split_method = sovits_config.text_split_method + self.sovits_batch_size = sovits_config.batch_size + self.sovits_media_type = sovits_config.media_type + self.sovits_streaming_mode = sovits_config.streaming_mode + + if self.sovits_mode == 'local': + # self.sovitsts = MELOTTS(language=self.melo_language, device=self.melo_device) + self.sovits_speaker = self.sovitstts.hps.data.spk2id + else: + self.sovits_url = sovits_config.url + logging.info('#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') + print('1.#### Initializing SoVITS Service in ' + self.sovits_device + ' mode...') + + @inject - def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None: + def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: self.tts_service = TTService("yunfeineo") self.melo_model_init(melo_config) self.cosyvoice_model_init(cosyvoice_config) + self.sovits_model_init(sovits_config) def __call__(self, *args, **kwargs): @@ -106,14 +167,19 @@ class TTS(Blackbox): settings = {} user_model_name = settings.get("tts_model_name") chroma_collection_id = settings.get("chroma_collection_id") + user_stream = settings.get('tts_stream') + print(f"chroma_collection_id: {chroma_collection_id}") print(f"tts_model_name: {user_model_name}") + if user_stream in [None, ""]: + user_stream = True + text = args[0] current_time = time.time() if user_model_name == 'melotts': if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.melo_mode == 'local': - audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed) + audio = self.melotts.tts_to_file(text, self.melo_speaker, speed=self.melo_speed) f = io.BytesIO() sf.write(f, audio, 44100, format='wav') f.seek(0) @@ -147,7 +213,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -166,7 +232,8 @@ class TTS(Blackbox): if chroma_collection_id == 'kiki' or chroma_collection_id is None: if self.cosyvoice_mode == 'local': set_all_random_seed(56056558) - audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language) + print("*"*90) + audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -183,7 +250,7 @@ class TTS(Blackbox): elif chroma_collection_id == 'boss': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -196,11 +263,63 @@ class TTS(Blackbox): } response = requests.post(self.cosyvoice_url, json=message) print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) - return response.content + return response.content + + elif user_model_name == 'sovitstts': + if chroma_collection_id == 'kiki' or chroma_collection_id is None: + if self.sovits_mode == 'local': + set_all_random_seed(56056558) + audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) + for i, j in enumerate(audio): + f = io.BytesIO() + sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + f.seek(0) + print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) + return f.read() + else: + message = { + "text": text, + "text_lang": self.sovits_text_lang, + "ref_audio_path": self.sovits_ref_audio_path, + "prompt_lang": self.sovits_prompt_lang, + "prompt_text": self.sovits_prompt_text, + "text_split_method": self.sovits_text_split_method, + "batch_size": self.sovits_batch_size, + "media_type": self.sovits_media_type, + "streaming_mode": self.sovits_streaming_mode + } + if user_stream: + response = requests.get(self.sovits_url, params=message, stream=True) + return response + else: + response = requests.get(self.sovits_url, params=message) + return response.content + print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) + + + # elif chroma_collection_id == 'boss': + # if self.cosyvoice_mode == 'local': + # set_all_random_seed(35616313) + # audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) + # for i, j in enumerate(audio): + # f = io.BytesIO() + # sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') + # f.seek(0) + # print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) + # return f.read() + # else: + # message = { + # "text": text + # } + # response = requests.post(self.cosyvoice_url, json=message) + # print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) + # return response.content + + elif user_model_name == 'man': if self.cosyvoice_mode == 'local': set_all_random_seed(35616313) - audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5) + audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) for i, j in enumerate(audio): f = io.BytesIO() sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') @@ -232,7 +351,73 @@ class TTS(Blackbox): return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) text = data.get("text") setting = data.get("settings") + tts_model_name = setting.get("tts_model_name") + user_stream = setting.get("tts_stream") + if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) by = self.processing(text, settings=setting) - return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file + # return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) + + if user_stream: + if by.status_code == 200: + print("*"*90) + def audio_stream(): + # 从上游服务器流式读取数据并逐块发送 + for chunk in by.iter_content(chunk_size=1024): + if chunk: + yield chunk + + return StreamingResponse(audio_stream(), media_type="audio/wav") + else: + raise HTTPException(status_code=by.status_code, detail="请求失败") + + # if user_stream and tts_model_name == 'sovitstts': + # if by.status_code == 200: + # # 保存 WAV 文件 + # wav_filename = 'audio.wav' + # with open(wav_filename, 'wb') as f: + # for chunk in by.iter_content(chunk_size=1024): + # if chunk: + # f.write(chunk) + + # try: + # # 检查文件是否存在 + # if not os.path.exists(wav_filename): + # return JSONResponse(content={"error": "File not found"}, status_code=404) + + # # 转换 WAV 为 MP3 + # mp3_filename = convert_wav_to_mp3(wav_filename) + + # # 返回 WAV 和 MP3 文件的下载链接 + # wav_url = f"/audio/{wav_filename}" + # mp3_url = f"/audio/{mp3_filename}" + # return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) + + # except Exception as e: + # return JSONResponse(content={"error": str(e)}, status_code=500) + + # else: + # raise HTTPException(status_code=by.status_code, detail="请求失败") + + + else: + wav_filename = 'audio.wav' + with open(wav_filename, 'wb') as f: + f.write(by) + + try: + # 先检查文件是否存在 + if not os.path.exists(wav_filename): + return JSONResponse(content={"error": "File not found"}, status_code=404) + + # 转换 WAV 为 MP3 + mp3_filename = convert_wav_to_mp3(wav_filename) + + # 返回 MP3 文件的下载链接 + wav_url = f"/audio/{wav_filename}" + mp3_url = f"/audio/{mp3_filename}" + return JSONResponse(content={"wav_url": wav_url, "mp3_url": mp3_url}, status_code=200) + + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) \ No newline at end of file diff --git a/src/configuration.py b/src/configuration.py index 651cb95..f8909a7 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -82,6 +82,32 @@ class CosyVoiceConf(): self.language = config.get("cosyvoicetts.language") self.speaker = config.get("cosyvoicetts.speaker") +class SovitsConf(): + mode: str + url: str + speed: int + device: str + language: str + speaker: str + + @inject + def __init__(self, config: Configuration) -> None: + self.mode = config.get("sovitstts.mode") + self.url = config.get("sovitstts.url") + self.speed = config.get("sovitstts.speed") + self.device = config.get("sovitstts.device") + self.language = config.get("sovitstts.language") + self.speaker = config.get("sovitstts.speaker") + + self.text_lang = config.get("sovitstts.text_lang") + self.ref_audio_path = config.get("sovitstts.ref_audio_path") + self.prompt_lang = config.get("sovitstts.prompt_lang") + self.prompt_text = config.get("sovitstts.prompt_text") + self.text_split_method = config.get("sovitstts.text_split_method") + self.batch_size = config.get("sovitstts.batch_size") + self.media_type = config.get("sovitstts.media_type") + self.streaming_mode = config.get("sovitstts.streaming_mode") + class SenseVoiceConf(): mode: str url: str