From 62c3c6852b24fc9f1c9218093dcb6c39091f8970 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 7 May 2024 01:58:06 +0000 Subject: [PATCH] add chroma upsert --- src/blackbox/blackbox_factory.py | 5 +- src/blackbox/chroma_upsert.py | 140 +++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) create mode 100755 src/blackbox/chroma_upsert.py diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 55ad618..0945cfa 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -9,6 +9,7 @@ from .tesou import Tesou from .fastchat import Fastchat from .g2e import G2E from .text_and_image import TextAndImage +from .chroma_query import ChromaQuery from injector import inject, singleton @singleton @@ -26,7 +27,8 @@ class BlackboxFactory: fastchat: Fastchat, audio_chat: AudioChat, g2e: G2E, - text_and_image:TextAndImage ) -> None: + text_and_image:TextAndImage, + chroma_query: ChromaQuery) -> None: self.models["audio_to_text"] = audio_to_text self.models["text_to_audio"] = text_to_audio self.models["asr"] = asr @@ -37,6 +39,7 @@ class BlackboxFactory: self.models["audio_chat"] = audio_chat self.models["g2e"] = g2e self.models["text_and_image"] = text_and_image + self.models["chroma_query"] = chroma_query def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py new file mode 100755 index 0000000..0a05c35 --- /dev/null +++ b/src/blackbox/chroma_upsert.py @@ -0,0 +1,140 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from .blackbox import Blackbox + +import requests +import json + +from langchain_community.document_loaders.csv_loader import CSVLoader +from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader +from langchain_community.vectorstores import Chroma +from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter +from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings +import chromadb + +from injector import singleton +@singleton +class ChromaUpsert(Blackbox): + + def __init__(self, *args, **kwargs) -> None: + # config = read_yaml(args[0]) + # load embedding model + self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) + # load chroma db + self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool: + data = args[0] + return isinstance(data, list) + + def processing(self, collection_id, file, string, context, setting) -> str: + # 用户的操作历史 + if context is None: + context = [] + + context = [ + { + "collection_id": "123", + "action": "query", + "content": "你吃饭了吗" + }, + { + "collection_id": "123", + "action": "upsert", + "content": "file_name or string" + }, + ] + + # 用户的配置文件 每次操作都会更新 + setting = { + # collection_name + "collections": ["123", "collection_id2"], + # 插入的字符串的id,从1开始 + "ids_string": [0, 0], + # 插入的文件的文件名和ids + "ids_file": [ + # collection_id1 插入的文件 和 对应的ids(列表) + { + "file_name1": ["file_name1", ids], + "file_name2": ["file_name2", ["1","2","3","4"]] + }, + # collection_id2的文件和ids + {} + ] + } + + + + if collection_id is None and setting["collections"][0] != []: + collection_id = setting["collections"][0] + else: + collection_id = "123" + + if file is not None: + file_type = file.split(".")[-1] + if file_type == "pdf": + loader = PyPDFLoader(file) + elif file_type == "txt": + loader = TextLoader(file) + elif file_type == "csv": + loader = CSVLoader(file) + elif file_type == "html": + loader = UnstructuredHTMLLoader(file) + elif file_type == "json": + loader = JSONLoader(file, jq_schema='.', text_content=False) + elif file_type == "docx": + loader = Docx2txtLoader(file) + elif file_type == "xlsx": + loader = UnstructuredExcelLoader(file) + + + loader = PyPDFLoader(file) + documents = loader.load() + text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0) + + docs = text_splitter.split_documents(documents) + + ids = [str(file)+str(i) for i in range(len(docs))] + + Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client) + + if string is not None: + # 生成一个新的id ids_string: 1 + ids = setting['ids_string'][0] + 1 + + Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client) + + + collection_number = self.client.get_collection(collection_id).count() + response = f"collection {collection_id} has {collection_number} documents." + + return response + + + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + + user_collection_id = data.get("collection_id") + user_file = data.get("file") + user_string = data.get("string") + user_context = data.get("context") + user_setting = data.get("setting") + + if user_collection_id is None and user_setting["collections"] == []: + return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST) + + if user_file is None and user_string is None: + return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + return JSONResponse( + content={"response": self.processing(user_collection_id, user_file, user_string, user_context, user_setting)}, + status_code=status.HTTP_200_OK) \ No newline at end of file