From 76971c87f59dcda8fc50e45a632a0c113e622745 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Mon, 27 May 2024 08:42:58 +0000 Subject: [PATCH 1/4] fix: chroma setting url and key --- src/blackbox/chat.py | 12 +++++- src/blackbox/chroma_chat.py | 72 ++++++++++++++++++++++++++++++------ src/blackbox/chroma_query.py | 70 ++++++++++++++++++++++++++--------- 3 files changed, 124 insertions(+), 30 deletions(-) diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 369c228..bfee46c 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -40,6 +40,9 @@ class Chat(Blackbox): user_stop = settings.get("stop") user_frequency_penalty = settings.get("frequency_penalty") user_presence_penalty = settings.get("presence_penalty") + user_model_url = settings.get("model_url") + user_model_key = settings.get("model_key") + if user_context == None: user_context = [] @@ -72,15 +75,20 @@ class Chat(Blackbox): if user_presence_penalty is None or user_presence_penalty == "": user_presence_penalty = 0.8 + + if user_model_url is None or user_model_url.isspace() or user_model_url == "": + user_model_url = "http://120.196.116.194:48892/v1/chat/completions" + if user_model_key is None or user_model_key.isspace() or user_model_key == "": + user_model_key = "YOUR_API_KEY" # gpt-4, gpt-3.5-turbo if re.search(r"gpt", user_model_name): url = 'https://api.openai.com/v1/completions' key = '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' else: - url = 'http://120.196.116.194:48892/v1/chat/completions' - key = 'YOUR_API_KEY' + url = user_model_url + key = user_model_key prompt_template = [ {"role": "system", "content": user_template}, diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index 01685f3..fe3cb3e 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -6,7 +6,10 @@ from .blackbox import Blackbox from .chat import Chat from .chroma_query import ChromaQuery +from ..log.logging_time import logging_time +import logging +logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" from injector import singleton,inject @@ -24,19 +27,68 @@ class ChromaChat(Blackbox): def valid(self, *args, **kwargs) -> bool: data = args[0] return isinstance(data, list) + + @logging_time(logger=logger) + def processing(self, question: str, context: list, settings: dict) -> str: + + if settings is None: + settings = {} + + # # chat setting + user_model_name = settings.get("model_name") + user_context = context + user_question = question + user_template = settings.get("template") + user_temperature = settings.get("temperature") + user_top_p = settings.get("top_p") + user_n = settings.get("n") + user_max_tokens = settings.get("max_tokens") + user_stop = settings.get("stop") + user_frequency_penalty = settings.get("frequency_penalty") + user_presence_penalty = settings.get("presence_penalty") + + # # chroma_query settings + chroma_embedding_model = settings.get("chroma_embedding_model") + chroma_host = settings.get("chroma_host") + chroma_port = settings.get("chroma_port") + chroma_collection_id = settings.get("chroma_collection_id") + chroma_n_results = settings.get("chroma_n_results") - def processing(self, question, context: list) -> str: if context == None: context = [] - # load or create collection - collection_id = DEFAULT_COLLECTION_ID + if user_question is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) - # query it - chroma_result = self.chroma_query(question, collection_id) + chroma_settings_json={ + "chroma_embedding_model": chroma_embedding_model, + "chroma_host": chroma_host, + "chroma_port": chroma_port, + "chroma_collection_id": chroma_collection_id, + "chroma_n_results": chroma_n_results + } - fast_question = "问题: "+ question + "。根据问题,总结以下内容和来源:" + chroma_result - response = self.chat(model_name="Qwen1.5-14B-Chat", prompt=fast_question, template='', context=context, temperature=0.8, top_p=0.8, n=1, max_tokens=1024, stop=100,frequency_penalty=0.5,presence_penalty=0.8) + # chroma answer + chroma_result = self.chroma_query(user_question, chroma_settings_json) + + # chat prompt + fast_question = f"问题: {user_question}。根据问题,总结以下内容和来源:{chroma_result}" + + chat_settings_json = { + "model_name": user_model_name, + "context": user_context, + "template": user_template, + "temperature": user_temperature, + "top_p": user_top_p, + "n": user_n, + "max_tokens": user_max_tokens, + "stop": user_stop, + "frequency_penalty": user_frequency_penalty, + "presence_penalty": user_presence_penalty + } + + # chat answer + response = self.chat(fast_question, chat_settings_json) return response @@ -49,10 +101,8 @@ class ChromaChat(Blackbox): user_question = data.get("question") user_context = data.get("context") - - if user_question is None: - return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + setting: dict = data.get("settings") return JSONResponse( - content={"response": self.processing(user_question, user_context)}, + content={"response": self.processing(user_question, user_context, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 6e39eae..22c8581 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -6,8 +6,11 @@ from .blackbox import Blackbox import chromadb from chromadb.utils import embedding_functions -from ..utils import chroma_setting +import logging +from ..log.logging_time import logging_time +import re +logger = logging.getLogger DEFAULT_COLLECTION_ID = "123" from injector import singleton @@ -16,10 +19,11 @@ class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) - # load embedding model - self.embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") - # load chromadb - self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) + # load chromadb and embedding model + self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + # self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) + # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -28,14 +32,52 @@ class ChromaQuery(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, question: str, collection_id) -> str: + @logging_time(logger=logger) + def processing(self, question: str, settings: dict) -> str: - # load or create collection - collection = self.client.get_collection(collection_id, embedding_function=self.embedding_model) + if settings is None: + settings = {} + + usr_question = question + + # # chroma_query settings + chroma_embedding_model = settings.get("chroma_embedding_model") + chroma_host = settings.get("chroma_host") + chroma_port = settings.get("chroma_port") + chroma_collection_id = settings.get("chroma_collection_id") + chroma_n_results = settings.get("chroma_n_results") + + if usr_question is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": + chroma_embedding_model = "bge-small-en-v1.5" + + if chroma_host is None or chroma_host.isspace() or chroma_host == "": + chroma_host = "10.6.82.192" + + if chroma_port is None or chroma_port.isspace() or chroma_port == "": + chroma_port = "8000" + + if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": + chroma_collection_id = DEFAULT_COLLECTION_ID + + if chroma_n_results is None or chroma_n_results == "": + chroma_n_results = 3 + + # load client + if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): + client = self.client_1 + + if re.search(r"bge-small-en-v1.5", chroma_embedding_model): + embedding_model = self.embedding_model_1 + + # load collection + collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) # query it results = collection.query( - query_texts=[question], + query_texts=[usr_question], n_results=3, ) @@ -50,14 +92,8 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_question = data.get("question") - user_collection_id = data.get("collection_id") - - if user_question is None: - return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) - - if user_collection_id is None: - user_collection_id = DEFAULT_COLLECTION_ID + setting = data.get("settings") return JSONResponse( - content={"response": self.processing(user_question, user_collection_id)}, + content={"response": self.processing(user_question, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file From 347eea2f9caf346e1e4b89985ecd3debc5077663 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 28 May 2024 07:06:16 +0000 Subject: [PATCH 2/4] refactor: processing of blackbox chroma query and chat --- requirements.txt | 5 +++ src/blackbox/chroma_chat.py | 69 ++++++++++++++---------------------- src/blackbox/chroma_query.py | 19 ++++++---- test.pdf | 0 4 files changed, 44 insertions(+), 49 deletions(-) delete mode 100644 test.pdf diff --git a/requirements.txt b/requirements.txt index 1efc43a..58d7037 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,8 @@ uvicorn==0.29.0 SpeechRecognition==3.10.3 PyYAML==6.0.1 injector==0.21.0 +chromadb==0.5.0 +langchain==0.1.17 +langchain-community==0.0.36 +sentence-transformers==2.7.0 +openai \ No newline at end of file diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index fe3cb3e..09683bd 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -31,64 +31,47 @@ class ChromaChat(Blackbox): @logging_time(logger=logger) def processing(self, question: str, context: list, settings: dict) -> str: + # chroma_chat settings + # { + # "chroma_embedding_model": "bge-large-zh-v1.5", + # "chroma_host": "10.6.82.192", + # "chroma_port": "8000", + # "chroma_collection_id": "123", + # "chroma_n_results": 3, + # "model_name": "Qwen1.5-14B-Chat", + # "context": [], + # "template": "", + # "temperature": 0.8, + # "top_p": 0.8, + # "n": 1, + # "max_tokens": 1024, + # "frequency_penalty": 0.5, + # "presence_penalty": 0.8, + # "stop": 100, + # "model_url": "http://120.196.116.194:48892/v1/chat/completions", + # "model_key": "YOUR_API_KEY" + # } + if settings is None: settings = {} - # # chat setting - user_model_name = settings.get("model_name") user_context = context user_question = question - user_template = settings.get("template") - user_temperature = settings.get("temperature") - user_top_p = settings.get("top_p") - user_n = settings.get("n") - user_max_tokens = settings.get("max_tokens") - user_stop = settings.get("stop") - user_frequency_penalty = settings.get("frequency_penalty") - user_presence_penalty = settings.get("presence_penalty") - # # chroma_query settings - chroma_embedding_model = settings.get("chroma_embedding_model") - chroma_host = settings.get("chroma_host") - chroma_port = settings.get("chroma_port") - chroma_collection_id = settings.get("chroma_collection_id") - chroma_n_results = settings.get("chroma_n_results") - - if context == None: - context = [] + if user_context == None: + user_context = [] if user_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) - chroma_settings_json={ - "chroma_embedding_model": chroma_embedding_model, - "chroma_host": chroma_host, - "chroma_port": chroma_port, - "chroma_collection_id": chroma_collection_id, - "chroma_n_results": chroma_n_results - } - # chroma answer - chroma_result = self.chroma_query(user_question, chroma_settings_json) + chroma_result = self.chroma_query(user_question, settings) # chat prompt - fast_question = f"问题: {user_question}。根据问题,总结以下内容和来源:{chroma_result}" - - chat_settings_json = { - "model_name": user_model_name, - "context": user_context, - "template": user_template, - "temperature": user_temperature, - "top_p": user_top_p, - "n": user_n, - "max_tokens": user_max_tokens, - "stop": user_stop, - "frequency_penalty": user_frequency_penalty, - "presence_penalty": user_presence_penalty - } + fast_question = f"问题: {user_question}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。只从检索的内容中选取与问题相关信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{chroma_result}" # chat answer - response = self.chat(fast_question, chat_settings_json) + response = self.chat(fast_question, user_context, settings) return response diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 22c8581..3a90d8d 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -20,7 +20,7 @@ class ChromaQuery(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load chromadb and embedding model - self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") + self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", device = "cuda") # self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda") self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) @@ -51,7 +51,7 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "bge-small-en-v1.5" + chroma_embedding_model = "bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "10.6.82.192" @@ -60,17 +60,22 @@ class ChromaQuery(Blackbox): chroma_port = "8000" if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": - chroma_collection_id = DEFAULT_COLLECTION_ID + chroma_collection_id = "g2e" if chroma_n_results is None or chroma_n_results == "": chroma_n_results = 3 - # load client + # load client and embedding model from init if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): client = self.client_1 + else: + client = chromadb.HttpClient(host=chroma_host, port=chroma_port) - if re.search(r"bge-small-en-v1.5", chroma_embedding_model): + if re.search(r"bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 + else: + chroma_embedding_model = "/model/Weight/BAAI/" + chroma_embedding_model + embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda") # load collection collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model) @@ -81,7 +86,9 @@ class ChromaQuery(Blackbox): n_results=3, ) - response = str(results["documents"] + results["metadatas"]) + # response = str(results["documents"] + results["metadatas"]) + response = str(results["documents"]) + return response diff --git a/test.pdf b/test.pdf deleted file mode 100644 index e69de29..0000000 From ca3496237622db884f5f5c089b692ab793274457 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 28 May 2024 08:23:10 +0000 Subject: [PATCH 3/4] update blackbox chroma and chat --- src/blackbox/chat.py | 42 ++++++++++++++- src/blackbox/chroma_chat.py | 15 +++--- src/blackbox/chroma_query.py | 5 +- src/blackbox/chroma_upsert.py | 99 ++++++++++++++++++++++------------- 4 files changed, 113 insertions(+), 48 deletions(-) diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index bfee46c..0433ce3 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -82,14 +82,52 @@ class Chat(Blackbox): if user_model_key is None or user_model_key.isspace() or user_model_key == "": user_model_key = "YOUR_API_KEY" + # 文心格式和openai的不一样,需要单独处理 + if re.search(r"ernie", user_model_name): + # key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱,只有ernie-speed-128k能用 + key = user_model_key + if re.search(r"ernie-speed-128k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k?access_token=" + key + elif re.search(r"ernie-3.5-8k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + key + elif re.search(r"ernie-4.0-8k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + key + + payload = json.dumps({ + "system": prompt_template, + "messages": user_context + [ + { + "role": "user", + "content": user_question + } + ], + "temperature": user_temperature, + "top_p": user_top_p, + "stop": [str(user_stop)], + "max_output_tokens": user_max_tokens + }) + + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + + return response.json()["result"] + + # gpt-4, gpt-3.5-turbo - if re.search(r"gpt", user_model_name): + elif re.search(r"gpt", user_model_name): url = 'https://api.openai.com/v1/completions' - key = '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' + # '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' + key = user_model_key + + # 自定义model else: url = user_model_url key = user_model_key + prompt_template = [ {"role": "system", "content": user_template}, ] diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index 09683bd..65495e6 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -33,20 +33,20 @@ class ChromaChat(Blackbox): # chroma_chat settings # { - # "chroma_embedding_model": "bge-large-zh-v1.5", + # "chroma_embedding_model": "/model/Weight/BAAI/bge-large-zh-v1.5", # "chroma_host": "10.6.82.192", # "chroma_port": "8000", - # "chroma_collection_id": "123", + # "chroma_collection_id": "g2e", # "chroma_n_results": 3, # "model_name": "Qwen1.5-14B-Chat", # "context": [], # "template": "", - # "temperature": 0.8, - # "top_p": 0.8, + # "temperature": 0, + # "top_p": 0, # "n": 1, # "max_tokens": 1024, - # "frequency_penalty": 0.5, - # "presence_penalty": 0.8, + # "frequency_penalty": 0, + # "presence_penalty": 0, # "stop": 100, # "model_url": "http://120.196.116.194:48892/v1/chat/completions", # "model_key": "YOUR_API_KEY" @@ -87,5 +87,4 @@ class ChromaChat(Blackbox): setting: dict = data.get("settings") return JSONResponse( - content={"response": self.processing(user_question, user_context, setting)}, - status_code=status.HTTP_200_OK) \ No newline at end of file + content={"response": self.processing(user_question, user_context, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 3a90d8d..5376b7e 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -51,7 +51,7 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "bge-large-zh-v1.5" + chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "10.6.82.192" @@ -71,10 +71,9 @@ class ChromaQuery(Blackbox): else: client = chromadb.HttpClient(host=chroma_host, port=chroma_port) - if re.search(r"bge-large-zh-v1.5", chroma_embedding_model): + if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 else: - chroma_embedding_model = "/model/Weight/BAAI/" + chroma_embedding_model embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda") # load collection diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index c6e71f9..aad5327 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -17,7 +17,12 @@ import chromadb import os import tempfile -from ..utils import chroma_setting +import logging +from ..log.logging_time import logging_time +import re + +logger = logging.getLogger +DEFAULT_COLLECTION_ID = "123" from injector import singleton @singleton @@ -26,9 +31,9 @@ class ChromaUpsert(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model - self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) + self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) + self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -37,30 +42,59 @@ class ChromaUpsert(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, collection_id, file, string, context, setting: chroma_setting) -> str: + @logging_time(logger=logger) + def processing(self, file, string, context: list, settings: dict) -> str: # 用户的操作历史 if context is None: context = [] - context = [ - { - "collection_id": "123", - "action": "query", - "content": "你吃饭了吗", - "answer": "吃了", - }, - { - "collection_id": "123", - "action": "upsert", - "content": "file_name or string", - "answer": "collection 123 has 12472 documents. /tmp/Cheap and Quick:Efficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111", - }, - ] + # context = [ + # { + # "collection_id": "123", + # "action": "query", + # "content": "你吃饭了吗", + # "answer": "吃了", + # }, + # { + # "collection_id": "123", + # "action": "upsert", + # "content": "file_name or string", + # "answer": "collection 123 has 12472 documents. /tmp/Cheap and Quick:Efficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111", + # }, + # ] - if collection_id is None and setting.ChromaSetting.collection_ids[0] != []: - collection_id = setting.ChromaSetting.collection_ids[0] + if settings is None: + settings = {} + + # # chroma_query settings + chroma_embedding_model = settings.get("chroma_embedding_model") + chroma_host = settings.get("chroma_host") + chroma_port = settings.get("chroma_port") + chroma_collection_id = settings.get("chroma_collection_id") + + if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": + chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5" + + if chroma_host is None or chroma_host.isspace() or chroma_host == "": + chroma_host = "10.6.82.192" + + if chroma_port is None or chroma_port.isspace() or chroma_port == "": + chroma_port = "8000" + + if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": + chroma_collection_id = "g2e" + + # load client and embedding model from init + if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): + client = self.client_1 else: - collection_id = "123" + client = chromadb.HttpClient(host=chroma_host, port=chroma_port) + + if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model): + embedding_model = self.embedding_model_1 + else: + embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda") + if file is not None: file_type = file.split(".")[-1] @@ -79,7 +113,6 @@ class ChromaUpsert(Blackbox): loader = Docx2txtLoader(file) elif file_type == "xlsx": loader = UnstructuredExcelLoader(file) - documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0) @@ -88,21 +121,21 @@ class ChromaUpsert(Blackbox): ids = [str(file)+str(i) for i in range(len(docs))] - Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client) + Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=chroma_collection_id, client=client) - collection_number = self.client.get_collection(collection_id).count() - response_file = f"collection {collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}" + collection_number = client.get_collection(chroma_collection_id).count() + response_file = f"collection {chroma_collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}" if string is not None: # 生成一个新的id ids_string: 1 # ids = setting.ChromaSetting.string_ids[0] + 1 ids = "1" - Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client) + Chroma.from_texts(texts=[string], embedding=embedding_model, ids=[ids], collection_name=chroma_collection_id, client=client) - collection_number = self.client.get_collection(collection_id).count() - response_string = f"collection {collection_id} has {collection_number} documents. {string} ids is {ids}" + collection_number = client.get_collection(chroma_collection_id).count() + response_string = f"collection {chroma_collection_id} has {collection_number} documents. {string} ids is {ids}" if file is not None and string is not None: @@ -116,14 +149,10 @@ class ChromaUpsert(Blackbox): async def fast_api_handler(self, request: Request) -> Response: - user_collection_id = (await request.form()).get("collection_id") user_file = (await request.form()).get("file") user_string = (await request.form()).get("string") - user_context = (await request.form()).get("context") - user_setting = (await request.form()).get("setting") - - if user_collection_id is None and user_setting["collections"] == []: - return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST) + context = (await request.form()).get("context") + setting: dict = (await request.form()).get("settings") if user_file is None and user_string is None: return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -142,5 +171,5 @@ class ChromaUpsert(Blackbox): return JSONResponse( - content={"response": self.processing(user_collection_id, safe_filename, user_string, user_context, user_setting)}, + content={"response": self.processing(safe_filename, user_string, context, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file From a049a60cb06c0fc4c5915c5b2e29588b6e674216 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Wed, 29 May 2024 08:39:52 +0000 Subject: [PATCH 4/4] fix: chat string --- src/blackbox/chat.py | 14 +++++++------- src/blackbox/chroma_chat.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 0433ce3..a0bc1bd 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -140,13 +140,13 @@ class Chat(Blackbox): "content": user_question } ], - "temperature": user_temperature, - "top_p": user_top_p, - "n": user_n, - "max_tokens": user_max_tokens, - "frequency_penalty": user_frequency_penalty, - "presence_penalty": user_presence_penalty, - "stop": user_stop + "temperature": str(user_temperature), + "top_p": str(user_top_p), + "n": str(user_n), + "max_tokens": str(user_max_tokens), + "frequency_penalty": str(user_frequency_penalty), + "presence_penalty": str(user_presence_penalty), + "stop": str(user_stop) } header = { diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index 65495e6..26c1093 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -42,7 +42,7 @@ class ChromaChat(Blackbox): # "context": [], # "template": "", # "temperature": 0, - # "top_p": 0, + # "top_p": 0.1, # "n": 1, # "max_tokens": 1024, # "frequency_penalty": 0,