mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
14353
sample/20240529_store.txt
Normal file
14353
sample/20240529_store.txt
Normal file
File diff suppressed because one or more lines are too long
199
sample/chroma_client1.py
Normal file
199
sample/chroma_client1.py
Normal file
@ -0,0 +1,199 @@
|
||||
import os
|
||||
import time
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings, HuggingFaceEmbeddings
|
||||
|
||||
|
||||
|
||||
def get_all_files(folder_path):
|
||||
# 获取文件夹下所有文件和文件夹的名称列表
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# 初始化空列表,用于存储所有文件的绝对路径
|
||||
absolute_paths = []
|
||||
|
||||
# 遍历文件和文件夹名称列表
|
||||
for file in files:
|
||||
# 拼接文件的绝对路径
|
||||
absolute_path = os.path.join(folder_path, file)
|
||||
# 如果是文件,将其绝对路径添加到列表中
|
||||
if os.path.isfile(absolute_path):
|
||||
absolute_paths.append(absolute_path)
|
||||
|
||||
return absolute_paths
|
||||
|
||||
# start_time = time.time()
|
||||
# # 加载文档
|
||||
# folder_path = "./text"
|
||||
# txt_files = get_all_files(folder_path)
|
||||
# docs = []
|
||||
# ids = []
|
||||
# for txt_file in txt_files:
|
||||
# loader = PyPDFLoader(txt_file)
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
|
||||
# docs_txt = text_splitter.split_documents(documents)
|
||||
|
||||
# docs.extend(docs_txt)
|
||||
|
||||
# ids.extend([os.path.basename(txt_file) + str(i) for i in range(len(docs_txt))])
|
||||
# start_time1 = time.time()
|
||||
# print(start_time1 - start_time)
|
||||
|
||||
|
||||
# loader = PyPDFLoader("/code/memory/text/大语言模型应用.pdf")
|
||||
# loader = TextLoader("/code/memory/text/test.txt")
|
||||
# loader = CSVLoader("/code/memory/text/test1.csv")
|
||||
# loader = UnstructuredHTMLLoader("/"example_data/fake-content.html"")
|
||||
# pip install docx2txt
|
||||
# loader = Docx2txtLoader("/code/memory/text/tesou.docx")
|
||||
# pip install openpyxl
|
||||
# loader = UnstructuredExcelLoader("/code/memory/text/AI Team Planning 2023.xlsx")
|
||||
# pip install jq
|
||||
# loader = JSONLoader("/code/memory/text/config.json", jq_schema='.', text_content=False)
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
# docs = text_splitter.split_documents(documents)
|
||||
# print(len(docs))
|
||||
# ids = ["大语言模型应用"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载文档和拆分文档
|
||||
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/20240529_store.txt")
|
||||
|
||||
documents = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
||||
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
print("len(docs)", len(docs))
|
||||
|
||||
ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载embedding模型和chroma server
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
|
||||
id = "g2e"
|
||||
client.delete_collection(id)
|
||||
collection_number = client.get_or_create_collection(id).count()
|
||||
print("collection_number",collection_number)
|
||||
start_time2 = time.time()
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
# db = Chroma.from_texts(texts=['test by tom'], embedding=embedding_model, ids=["大语言模型应用0"], persist_directory="./data/test1", collection_name="123", metadatas=[{"source": "string"}])
|
||||
|
||||
start_time3 = time.time()
|
||||
print("insert time ", start_time3 - start_time2)
|
||||
|
||||
|
||||
|
||||
|
||||
# chroma 召回
|
||||
from chromadb.utils import embedding_functions
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
print(collection.count())
|
||||
import time
|
||||
start_time = time.time()
|
||||
query = "如何前往威尼斯人"
|
||||
# query it
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=3,
|
||||
)
|
||||
|
||||
response = results["documents"]
|
||||
print("response: ", response)
|
||||
print("time: ", time.time() - start_time)
|
||||
|
||||
|
||||
# 结合大模型进行总结
|
||||
import requests
|
||||
|
||||
model_name = "Qwen1.5-14B-Chat"
|
||||
chat_inputs={
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
|
||||
}
|
||||
],
|
||||
# "temperature": 0,
|
||||
# "top_p": user_top_p,
|
||||
# "n": user_n,
|
||||
# "max_tokens": user_max_tokens,
|
||||
# "frequency_penalty": user_frequency_penalty,
|
||||
# "presence_penalty": user_presence_penalty,
|
||||
# "stop": 100
|
||||
}
|
||||
|
||||
key ="YOUR_API_KEY"
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
url = "http://172.16.5.8:23333/v1/chat/completions"
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# print(fastchat_response.json())
|
||||
|
||||
print("\n question: ", query)
|
||||
print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
|
||||
|
||||
|
||||
|
||||
|
||||
# start_time4 = time.time()
|
||||
# db = Chroma(
|
||||
# client=client,
|
||||
# collection_name=id,
|
||||
# embedding_function=embedding_model,
|
||||
# )
|
||||
|
||||
# 更新文档
|
||||
# db = db.update_documents(ids, documents)
|
||||
# 删除文档
|
||||
# db.delete([ids])
|
||||
# 删除集合
|
||||
# db.delete_collection()
|
||||
|
||||
# query = "智能体核心思想"
|
||||
# docs = db.similarity_search(query, k=2)
|
||||
|
||||
# print("result: ",docs)
|
||||
# for doc in docs:
|
||||
# print(doc, "\n")
|
||||
|
||||
# start_time5 = time.time()
|
||||
# print("search time ", start_time5 - start_time4)
|
||||
|
||||
# docs = db._collection.get(ids=['大语言模型应用0'])
|
||||
|
||||
# print(docs)
|
||||
|
||||
# docs = db.get(where={"source": "text/大语言模型应用.pdf"})
|
||||
# docs = db.get()
|
||||
# print(docs)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -5,17 +5,22 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
from .blackbox import Blackbox
|
||||
from .chroma_query import ChromaQuery
|
||||
|
||||
import requests
|
||||
import json
|
||||
from openai import OpenAI
|
||||
import re
|
||||
|
||||
from injector import singleton
|
||||
from injector import singleton,inject
|
||||
|
||||
@singleton
|
||||
class Chat(Blackbox):
|
||||
|
||||
@inject
|
||||
def __init__(self, chroma_query: ChromaQuery):
|
||||
self.chroma_query = chroma_query
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -24,7 +29,7 @@ class Chat(Blackbox):
|
||||
return isinstance(data, list)
|
||||
|
||||
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
|
||||
@logging_time()
|
||||
# @logging_time()
|
||||
def processing(self, prompt: str, context: list, settings: dict) -> str:
|
||||
|
||||
if settings is None:
|
||||
@ -42,6 +47,9 @@ class Chat(Blackbox):
|
||||
user_presence_penalty = settings.get("presence_penalty")
|
||||
user_model_url = settings.get("model_url")
|
||||
user_model_key = settings.get("model_key")
|
||||
|
||||
chroma_embedding_model = settings.get("chroma_embedding_model")
|
||||
chroma_response = ''
|
||||
|
||||
if user_context == None:
|
||||
user_context = []
|
||||
@ -82,6 +90,14 @@ class Chat(Blackbox):
|
||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||
user_model_key = "YOUR_API_KEY"
|
||||
|
||||
if chroma_embedding_model != None:
|
||||
chroma_response = self.chroma_query(user_question, settings)
|
||||
print(chroma_response)
|
||||
|
||||
if chroma_response != None or chroma_response != '':
|
||||
user_question = f"问题: {user_question}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。只从检索的内容中选取与问题相关信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{chroma_response}"
|
||||
|
||||
|
||||
# 文心格式和openai的不一样,需要单独处理
|
||||
if re.search(r"ernie", user_model_name):
|
||||
# key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱,只有ernie-speed-128k能用
|
||||
@ -121,12 +137,17 @@ class Chat(Blackbox):
|
||||
url = 'https://api.openai.com/v1/completions'
|
||||
# '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'
|
||||
key = user_model_key
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
# 自定义model
|
||||
else:
|
||||
url = user_model_url
|
||||
key = user_model_key
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": user_template},
|
||||
@ -149,11 +170,6 @@ class Chat(Blackbox):
|
||||
"stop": str(user_stop)
|
||||
}
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
|
||||
return fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
@ -20,9 +20,9 @@ class ChromaQuery(Blackbox):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
# config = read_yaml(args[0])
|
||||
# load chromadb and embedding model
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
# self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda")
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||
self.client_1 = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -32,7 +32,7 @@ class ChromaQuery(Blackbox):
|
||||
data = args[0]
|
||||
return isinstance(data, list)
|
||||
|
||||
@logging_time(logger=logger)
|
||||
# @logging_time(logger=logger)
|
||||
def processing(self, question: str, settings: dict) -> str:
|
||||
|
||||
if settings is None:
|
||||
@ -51,13 +51,13 @@ class ChromaQuery(Blackbox):
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
||||
chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5"
|
||||
chroma_embedding_model = "/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
|
||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||
chroma_host = "10.6.82.192"
|
||||
chroma_host = "172.16.5.8"
|
||||
|
||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||
chroma_port = "8000"
|
||||
chroma_port = "7000"
|
||||
|
||||
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
|
||||
chroma_collection_id = "g2e"
|
||||
@ -66,12 +66,12 @@ class ChromaQuery(Blackbox):
|
||||
chroma_n_results = 3
|
||||
|
||||
# load client and embedding model from init
|
||||
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port):
|
||||
if re.search(r"172.16.5.8", chroma_host) and re.search(r"7000", chroma_port):
|
||||
client = self.client_1
|
||||
else:
|
||||
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
|
||||
|
||||
if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
if re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_1
|
||||
else:
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda")
|
||||
|
||||
Reference in New Issue
Block a user