diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index bfee46c..0433ce3 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -82,14 +82,52 @@ class Chat(Blackbox): if user_model_key is None or user_model_key.isspace() or user_model_key == "": user_model_key = "YOUR_API_KEY" + # 文心格式和openai的不一样,需要单独处理 + if re.search(r"ernie", user_model_name): + # key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱,只有ernie-speed-128k能用 + key = user_model_key + if re.search(r"ernie-speed-128k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k?access_token=" + key + elif re.search(r"ernie-3.5-8k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + key + elif re.search(r"ernie-4.0-8k", user_model_name): + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + key + + payload = json.dumps({ + "system": prompt_template, + "messages": user_context + [ + { + "role": "user", + "content": user_question + } + ], + "temperature": user_temperature, + "top_p": user_top_p, + "stop": [str(user_stop)], + "max_output_tokens": user_max_tokens + }) + + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + + return response.json()["result"] + + # gpt-4, gpt-3.5-turbo - if re.search(r"gpt", user_model_name): + elif re.search(r"gpt", user_model_name): url = 'https://api.openai.com/v1/completions' - key = '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' + # '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' + key = user_model_key + + # 自定义model else: url = user_model_url key = user_model_key + prompt_template = [ {"role": "system", "content": user_template}, ] diff --git a/src/blackbox/chroma_chat.py b/src/blackbox/chroma_chat.py index 09683bd..65495e6 100755 --- a/src/blackbox/chroma_chat.py +++ b/src/blackbox/chroma_chat.py @@ -33,20 +33,20 @@ class ChromaChat(Blackbox): # chroma_chat settings # { - # "chroma_embedding_model": "bge-large-zh-v1.5", + # "chroma_embedding_model": "/model/Weight/BAAI/bge-large-zh-v1.5", # "chroma_host": "10.6.82.192", # "chroma_port": "8000", - # "chroma_collection_id": "123", + # "chroma_collection_id": "g2e", # "chroma_n_results": 3, # "model_name": "Qwen1.5-14B-Chat", # "context": [], # "template": "", - # "temperature": 0.8, - # "top_p": 0.8, + # "temperature": 0, + # "top_p": 0, # "n": 1, # "max_tokens": 1024, - # "frequency_penalty": 0.5, - # "presence_penalty": 0.8, + # "frequency_penalty": 0, + # "presence_penalty": 0, # "stop": 100, # "model_url": "http://120.196.116.194:48892/v1/chat/completions", # "model_key": "YOUR_API_KEY" @@ -87,5 +87,4 @@ class ChromaChat(Blackbox): setting: dict = data.get("settings") return JSONResponse( - content={"response": self.processing(user_question, user_context, setting)}, - status_code=status.HTTP_200_OK) \ No newline at end of file + content={"response": self.processing(user_question, user_context, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file diff --git a/src/blackbox/chroma_query.py b/src/blackbox/chroma_query.py index 3a90d8d..5376b7e 100755 --- a/src/blackbox/chroma_query.py +++ b/src/blackbox/chroma_query.py @@ -51,7 +51,7 @@ class ChromaQuery(Blackbox): return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": - chroma_embedding_model = "bge-large-zh-v1.5" + chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5" if chroma_host is None or chroma_host.isspace() or chroma_host == "": chroma_host = "10.6.82.192" @@ -71,10 +71,9 @@ class ChromaQuery(Blackbox): else: client = chromadb.HttpClient(host=chroma_host, port=chroma_port) - if re.search(r"bge-large-zh-v1.5", chroma_embedding_model): + if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model): embedding_model = self.embedding_model_1 else: - chroma_embedding_model = "/model/Weight/BAAI/" + chroma_embedding_model embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda") # load collection diff --git a/src/blackbox/chroma_upsert.py b/src/blackbox/chroma_upsert.py index c6e71f9..aad5327 100755 --- a/src/blackbox/chroma_upsert.py +++ b/src/blackbox/chroma_upsert.py @@ -17,7 +17,12 @@ import chromadb import os import tempfile -from ..utils import chroma_setting +import logging +from ..log.logging_time import logging_time +import re + +logger = logging.getLogger +DEFAULT_COLLECTION_ID = "123" from injector import singleton @singleton @@ -26,9 +31,9 @@ class ChromaUpsert(Blackbox): def __init__(self, *args, **kwargs) -> None: # config = read_yaml(args[0]) # load embedding model - self.embedding_model = SentenceTransformerEmbeddings(model_name='/model/Weight/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"}) + self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"}) # load chroma db - self.client = chromadb.HttpClient(host='10.6.82.192', port=8000) + self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -37,30 +42,59 @@ class ChromaUpsert(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, collection_id, file, string, context, setting: chroma_setting) -> str: + @logging_time(logger=logger) + def processing(self, file, string, context: list, settings: dict) -> str: # 用户的操作历史 if context is None: context = [] - context = [ - { - "collection_id": "123", - "action": "query", - "content": "你吃饭了吗", - "answer": "吃了", - }, - { - "collection_id": "123", - "action": "upsert", - "content": "file_name or string", - "answer": "collection 123 has 12472 documents. /tmp/Cheap and Quick:Efficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111", - }, - ] + # context = [ + # { + # "collection_id": "123", + # "action": "query", + # "content": "你吃饭了吗", + # "answer": "吃了", + # }, + # { + # "collection_id": "123", + # "action": "upsert", + # "content": "file_name or string", + # "answer": "collection 123 has 12472 documents. /tmp/Cheap and Quick:Efficient Vision-Language Instruction Tuning for Large Language Models.pdf ids is 0~111", + # }, + # ] - if collection_id is None and setting.ChromaSetting.collection_ids[0] != []: - collection_id = setting.ChromaSetting.collection_ids[0] + if settings is None: + settings = {} + + # # chroma_query settings + chroma_embedding_model = settings.get("chroma_embedding_model") + chroma_host = settings.get("chroma_host") + chroma_port = settings.get("chroma_port") + chroma_collection_id = settings.get("chroma_collection_id") + + if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": + chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5" + + if chroma_host is None or chroma_host.isspace() or chroma_host == "": + chroma_host = "10.6.82.192" + + if chroma_port is None or chroma_port.isspace() or chroma_port == "": + chroma_port = "8000" + + if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "": + chroma_collection_id = "g2e" + + # load client and embedding model from init + if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): + client = self.client_1 else: - collection_id = "123" + client = chromadb.HttpClient(host=chroma_host, port=chroma_port) + + if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model): + embedding_model = self.embedding_model_1 + else: + embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda") + if file is not None: file_type = file.split(".")[-1] @@ -79,7 +113,6 @@ class ChromaUpsert(Blackbox): loader = Docx2txtLoader(file) elif file_type == "xlsx": loader = UnstructuredExcelLoader(file) - documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0) @@ -88,21 +121,21 @@ class ChromaUpsert(Blackbox): ids = [str(file)+str(i) for i in range(len(docs))] - Chroma.from_documents(documents=docs, embedding=self.embedding_model, ids=ids, collection_name=collection_id, client=self.client) + Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=chroma_collection_id, client=client) - collection_number = self.client.get_collection(collection_id).count() - response_file = f"collection {collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}" + collection_number = client.get_collection(chroma_collection_id).count() + response_file = f"collection {chroma_collection_id} has {collection_number} documents. {file} ids is 0~{len(docs)-1}" if string is not None: # 生成一个新的id ids_string: 1 # ids = setting.ChromaSetting.string_ids[0] + 1 ids = "1" - Chroma.from_texts(texts=[string], embedding=self.embedding_model, ids=[ids], collection_name=collection_id, client=self.client) + Chroma.from_texts(texts=[string], embedding=embedding_model, ids=[ids], collection_name=chroma_collection_id, client=client) - collection_number = self.client.get_collection(collection_id).count() - response_string = f"collection {collection_id} has {collection_number} documents. {string} ids is {ids}" + collection_number = client.get_collection(chroma_collection_id).count() + response_string = f"collection {chroma_collection_id} has {collection_number} documents. {string} ids is {ids}" if file is not None and string is not None: @@ -116,14 +149,10 @@ class ChromaUpsert(Blackbox): async def fast_api_handler(self, request: Request) -> Response: - user_collection_id = (await request.form()).get("collection_id") user_file = (await request.form()).get("file") user_string = (await request.form()).get("string") - user_context = (await request.form()).get("context") - user_setting = (await request.form()).get("setting") - - if user_collection_id is None and user_setting["collections"] == []: - return JSONResponse(content={"error": "The first creation requires a collection id"}, status_code=status.HTTP_400_BAD_REQUEST) + context = (await request.form()).get("context") + setting: dict = (await request.form()).get("settings") if user_file is None and user_string is None: return JSONResponse(content={"error": "file or string is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -142,5 +171,5 @@ class ChromaUpsert(Blackbox): return JSONResponse( - content={"response": self.processing(user_collection_id, safe_filename, user_string, user_context, user_setting)}, + content={"response": self.processing(safe_filename, user_string, context, setting)}, status_code=status.HTTP_200_OK) \ No newline at end of file