diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index f4a6f36..cef650e 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -7,8 +7,9 @@ from .blackbox import Blackbox from .calculator import Calculator from .text_to_audio import TextToAudio from .tesou import Tesou +from .fastcaht import Fastchat -class BlackboxFactory: +class BlackboxFactor: def __init__(self) -> None: self.tts = TTS() @@ -19,6 +20,7 @@ class BlackboxFactory: self.audio_to_text = AudioToText() self.text_to_audio = TextToAudio() self.tesou = Tesou() + self.fastchat = Fastchat() def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -40,4 +42,6 @@ class BlackboxFactory: return self.sum if blackbox_name == "tesou": return self.tesou + if blackbox_name == "fastchat": + return self.fastchat raise ValueError("Invalid blockbox type") \ No newline at end of file diff --git a/src/blackbox/fastchat.py b/src/blackbox/fastchat.py new file mode 100755 index 0000000..5432145 --- /dev/null +++ b/src/blackbox/fastchat.py @@ -0,0 +1,73 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from .blackbox import Blackbox + +import requests +import json + +class Fastchat(Blackbox): + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool: + data = args[0] + return isinstance(data, list) + + # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b + def processing(self, model_name, prompt, template) -> str: + url = 'http://120.196.116.194:48892/v1/chat/completions' + + # history可以为空列表,也可以是用户的对话历史 + # history = [ + # { + # "role": "user", + # "content": "你吃饭了吗" + # }, + # { + # 'role': 'assistant', + # 'content': '作为一个AI模型,我没有吃饭的需要,因为我并不具备实体形态。我专注于提供信息和帮助回答你的问题。你有什么需要帮助的吗?' + # }, + # ] + history = [] + + fastchat_inputs={ + "model": model_name, + "messages": history + [ + { + "role": "user", + "content": template + prompt + } + ] + } + + fastchat_response = requests.post(url, json=fastchat_inputs) + + user_message = fastchat_inputs["messages"] + history.append(user_message) + + assistant_message = fastchat_response.json()["choices"][0]["message"] + history.append(assistant_message) + + fastchat_content = assistant_message["content"] + + return fastchat_content + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + + user_model_name = data.get("model_name") + user_prompt = data.get("prompt") + user_template = data.get("template") + # user_template 是定义LLM的语气,例如template = "使用小丑的语气说话。",user_template可以为空字串,或者是用户自定义的语气,或者是使用我们提供的语气 + + if user_prompt is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + if user_model_name is None: + return JSONResponse(content={"error": "model selection is required"}, status_code=status.HTTP_400_BAD_REQUEST) + return JSONResponse(content={"Response": self.processing(user_model_name, user_prompt, user_template)}, status_code=status.HTTP_200_OK) \ No newline at end of file