mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
73 lines
2.7 KiB
Python
73 lines
2.7 KiB
Python
from sentence_transformers import CrossEncoder
|
||
import chromadb
|
||
from chromadb.utils import embedding_functions
|
||
import numpy as np
|
||
from langchain_community.document_loaders import TextLoader
|
||
from langchain_community.vectorstores import Chroma
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||
import time
|
||
|
||
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
|
||
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
|
||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
|
||
documents = loader.load()
|
||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||
docs = text_splitter.split_documents(documents)
|
||
print("len(docs)", len(docs))
|
||
ids = ["粤语语料"+str(i) for i in range(len(docs))]
|
||
|
||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
|
||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||
|
||
id = "kiki"
|
||
# client.delete_collection(id)
|
||
# 插入向量(如果ids已存在,则会更新向量)
|
||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||
|
||
|
||
|
||
|
||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||
|
||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||
|
||
collection = client.get_collection(id, embedding_function=embedding_model)
|
||
|
||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
|
||
|
||
# while True:
|
||
# usr_question = input("\n 请输入问题: ")
|
||
# # query it
|
||
# time1 = time.time()
|
||
# results = collection.query(
|
||
# query_texts=[usr_question],
|
||
# n_results=10,
|
||
# )
|
||
# time2 = time.time()
|
||
# print("query time: ", time2 - time1)
|
||
|
||
# # print("query: ",usr_question)
|
||
# # print("results: ",print(results["documents"][0]))
|
||
|
||
|
||
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||
# # print('\n',pairs)
|
||
# scores = reranker_model.predict(pairs)
|
||
|
||
# #重新排列文件顺序:
|
||
# print("New Ordering:")
|
||
# i = 0
|
||
# final_result = ''
|
||
# for o in np.argsort(scores)[::-1]:
|
||
# if i == 3 or scores[o] < 0.5:
|
||
# break
|
||
# i += 1
|
||
# print(o+1)
|
||
# print("Scores:", scores[o])
|
||
# print(results["documents"][0][o],'\n')
|
||
# final_result += results["documents"][0][o] + '\n'
|
||
|
||
# print("\n final_result: ", final_result)
|
||
# time3 = time.time()
|
||
# print("rerank time: ", time3 - time2) |