Files
jarvis-models/src/blackbox/chroma_upsert.py
2025-08-19 11:20:02 +08:00

180 lines
8.0 KiB
Python
Executable File

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader, UnstructuredPDFLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import chromadb
import os
import tempfile
import logging
from ..log.logging_time import logging_time
import re
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123"
from injector import singleton
@singleton
class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
path = PathConf(Configuration())
self.model_path = Path(path.chroma_rerank_embedding_model)
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name=str(self.model_path / "bge-large-zh-v1.5"), model_kwargs={"device": "cuda"})
# load chroma db
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# @logging_time(logger=logger)
def processing(self, file, text, text_ids, settings: dict) -> str:
if settings is None:
settings = {}
print("\nSettings: ", settings)
# # chroma_query settings
if "settings" in settings:
chroma_embedding_model = settings["settings"].get("chroma_embedding_model")
chroma_host = settings["settings"].get("chroma_host", "localhost")
chroma_port = settings["settings"].get("chroma_port", "7000")
chroma_collection_id = settings["settings"].get("chroma_collection_id", DEFAULT_COLLECTION_ID)
user_chunk_size = settings["settings"].get("chunk_size", 256)
user_chunk_overlap = settings["settings"].get("chunk_overlap", 10)
user_separators = settings["settings"].get("separators", ["\n\n"])
else:
chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_host = settings.get("chroma_host", "localhost")
chroma_port = settings.get("chroma_port", "7000")
chroma_collection_id = settings.get("chroma_collection_id", DEFAULT_COLLECTION_ID)
user_chunk_size = settings.get("chunk_size", 256)
user_chunk_overlap = settings.get("chunk_overlap", 10)
user_separators = settings.get("separators", ["\n\n"])
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = model_name=str(self.model_path / "bge-large-zh-v1.5")
# load client and embedding model from init
if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1
else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
print(f"chroma_embedding_model: {chroma_embedding_model}")
if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1
else:
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, model_kwargs={"device": "cuda"})
response_file =''
response_string = ''
if file is not None:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=user_chunk_size, chunk_overlap=user_chunk_overlap, separators=user_separators)
file_type = file.split(".")[-1]
print("file_type: ",file_type)
if file_type == "pdf":
loader = UnstructuredPDFLoader(file)
elif file_type == "txt":
loader = TextLoader(file)
elif file_type == "csv":
loader = CSVLoader(file)
elif file_type == "html":
loader = UnstructuredHTMLLoader(file)
elif file_type == "json":
loader = JSONLoader(file, jq_schema='.', text_content=False)
elif file_type == "docx":
loader = Docx2txtLoader(file)
elif file_type == "xlsx":
loader = UnstructuredExcelLoader(file)
elif file_type == "md":
loader = UnstructuredMarkdownLoader(file, mode="single", strategy="fast")
documents = loader.load()
docs = text_splitter.split_documents(documents)
ids = [str(file)+str(i) for i in range(len(docs))]
Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=chroma_collection_id, client=client)
response_file = f"\n{file} ids is 0~{len(docs)-1}"
if text is not None and text_ids is not None:
Chroma.from_texts(texts=[text], embedding=embedding_model, ids=[text_ids], collection_name=chroma_collection_id, client=client)
response_string = f"\n{text} ids is {text_ids}"
vector_count = client.get_collection(chroma_collection_id).count()
response_documents_num = f"collection {chroma_collection_id} has {vector_count} vectors."
print(client.get_collection(chroma_collection_id).get())
return response_documents_num + response_file + response_string
async def fast_api_handler(self, request: Request) -> Response:
user_file = (await request.form()).get("file")
user_text = (await request.form()).get("text")
user_text_ids = (await request.form()).get("text_ids")
setting: dict = (await request.form()).get("settings")
if isinstance(setting, str):
try:
setting = json.loads(setting) # 尝试将字符串转换为字典
except json.JSONDecodeError:
return JSONResponse(content={"error": "Invalid settings format"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is None and user_text is None:
return JSONResponse(content={"error": "file or text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_text is not None and user_text_ids is None:
return JSONResponse(content={"error": "text_ids is required when text is provided"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is not None and user_file.size != 0:
pdf_bytes = await user_file.read()
custom_filename = user_file.filename
# 获取系统的临时目录路径
safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename))
print("file_path", safe_filename)
with open(safe_filename, "wb") as f:
f.write(pdf_bytes)
else:
safe_filename = None
try:
txt = self.processing(safe_filename, user_text, user_text_ids, setting)
print(txt)
except ValueError as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
# return JSONResponse(
# content={"response": self.processing(safe_filename, user_string, context, setting)},
# status_code=status.HTTP_200_OK)