update chroma query

This commit is contained in:
ACBBZ
2024-05-07 06:24:06 +00:00
parent 2138ab0653
commit a6b2d6c3e7

View File

@ -6,6 +6,10 @@ from .blackbox import Blackbox
import chromadb import chromadb
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
from ..utils import chroma_setting
DEFAULT_COLLECTION_ID = "123"
from injector import singleton from injector import singleton
@singleton @singleton
class ChromaQuery(Blackbox): class ChromaQuery(Blackbox):
@ -24,7 +28,7 @@ class ChromaQuery(Blackbox):
data = args[0] data = args[0]
return isinstance(data, list) return isinstance(data, list)
def processing(self, question, collection_id) -> str: def processing(self, question: str, collection_id) -> str:
# load or create collection # load or create collection
collection = self.client.get_or_create_collection(collection_id, embedding_function=self.embedding_model) collection = self.client.get_or_create_collection(collection_id, embedding_function=self.embedding_model)
@ -46,14 +50,13 @@ class ChromaQuery(Blackbox):
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_question = data.get("question") user_question = data.get("question")
user_context = data.get("context")
user_collection_id = data.get("collection_id") user_collection_id = data.get("collection_id")
if user_question is None: if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_collection_id is None or user_collection_id.isspace(): if user_collection_id is None:
user_collection_id = "123" user_collection_id = DEFAULT_COLLECTION_ID
return JSONResponse( return JSONResponse(
content={"response": self.processing(user_question, user_collection_id)}, content={"response": self.processing(user_question, user_collection_id)},