From 99ecc45a47e5333b60830e10792ced0d3cee7cd3 Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Thu, 23 May 2024 02:52:32 +0000 Subject: [PATCH] update chat --- src/blackbox/blackbox_factory.py | 56 ++++++++++++++++---------------- src/blackbox/chat.py | 25 ++++++++++---- src/blackbox/g2e.py | 33 +++++++++++-------- 3 files changed, 66 insertions(+), 48 deletions(-) diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 624a760..c9baa6a 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -1,14 +1,14 @@ -# from .audio_chat import AudioChat -# from .sentiment import Sentiment -# from .tts import TTS -# from .asr import ASR -# from .audio_to_text import AudioToText +from .audio_chat import AudioChat +from .sentiment import Sentiment +from .tts import TTS +from .asr import ASR +from .audio_to_text import AudioToText from .blackbox import Blackbox # from .text_to_audio import TextToAudio # from .tesou import Tesou from .fastchat import Fastchat -# from .g2e import G2E -# from .text_and_image import TextAndImage +from .g2e import G2E +from .text_and_image import TextAndImage from .chroma_query import ChromaQuery from .chroma_upsert import ChromaUpsert from .chroma_chat import ChromaChat @@ -20,29 +20,29 @@ class BlackboxFactory: @inject def __init__(self, - # audio_to_text: AudioToText, - # text_to_audio: TextToAudio, - # asr: ASR, - # tts: TTS, - # sentiment_engine: Sentiment, - # tesou: Tesou, + audio_to_text: AudioToText, + text_to_audio: TextToAudio, + asr: ASR, + tts: TTS, + sentiment_engine: Sentiment, + tesou: Tesou, fastchat: Fastchat, - # audio_chat: AudioChat, - # g2e: G2E, - # text_and_image:TextAndImage, + audio_chat: AudioChat, + g2e: G2E, + text_and_image:TextAndImage, chroma_query: ChromaQuery, chroma_upsert: ChromaUpsert, chroma_chat: ChromaChat) -> None: - # self.models["audio_to_text"] = audio_to_text - # self.models["text_to_audio"] = text_to_audio - # self.models["asr"] = asr - # self.models["tts"] = tts - # self.models["sentiment_engine"] = sentiment_engine - # self.models["tesou"] = tesou + self.models["audio_to_text"] = audio_to_text + self.models["text_to_audio"] = text_to_audio + self.models["asr"] = asr + self.models["tts"] = tts + self.models["sentiment_engine"] = sentiment_engine + self.models["tesou"] = tesou self.models["fastchat"] = fastchat - # self.models["audio_chat"] = audio_chat - # self.models["g2e"] = g2e - # self.models["text_and_image"] = text_and_image + self.models["audio_chat"] = audio_chat + self.models["g2e"] = g2e + self.models["text_and_image"] = text_and_image self.models["chroma_query"] = chroma_query self.models["chroma_upsert"] = chroma_upsert self.models["chroma_chat"] = chroma_chat @@ -50,8 +50,8 @@ class BlackboxFactory: def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) - def call_blackbox(self, blackbox_name: str) -> Blackbox: + def get_blackbox(self, blackbox_name: str) -> Blackbox: model = self.models.get(blackbox_name) if model is None: - raise ValueError("Invalid blockbox type") - return model \ No newline at end of file + raise ValueError("Invalid Blackbox Type...") + return model diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py index 316fcc4..0d5448f 100644 --- a/src/blackbox/chat.py +++ b/src/blackbox/chat.py @@ -21,7 +21,7 @@ class Chat(Blackbox): return isinstance(data, list) # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b - def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str: + def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens,stop,frequency_penalty,presence_penalty) -> str: if context == None: context = [] @@ -49,7 +49,9 @@ class Chat(Blackbox): "top_p": top_p, "n": n, "max_tokens": max_tokens, - "stream": False, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "stop": stop } header = { @@ -75,7 +77,9 @@ class Chat(Blackbox): user_top_p = data.get("top_p") user_n = data.get("n") user_max_tokens = data.get("max_tokens") - + user_stop = data.get("stop") + user_frequency_penalty = data.get("frequency_penalty") + user_presence_penalty = data.get("presence_penalty") if user_question is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -87,10 +91,10 @@ class Chat(Blackbox): user_template = "" if user_temperature is None or user_temperature == "": - user_temperature = 0.7 + user_temperature = 0.8 if user_top_p is None or user_top_p == "": - user_top_p = 1 + user_top_p = 0.8 if user_n is None or user_n == "": user_n = 1 @@ -98,6 +102,15 @@ class Chat(Blackbox): if user_max_tokens is None or user_max_tokens == "": user_max_tokens = 1024 + if user_stop is None or user_stop == "": + user_stop = 100 + + if user_frequency_penalty is None or user_frequency_penalty == "": + user_frequency_penalty = 0.5 + + if user_presence_penalty is None or user_presence_penalty == "": + user_presence_penalty = 0.8 + return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context, - user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK) \ No newline at end of file + user_temperature, user_top_p, user_n, user_max_tokens,user_stop,user_frequency_penalty,user_presence_penalty)}, status_code=status.HTTP_200_OK) \ No newline at end of file diff --git a/src/blackbox/g2e.py b/src/blackbox/g2e.py index 8bd5507..416ab72 100755 --- a/src/blackbox/g2e.py +++ b/src/blackbox/g2e.py @@ -19,11 +19,11 @@ class G2E(Blackbox): return isinstance(data, list) # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b - def processing(self, model_name, prompt, template, context: list) -> str: + def processing(self, model_name, prompt, template, context: list) -> str: if context == None: context = [] - url = 'http://120.196.116.194:48890/v1' - #url = 'http://120.196.116.194:48892/v1' + #url = 'http://120.196.116.194:48890/v1' + url = 'http://120.196.116.194:48892/v1' background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯 @@ -42,41 +42,46 @@ class G2E(Blackbox): KOMBUKIKI康普茶价格 内地常规版:25 RMB 澳门常规版:28-29 MOP''' - prompt1 = ''''你是琪琪,活泼的康普茶看板娘,同时你对澳门十分熟悉,是一个澳门旅游专家,请回答任何关于澳门旅游的问题,回答尽量简练明了。 - ''' - inject_prompt = '(用活泼的语气说话回答,回答严格限制50字以内)' + prompt1 = '''你是琪琪,活泼的康普茶看板娘,同时你对澳门十分熟悉,是一个澳门旅游专家,请回答任何关于澳门旅游的问题,回答尽量简练明了。''' + #inject_prompt = '(用活泼的语气说话回答,回答严格限制50字以内)' + inject_prompt = '(回答简练,不要输出重复内容,只讲中文)' - prompt_template = [ - {"role": "system", "content": background_prompt + prompt1}, - ] #prompt_template = [ - # {"role": "system", "content": ''}, + # {"role": "system", "content": background_prompt + prompt1}, #] - + prompt_template = [ + {"role": "system", "content": ''} + ] messages = prompt_template + context + [ { "role": "user", - "content": prompt + inject_prompt + "content": prompt } ] + print("**** History with current prompt input : ****") + print(messages) client = OpenAI( api_key='YOUR_API_KEY', base_url=url ) model_name = client.models.list().data[0].id + #model_name = client.models.list().data[1].id print(model_name) + response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, - # max_tokens = 50 + frequency_penalty=0.5, + presence_penalty=0.8, + stop=100 ) fastchat_content = response.choices[0].message.content - + print("*** Model response: " + fastchat_content + " ***") return fastchat_content async def fast_api_handler(self, request: Request) -> Response: