Chroma and chat updated

This commit is contained in:
0Xiao0
2024-05-30 12:21:21 +08:00
parent f61a9ad6e6
commit 4579528b9e
4 changed files with 14578 additions and 11 deletions

View File

@ -5,17 +5,22 @@ from fastapi.responses import JSONResponse
from ..log.logging_time import logging_time
from .blackbox import Blackbox
from .chroma_query import ChromaQuery
import requests
import json
from openai import OpenAI
import re
from injector import singleton
from injector import singleton,inject
@singleton
class Chat(Blackbox):
@inject
def __init__(self, chroma_query: ChromaQuery):
self.chroma_query = chroma_query
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
@ -24,7 +29,7 @@ class Chat(Blackbox):
return isinstance(data, list)
# model_name有 Qwen1.5-14B-Chat , internlm2-chat-20b
#@logging_time()
# @logging_time()
def processing(self, prompt: str, context: list, settings: dict) -> str:
if settings is None:
@ -42,6 +47,9 @@ class Chat(Blackbox):
user_presence_penalty = settings.get("presence_penalty")
user_model_url = settings.get("model_url")
user_model_key = settings.get("model_key")
chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_response = ''
if user_context == None:
user_context = []
@ -82,6 +90,14 @@ class Chat(Blackbox):
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
user_model_key = "YOUR_API_KEY"
if chroma_embedding_model != None:
chroma_response = self.chroma_query(user_question, settings)
print(chroma_response)
if chroma_response != None or chroma_response != '':
user_question = f"问题: {user_question}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。只从检索的内容中选取与问题相关信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{chroma_response}"
# 文心格式和openai的不一样需要单独处理
if re.search(r"ernie", user_model_name):
# key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱只有ernie-speed-128k能用
@ -132,7 +148,6 @@ class Chat(Blackbox):
header = {
'Content-Type': 'application/json',
}
prompt_template = [
{"role": "system", "content": user_template},

View File

@ -20,9 +20,9 @@ class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0])
# load chromadb and embedding model
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-large-zh-v1.5", device = "cuda")
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
# self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/model/Weight/BAAI/bge-small-en-v1.5", device = "cuda")
self.client_1 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.client_1 = chromadb.HttpClient(host='172.16.5.8', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
def __call__(self, *args, **kwargs):
@ -32,7 +32,7 @@ class ChromaQuery(Blackbox):
data = args[0]
return isinstance(data, list)
@logging_time(logger=logger)
# @logging_time(logger=logger)
def processing(self, question: str, settings: dict) -> str:
if settings is None:
@ -51,13 +51,13 @@ class ChromaQuery(Blackbox):
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/model/Weight/BAAI/bge-large-zh-v1.5"
chroma_embedding_model = "/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5"
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.82.192"
chroma_host = "172.16.5.8"
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "8000"
chroma_port = "7000"
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
chroma_collection_id = "g2e"
@ -66,12 +66,12 @@ class ChromaQuery(Blackbox):
chroma_n_results = 3
# load client and embedding model from init
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port):
if re.search(r"172.16.5.8", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1
else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
if re.search(r"/model/Weight/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
if re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
embedding_model = self.embedding_model_1
else:
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda")