from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse import numpy as np from .blackbox import Blackbox import chromadb from chromadb.utils import embedding_functions import logging from ..log.logging_time import logging_time import re from sentence_transformers import CrossEncoder from pathlib import Path from ..configuration import Configuration from ..configuration import PathConf logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" from injector import singleton @singleton class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load chromadb and embedding model path = PathConf(Configuration()) self.model_path = Path(path.chroma_rerank_embedding_model) self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0") self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0") self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0") self.client_1 = chromadb.HttpClient(host='localhost', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) def valid(self, *args, **kwargs) -> bool: data = args[0] return isinstance(data, list) # @logging_time(logger=logger) def processing(self, question: str, settings: dict) -> str: if settings is None: settings = {} usr_question = question # # chroma_query settings chroma_embedding_model = settings.get("chroma_embedding_model") chroma_host = settings.get("chroma_host") chroma_port = settings.get("chroma_port") chroma_collection_id = settings.get("chroma_collection_id") chroma_n_results = settings.get("chroma_n_results") chroma_reranker_model = settings.get("chroma_reranker_model") chroma_reranker_num = settings.get("chroma_reranker_num") if usr_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5") if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "localhost" if chroma_port is None or chroma_port.isspace() or chroma_port == "": chroma_port = "7000" if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": chroma_collection_id = "g2e" if chroma_n_results is None or chroma_n_results == "": chroma_n_results = 10 # load client and embedding model from init if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port): client = self.client_1 else: try: client = chromadb.HttpClient(host=chroma_host, port=chroma_port) except: return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST) if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model): embedding_model = self.embedding_model_1 elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model): embedding_model = self.embedding_model_2 elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model): embedding_model = self.embedding_model_3 else: try: embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:0") except: return JSONResponse(content={"error": "embedding model not found"}, status_code=status.HTTP_400_BAD_REQUEST) # load collection collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) print(usr_question) # query it results = collection.query( query_texts=[usr_question], n_results=chroma_n_results, ) # response = str(results["documents"] + results["metadatas"]) # response = str(results["documents"]) final_result = '' # if results is not None: # results_distances = results["distances"][0] # #distance越高越不准确 # top_distance = 0.8 # for i in range(len(results_distances)): # if results_distances[i] < top_distance: # final_result += results["documents"][0][i] # print("\n final_result: ", final_result) final_result = str(results["documents"]) if chroma_reranker_model: if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model): reranker_model = self.reranker_model_1 else: try: reranker_model = CrossEncoder(chroma_reranker_model, max_length=512, device = "cuda") except: return JSONResponse(content={"error": "reranker model not found"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_reranker_num: if chroma_reranker_num > chroma_n_results: chroma_reranker_num = chroma_n_results else: chroma_reranker_num = 5 #对每一对(查询、文档)进行评分 pairs = [[usr_question, doc] for doc in results["documents"][0]] print('\n',pairs) scores = reranker_model.predict(pairs) #重新排列文件顺序: print("New Ordering:") i = 0 final_result = '' for o in np.argsort(scores)[::-1]: if i == chroma_reranker_num or scores[o] < 0.5: break i += 1 print(o+1) print("Scores:", scores[o]) print(results["documents"][0][o],'\n') final_result += results["documents"][0][o] + '\n' return final_result async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() except: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_question = data.get("question") setting = data.get("settings") return JSONResponse( content={"response": self.processing(user_question, setting)}, status_code=status.HTTP_200_OK)