update chat

This commit is contained in:
ACBBZ
2024-05-23 02:52:32 +00:00
committed by superobk
parent e2367a7aa8
commit b42b13d0e6
2 changed files with 26 additions and 18 deletions

View File

@ -1,10 +1,8 @@
from . import melotts
from .audio_chat import AudioChat from .audio_chat import AudioChat
from .sentiment import Sentiment from .sentiment import Sentiment
from .tts import TTS from .tts import TTS
from .asr import ASR from .asr import ASR
from .audio_to_text import AudioToText from .audio_to_text import AudioToText
#from .emotion import Emotion
from .blackbox import Blackbox from .blackbox import Blackbox
# from .text_to_audio import TextToAudio # from .text_to_audio import TextToAudio
# from .tesou import Tesou # from .tesou import Tesou
@ -25,12 +23,10 @@ class BlackboxFactory:
@inject @inject
def __init__(self, def __init__(self,
audio_to_text: AudioToText, audio_to_text: AudioToText,
text_to_audio: TextToAudio,
asr: ASR, asr: ASR,
tts: TTS, tts: TTS,
sentiment_engine: Sentiment, sentiment_engine: Sentiment,
#emotion: Emotion, #emotion: Emotion,
tesou: Tesou,
fastchat: Fastchat, fastchat: Fastchat,
audio_chat: AudioChat, audio_chat: AudioChat,
g2e: G2E, g2e: G2E,
@ -39,23 +35,22 @@ class BlackboxFactory:
#chroma_upsert: ChromaUpsert, #chroma_upsert: ChromaUpsert,
#chroma_chat: ChromaChat, #chroma_chat: ChromaChat,
melotts: MeloTTS, melotts: MeloTTS,
vlms: VLMS) -> None: vlms: VLMS,
chroma_query: ChromaQuery,
chroma_upsert: ChromaUpsert,
chroma_chat: ChromaChat) -> None:
self.models["audio_to_text"] = audio_to_text self.models["audio_to_text"] = audio_to_text
self.models["text_to_audio"] = text_to_audio
self.models["asr"] = asr self.models["asr"] = asr
self.models["tts"] = tts self.models["tts"] = tts
self.models["sentiment_engine"] = sentiment_engine self.models["sentiment_engine"] = sentiment_engine
self.models["tesou"] = tesou
#self.models["emotion"] = emotion #self.models["emotion"] = emotion
self.models["fastchat"] = fastchat self.models["fastchat"] = fastchat
self.models["audio_chat"] = audio_chat self.models["audio_chat"] = audio_chat
self.models["g2e"] = g2e self.models["g2e"] = g2e
self.models["text_and_image"] = text_and_image self.models["text_and_image"] = text_and_image
#self.models["chroma_query"] = chroma_query self.models["chroma_query"] = chroma_query
#self.models["chroma_upsert"] = chroma_upsert self.models["chroma_upsert"] = chroma_upsert
#self.models["chroma_chat"] = chroma_chat self.models["chroma_chat"] = chroma_chat
self.models["melotts"] = melotts
self.models["vlms"] = vlms
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -21,7 +21,7 @@ class Chat(Blackbox):
return isinstance(data, list) return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str: def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens,stop,frequency_penalty,presence_penalty) -> str:
if context == None: if context == None:
context = [] context = []
@ -49,7 +49,9 @@ class Chat(Blackbox):
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"stream": False, "frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stop": stop
} }
header = { header = {
@ -75,7 +77,9 @@ class Chat(Blackbox):
user_top_p = data.get("top_p") user_top_p = data.get("top_p")
user_n = data.get("n") user_n = data.get("n")
user_max_tokens = data.get("max_tokens") user_max_tokens = data.get("max_tokens")
user_stop = data.get("stop")
user_frequency_penalty = data.get("frequency_penalty")
user_presence_penalty = data.get("presence_penalty")
if user_question is None: if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -87,10 +91,10 @@ class Chat(Blackbox):
user_template = "" user_template = ""
if user_temperature is None or user_temperature == "": if user_temperature is None or user_temperature == "":
user_temperature = 0.7 user_temperature = 0.8
if user_top_p is None or user_top_p == "": if user_top_p is None or user_top_p == "":
user_top_p = 1 user_top_p = 0.8
if user_n is None or user_n == "": if user_n is None or user_n == "":
user_n = 1 user_n = 1
@ -98,6 +102,15 @@ class Chat(Blackbox):
if user_max_tokens is None or user_max_tokens == "": if user_max_tokens is None or user_max_tokens == "":
user_max_tokens = 1024 user_max_tokens = 1024
if user_stop is None or user_stop == "":
user_stop = 100
if user_frequency_penalty is None or user_frequency_penalty == "":
user_frequency_penalty = 0.5
if user_presence_penalty is None or user_presence_penalty == "":
user_presence_penalty = 0.8
return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context, return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context,
user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK) user_temperature, user_top_p, user_n, user_max_tokens,user_stop,user_frequency_penalty,user_presence_penalty)}, status_code=status.HTTP_200_OK)