Files
jarvis-models/src/blackbox/chroma_upsert.py
2024-05-24 10:41:17 +08:00

146 lines
5.9 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import chromadb
import os
import tempfile
from ..utils import chroma_setting
from injector import singleton
@singleton
class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
# load chroma db
self.client = chromadb.HttpClient(host='10.6.82.192', port=8000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
def processing(self, collection_id, file, string, context, setting: chroma_setting) -> str:
# 用户的操作历史
if context is None:
context = []
context = [
{
"collection_id": "123",
"action": "query",
"content": "你吃饭了吗",
"answer": "吃了",
},
{
"collection_id": "123",
"action": "upsert",
"content": "file_name or string",
"answer": "collection 123 has 12472 documents. /tmp/Cheap and QuickEfficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111",
},
]
if collection_id is None and setting.ChromaSetting.collection_ids[0] != []:
collection_id = setting.ChromaSetting.collection_ids[0]
else:
collection_id = "123"
if file is not None:
file_type = file.split(".")[-1]
print("file_type: ",file_type)
if file_type == "pdf":
loader = PyPDFLoader(file)
elif file_type == "txt":
loader = TextLoader(file)
elif file_type == "csv":
loader = CSVLoader(file)
elif file_type == "html":
loader = UnstructuredHTMLLoader(file)
elif file_type == "json":
loader = JSONLoader(file, jq_schema='.', text_content=False)
elif file_type == "docx":
loader = Docx2txtLoader(file)
elif file_type == "xlsx":
loader = UnstructuredExcelLoader(file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
ids = [str(file)+str(i) for i in range(len(docs))]
Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client)
collection_number = self.client.get_collection(collection_id).count()
response_file = f"collection {collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}"
if string is not None:
# 生成一个新的id ids_string: 1
# ids = setting.ChromaSetting.string_ids[0] + 1
ids = "1"
Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client)
collection_number = self.client.get_collection(collection_id).count()
response_string = f"collection {collection_id} has {collection_number} documents. {string} ids is {ids}"
if file is not None and string is not None:
return response_file + " \n and " + response_string
elif file is not None and string is None:
return response_file
elif file is None and string is not None:
return response_string
async def fast_api_handler(self, request: Request) -> Response:
user_collection_id = (await request.form()).get("collection_id")
user_file = (await request.form()).get("file")
user_string = (await request.form()).get("string")
user_context = (await request.form()).get("context")
user_setting = (await request.form()).get("setting")
if user_collection_id is None and user_setting["collections"] == []:
return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is None and user_string is None:
return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is not None:
pdf_bytes = await user_file.read()
custom_filename = user_file.filename
# 获取系统的临时目录路径
safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename))
with open(safe_filename, "wb") as f:
f.write(pdf_bytes)
else:
safe_filename = None
return JSONResponse(
content={"response": self.processing(user_collection_id, safe_filename, user_string, user_context, user_setting)},
status_code=status.HTTP_200_OK)