diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index f2deaba..8597677 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -6,6 +6,10 @@ from .blackbox import Blackbox import chromadb from chromadb.utils import embedding_functions +from ..utils import chroma_setting + +DEFAULT_COLLECTION_ID = "123" + from injector import singleton @singleton class ChromaQuery(Blackbox): @@ -24,7 +28,7 @@ class ChromaQuery(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, question, collection_id) -> str: + def processing(self, question: str, collection_id) -> str: # load or create collection collection = self.client.get_or_create_collection(collection_id, embedding_function=self.embedding_model) @@ -46,14 +50,13 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_question = data.get("question") - user_context = data.get("context") user_collection_id = data.get("collection_id") if user_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) - if user_collection_id is None or user_collection_id.isspace(): - user_collection_id = "123" + if user_collection_id is None: + user_collection_id = DEFAULT_COLLECTION_ID return JSONResponse( content={"response": self.processing(user_question, user_collection_id)},