mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
add sample/chroma_rerank.py
This commit is contained in:
73
sample/chroma_rerank.py
Normal file
73
sample/chroma_rerank.py
Normal file
@ -0,0 +1,73 @@
|
||||
from sentence_transformers import CrossEncoder
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
import numpy as np
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||
import time
|
||||
|
||||
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
|
||||
# loader = TextLoader("/home/administrator/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
|
||||
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_boss.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
docs = text_splitter.split_documents(documents)
|
||||
print("len(docs)", len(docs))
|
||||
ids = ["粤语语料"+str(i) for i in range(len(docs))]
|
||||
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
|
||||
client = chromadb.HttpClient(host='172.16.4.7', port=7000)
|
||||
|
||||
id = "boss"
|
||||
client.delete_collection(id)
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
|
||||
|
||||
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
|
||||
|
||||
client = chromadb.HttpClient(host='172.16.4.7', port=7000)
|
||||
|
||||
collection = client.get_collection(id, embedding_function=embedding_model)
|
||||
|
||||
reranker_model = CrossEncoder("/home/administrator/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
|
||||
|
||||
while True:
|
||||
usr_question = input("\n 请输入问题: ")
|
||||
# query it
|
||||
time1 = time.time()
|
||||
results = collection.query(
|
||||
query_texts=[usr_question],
|
||||
n_results=10,
|
||||
)
|
||||
time2 = time.time()
|
||||
print("query time: ", time2 - time1)
|
||||
|
||||
# print("query: ",usr_question)
|
||||
# print("results: ",print(results["documents"][0]))
|
||||
|
||||
|
||||
pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||||
# print('\n',pairs)
|
||||
scores = reranker_model.predict(pairs)
|
||||
|
||||
#重新排列文件顺序:
|
||||
print("New Ordering:")
|
||||
i = 0
|
||||
final_result = ''
|
||||
for o in np.argsort(scores)[::-1]:
|
||||
if i == 3 or scores[o] < 0.5:
|
||||
break
|
||||
i += 1
|
||||
print(o+1)
|
||||
print("Scores:", scores[o])
|
||||
print(results["documents"][0][o],'\n')
|
||||
final_result += results["documents"][0][o] + '\n'
|
||||
|
||||
print("\n final_result: ", final_result)
|
||||
time3 = time.time()
|
||||
print("rerank time: ", time3 - time2)
|
||||
Reference in New Issue
Block a user