from sentence_transformers import CrossEncoder import chromadb from chromadb.utils import embedding_functions import numpy as np from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings import time # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"}) client = chromadb.HttpClient(host='192.168.0.200', port=7000) id = "boss" client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1") client = chromadb.HttpClient(host='192.168.0.200', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1") while True: usr_question = input("\n 请输入问题: ") # query it time1 = time.time() results = collection.query( query_texts=[usr_question], n_results=10, ) time2 = time.time() print("query time: ", time2 - time1) # print("query: ",usr_question) # print("results: ",print(results["documents"][0])) pairs = [[usr_question, doc] for doc in results["documents"][0]] # print('\n',pairs) scores = reranker_model.predict(pairs) #重新排列文件顺序: print("New Ordering:") i = 0 final_result = '' for o in np.argsort(scores)[::-1]: if i == 3 or scores[o] < 0.5: break i += 1 print(o+1) print("Scores:", scores[o]) print(results["documents"][0][o],'\n') final_result += results["documents"][0][o] + '\n' print("\n final_result: ", final_result) time3 = time.time() print("rerank time: ", time3 - time2)