Files
jarvis-models/src/blackbox/chroma_upsert.py
2024-10-28 17:38:40 +08:00

193 lines
8.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import chromadb
import os
import tempfile
import logging
from ..log.logging_time import logging_time
import re
logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123"
from injector import singleton
@singleton
class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
# load chroma db
self.client_1 = chromadb.HttpClient(host='192.168.0.200', port=7000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# @logging_time(logger=logger)
def processing(self, file, string, context: list, settings: dict) -> str:
# 用户的操作历史
if context is None:
context = []
# context = [
# {
# "collection_id": "123",
# "action": "query",
# "content": "你吃饭了吗",
# "answer": "吃了",
# },
# {
# "collection_id": "123",
# "action": "upsert",
# "content": "file_name or string",
# "answer": "collection 123 has 12472 documents. /tmp/Cheap and QuickEfficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111",
# },
# ]
if settings is None:
settings = {}
print("\nSettings: ", settings)
# # chroma_query settings
if "settings" in settings:
chroma_embedding_model = settings["settings"].get("chroma_embedding_model")
chroma_host = settings["settings"].get("chroma_host")
chroma_port = settings["settings"].get("chroma_port")
chroma_collection_id = settings["settings"].get("chroma_collection_id")
else:
chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_host = settings.get("chroma_host")
chroma_port = settings.get("chroma_port")
chroma_collection_id = settings.get("chroma_collection_id")
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.82.192"
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "8000"
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
chroma_collection_id = "g2e"
# load client and embedding model from init
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port):
client = self.client_1
else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
print(f"chroma_embedding_model: {chroma_embedding_model}")
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
embedding_model = self.embedding_model_1
else:
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0")
if file is not None:
file_type = file.split(".")[-1]
print("file_type: ",file_type)
if file_type == "pdf":
loader = PyPDFLoader(file)
elif file_type == "txt":
loader = TextLoader(file)
elif file_type == "csv":
loader = CSVLoader(file)
elif file_type == "html":
loader = UnstructuredHTMLLoader(file)
elif file_type == "json":
loader = JSONLoader(file, jq_schema='.', text_content=False)
elif file_type == "docx":
loader = Docx2txtLoader(file)
elif file_type == "xlsx":
loader = UnstructuredExcelLoader(file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
ids = [str(file)+str(i) for i in range(len(docs))]
Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=chroma_collection_id, client=client)
collection_number = client.get_collection(chroma_collection_id).count()
response_file = f"collection {chroma_collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}"
if string is not None:
# 生成一个新的id ids_string: 1
# ids = setting.ChromaSetting.string_ids[0] + 1
ids = "1"
Chroma.from_texts(texts=[string], embedding=embedding_model, ids=[ids], collection_name=chroma_collection_id, client=client)
collection_number = client.get_collection(chroma_collection_id).count()
response_string = f"collection {chroma_collection_id} has {collection_number} documents. {string} ids is {ids}"
if file is not None and string is not None:
return response_file + " \n and " + response_string
elif file is not None and string is None:
return response_file
elif file is None and string is not None:
return response_string
async def fast_api_handler(self, request: Request) -> Response:
user_file = (await request.form()).get("file")
user_string = (await request.form()).get("string")
context = (await request.form()).get("context")
setting: dict = (await request.form()).get("settings")
if isinstance(setting, str):
try:
setting = json.loads(setting) # 尝试将字符串转换为字典
except json.JSONDecodeError:
return JSONResponse(content={"error": "Invalid settings format"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is None and user_string is None:
return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is not None:
pdf_bytes = await user_file.read()
custom_filename = user_file.filename
# 获取系统的临时目录路径
safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename))
with open(safe_filename, "wb") as f:
f.write(pdf_bytes)
else:
safe_filename = None
try:
txt = self.processing(safe_filename, user_string, context, setting)
print(txt)
except ValueError as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
# return JSONResponse(
# content={"response": self.processing(safe_filename, user_string, context, setting)},
# status_code=status.HTTP_200_OK)