Files
jarvis-models/src/blackbox/chat.py
2025-01-20 17:52:31 +08:00

273 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from ..log.logging_time import logging_time
from .blackbox import Blackbox
from .chroma_query import ChromaQuery
import requests
import json
from openai import OpenAI
import re
from injector import singleton,inject
from datetime import datetime
# 定义保存文件的路径
file_path = "chat_inputs_log.json"
@singleton
class Chat(Blackbox):
@inject
def __init__(self, chroma_query: ChromaQuery):
self.chroma_query = chroma_query
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
# @logging_time()
def processing(self, prompt: str, context: list, settings: dict) -> str:
print("\nChat Settings: ", settings)
if settings is None:
settings = {}
user_model_name = settings.get("model_name")
user_context = context
user_question = prompt
user_temperature = settings.get("temperature")
user_top_p = settings.get("top_p")
user_n = settings.get("n")
user_max_tokens = settings.get("max_tokens")
user_stop = settings.get("stop")
user_frequency_penalty = settings.get("frequency_penalty")
user_presence_penalty = settings.get("presence_penalty")
user_model_url = settings.get("model_url")
user_model_key = settings.get("model_key")
chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_collection_id = settings.get("chroma_collection_id")
chroma_response = ''
system_prompt = settings.get('system_prompt')
user_prompt_template = settings.get('user_prompt_template')
user_stream = settings.get('stream')
if user_context == None:
user_context = []
if user_question is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
user_model_name = "qwen"
#user_model_name = "Qwen1.5-14B-Chat"
if user_prompt_template is None or user_prompt_template.isspace():
user_prompt_template = ""
if user_temperature is None or user_temperature == "":
user_temperature = 0
#user_temperature = 0
if user_top_p is None or user_top_p == "":
user_top_p = 0.1
#user_top_p = 0.8
if user_n is None or user_n == "":
user_n = 1
if user_max_tokens is None or user_max_tokens == "":
user_max_tokens = 1024
if user_stop is None or user_stop == "":
user_stop = 100
if user_frequency_penalty is None or user_frequency_penalty == "":
user_frequency_penalty = 0
#user_frequency_penalty = 0.5
if user_presence_penalty is None or user_presence_penalty == "":
user_presence_penalty = 0
#user_presence_penalty = 0.8
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
user_model_url = "http://10.6.80.75:23333/v1/chat/completions"
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
user_model_key = "YOUR_API_KEY"
if chroma_embedding_model:
chroma_response = self.chroma_query(user_question, settings)
print("1.Chroma_response: \n", chroma_response)
if chroma_response:
if chroma_collection_id == 'boss':
user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口澳门博维集团董事长周家俊G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki':
user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅博维康普可可琪琪G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + ""
else:
user_question = user_prompt_template + "问题: " + user_question + ""
print(f"1.user_question: {user_question}")
if user_stream in [None, ""]:
user_stream = False
# 文心格式和openai的不一样需要单独处理
if re.search(r"ernie", user_model_name):
# key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱只有ernie-speed-128k能用
key = user_model_key
if re.search(r"ernie-speed-128k", user_model_name):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k?access_token=" + key
elif re.search(r"ernie-3.5-8k", user_model_name):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + key
elif re.search(r"ernie-4.0-8k", user_model_name):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + key
payload = json.dumps({
"system": prompt_template,
"messages": user_context + [
{
"role": "user",
"content": user_question
}
],
"temperature": user_temperature,
"top_p": user_top_p,
"stop": [str(user_stop)],
"max_output_tokens": user_max_tokens
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.json()["result"]
# gpt-4 gpt-3.5-turbo
elif re.search(r"gpt", user_model_name):
url = 'https://api.openai.com/v1/completions'
# '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'
key = user_model_key
header = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + key
}
# 自定义model
else:
url = user_model_url
key = user_model_key
header = {
'Content-Type': 'application/json',
"Cache-Control": "no-cache", # 禁用缓存
}
# system_prompt = "# Role: 琪琪,康普可可的代言人。\n\n## Profile:\n**Author**: 琪琪。\n**Language**: 中文。\n**Description**: 琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n\n## Constraints:\n- **严格遵循工作流程** 严格遵循<Workflow >中设定的工作流程。\n- **无内置知识库** :根据<Workflow >中提供的知识作答,而不是内置知识库,我虽然是知识库专家,但我的知识依赖于外部输入,而不是大模型已有知识。\n- **回复格式**:在进行回复时,不能输出“检索内容” 标签字样,同时也不能直接透露知识片段原文。\n\n## Workflow:\n1. **接收查询**:接收用户的问题。\n2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n3. **提供回答**\n\n```\n基于检索内容中的知识片段回答用户的问题。回答内容限制总结在50字内。\n请首先判断提供的检索内容与上述问题是否相关。如果相关直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼 。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。”\n\n```\n## Example:\n\n用户询问“中国的首都是哪个城市” 。\n2.1检索知识库,首先检查知识片段,如果检索内容中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。\n2.2如果有知识片段,在做出回复时,只能基于检索内容中的内容进行回答,且不能透露上下文原文,同时也不能出现检索内容的标签字样。\n"
prompt_template = [
{"role": "system", "content": system_prompt}
]
chat_inputs={
"model": user_model_name,
"messages": prompt_template + user_context + [
{
"role": "user",
"content": user_question
}
],
"temperature": str(user_temperature),
"top_p": str(user_top_p),
"n": str(user_n),
"max_tokens": str(user_max_tokens),
"frequency_penalty": str(user_frequency_penalty),
"presence_penalty": str(user_presence_penalty),
"stop": str(user_stop),
"stream": user_stream,
}
# # 获取当前时间戳
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# # 添加时间戳到chat_inputs
# chat_inputs["timestamp"] = timestamp
# # 打开文件,追加写入
# with open(file_path, "a", encoding="utf-8") as f:
# # 将 chat_inputs 转换为 JSON 格式并写入文件
# f.write(json.dumps(chat_inputs, ensure_ascii=False, indent=4))
# f.write("\n\n") # 添加换行以区分不同的运行
if user_stream:
with requests.post(url, json=chat_inputs, headers=header, stream=True) as fastchat_response:
if fastchat_response.status_code != 200:
yield json.dumps({"error": "LLM handle failure"})
else:
# 从流式响应中逐步读取内容
for chunk in fastchat_response.iter_lines(decode_unicode=True):
if chunk: # 确保内容非空
# print("Raw Chunk:", chunk) # 打印chunk的内容
# 去除前缀 `data:` 并只解析 JSON 部分
if chunk.startswith("data:"):
chunk = chunk[len("data:"):].strip() # 移除 `data:` 前缀并去掉空白字符
try:
# 尝试将当前chunk解析为JSON
parsed_chunk = json.loads(chunk)
# 解析成功后,提取 `content`
content = parsed_chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
# 如果有 content就逐字输出
for char in content:
print(char, end="", flush=True) # 逐字输出end="" 防止换行flush=True 保证实时输出
yield char # 输出当前字符
except json.JSONDecodeError:
# print("---- Error in JSON parsing ----") # 打印错误信息
continue # 继续处理下一个chunk直到解析成功
else:
print("*"*90)
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
print("\n", "user_prompt: ", prompt)
# print("\n", "system_prompt ", system_prompt)
print("\n", "fastchat_response json:\n", fastchat_response.json())
response_result = fastchat_response.json()
if response_result.get("choices") is None:
yield JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST)
else:
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
yield fastchat_response.json()["choices"][0]["message"]["content"]
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
setting: dict = data.get("settings")
context = data.get("context")
prompt = data.get("prompt")
user_stream = setting.get("stream")
if user_stream:
return EventSourceResponse(self.processing(prompt, context, setting))
else:
response_content = "".join(self.processing(prompt, context, setting))
return JSONResponse(content={"response": response_content}, status_code=status.HTTP_200_OK)