This commit is contained in:
ACBBZ
2024-05-21 03:19:54 +00:00
parent b7d789fb04
commit 4f7f64a49a

103
src/blackbox/chat.py Normal file
View File

@ -0,0 +1,103 @@
from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from openai import OpenAI
import re
from injector import singleton
@singleton
class Chat(Blackbox):
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str:
if context == None:
context = []
# gpt-4 gpt-3.5-turbo
if re.search(r"gpt", model_name):
url = 'https://api.openai.com/v1/completions'
key = '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'
else:
url = 'http://120.196.116.194:48892/v1/chat/completions'
key = 'YOUR_API_KEY'
prompt_template = [
{"role": "system", "content": template},
]
chat_inputs={
"model": model_name,
"messages": prompt_template + context + [
{
"role": "user",
"content": prompt
}
],
"temperature": temperature,
"top_p": top_p,
"n": n,
"max_tokens": max_tokens,
"stream": False,
}
header = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + key
}
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
return fastchat_response.json()["choices"][0]["message"]["content"]
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_model_name = data.get("model_name")
user_context = data.get("context")
user_question = data.get("question")
user_template = data.get("template")
user_temperature = data.get("temperature")
user_top_p = data.get("top_p")
user_n = data.get("n")
user_max_tokens = data.get("max_tokens")
if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
user_model_name = "Qwen1.5-14B-Chat"
if user_template is None or user_template.isspace():
user_template = ""
if user_temperature is None or user_temperature == "":
user_temperature = 0.7
if user_top_p is None or user_top_p == "":
user_top_p = 1
if user_n is None or user_n == "":
user_n = 1
if user_max_tokens is None or user_max_tokens == "":
user_max_tokens = 1024
return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context,
user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK)