from sentence_transformers import CrossEncoder import chromadb from chromadb.utils import embedding_functions import numpy as np from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings import time # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) docs = text_splitter.split_documents(documents) print("len(docs)", len(docs)) ids = ["粤语语料"+str(i) for i in range(len(docs))] embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) client = chromadb.HttpClient(host='10.6.44.141', port=7000) id = "kiki" # client.delete_collection(id) # 插入向量(如果ids已存在,则会更新向量) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") client = chromadb.HttpClient(host='10.6.44.141', port=7000) collection = client.get_collection(id, embedding_function=embedding_model) reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") # while True: # usr_question = input("\n 请输入问题: ") # # query it # time1 = time.time() # results = collection.query( # query_texts=[usr_question], # n_results=10, # ) # time2 = time.time() # print("query time: ", time2 - time1) # # print("query: ",usr_question) # # print("results: ",print(results["documents"][0])) # pairs = [[usr_question, doc] for doc in results["documents"][0]] # # print('\n',pairs) # scores = reranker_model.predict(pairs) # #重新排列文件顺序: # print("New Ordering:") # i = 0 # final_result = '' # for o in np.argsort(scores)[::-1]: # if i == 3 or scores[o] < 0.5: # break # i += 1 # print(o+1) # print("Scores:", scores[o]) # print(results["documents"][0][o],'\n') # final_result += results["documents"][0][o] + '\n' # print("\n final_result: ", final_result) # time3 = time.time() # print("rerank time: ", time3 - time2)