Files
jarvis-models/src/blackbox/chroma_query.py
2025-04-03 18:26:05 +08:00

179 lines
7.3 KiB
Python
Executable File

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
import numpy as np
from .blackbox import Blackbox
import chromadb
from chromadb.utils import embedding_functions
import logging
from ..log.logging_time import logging_time
import re
from sentence_transformers import CrossEncoder
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123"
from injector import singleton
@singleton
class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load chromadb and embedding model
path = PathConf(Configuration())
self.model_path = Path(path.chroma_rerank_embedding_model)
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0")
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda")
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# @logging_time(logger=logger)
def processing(self, question: str, settings: dict) -> str:
if settings is None:
settings = {}
usr_question = question
# # chroma_query settings
chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_host = settings.get("chroma_host")
chroma_port = settings.get("chroma_port")
chroma_collection_id = settings.get("chroma_collection_id")
chroma_n_results = settings.get("chroma_n_results")
chroma_reranker_model = settings.get("chroma_reranker_model")
chroma_reranker_num = settings.get("chroma_reranker_num")
if usr_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5")
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "localhost"
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "7000"
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
chroma_collection_id = "g2e"
if chroma_n_results is None or chroma_n_results == "":
chroma_n_results = 10
# load client and embedding model from init
if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1
else:
try:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
except:
return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1
elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_2
elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model):
embedding_model = self.embedding_model_3
else:
try:
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:0")
except:
return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST)
# load collection
collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model)
print(usr_question)
# query it
results = collection.query(
query_texts=[usr_question],
n_results=chroma_n_results,
)
# response = str(results["documents"] + results["metadatas"])
# response = str(results["documents"])
final_result = ''
# if results is not None:
# results_distances = results["distances"][0]
# #distance越高越不准确
# top_distance = 0.8
# for i in range(len(results_distances)):
# if results_distances[i] < top_distance:
# final_result += results["documents"][0][i]
# print("\n final_result: ", final_result)
final_result = str(results["documents"])
if chroma_reranker_model:
if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model):
reranker_model = self.reranker_model_1
else:
try:
reranker_model = CrossEncoder(chroma_reranker_model, max_length=512, device = "cuda")
except:
return JSONResponse(content={"error": "reranker model not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_reranker_num:
if chroma_reranker_num > chroma_n_results:
chroma_reranker_num = chroma_n_results
else:
chroma_reranker_num = 5
#对每一对(查询、文档)进行评分
pairs = [[usr_question, doc] for doc in results["documents"][0]]
print('\n',pairs)
scores = reranker_model.predict(pairs)
#重新排列文件顺序:
print("New Ordering:")
i = 0
final_result = ''
for o in np.argsort(scores)[::-1]:
if i == chroma_reranker_num or scores[o] < 0.5:
break
i += 1
print(o+1)
print("Scores:", scores[o])
print(results["documents"][0][o],'\n')
final_result += results["documents"][0][o] + '\n'
return final_result
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_question = data.get("question")
setting = data.get("settings")
return JSONResponse(
content={"response": self.processing(user_question, setting)},
status_code=status.HTTP_200_OK)