Files
jarvis-models/src/blackbox/chroma_query.py
2024-05-24 10:41:17 +08:00

63 lines
2.1 KiB
Python
Executable File

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import chromadb
from chromadb.utils import embedding_functions
from ..utils import chroma_setting
DEFAULT_COLLECTION_ID = "123"
from injector import singleton
@singleton
class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load embedding model
self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda")
# load chromadb
self.client = chromadb.HttpClient(host='10.6.82.192', port=8000)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
def processing(self, question: str, collection_id) -> str:
# load or create collection
collection = self.client.get_collection(collection_id, embedding_function=self.embedding_model)
# query it
results = collection.query(
query_texts=[question],
n_results=3,
)
response = str(results["documents"] + results["metadatas"])
return response
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_question = data.get("question")
user_collection_id = data.get("collection_id")
if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_collection_id is None:
user_collection_id = DEFAULT_COLLECTION_ID
return JSONResponse(
content={"response": self.processing(user_question, user_collection_id)},
status_code=status.HTTP_200_OK)