mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
add chat
This commit is contained in:
103
src/blackbox/chat.py
Normal file
103
src/blackbox/chat.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Any, Coroutine
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from .blackbox import Blackbox
|
||||
|
||||
import requests
|
||||
import json
|
||||
from openai import OpenAI
|
||||
import re
|
||||
|
||||
from injector import singleton
|
||||
@singleton
|
||||
class Chat(Blackbox):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
data = args[0]
|
||||
return isinstance(data, list)
|
||||
|
||||
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
|
||||
def processing(self, model_name, prompt, template, context: list, temperature, top_p, n, max_tokens) -> str:
|
||||
if context == None:
|
||||
context = []
|
||||
|
||||
# gpt-4, gpt-3.5-turbo
|
||||
if re.search(r"gpt", model_name):
|
||||
url = 'https://api.openai.com/v1/completions'
|
||||
key = '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'
|
||||
else:
|
||||
url = 'http://120.196.116.194:48892/v1/chat/completions'
|
||||
key = 'YOUR_API_KEY'
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": template},
|
||||
]
|
||||
|
||||
chat_inputs={
|
||||
"model": model_name,
|
||||
"messages": prompt_template + context + [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
|
||||
return fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user_model_name = data.get("model_name")
|
||||
user_context = data.get("context")
|
||||
user_question = data.get("question")
|
||||
user_template = data.get("template")
|
||||
user_temperature = data.get("temperature")
|
||||
user_top_p = data.get("top_p")
|
||||
user_n = data.get("n")
|
||||
user_max_tokens = data.get("max_tokens")
|
||||
|
||||
|
||||
if user_question is None:
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
|
||||
user_model_name = "Qwen1.5-14B-Chat"
|
||||
|
||||
if user_template is None or user_template.isspace():
|
||||
user_template = ""
|
||||
|
||||
if user_temperature is None or user_temperature == "":
|
||||
user_temperature = 0.7
|
||||
|
||||
if user_top_p is None or user_top_p == "":
|
||||
user_top_p = 1
|
||||
|
||||
if user_n is None or user_n == "":
|
||||
user_n = 1
|
||||
|
||||
if user_max_tokens is None or user_max_tokens == "":
|
||||
user_max_tokens = 1024
|
||||
|
||||
|
||||
return JSONResponse(content={"response": self.processing(user_model_name, user_question, user_template, user_context,
|
||||
user_temperature, user_top_p, user_n, user_max_tokens)}, status_code=status.HTTP_200_OK)
|
||||
Reference in New Issue
Block a user