From 4f7f64a49ae71263f5a1bbce7f9f2fd40b3bf56f Mon Sep 17 00:00:00 2001 From: ACBBZ Date: Tue, 21 May 2024 03:19:54 +0000 Subject: [PATCH] add chat --- src/blackbox/chat.py | 103 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 src/blackbox/chat.py diff --git a/src/blackbox/chat.py b/src/blackbox/chat.py new file mode 100644 index 0000000..316fcc4 --- /dev/null +++ b/src/blackbox/chat.py @@ -0,0 +1,103 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from .blackbox import Blackbox + +import requests +import json +from openai import OpenAI +import re + +from injector import singleton +@singleton +class Chat(Blackbox): + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool: + data = args[0] + return isinstance(data, list) + + # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b + def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str: + if context == None: + context = [] + + # gpt-4, gpt-3.5-turbo + if re.search(r"gpt", model_name): + url = 'https://api.openai.com/v1/completions' + key = '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' + else: + url = 'http://120.196.116.194:48892/v1/chat/completions' + key = 'YOUR_API_KEY' + + prompt_template = [ + {"role": "system", "content": template}, + ] + + chat_inputs={ + "model": model_name, + "messages": prompt_template + context + [ + { + "role": "user", + "content": prompt + } + ], + "temperature": temperature, + "top_p": top_p, + "n": n, + "max_tokens": max_tokens, + "stream": False, + } + + header = { + 'Content-Type': 'application/json', + 'Authorization': "Bearer " + key + } + + fastchat_response = requests.post(url, json=chat_inputs, headers=header) + + return fastchat_response.json()["choices"][0]["message"]["content"] + + async def fast_api_handler(self, request: Request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + + user_model_name = data.get("model_name") + user_context = data.get("context") + user_question = data.get("question") + user_template = data.get("template") + user_temperature = data.get("temperature") + user_top_p = data.get("top_p") + user_n = data.get("n") + user_max_tokens = data.get("max_tokens") + + + if user_question is None: + return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) + + if user_model_name is None or user_model_name.isspace() or user_model_name == "": + user_model_name = "Qwen1.5-14B-Chat" + + if user_template is None or user_template.isspace(): + user_template = "" + + if user_temperature is None or user_temperature == "": + user_temperature = 0.7 + + if user_top_p is None or user_top_p == "": + user_top_p = 1 + + if user_n is None or user_n == "": + user_n = 1 + + if user_max_tokens is None or user_max_tokens == "": + user_max_tokens = 1024 + + + return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context, + user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK) \ No newline at end of file