From 347eea2f9caf346e1e4b89985ecd3debc5077663 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 28 May 2024 07:06:16 +0000 Subject: [PATCH] refactor: processing of blackbox chroma query and chat --- requirements.txt | 5 +++ src/blackbox/chroma_chat.py | 69 ++++++++++++++---------------------- src/blackbox/chroma_query.py | 19 ++++++---- test.pdf | 0 4 files changed, 44 insertions(+), 49 deletions(-) delete mode 100644 test.pdf diff --git a/requirements.txt b/requirements.txt index 1efc43a..58d7037 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,8 @@ uvicorn==0.29.0 SpeechRecognition==3.10.3 PyYAML==6.0.1 injector==0.21.0 +chromadb==0.5.0 +langchain==0.1.17 +langchain-community==0.0.36 +sentence-transformers==2.7.0 +openai \ No newline at end of file diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index fe3cb3e..09683bd 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -31,64 +31,47 @@ class ChromaChat(Blackbox): @logging_time(logger=logger) def processing(self, question: str, context: list, settings: dict) -> str: + # chroma_chat settings + # { + # "chroma_embedding_model": "bge-large-zh-v1.5", + # "chroma_host": "10.6.82.192", + # "chroma_port": "8000", + # "chroma_collection_id": "123", + # "chroma_n_results": 3, + # "model_name": "Qwen1.5-14B-Chat", + # "context": [], + # "template": "", + # "temperature": 0.8, + # "top_p": 0.8, + # "n": 1, + # "max_tokens": 1024, + # "frequency_penalty": 0.5, + # "presence_penalty": 0.8, + # "stop": 100, + # "model_url": "http://120.196.116.194:48892/v1/chat/completions", + # "model_key": "YOUR_API_KEY" + # } + if settings is None: settings = {} - # # chat setting - user_model_name = settings.get("model_name") user_context = context user_question = question - user_template = settings.get("template") - user_temperature = settings.get("temperature") - user_top_p = settings.get("top_p") - user_n = settings.get("n") - user_max_tokens = settings.get("max_tokens") - user_stop = settings.get("stop") - user_frequency_penalty = settings.get("frequency_penalty") - user_presence_penalty = settings.get("presence_penalty") - # # chroma_query settings - chroma_embedding_model = settings.get("chroma_embedding_model") - chroma_host = settings.get("chroma_host") - chroma_port = settings.get("chroma_port") - chroma_collection_id = settings.get("chroma_collection_id") - chroma_n_results = settings.get("chroma_n_results") - - if context == None: - context = [] + if user_context == None: + user_context = [] if user_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) - chroma_settings_json={ - "chroma_embedding_model": chroma_embedding_model, - "chroma_host": chroma_host, - "chroma_port": chroma_port, - "chroma_collection_id": chroma_collection_id, - "chroma_n_results": chroma_n_results - } - # chroma answer - chroma_result = self.chroma_query(user_question, chroma_settings_json) + chroma_result = self.chroma_query(user_question, settings) # chat prompt - fast_question = f"问题: {user_question}。根据问题,总结以下内容和来源:{chroma_result}" - - chat_settings_json = { - "model_name": user_model_name, - "context": user_context, - "template": user_template, - "temperature": user_temperature, - "top_p": user_top_p, - "n": user_n, - "max_tokens": user_max_tokens, - "stop": user_stop, - "frequency_penalty": user_frequency_penalty, - "presence_penalty": user_presence_penalty - } + fast_question = f"问题: {user_question}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。只从检索的内容中选取与问题相关信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{chroma_result}" # chat answer - response = self.chat(fast_question, chat_settings_json) + response = self.chat(fast_question, user_context, settings) return response diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 22c8581..3a90d8d 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -20,7 +20,7 @@ class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load chromadb and embedding model - self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", device = "cuda") # self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) @@ -51,7 +51,7 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "bge-small-en-v1.5" + chroma_embedding_model = "bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "10.6.82.192" @@ -60,17 +60,22 @@ class ChromaQuery(Blackbox): chroma_port = "8000" if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": - chroma_collection_id = DEFAULT_COLLECTION_ID + chroma_collection_id = "g2e" if chroma_n_results is None or chroma_n_results == "": chroma_n_results = 3 - # load client + # load client and embedding model from init if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): client = self.client_1 + else: + client = chromadb.HttpClient(host=chroma_host, port=chroma_port) - if re.search(r"bge-small-en-v1.5", chroma_embedding_model): + if re.search(r"bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 + else: + chroma_embedding_model = "/model/Weight/BAAI/" + chroma_embedding_model + embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda") # load collection collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) @@ -81,7 +86,9 @@ class ChromaQuery(Blackbox): n_results=3, ) - response = str(results["documents"] + results["metadatas"]) + # response = str(results["documents"] + results["metadatas"]) + response = str(results["documents"]) + return response diff --git a/test.pdf b/test.pdf deleted file mode 100644 index e69de29..0000000