style: add path to yaml

This commit is contained in:
0Xiao0
2025-04-03 18:10:32 +08:00
parent 1ee838327e
commit eb9ec3c0bf
8 changed files with 98 additions and 49 deletions

View File

@ -7,20 +7,23 @@ from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import time import time
from pathlib import Path
path = Path("/media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI")
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") loader = TextLoader("./RAG_boss.txt")
documents = loader.load() documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents) docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs)) print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))] ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) embedding_model = SentenceTransformerEmbeddings(model_name= str(path / "bge-m3"), model_kwargs={"device": "cuda:0"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host="localhost", port=7000)
id = "kiki" id = "boss2"
# client.delete_collection(id) # client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量) # 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
@ -28,13 +31,13 @@ db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, c
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name= str(path / "bge-m3"), device = "cuda:0")
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='localhost', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model) collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") reranker_model = CrossEncoder(str(path / "bge-reranker-v2-m3"), max_length=512, device = "cuda:0")
# while True: # while True:
# usr_question = input("\n 请输入问题: ") # usr_question = input("\n 请输入问题: ")

View File

@ -12,6 +12,9 @@ from funasr.utils.postprocess_utils import rich_transcription_postprocess
from .blackbox import Blackbox from .blackbox import Blackbox
from injector import singleton, inject from injector import singleton, inject
from pathlib import Path
from ..configuration import PathConf
import tempfile import tempfile
import json import json
@ -44,12 +47,13 @@ class ASR(Blackbox):
speaker: str speaker: str
@logging_time(logger=logger) @logging_time(logger=logger)
def model_init(self, sensevoice_config: SenseVoiceConf) -> None: def model_init(self, sensevoice_config: SenseVoiceConf, path: PathConf) -> None:
config = read_yaml(".env.yaml") config = read_yaml(".env.yaml")
self.paraformer = RapidParaformer(config) self.paraformer = RapidParaformer(config)
sense_model_path = Path(path.sensevoice_model_path)
model_dir = "/model/Voice/SenseVoice/SenseVoiceSmall" model_dir = str(sense_model_path)
self.speed = sensevoice_config.speed self.speed = sensevoice_config.speed
self.device = sensevoice_config.device self.device = sensevoice_config.device
@ -65,7 +69,7 @@ class ASR(Blackbox):
self.asr = AutoModel( self.asr = AutoModel(
model=model_dir, model=model_dir,
trust_remote_code=True, trust_remote_code=True,
remote_code= "/Workspace/SenseVoice/model.py", remote_code= "../SenseVoice/model.py",
vad_model="fsmn-vad", vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000}, vad_kwargs={"max_single_segment_time": 30000},
device=self.device, device=self.device,
@ -76,8 +80,8 @@ class ASR(Blackbox):
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...') logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...')
@inject @inject
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None: def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict, path: PathConf) -> None:
self.model_init(sensevoice_config) self.model_init(sensevoice_config, path)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -64,7 +64,7 @@ class Chat(Blackbox):
user_stream = settings.get('stream') user_stream = settings.get('stream')
user_websearch = settings.get('websearch') user_websearch = settings.get('websearch')
llm_model = "vllm" llm_model = "llm"
if user_context == None: if user_context == None:
user_context = [] user_context = []
@ -106,9 +106,9 @@ class Chat(Blackbox):
if user_model_url is None or user_model_url.isspace() or user_model_url == "": if user_model_url is None or user_model_url.isspace() or user_model_url == "":
if llm_model != "vllm": if llm_model != "vllm":
user_model_url = "http://10.6.80.75:23333/v1/chat/completions" user_model_url = "http://localhost:23333/v1/chat/completions"
else: else:
user_model_url = "http://10.6.80.94:8000/v1/completions" user_model_url = "http://localhost:8000/v1/completions"
if user_model_key is None or user_model_key.isspace() or user_model_key == "": if user_model_key is None or user_model_key.isspace() or user_model_key == "":
if llm_model != "vllm": if llm_model != "vllm":

