diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 4f22d68..ae851e4 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -2,6 +2,7 @@ from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse +import numpy as np from .blackbox import Blackbox import chromadb @@ -9,6 +10,7 @@ from chromadb.utils import embedding_functions import logging from ..log.logging_time import logging_time import re +from sentence_transformers import CrossEncoder logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" @@ -24,6 +26,7 @@ class ChromaQuery(Blackbox): self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1") self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) + self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/bge-reranker-v2-m3", max_length=512, device = "cuda") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -46,6 +49,8 @@ class ChromaQuery(Blackbox): chroma_port = settings.get("chroma_port") chroma_collection_id = settings.get("chroma_collection_id") chroma_n_results = settings.get("chroma_n_results") + chroma_reranker_model = settings.get("chroma_reranker_model") + chroma_reranker_num = settings.get("chroma_reranker_num") if usr_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -63,20 +68,26 @@ class ChromaQuery(Blackbox): chroma_collection_id = "g2e" if chroma_n_results is None or chroma_n_results == "": - chroma_n_results = 3 + chroma_n_results = 10 # load client and embedding model from init if re.search(r"10.6.81.119", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: - client = chromadb.HttpClient(host=chroma_host, port=chroma_port) + try: + client = chromadb.HttpClient(host=chroma_host, port=chroma_port) + except: + return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST) if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 elif re.search(r"/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_2 else: - embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1") + try: + embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1") + except: + return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST) # load collection collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) @@ -93,18 +104,53 @@ class ChromaQuery(Blackbox): final_result = '' - if results is not None: - results_distances = results["distances"][0] + # if results is not None: + # results_distances = results["distances"][0] - #distance越高越不准确 - top_distance = 0.8 + # #distance越高越不准确 + # top_distance = 0.8 - for i in range(len(results_distances)): - if results_distances[i] < top_distance: - final_result += results["documents"][0][i] + # for i in range(len(results_distances)): + # if results_distances[i] < top_distance: + # final_result += results["documents"][0][i] - print("\n final_result: ", final_result) + # print("\n final_result: ", final_result) + + final_result = str(results["documents"]) + + if chroma_reranker_model: + if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_embedding_model): + reranker_model = self.chroma_reranker_model_1 + else: + try: + reranker_model = CrossEncoder(chroma_reranker_model, max_length=512, device = "cuda") + except: + return JSONResponse(content={"error": "reranker model not found"}, status_code=status.HTTP_400_BAD_REQUEST) + + if chroma_reranker_num: + if chroma_reranker_num > chroma_n_results: + chroma_reranker_num = chroma_n_results + else: + chroma_reranker_num = 5 + + #对每一对(查询、文档)进行评分 + pairs = [[usr_question, doc] for doc in results["documents"][0]] + print('\n',pairs) + scores = reranker_model.predict(pairs) + + #重新排列文件顺序: + print("New Ordering:") + i = 0 + final_result = '' + for o in np.argsort(scores)[::-1]: + if i == chroma_reranker_num or scores[o] < 0.5: + break + i += 1 + print(o+1) + print("Scores:", scores[o]) + print(results["documents"][0][o],'\n') + final_result += results["documents"][0][o] + '\n' return final_result