Files
jarvis-models/sample/chroma_rerank.py
2025-04-03 18:26:05 +08:00

76 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sentence_transformers import CrossEncoder
import chromadb
from chromadb.utils import embedding_functions
import numpy as np
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import time
from pathlib import Path
path = Path("/media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI")
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("./RAG_boss.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name= str(path / "bge-m3"), model_kwargs={"device": "cuda:0"})
client = chromadb.HttpClient(host="localhost", port=7000)
id = "boss2"
# client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name= str(path / "bge-m3"), device = "cuda:0")
client = chromadb.HttpClient(host='localhost', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder(str(path / "bge-reranker-v2-m3"), max_length=512, device = "cuda:0")
# while True:
# usr_question = input("\n 请输入问题: ")
# # query it
# time1 = time.time()
# results = collection.query(
# query_texts=[usr_question],
# n_results=10,
# )
# time2 = time.time()
# print("query time: ", time2 - time1)
# # print("query: ",usr_question)
# # print("results: ",print(results["documents"][0]))
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
# # print('\n',pairs)
# scores = reranker_model.predict(pairs)
# #重新排列文件顺序:
# print("New Ordering:")
# i = 0
# final_result = ''
# for o in np.argsort(scores)[::-1]:
# if i == 3 or scores[o] < 0.5:
# break
# i += 1
# print(o+1)
# print("Scores:", scores[o])
# print(results["documents"][0][o],'\n')
# final_result += results["documents"][0][o] + '\n'
# print("\n final_result: ", final_result)
# time3 = time.time()
# print("rerank time: ", time3 - time2)