From 95d6c3785c9c7fba4065636ec8915a11112096ae Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 7 May 2024 02:17:46 +0000 Subject: [PATCH] add chroma chat --- src/blackbox/chroma_chat.py | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100755 src/blackbox/chroma_chat.py diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py new file mode 100755 index 0000000..c1e47dd --- /dev/null +++ b/src/blackbox/chroma_chat.py @@ -0,0 +1,52 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from .blackbox import Blackbox + +from injector import singleton +@singleton +class ChromaChat(Blackbox): + + def __init__(self, fastchat, chroma_query): + self.fastchat = fastchat + self.chroma_query = chroma_query + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool: + data = args[0] + return isinstance(data, list) + + def processing(self, question, context) -> str: + + # load or create collection + if context is None: + collection_id = "123" + else: + collection_id = context["collections"][0] + # query it + chroma_result = self.chroma_query(question, collection_id) + + fast_question = "问题: "+ question + "。根据问题,总结以下内容:" + chroma_result + response = self.fastchat(fast_question) + + return response + + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + + user_question = data.get("question") + user_context = data.get("context") + + if user_question is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + return JSONResponse( + content={"response": self.processing(user_question, user_context)}, + status_code=status.HTTP_200_OK) \ No newline at end of file