mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
add chroma upsert
This commit is contained in:
@ -9,6 +9,7 @@ from .tesou import Tesou
|
|||||||
from .fastchat import Fastchat
|
from .fastchat import Fastchat
|
||||||
from .g2e import G2E
|
from .g2e import G2E
|
||||||
from .text_and_image import TextAndImage
|
from .text_and_image import TextAndImage
|
||||||
|
from .chroma_query import ChromaQuery
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
@ -26,7 +27,8 @@ class BlackboxFactory:
|
|||||||
fastchat: Fastchat,
|
fastchat: Fastchat,
|
||||||
audio_chat: AudioChat,
|
audio_chat: AudioChat,
|
||||||
g2e: G2E,
|
g2e: G2E,
|
||||||
text_and_image:TextAndImage ) -> None:
|
text_and_image:TextAndImage,
|
||||||
|
chroma_query: ChromaQuery) -> None:
|
||||||
self.models["audio_to_text"] = audio_to_text
|
self.models["audio_to_text"] = audio_to_text
|
||||||
self.models["text_to_audio"] = text_to_audio
|
self.models["text_to_audio"] = text_to_audio
|
||||||
self.models["asr"] = asr
|
self.models["asr"] = asr
|
||||||
@ -37,6 +39,7 @@ class BlackboxFactory:
|
|||||||
self.models["audio_chat"] = audio_chat
|
self.models["audio_chat"] = audio_chat
|
||||||
self.models["g2e"] = g2e
|
self.models["g2e"] = g2e
|
||||||
self.models["text_and_image"] = text_and_image
|
self.models["text_and_image"] = text_and_image
|
||||||
|
self.models["chroma_query"] = chroma_query
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.processing(*args, **kwargs)
|
return self.processing(*args, **kwargs)
|
||||||
|
|||||||
140
src/blackbox/chroma_upsert.py
Executable file
140
src/blackbox/chroma_upsert.py
Executable file
@ -0,0 +1,140 @@
|
|||||||
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
|
from fastapi import Request, Response, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from .blackbox import Blackbox
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||||
|
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
||||||
|
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
from injector import singleton
|
||||||
|
@singleton
|
||||||
|
class ChromaUpsert(Blackbox):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
# config = read_yaml(args[0])
|
||||||
|
# load embedding model
|
||||||
|
self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
|
||||||
|
# load chroma db
|
||||||
|
self.client = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
|
def valid(self, *args, **kwargs) -> bool:
|
||||||
|
data = args[0]
|
||||||
|
return isinstance(data, list)
|
||||||
|
|
||||||
|
def processing(self, collection_id, file, string, context, setting) -> str:
|
||||||
|
# 用户的操作历史
|
||||||
|
if context is None:
|
||||||
|
context = []
|
||||||
|
|
||||||
|
context = [
|
||||||
|
{
|
||||||
|
"collection_id": "123",
|
||||||
|
"action": "query",
|
||||||
|
"content": "你吃饭了吗"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"collection_id": "123",
|
||||||
|
"action": "upsert",
|
||||||
|
"content": "file_name or string"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 用户的配置文件 每次操作都会更新
|
||||||
|
setting = {
|
||||||
|
# collection_name
|
||||||
|
"collections": ["123", "collection_id2"],
|
||||||
|
# 插入的字符串的id,从1开始
|
||||||
|
"ids_string": [0, 0],
|
||||||
|
# 插入的文件的文件名和ids
|
||||||
|
"ids_file": [
|
||||||
|
# collection_id1 插入的文件 和 对应的ids(列表)
|
||||||
|
{
|
||||||
|
"file_name1": ["file_name1", ids],
|
||||||
|
"file_name2": ["file_name2", ["1","2","3","4"]]
|
||||||
|
},
|
||||||
|
# collection_id2的文件和ids
|
||||||
|
{}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if collection_id is None and setting["collections"][0] != []:
|
||||||
|
collection_id = setting["collections"][0]
|
||||||
|
else:
|
||||||
|
collection_id = "123"
|
||||||
|
|
||||||
|
if file is not None:
|
||||||
|
file_type = file.split(".")[-1]
|
||||||
|
if file_type == "pdf":
|
||||||
|
loader = PyPDFLoader(file)
|
||||||
|
elif file_type == "txt":
|
||||||
|
loader = TextLoader(file)
|
||||||
|
elif file_type == "csv":
|
||||||
|
loader = CSVLoader(file)
|
||||||
|
elif file_type == "html":
|
||||||
|
loader = UnstructuredHTMLLoader(file)
|
||||||
|
elif file_type == "json":
|
||||||
|
loader = JSONLoader(file, jq_schema='.', text_content=False)
|
||||||
|
elif file_type == "docx":
|
||||||
|
loader = Docx2txtLoader(file)
|
||||||
|
elif file_type == "xlsx":
|
||||||
|
loader = UnstructuredExcelLoader(file)
|
||||||
|
|
||||||
|
|
||||||
|
loader = PyPDFLoader(file)
|
||||||
|
documents = loader.load()
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||||
|
|
||||||
|
docs = text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
ids = [str(file)+str(i) for i in range(len(docs))]
|
||||||
|
|
||||||
|
Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client)
|
||||||
|
|
||||||
|
if string is not None:
|
||||||
|
# 生成一个新的id ids_string: 1
|
||||||
|
ids = setting['ids_string'][0] + 1
|
||||||
|
|
||||||
|
Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client)
|
||||||
|
|
||||||
|
|
||||||
|
collection_number = self.client.get_collection(collection_id).count()
|
||||||
|
response = f"collection {collection_id} has {collection_number} documents."
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except:
|
||||||
|
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
user_collection_id = data.get("collection_id")
|
||||||
|
user_file = data.get("file")
|
||||||
|
user_string = data.get("string")
|
||||||
|
user_context = data.get("context")
|
||||||
|
user_setting = data.get("setting")
|
||||||
|
|
||||||
|
if user_collection_id is None and user_setting["collections"] == []:
|
||||||
|
return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
if user_file is None and user_string is None:
|
||||||
|
return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={"response": self.processing(user_collection_id, user_file, user_string, user_context, user_setting)},
|
||||||
|
status_code=status.HTTP_200_OK)
|
||||||
Reference in New Issue
Block a user