update chroma and chat

This commit is contained in:
0Xiao0
2024-06-02 15:41:07 +08:00
parent a96e845807
commit 179281f032
9 changed files with 15174 additions and 101 deletions

View File

@ -0,0 +1,196 @@
import os
import time
import chromadb
from chromadb.config import Settings
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings, HuggingFaceEmbeddings
def get_all_files(folder_path):
# 获取文件夹下所有文件和文件夹的名称列表
files = os.listdir(folder_path)
# 初始化空列表,用于存储所有文件的绝对路径
absolute_paths = []
# 遍历文件和文件夹名称列表
for file in files:
# 拼接文件的绝对路径
absolute_path = os.path.join(folder_path, file)
# 如果是文件,将其绝对路径添加到列表中
if os.path.isfile(absolute_path):
absolute_paths.append(absolute_path)
return absolute_paths
# start_time = time.time()
# # 加载文档
# folder_path = "./text"
# txt_files = get_all_files(folder_path)
# docs = []
# ids = []
# for txt_file in txt_files:
# loader = PyPDFLoader(txt_file)
# documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
# docs_txt = text_splitter.split_documents(documents)
# docs.extend(docs_txt)
# ids.extend([os.path.basename(txt_file) + str(i) for i in range(len(docs_txt))])
# start_time1 = time.time()
# print(start_time1 - start_time)
# loader = PyPDFLoader("/code/memory/text/大语言模型应用.pdf")
# loader = TextLoader("/code/memory/text/test.txt")
# loader = CSVLoader("/code/memory/text/test1.csv")
# loader = UnstructuredHTMLLoader("/"example_data/fake-content.html"")
# pip install docx2txt
# loader = Docx2txtLoader("/code/memory/text/tesou.docx")
# pip install openpyxl
# loader = UnstructuredExcelLoader("/code/memorinject_prompt = '(用活泼的语气说话回答回答严格限制50字以内)'
# inject_prompt = '(回答简练,不要输出重复内容,只讲中文)'
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
# docs = text_splitter.split_documents(documents)
# print(len(docs))
# ids = ["大语言模型应用"+str(i) for i in range(len(docs))]
# 加载文档和拆分文档
# loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_zh.txt")
# documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
# docs = text_splitter.split_documents(documents)
# print("len(docs)", len(docs))
# ids = ["20240521_store"+str(i) for i in range(len(docs))]
# # 加载embedding模型和chroma server
# embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
# client = chromadb.HttpClient(host='172.16.5.8', port=7000)
# id = "g2e"
# client.delete_collection(id)
# collection_number = client.get_or_create_collection(id).count()
# print("collection_number",collection_number)
# start_time2 = time.time()
# # 插入向量(如果ids已存在则会更新向量)
# db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
# # db = Chroma.from_texts(texts=['test by tom'], embedding=embedding_model, ids=["大语言模型应用0"], persist_directory="./data/test1", collection_name="123", metadatas=[{"source": "string"}])
# start_time3 = time.time()
# print("insert time ", start_time3 - start_time2)
# collection_number = client.get_or_create_collection(id).count()
# print("collection_number",collection_number)
# chroma 召回
from chromadb.utils import embedding_functions
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
collection = client.get_collection("g2e", embedding_function=embedding_model)
print(collection.count())
import time
start_time = time.time()
query = "你知道澳门银河吗"
# query it
results = collection.query(
query_texts=[query],
n_results=5,
)
response = results["documents"]
print("response: ", response)
print("time: ", time.time() - start_time)
# # 结合大模型进行总结
# import requests
# model_name = "Qwen1.5-14B-Chat"
# chat_inputs={
# "model": model_name,
# "messages": [
# {
# "role": "user",
# "content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
# }
# ],
# # "temperature": 0,
# # "top_p": user_top_p,
# # "n": user_n,
# # "max_tokens": user_max_tokens,
# # "frequency_penalty": user_frequency_penalty,
# # "presence_penalty": user_presence_penalty,
# # "stop": 100
# }
# key ="YOUR_API_KEY"
# header = {
# 'Content-Type': 'application/json',
# 'Authorization': "Bearer " + key
# }
# url = "http://172.16.5.8:23333/v1/chat/completions"
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
# # print(fastchat_response.json())
# print("\n question: ", query)
# print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
# start_time4 = time.time()
# db = Chroma(
# client=client,
# collection_name=id,
# embedding_function=embedding_model,
# )
# 更新文档
# db = db.update_documents(ids, documents)
# 删除文档
# db.delete([ids])
# 删除集合
# db.delete_collection()
# query = "智能体核心思想"
# docs = db.similarity_search(query, k=2)
# print("result: ",docs)
# for doc in docs:
# print(doc, "\n")
# start_time5 = time.time()
# print("search time ", start_time5 - start_time4)
# docs = db._collection.get(ids=['大语言模型应用0'])
# print(docs)
# docs = db.get(where={"source": "text/大语言模型应用.pdf"})
# docs = db.get()
# print(docs)