from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse from .blackbox import Blackbox import chromadb from chromadb.utils import embedding_functions from ..utils import chroma_setting DEFAULT_COLLECTION_ID = "123" from injector import singleton @singleton class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") # load chromadb self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) def valid(self, *args, **kwargs) -> bool: data = args[0] return isinstance(data, list) def processing(self, question: str, collection_id) -> str: # load or create collection collection = self.client.get_collection(collection_id, embedding_function=self.embedding_model) # query it results = collection.query( query_texts=[question], n_results=3, ) response = str(results["documents"] + results["metadatas"]) return response async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() except: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_question = data.get("question") user_collection_id = data.get("collection_id") if user_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if user_collection_id is None: user_collection_id = DEFAULT_COLLECTION_ID return JSONResponse( content={"response": self.processing(user_question, user_collection_id)}, status_code=status.HTTP_200_OK)