add sample/chroma_rerank.py

This commit is contained in:
2024-10-15 15:52:30 +08:00
parent 31e9e56ab3
commit f4b971d2fd

73
sample/chroma_rerank.py Normal file
View File

@ -0,0 +1,73 @@
from sentence_transformers import CrossEncoder
import chromadb
from chromadb.utils import embedding_functions
import numpy as np
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import time
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/home/administrator/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_boss.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
client = chromadb.HttpClient(host='172.16.4.7', port=7000)
id = "boss"
client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
client = chromadb.HttpClient(host='172.16.4.7', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/home/administrator/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
while True:
usr_question = input("\n 请输入问题: ")
# query it
time1 = time.time()
results = collection.query(
query_texts=[usr_question],
n_results=10,
)
time2 = time.time()
print("query time: ", time2 - time1)
# print("query: ",usr_question)
# print("results: ",print(results["documents"][0]))
pairs = [[usr_question, doc] for doc in results["documents"][0]]
# print('\n',pairs)
scores = reranker_model.predict(pairs)
#重新排列文件顺序:
print("New Ordering:")
i = 0
final_result = ''
for o in np.argsort(scores)[::-1]:
if i == 3 or scores[o] < 0.5:
break
i += 1
print(o+1)
print("Scores:", scores[o])
print(results["documents"][0][o],'\n')
final_result += results["documents"][0][o] + '\n'
print("\n final_result: ", final_result)
time3 = time.time()
print("rerank time: ", time3 - time2)