From 99c9cf83a424373b6a7619860bb853155bbe2c95 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Mon, 29 Apr 2024 02:55:51 +0000 Subject: [PATCH] add chroma_query.py --- src/blackbox/chroma_query.py | 68 ++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100755 src/blackbox/chroma_query.py diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py new file mode 100755 index 0000000..19147dd --- /dev/null +++ b/src/blackbox/chroma_query.py @@ -0,0 +1,68 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from .blackbox import Blackbox + +import requests +import json + +from langchain.text_splitter import CharacterTextSplitter +from langchain_community.document_loaders import TextLoader, DirectoryLoader +import chromadb +from chromadb.utils import embedding_functions +# from langchain_community.embeddings.sentence_transformer import ( +# SentenceTransformerEmbeddings, HuggingFaceEmbeddings +# ) + +class ChromaQuery(Blackbox): + + def __init__(self, *args, **kwargs) -> None: + # config = read_yaml(args[0]) + # load embedding model + # self.embedding_model = embedding_functions.DefaultEmbeddingFunction() + self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + # load chroma db + self.persistent_client = chromadb.PersistentClient(path="./data/test1") + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool: + data = args[0] + return isinstance(data, list) + + def processing(self, question, collection_id, context: list) -> str: + + # load or create collection + collection = persistent_client.get_or_create_collection(collection_id, embedding_function=embedding_model) + + # query it + results = collection.query( + query_texts=[question], + n_results=3, + ) + + response = results["documents"] + results["ids"] + return results + + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + + user_question = data.get("question") + user_context = data.get("context") + user_collection_id = data.get("collection_id") + + if user_question is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + if user_collection_id is None or user_collection_id.isspace(): + user_collection_id = "123" + + return JSONResponse( + content={"response": self.processing(user_question, user_collection_id, user_context)}, + status_code=status.HTTP_200_OK) \ No newline at end of file