refactor: chat processing

This commit is contained in:
superobk
2024-05-24 16:16:43 +08:00
parent 28621b7169
commit ab82aa7575

View File

@ -1,4 +1,3 @@
import logging
from typing import Any, Coroutine from typing import Any, Coroutine
from fastapi import Request, Response, status from fastapi import Request, Response, status
@ -14,8 +13,6 @@ import re
from injector import singleton from injector import singleton
logger = logging.getLogger
@singleton @singleton
class Chat(Blackbox): class Chat(Blackbox):
@ -27,15 +24,14 @@ class Chat(Blackbox):
return isinstance(data, list) return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b # model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
@logging_time(logger=logger) @logging_time()
def processing(self, *args, **kwargs) -> str: def processing(self, prompt: str, context: list, settings: dict) -> str:
settings: dict = args[0]
if settings is None: if settings is None:
settings = {} settings = {}
user_model_name = settings.get("model_name") user_model_name = settings.get("model_name")
user_context = settings.get("context") user_context = context
user_question = settings.get("question") user_question = prompt
user_template = settings.get("template") user_template = settings.get("template")
user_temperature = settings.get("temperature") user_temperature = settings.get("temperature")
user_top_p = settings.get("top_p") user_top_p = settings.get("top_p")
@ -44,7 +40,6 @@ class Chat(Blackbox):
user_stop = settings.get("stop") user_stop = settings.get("stop")
user_frequency_penalty = settings.get("frequency_penalty") user_frequency_penalty = settings.get("frequency_penalty")
user_presence_penalty = settings.get("presence_penalty") user_presence_penalty = settings.get("presence_penalty")
if user_context == None: if user_context == None:
user_context = [] user_context = []
@ -124,5 +119,7 @@ class Chat(Blackbox):
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
setting: dict = data.get("settings") setting: dict = data.get("settings")
context = data.get("context")
return JSONResponse(content={"response": self.processing(setting)}, status_code=status.HTTP_200_OK) prompt = data.get("prompt")
return JSONResponse(content={"response": self.processing(prompt, context, setting)}, status_code=status.HTTP_200_OK)