From f4b971d2fd76050055c516e7ce00ed5e9e9ce90a Mon Sep 17 00:00:00 2001 From: chenyunda218 Date: Tue, 15 Oct 2024 15:52:30 +0800 Subject: [PATCH] add sample/chroma_rerank.py --- sample/chroma_rerank.py | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 sample/chroma_rerank.py diff --git a/sample/chroma_rerank.py b/sample/chroma_rerank.py new file mode 100644 index 0000000..b94f311 --- /dev/null +++ b/sample/chroma_rerank.py @@ -0,0 +1,73 @@ +from sentence_transformers import CrossEncoder +import chromadb +from chromadb.utils import embedding_functions +import numpy as np +from langchain_community.document_loaders import TextLoader +from langchain_community.vectorstores import Chroma +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings +import time + +# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 +# loader = TextLoader("/home/administrator/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") +loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_boss.txt") +documents = loader.load() +text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) +docs = text_splitter.split_documents(documents) +print("len(docs)", len(docs)) +ids = ["粤语语料"+str(i) for i in range(len(docs))] + +embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"}) +client = chromadb.HttpClient(host='172.16.4.7', port=7000) + +id = "boss" +client.delete_collection(id) +# 插入向量(如果ids已存在,则会更新向量) +db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) + + + + +embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-m3", device = "cuda:1") + +client = chromadb.HttpClient(host='172.16.4.7', port=7000) + +collection = client.get_collection(id, embedding_function=embedding_model) + +reranker_model = CrossEncoder("/home/administrator/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1") + +while True: + usr_question = input("\n 请输入问题: ") + # query it + time1 = time.time() + results = collection.query( + query_texts=[usr_question], + n_results=10, + ) + time2 = time.time() + print("query time: ", time2 - time1) + + # print("query: ",usr_question) + # print("results: ",print(results["documents"][0])) + + + pairs = [[usr_question, doc] for doc in results["documents"][0]] + # print('\n',pairs) + scores = reranker_model.predict(pairs) + + #重新排列文件顺序: + print("New Ordering:") + i = 0 + final_result = '' + for o in np.argsort(scores)[::-1]: + if i == 3 or scores[o] < 0.5: + break + i += 1 + print(o+1) + print("Scores:", scores[o]) + print(results["documents"][0][o],'\n') + final_result += results["documents"][0][o] + '\n' + + print("\n final_result: ", final_result) + time3 = time.time() + print("rerank time: ", time3 - time2) \ No newline at end of file