update chroma query

This commit is contained in:
ACBBZ
2024-05-06 02:31:27 +00:00
parent 173fa41386
commit 62b37e7f20

View File

@ -4,26 +4,18 @@ from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader, DirectoryLoader
import chromadb
from chromadb.utils import embedding_functions
# from langchain_community.embeddings.sentence_transformer import (
# SentenceTransformerEmbeddings, HuggingFaceEmbeddings
# )
from injector import singleton
@singleton
class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
# self.embedding_model = embedding_functions.DefaultEmbeddingFunction()
self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda")
# load chromadb
self.persistent_client = chromadb.PersistentClient(path="./data/test1")
self.client = chromadb.HttpClient(host='10.6.82.192', port=8000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
@ -32,10 +24,10 @@ class ChromaQuery(Blackbox):
data = args[0]
return isinstance(data, list)
def processing(self, question, collection_id, context: list) -> str:
def processing(self, question, collection_id) -> str:
# load or create collection
collection = persistent_client.get_or_create_collection(collection_id, embedding_function=embedding_model)
collection = self.client.get_or_create_collection(collection_id, embedding_function=self.embedding_model)
# query it
results = collection.query(
@ -43,8 +35,8 @@ class ChromaQuery(Blackbox):
n_results=3,
)
response = results["documents"] + results["ids"]
return results
response = results["documents"] + results["metadatas"]
return response
async def fast_api_handler(self, request: Request) -> Response:
@ -64,5 +56,5 @@ class ChromaQuery(Blackbox):
user_collection_id = "123"
return JSONResponse(
content={"response": self.processing(user_question, user_collection_id, user_context)},
content={"response": self.processing(user_question, user_collection_id)},
status_code=status.HTTP_200_OK)