From 62b37e7f205b0a023cfa98d3509cf2b143dc569a Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Mon, 6 May 2024 02:31:27 +0000 Subject: [PATCH] update chroma query --- src/blackbox/chroma_query.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 19147dd..f2deaba 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -4,26 +4,18 @@ from fastapi import Request, Response, status from fastapi.responses import JSONResponse from .blackbox import Blackbox -import requests -import json - -from langchain.text_splitter import CharacterTextSplitter -from langchain_community.document_loaders import TextLoader, DirectoryLoader import chromadb from chromadb.utils import embedding_functions -# from langchain_community.embeddings.sentence_transformer import ( -# SentenceTransformerEmbeddings, HuggingFaceEmbeddings -# ) - +from injector import singleton +@singleton class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model - # self.embedding_model = embedding_functions.DefaultEmbeddingFunction() self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") - # load chroma db - self.persistent_client = chromadb.PersistentClient(path="./data/test1") + # load chromadb + self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -32,10 +24,10 @@ class ChromaQuery(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, question, collection_id, context: list) -> str: + def processing(self, question, collection_id) -> str: # load or create collection - collection = persistent_client.get_or_create_collection(collection_id, embedding_function=embedding_model) + collection = self.client.get_or_create_collection(collection_id, embedding_function=self.embedding_model) # query it results = collection.query( @@ -43,8 +35,8 @@ class ChromaQuery(Blackbox): n_results=3, ) - response = results["documents"] + results["ids"] - return results + response = results["documents"] + results["metadatas"] + return response async def fast_api_handler(self, request: Request) -> Response: @@ -64,5 +56,5 @@ class ChromaQuery(Blackbox): user_collection_id = "123" return JSONResponse( - content={"response": self.processing(user_question, user_collection_id, user_context)}, + content={"response": self.processing(user_question, user_collection_id)}, status_code=status.HTTP_200_OK) \ No newline at end of file