feat: add reranker model in chroma query

This commit is contained in:
tom
2024-10-03 14:16:30 +08:00
parent 8059752437
commit 8cad8abecc

View File

@ -2,6 +2,7 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import numpy as np
from .blackbox import Blackbox from .blackbox import Blackbox
import chromadb import chromadb
@ -9,6 +10,7 @@ from chromadb.utils import embedding_functions
import logging import logging
from ..log.logging_time import logging_time from ..log.logging_time import logging_time
import re import re
from sentence_transformers import CrossEncoder
logger = logging.getLogger logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123" DEFAULT_COLLECTION_ID = "123"
@ -24,6 +26,7 @@ class ChromaQuery(Blackbox):
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1")
self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000) self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/home/gpu/Workspace/Models/bge-reranker-v2-m3", max_length=512, device = "cuda")
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -46,6 +49,8 @@ class ChromaQuery(Blackbox):
chroma_port = settings.get("chroma_port") chroma_port = settings.get("chroma_port")
chroma_collection_id = settings.get("chroma_collection_id") chroma_collection_id = settings.get("chroma_collection_id")
chroma_n_results = settings.get("chroma_n_results") chroma_n_results = settings.get("chroma_n_results")
chroma_reranker_model = settings.get("chroma_reranker_model")
chroma_reranker_num = settings.get("chroma_reranker_num")
if usr_question is None: if usr_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -63,20 +68,26 @@ class ChromaQuery(Blackbox):
chroma_collection_id = "g2e" chroma_collection_id = "g2e"
if chroma_n_results is None or chroma_n_results == "": if chroma_n_results is None or chroma_n_results == "":
chroma_n_results = 3 chroma_n_results = 10
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.81.119", chroma_host) and re.search(r"7000", chroma_port): if re.search(r"10.6.81.119", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port) try:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
except:
return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
embedding_model = self.embedding_model_1 embedding_model = self.embedding_model_1
elif re.search(r"/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model): elif re.search(r"/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model):
embedding_model = self.embedding_model_2 embedding_model = self.embedding_model_2
else: else:
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1") try:
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1")
except:
return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST)
# load collection # load collection
collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model)
@ -93,18 +104,53 @@ class ChromaQuery(Blackbox):
final_result = '' final_result = ''
if results is not None: # if results is not None:
results_distances = results["distances"][0] # results_distances = results["distances"][0]
#distance越高越不准确 # #distance越高越不准确
top_distance = 0.8 # top_distance = 0.8
for i in range(len(results_distances)): # for i in range(len(results_distances)):
if results_distances[i] < top_distance: # if results_distances[i] < top_distance:
final_result += results["documents"][0][i] # final_result += results["documents"][0][i]
print("\n final_result: ", final_result) # print("\n final_result: ", final_result)
final_result = str(results["documents"])
if chroma_reranker_model:
if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_embedding_model):
reranker_model = self.chroma_reranker_model_1
else:
try:
reranker_model = CrossEncoder(chroma_reranker_model, max_length=512, device = "cuda")
except:
return JSONResponse(content={"error": "reranker model not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_reranker_num:
if chroma_reranker_num > chroma_n_results:
chroma_reranker_num = chroma_n_results
else:
chroma_reranker_num = 5
#对每一对(查询、文档)进行评分
pairs = [[usr_question, doc] for doc in results["documents"][0]]
print('\n',pairs)
scores = reranker_model.predict(pairs)
#重新排列文件顺序:
print("New Ordering:")
i = 0
final_result = ''
for o in np.argsort(scores)[::-1]:
if i == chroma_reranker_num or scores[o] < 0.5:
break
i += 1
print(o+1)
print("Scores:", scores[o])
print(results["documents"][0][o],'\n')
final_result += results["documents"][0][o] + '\n'
return final_result return final_result