add chroma upsert

This commit is contained in:
ACBBZ
2024-05-07 01:58:06 +00:00
parent 62b37e7f20
commit 62c3c6852b
2 changed files with 144 additions and 1 deletions

View File

@ -9,6 +9,7 @@ from .tesou import Tesou
from .fastchat import Fastchat from .fastchat import Fastchat
from .g2e import G2E from .g2e import G2E
from .text_and_image import TextAndImage from .text_and_image import TextAndImage
from .chroma_query import ChromaQuery
from injector import inject, singleton from injector import inject, singleton
@singleton @singleton
@ -26,7 +27,8 @@ class BlackboxFactory:
fastchat: Fastchat, fastchat: Fastchat,
audio_chat: AudioChat, audio_chat: AudioChat,
g2e: G2E, g2e: G2E,
text_and_image:TextAndImage ) -> None: text_and_image:TextAndImage,
chroma_query: ChromaQuery) -> None:
self.models["audio_to_text"] = audio_to_text self.models["audio_to_text"] = audio_to_text
self.models["text_to_audio"] = text_to_audio self.models["text_to_audio"] = text_to_audio
self.models["asr"] = asr self.models["asr"] = asr
@ -37,6 +39,7 @@ class BlackboxFactory:
self.models["audio_chat"] = audio_chat self.models["audio_chat"] = audio_chat
self.models["g2e"] = g2e self.models["g2e"] = g2e
self.models["text_and_image"] = text_and_image self.models["text_and_image"] = text_and_image
self.models["chroma_query"] = chroma_query
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

140
src/blackbox/chroma_upsert.py Executable file
View File

@ -0,0 +1,140 @@
from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import chromadb
from injector import singleton
@singleton
class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
# load chroma db
self.client = chromadb.HttpClient(host='10.6.82.192', port=8000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
def processing(self, collection_id, file, string, context, setting) -> str:
# 用户的操作历史
if context is None:
context = []
context = [
{
"collection_id": "123",
"action": "query",
"content": "你吃饭了吗"
},
{
"collection_id": "123",
"action": "upsert",
"content": "file_name or string"
},
]
# 用户的配置文件 每次操作都会更新
setting = {
# collection_name
"collections": ["123", "collection_id2"],
# 插入的字符串的id从1开始
"ids_string": [0, 0],
# 插入的文件的文件名和ids
"ids_file": [
# collection_id1 插入的文件 和 对应的ids列表
{
"file_name1": ["file_name1", ids],
"file_name2": ["file_name2", ["1","2","3","4"]]
},
# collection_id2的文件和ids
{}
]
}
if collection_id is None and setting["collections"][0] != []:
collection_id = setting["collections"][0]
else:
collection_id = "123"
if file is not None:
file_type = file.split(".")[-1]
if file_type == "pdf":
loader = PyPDFLoader(file)
elif file_type == "txt":
loader = TextLoader(file)
elif file_type == "csv":
loader = CSVLoader(file)
elif file_type == "html":
loader = UnstructuredHTMLLoader(file)
elif file_type == "json":
loader = JSONLoader(file, jq_schema='.', text_content=False)
elif file_type == "docx":
loader = Docx2txtLoader(file)
elif file_type == "xlsx":
loader = UnstructuredExcelLoader(file)
loader = PyPDFLoader(file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
ids = [str(file)+str(i) for i in range(len(docs))]
Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client)
if string is not None:
# 生成一个新的id ids_string: 1
ids = setting['ids_string'][0] + 1
Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client)
collection_number = self.client.get_collection(collection_id).count()
response = f"collection {collection_id} has {collection_number} documents."
return response
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_collection_id = data.get("collection_id")
user_file = data.get("file")
user_string = data.get("string")
user_context = data.get("context")
user_setting = data.get("setting")
if user_collection_id is None and user_setting["collections"] == []:
return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is None and user_string is None:
return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(
content={"response": self.processing(user_collection_id, user_file, user_string, user_context, user_setting)},
status_code=status.HTTP_200_OK)