diff --git a/sample/chroma_rerank.py b/sample/chroma_rerank.py index 09be954..af36c37 100644 --- a/sample/chroma_rerank.py +++ b/sample/chroma_rerank.py @@ -7,20 +7,23 @@ from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings import time +from pathlib import Path + +path = Path("/media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI") # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") -loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") +loader = TextLoader("./RAG_boss.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] -embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +embedding_model = SentenceTransformerEmbeddings(model_name= str(path / "bge-m3"), model_kwargs={"device": "cuda:0"}) +client = chromadb.HttpClient(host="localhost", port=7000) -id = "kiki" +id = "boss2" # client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) @@ -28,13 +31,13 @@ db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, c -embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") +embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name= str(path / "bge-m3"), device = "cuda:0") -client = chromadb.HttpClient(host='10.6.44.141', port=7000) +client = chromadb.HttpClient(host='localhost', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) -reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") +reranker_model = CrossEncoder(str(path / "bge-reranker-v2-m3"), max_length=512, device = "cuda:0") # while True: # usr_question = input("\n 请输入问题: ") diff --git a/src/blackbox/asr.py b/src/blackbox/asr.py index 00d5ac6..ee971e8 100644 --- a/src/blackbox/asr.py +++ b/src/blackbox/asr.py @@ -12,6 +12,9 @@ from funasr.utils.postprocess_utils import rich_transcription_postprocess from .blackbox import Blackbox from injector import singleton, inject +from pathlib import Path +from ..configuration import PathConf + import tempfile import json @@ -44,12 +47,13 @@ class ASR(Blackbox): speaker: str @logging_time(logger=logger) - def model_init(self, sensevoice_config: SenseVoiceConf) -> None: + def model_init(self, sensevoice_config: SenseVoiceConf, path: PathConf) -> None: config = read_yaml(".env.yaml") self.paraformer = RapidParaformer(config) + sense_model_path = Path(path.sensevoice_model_path) - model_dir = "/model/Voice/SenseVoice/SenseVoiceSmall" + model_dir = str(sense_model_path) self.speed = sensevoice_config.speed self.device = sensevoice_config.device @@ -65,7 +69,7 @@ class ASR(Blackbox): self.asr = AutoModel( model=model_dir, trust_remote_code=True, - remote_code= "/Workspace/SenseVoice/model.py", + remote_code= "../SenseVoice/model.py", vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 30000}, device=self.device, @@ -76,8 +80,8 @@ class ASR(Blackbox): logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...') @inject - def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None: - self.model_init(sensevoice_config) + def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict, path: PathConf) -> None: + self.model_init(sensevoice_config, path) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 4a84761..e62beb5 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -64,7 +64,7 @@ class Chat(Blackbox): user_stream = settings.get('stream') user_websearch = settings.get('websearch') - llm_model = "vllm" + llm_model = "llm" if user_context == None: user_context = [] @@ -106,9 +106,9 @@ class Chat(Blackbox): if user_model_url is None or user_model_url.isspace() or user_model_url == "": if llm_model != "vllm": - user_model_url = "http://10.6.80.75:23333/v1/chat/completions" + user_model_url = "http://localhost:23333/v1/chat/completions" else: - user_model_url = "http://10.6.80.94:8000/v1/completions" + user_model_url = "http://localhost:8000/v1/completions" if user_model_key is None or user_model_key.isspace() or user_model_key == "": if llm_model != "vllm": diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index dff10ce..f501283 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -12,22 +12,29 @@ from ..log.logging_time import logging_time import re from sentence_transformers import CrossEncoder +from pathlib import Path +from ..configuration import Configuration +from ..configuration import PathConf + logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" from injector import singleton @singleton class ChromaQuery(Blackbox): - def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load chromadb and embedding model - self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0") - self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") - self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") - self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) + + path = PathConf(Configuration()) + self.model_path = Path(path.chroma_rerank_embedding_model) + + self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0") + self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0") + self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0") + self.client_1 = chromadb.HttpClient(host='localhost', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) - self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") + self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -57,10 +64,10 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" + chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5") if chroma_host is None or chroma_host.isspace() or chroma_host == "": - chroma_host = "10.6.44.141" + chroma_host = "localhost" if chroma_port is None or chroma_port.isspace() or chroma_port == "": chroma_port = "7000" @@ -72,7 +79,7 @@ class ChromaQuery(Blackbox): chroma_n_results = 10 # load client and embedding model from init - if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): + if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: try: @@ -80,11 +87,11 @@ class ChromaQuery(Blackbox): except: return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST) - if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): + if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model): embedding_model = self.embedding_model_1 - elif re.search(r"/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model): + elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model): embedding_model = self.embedding_model_2 - elif re.search(r"/Workspace/Models/BAAI/bge-m3", chroma_embedding_model): + elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model): embedding_model = self.embedding_model_3 else: try: @@ -123,7 +130,7 @@ class ChromaQuery(Blackbox): final_result = str(results["documents"]) if chroma_reranker_model: - if re.search(r"/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_reranker_model): + if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model): reranker_model = self.reranker_model_1 else: try: diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index f266e0a..391b761 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -21,6 +21,10 @@ import logging from ..log.logging_time import logging_time import re +from pathlib import Path +from ..configuration import Configuration +from ..configuration import PathConf + logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" @@ -31,9 +35,13 @@ class ChromaUpsert(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model - self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) + + path = PathConf(Configuration()) + self.model_path = Path(path.chroma_rerank_embedding_model) + + self.embedding_model_1 = SentenceTransformerEmbeddings(model_name=str(self.model_path / "bge-large-zh-v1.5"), model_kwargs={"device": "cuda"}) # load chroma db - self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) + self.client_1 = chromadb.HttpClient(host='localhost', port=7000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -79,24 +87,24 @@ class ChromaUpsert(Blackbox): chroma_collection_id = settings.get("chroma_collection_id") if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" + chroma_embedding_model = model_name=str(self.model_path / "bge-large-zh-v1.5") if chroma_host is None or chroma_host.isspace() or chroma_host == "": - chroma_host = "10.6.82.192" + chroma_host = "localhost" if chroma_port is None or chroma_port.isspace() or chroma_port == "": - chroma_port = "8000" + chroma_port = "7000" if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": chroma_collection_id = "g2e" # load client and embedding model from init - if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): + if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: client = chromadb.HttpClient(host=chroma_host, port=chroma_port) print(f"chroma_embedding_model: {chroma_embedding_model}") - if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): + if re.search((self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model): embedding_model = self.embedding_model_1 else: embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0") diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 3228daf..b4d5b78 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -10,14 +10,15 @@ from ..tts.tts_service import TTService from ..configuration import MeloConf from ..configuration import CosyVoiceConf from ..configuration import SovitsConf +from ..configuration import PathConf + from injector import inject from injector import singleton import sys,os -sys.path.append('/Workspace/CosyVoice') -sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 -from cosyvoice.utils.file_utils import load_wav#, speed_change + +from pathlib import Path + import soundfile as sf import pyloudnorm as pyln @@ -100,7 +101,18 @@ class TTS(Blackbox): print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') @logging_time(logger=logger) - def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None: + def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf, path: PathConf) -> None: + + cosy_path = Path(path.cosyvoice_path) + cosy_model_path = Path(path.cosyvoice_model_path) + + sys.path.append(str(cosy_path)) + sys.path.append(str(cosy_path / "third_party/Matcha-TTS")) + + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav#, speed_change + + self.cosyvoice_speed = cosyvoice_config.speed self.cosyvoice_device = cosyvoice_config.device self.cosyvoice_language = cosyvoice_config.language @@ -113,8 +125,8 @@ class TTS(Blackbox): if self.cosyvoice_mode == 'local': # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') # self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') - self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) - self.prompt_speech_16k = load_wav('/Workspace/jarvis-models/Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav', 16000) + self.cosyvoicetts = CosyVoice2(str (cosy_model_path / "CosyVoice2-0.5B"), load_jit=True, load_trt=False) + self.prompt_speech_16k = load_wav("./Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav", 16000) else: self.cosyvoice_url = cosyvoice_config.url @@ -151,10 +163,10 @@ class TTS(Blackbox): @inject - def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: + def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict, path: PathConf) -> None: self.tts_service = TTService("yunfeineo") self.melo_model_init(melo_config) - self.cosyvoice_model_init(cosyvoice_config) + self.cosyvoice_model_init(cosyvoice_config, path) self.sovits_model_init(sovits_config) self.audio_dir = "audio_files" # 存储音频文件的目录 @@ -336,7 +348,7 @@ class TTS(Blackbox): "text_split_method": self.sovits_text_split_method, "batch_size": self.sovits_batch_size, "media_type": self.sovits_media_type, - "streaming_mode": self.sovits_streaming_mode + "streaming_mode": user_stream } if user_stream == True or str(user_stream).lower() == "true": response = requests.get(self.sovits_url, params=message, stream=True) @@ -409,10 +421,11 @@ class TTS(Blackbox): if text is None: return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) by = self.processing(text, settings=setting) - # return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) + if user_stream not in (True, "True", "true"): + return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) print(f"tts user_stream: {type(user_stream)}") - - if user_stream == True or str(user_stream).lower() == "true": + # import pdb; pdb.set_trace() + if user_stream in (True, "True", "true"): print(f"tts user_stream22: {user_stream}") if by.status_code == 200: print("*"*90) diff --git a/src/configuration.py b/src/configuration.py index bac66ea..e726d69 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -180,3 +180,17 @@ class VLMConf(): @inject def __init__(self, config: Configuration) -> None: self.urls = config.get("vlms.urls") + + +@singleton +class PathConf(): + chroma_rerank_embedding_model: str + cosyvoice_path: str + cosyvoice_model_path: str + + @inject + def __init__(self, config: Configuration) -> None: + self.chroma_rerank_embedding_model = config.get("path.chroma_rerank_embedding_model") + self.cosyvoice_path = config.get("path.cosyvoice_path") + self.cosyvoice_model_path = config.get("path.cosyvoice_model_path") + self.sensevoice_model_path = config.get("path.sensevoice_model_path") \ No newline at end of file diff --git a/src/tts/tts_service.py b/src/tts/tts_service.py index 39214b5..8b7c1ab 100644 --- a/src/tts/tts_service.py +++ b/src/tts/tts_service.py @@ -19,7 +19,7 @@ import logging logging.basicConfig(level=logging.INFO) dirbaspath = __file__.split("\\")[1:-1] -dirbaspath= "/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath) +dirbaspath= "./src/tts" + "/".join(dirbaspath) config = { 'ayaka': { 'cfg': dirbaspath + '/models/ayaka.json',