From c5050d5c4368cb5b2f7c2adb0f5c9b474da06670 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Wed, 18 Sep 2024 10:21:18 +0800 Subject: [PATCH] fix: chroma_upsert api --- src/blackbox/chroma_upsert.py | 50 ++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index aad5327..5e0fdf9 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -31,9 +31,9 @@ class ChromaUpsert(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model - self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) + self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) + self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -42,7 +42,7 @@ class ChromaUpsert(Blackbox): data = args[0] return isinstance(data, list) - @logging_time(logger=logger) + # @logging_time(logger=logger) def processing(self, file, string, context: list, settings: dict) -> str: # 用户的操作历史 if context is None: @@ -65,15 +65,21 @@ class ChromaUpsert(Blackbox): if settings is None: settings = {} - + print("\nSettings: ", settings) # # chroma_query settings - chroma_embedding_model = settings.get("chroma_embedding_model") - chroma_host = settings.get("chroma_host") - chroma_port = settings.get("chroma_port") - chroma_collection_id = settings.get("chroma_collection_id") + if "settings" in settings: + chroma_embedding_model = settings["settings"].get("chroma_embedding_model") + chroma_host = settings["settings"].get("chroma_host") + chroma_port = settings["settings"].get("chroma_port") + chroma_collection_id = settings["settings"].get("chroma_collection_id") + else: + chroma_embedding_model = settings.get("chroma_embedding_model") + chroma_host = settings.get("chroma_host") + chroma_port = settings.get("chroma_port") + chroma_collection_id = settings.get("chroma_collection_id") if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5" + chroma_embedding_model = "/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "10.6.82.192" @@ -89,11 +95,11 @@ class ChromaUpsert(Blackbox): client = self.client_1 else: client = chromadb.HttpClient(host=chroma_host, port=chroma_port) - - if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model): + print(f"chroma_embedding_model: {chroma_embedding_model}") + if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 else: - embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda") + embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0") if file is not None: @@ -154,6 +160,12 @@ class ChromaUpsert(Blackbox): context = (await request.form()).get("context") setting: dict = (await request.form()).get("settings") + if isinstance(setting, str): + try: + setting = json.loads(setting) # 尝试将字符串转换为字典 + except json.JSONDecodeError: + return JSONResponse(content={"error": "Invalid settings format"}, status_code=status.HTTP_400_BAD_REQUEST) + if user_file is None and user_string is None: return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -169,7 +181,13 @@ class ChromaUpsert(Blackbox): else: safe_filename = None - - return JSONResponse( - content={"response": self.processing(safe_filename, user_string, context, setting)}, - status_code=status.HTTP_200_OK) \ No newline at end of file + try: + txt = self.processing(safe_filename, user_string, context, setting) + print(txt) + except ValueError as e: + return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) + return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK) + + # return JSONResponse( + # content={"response": self.processing(safe_filename, user_string, context, setting)}, + # status_code=status.HTTP_200_OK) \ No newline at end of file