diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 4fafb29..72d0262 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -21,7 +21,7 @@ class ChromaQuery(Blackbox): # config = read_yaml(args[0]) # load chromadb and embedding model self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda") - # self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", device = "cuda") self.client_1 = chromadb.HttpClient(host='172.16.5.8', port=7000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) @@ -73,20 +73,40 @@ class ChromaQuery(Blackbox): if re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 + elif re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", chroma_embedding_model): + embedding_model = self.embedding_model_2 else: embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda") # load collection collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) + print(usr_question) # query it results = collection.query( query_texts=[usr_question], - n_results=3, + n_results=chroma_n_results, ) # response = str(results["documents"] + results["metadatas"]) - response = str(results["documents"]) + # response = str(results["documents"]) + + final_result = '' + + if results is not None: + results_distances = results["distances"][0] + + #distance越高越不准确 + top_distance = 0.8 + + + for i in range(len(results_distances)): + if results_distances[i] < top_distance: + final_result += results["documents"][0][i] + + print("\n final_result: ", final_result) + + return final_result return response