View File

@ -12,22 +12,29 @@ from ..log.logging_time import logging_time
import re import re
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123" DEFAULT_COLLECTION_ID = "123"
from injector import singleton from injector import singleton
@singleton @singleton
class ChromaQuery(Blackbox): class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0]) # config = read_yaml(args[0])
# load chromadb and embedding model # load chromadb and embedding model
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") path = PathConf(Configuration())
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") self.model_path = Path(path.chroma_rerank_embedding_model)
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0")
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda")
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -57,10 +64,10 @@ class ChromaQuery(Blackbox):
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5")
if chroma_host is None or chroma_host.isspace() or chroma_host == "": if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.44.141" chroma_host = "localhost"
if chroma_port is None or chroma_port.isspace() or chroma_port == "": if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "7000" chroma_port = "7000"
@ -72,7 +79,7 @@ class ChromaQuery(Blackbox):
chroma_n_results = 10 chroma_n_results = 10
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
try: try:
@ -80,11 +87,11 @@ class ChromaQuery(Blackbox):
except: except:
return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1 embedding_model = self.embedding_model_1
elif re.search(r"/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model): elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_2 embedding_model = self.embedding_model_2
elif re.search(r"/Workspace/Models/BAAI/bge-m3", chroma_embedding_model): elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model):
embedding_model = self.embedding_model_3 embedding_model = self.embedding_model_3
else: else:
try: try:
@ -123,7 +130,7 @@ class ChromaQuery(Blackbox):
final_result = str(results["documents"]) final_result = str(results["documents"])
if chroma_reranker_model: if chroma_reranker_model:
if re.search(r"/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_reranker_model): if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model):
reranker_model = self.reranker_model_1 reranker_model = self.reranker_model_1
else: else:
try: try:

View File

@ -21,6 +21,10 @@ import logging
from ..log.logging_time import logging_time from ..log.logging_time import logging_time
import re import re
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123" DEFAULT_COLLECTION_ID = "123"
@ -31,9 +35,13 @@ class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0]) # config = read_yaml(args[0])
# load embedding model # load embedding model
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
path = PathConf(Configuration())
self.model_path = Path(path.chroma_rerank_embedding_model)
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name=str(self.model_path / "bge-large-zh-v1.5"), model_kwargs={"device": "cuda"})
# load chroma db # load chroma db
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -79,24 +87,24 @@ class ChromaUpsert(Blackbox):
chroma_collection_id = settings.get("chroma_collection_id") chroma_collection_id = settings.get("chroma_collection_id")
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" chroma_embedding_model = model_name=str(self.model_path / "bge-large-zh-v1.5")
if chroma_host is None or chroma_host.isspace() or chroma_host == "": if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.82.192" chroma_host = "localhost"
if chroma_port is None or chroma_port.isspace() or chroma_port == "": if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "8000" chroma_port = "7000"
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
chroma_collection_id = "g2e" chroma_collection_id = "g2e"
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port) client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
print(f"chroma_embedding_model: {chroma_embedding_model}") print(f"chroma_embedding_model: {chroma_embedding_model}")
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): if re.search((self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1 embedding_model = self.embedding_model_1
else: else:
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0") embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0")

View File

@ -10,14 +10,15 @@ from ..tts.tts_service import TTService
from ..configuration import MeloConf from ..configuration import MeloConf
from ..configuration import CosyVoiceConf from ..configuration import CosyVoiceConf
from ..configuration import SovitsConf from ..configuration import SovitsConf
from ..configuration import PathConf
from injector import inject from injector import inject
from injector import singleton from injector import singleton
import sys,os import sys,os
sys.path.append('/Workspace/CosyVoice')
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') from pathlib import Path
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav#, speed_change
import soundfile as sf import soundfile as sf
import pyloudnorm as pyln import pyloudnorm as pyln
@ -100,7 +101,18 @@ class TTS(Blackbox):
print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
@logging_time(logger=logger) @logging_time(logger=logger)
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None: def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf, path: PathConf) -> None:
cosy_path = Path(path.cosyvoice_path)
cosy_model_path = Path(path.cosyvoice_model_path)
sys.path.append(str(cosy_path))
sys.path.append(str(cosy_path / "third_party/Matcha-TTS"))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav#, speed_change
self.cosyvoice_speed = cosyvoice_config.speed self.cosyvoice_speed = cosyvoice_config.speed
self.cosyvoice_device = cosyvoice_config.device self.cosyvoice_device = cosyvoice_config.device
self.cosyvoice_language = cosyvoice_config.language self.cosyvoice_language = cosyvoice_config.language
@ -113,8 +125,8 @@ class TTS(Blackbox):
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') # self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) self.cosyvoicetts = CosyVoice2(str (cosy_model_path / "CosyVoice2-0.5B"), load_jit=True, load_trt=False)
self.prompt_speech_16k = load_wav('/Workspace/jarvis-models/Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav', 16000) self.prompt_speech_16k = load_wav("./Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav", 16000)
else: else:
self.cosyvoice_url = cosyvoice_config.url self.cosyvoice_url = cosyvoice_config.url
@ -151,10 +163,10 @@ class TTS(Blackbox):
@inject @inject
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict, path: PathConf) -> None:
self.tts_service = TTService("yunfeineo") self.tts_service = TTService("yunfeineo")
self.melo_model_init(melo_config) self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config) self.cosyvoice_model_init(cosyvoice_config, path)
self.sovits_model_init(sovits_config) self.sovits_model_init(sovits_config)
self.audio_dir = "audio_files" # 存储音频文件的目录 self.audio_dir = "audio_files" # 存储音频文件的目录
@ -336,7 +348,7 @@ class TTS(Blackbox):
"text_split_method": self.sovits_text_split_method, "text_split_method": self.sovits_text_split_method,
"batch_size": self.sovits_batch_size, "batch_size": self.sovits_batch_size,
"media_type": self.sovits_media_type, "media_type": self.sovits_media_type,
"streaming_mode": self.sovits_streaming_mode "streaming_mode": user_stream
} }
if user_stream == True or str(user_stream).lower() == "true": if user_stream == True or str(user_stream).lower() == "true":
response = requests.get(self.sovits_url, params=message, stream=True) response = requests.get(self.sovits_url, params=message, stream=True)
@ -409,10 +421,11 @@ class TTS(Blackbox):
if text is None: if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting) by = self.processing(text, settings=setting)
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) if user_stream not in (True, "True", "true"):
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
print(f"tts user_stream: {type(user_stream)}") print(f"tts user_stream: {type(user_stream)}")
# import pdb; pdb.set_trace()
if user_stream == True or str(user_stream).lower() == "true": if user_stream in (True, "True", "true"):
print(f"tts user_stream22: {user_stream}") print(f"tts user_stream22: {user_stream}")
if by.status_code == 200: if by.status_code == 200:
print("*"*90) print("*"*90)

View File

@ -180,3 +180,17 @@ class VLMConf():
@inject @inject
def __init__(self, config: Configuration) -> None: def __init__(self, config: Configuration) -> None:
self.urls = config.get("vlms.urls") self.urls = config.get("vlms.urls")
@singleton
class PathConf():
chroma_rerank_embedding_model: str
cosyvoice_path: str
cosyvoice_model_path: str
@inject
def __init__(self, config: Configuration) -> None:
self.chroma_rerank_embedding_model = config.get("path.chroma_rerank_embedding_model")
self.cosyvoice_path = config.get("path.cosyvoice_path")
self.cosyvoice_model_path = config.get("path.cosyvoice_model_path")
self.sensevoice_model_path = config.get("path.sensevoice_model_path")

View File

@ -19,7 +19,7 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1] dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath) dirbaspath= "./src/tts" + "/".join(dirbaspath)
config = { config = {
'ayaka': { 'ayaka': {
'cfg': dirbaspath + '/models/ayaka.json', 'cfg': dirbaspath + '/models/ayaka.json',