Files
jarvis-models/src/blackbox/chat.py
2024-05-24 10:41:17 +08:00

103 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
import requests
import json
from openai import OpenAI
import re
from injector import singleton
@singleton
class Chat(Blackbox):
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str:
if context == None:
context = []
# gpt-4 gpt-3.5-turbo
if re.search(r"gpt", model_name):
url = 'https://api.openai.com/v1/completions'
key = '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'
else:
url = 'http://120.196.116.194:48892/v1/chat/completions'
key = 'YOUR_API_KEY'
prompt_template = [
{"role": "system", "content": template},
]
chat_inputs={
"model": model_name,
"messages": prompt_template + context + [
{
"role": "user",
"content": prompt
}
],
"temperature": temperature,
"top_p": top_p,
"n": n,
"max_tokens": max_tokens,
"stream": False,
}
header = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + key
}
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
return fastchat_response.json()["choices"][0]["message"]["content"]
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_model_name = data.get("model_name")
user_context = data.get("context")
user_question = data.get("question")
user_template = data.get("template")
user_temperature = data.get("temperature")
user_top_p = data.get("top_p")
user_n = data.get("n")
user_max_tokens = data.get("max_tokens")
if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
user_model_name = "Qwen1.5-14B-Chat"
if user_template is None or user_template.isspace():
user_template = ""
if user_temperature is None or user_temperature == "":
user_temperature = 0.7
if user_top_p is None or user_top_p == "":
user_top_p = 1
if user_n is None or user_n == "":
user_n = 1
if user_max_tokens is None or user_max_tokens == "":
user_max_tokens = 1024
return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context,
user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK)