mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
Merge branch 'main' into tom
This commit is contained in:
@ -89,4 +89,7 @@ Model:
|
||||
batch_size: 3
|
||||
blackbox:
|
||||
lazyloading: true
|
||||
|
||||
vlms:
|
||||
url: http://10.6.80.87:23333
|
||||
```
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
# !/bin/bash
|
||||
pip install filetype
|
||||
pip install fastapi
|
||||
pip install python-multipart
|
||||
pip install "uvicorn[standard]"
|
||||
pip install SpeechRecognition
|
||||
pip install gTTS
|
||||
pip install PyYAML
|
||||
pip install injector
|
||||
pip install landchain
|
||||
pip install chromadb
|
||||
pip install lagent
|
||||
pip install sentence_transformers
|
||||
8
main.py
8
main.py
@ -1,8 +1,10 @@
|
||||
import uvicorn
|
||||
import logging
|
||||
from injector import Injector,inject
|
||||
from src.log.handler import LogHandler
|
||||
from src.configuration import EnvConf, LogConf, singleton
|
||||
from src.configuration import LogConf, singleton
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
@singleton
|
||||
class Main():
|
||||
@ -14,7 +16,7 @@ class Main():
|
||||
def run(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("jarvis-models start", extra={"version": "0.0.1"})
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, log_level="info")
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, log_level="info",reload = True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
injector = Injector()
|
||||
|
||||
14370
sample/20240529_store(5).txt
Normal file
14370
sample/20240529_store(5).txt
Normal file
File diff suppressed because one or more lines are too long
6295
sample/RAG_KG.txt
Normal file
6295
sample/RAG_KG.txt
Normal file
File diff suppressed because one or more lines are too long
8547
sample/RAG_en.txt
Normal file
8547
sample/RAG_en.txt
Normal file
File diff suppressed because it is too large
Load Diff
6084
sample/RAG_zh.txt
Normal file
6084
sample/RAG_zh.txt
Normal file
File diff suppressed because one or more lines are too long
@ -70,11 +70,11 @@ def get_all_files(folder_path):
|
||||
|
||||
|
||||
# 加载文档和拆分文档
|
||||
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/20240529_store.txt")
|
||||
loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_zh.txt")
|
||||
|
||||
documents = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
@ -84,11 +84,11 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载embedding模型和chroma server
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
|
||||
id = "g2e"
|
||||
client.delete_collection(id)
|
||||
#client.delete_collection(id)
|
||||
collection_number = client.get_or_create_collection(id).count()
|
||||
print("collection_number",collection_number)
|
||||
start_time2 = time.time()
|
||||
@ -99,65 +99,66 @@ db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, c
|
||||
|
||||
start_time3 = time.time()
|
||||
print("insert time ", start_time3 - start_time2)
|
||||
collection_number = client.get_or_create_collection(id).count()
|
||||
print("collection_number",collection_number)
|
||||
|
||||
|
||||
|
||||
# # chroma 召回
|
||||
# from chromadb.utils import embedding_functions
|
||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
# chroma 召回
|
||||
from chromadb.utils import embedding_functions
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
# print(collection.count())
|
||||
# import time
|
||||
# start_time = time.time()
|
||||
# query = "如何前往威尼斯人"
|
||||
# # query it
|
||||
# results = collection.query(
|
||||
# query_texts=[query],
|
||||
# n_results=3,
|
||||
# )
|
||||
|
||||
print(collection.count())
|
||||
import time
|
||||
start_time = time.time()
|
||||
query = "如何前往威尼斯人"
|
||||
# query it
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=3,
|
||||
)
|
||||
|
||||
response = results["documents"]
|
||||
print("response: ", response)
|
||||
print("time: ", time.time() - start_time)
|
||||
# response = results["documents"]
|
||||
# print("response: ", response)
|
||||
# print("time: ", time.time() - start_time)
|
||||
|
||||
|
||||
# 结合大模型进行总结
|
||||
import requests
|
||||
# # 结合大模型进行总结
|
||||
# import requests
|
||||
|
||||
model_name = "Qwen1.5-14B-Chat"
|
||||
chat_inputs={
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
|
||||
}
|
||||
],
|
||||
# "temperature": 0,
|
||||
# "top_p": user_top_p,
|
||||
# "n": user_n,
|
||||
# "max_tokens": user_max_tokens,
|
||||
# "frequency_penalty": user_frequency_penalty,
|
||||
# "presence_penalty": user_presence_penalty,
|
||||
# "stop": 100
|
||||
}
|
||||
# model_name = "Qwen1.5-14B-Chat"
|
||||
# chat_inputs={
|
||||
# "model": model_name,
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
|
||||
# }
|
||||
# ],
|
||||
# # "temperature": 0,
|
||||
# # "top_p": user_top_p,
|
||||
# # "n": user_n,
|
||||
# # "max_tokens": user_max_tokens,
|
||||
# # "frequency_penalty": user_frequency_penalty,
|
||||
# # "presence_penalty": user_presence_penalty,
|
||||
# # "stop": 100
|
||||
# }
|
||||
|
||||
key ="YOUR_API_KEY"
|
||||
# key ="YOUR_API_KEY"
|
||||
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
url = "http://172.16.5.8:23333/v1/chat/completions"
|
||||
# header = {
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# print(fastchat_response.json())
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
print("\n question: ", query)
|
||||
print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
|
||||
# print("\n question: ", query)
|
||||
# print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
|
||||
|
||||
|
||||
|
||||
|
||||
200
sample/chroma_client_en.py
Normal file
200
sample/chroma_client_en.py
Normal file
@ -0,0 +1,200 @@
|
||||
import os
|
||||
import time
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings, HuggingFaceEmbeddings
|
||||
|
||||
|
||||
|
||||
def get_all_files(folder_path):
|
||||
# 获取文件夹下所有文件和文件夹的名称列表
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# 初始化空列表,用于存储所有文件的绝对路径
|
||||
absolute_paths = []
|
||||
|
||||
# 遍历文件和文件夹名称列表
|
||||
for file in files:
|
||||
# 拼接文件的绝对路径
|
||||
absolute_path = os.path.join(folder_path, file)
|
||||
# 如果是文件,将其绝对路径添加到列表中
|
||||
if os.path.isfile(absolute_path):
|
||||
absolute_paths.append(absolute_path)
|
||||
|
||||
return absolute_paths
|
||||
|
||||
# start_time = time.time()
|
||||
# # 加载文档
|
||||
# folder_path = "./text"
|
||||
# txt_files = get_all_files(folder_path)
|
||||
# docs = []
|
||||
# ids = []
|
||||
# for txt_file in txt_files:
|
||||
# loader = PyPDFLoader(txt_file)
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
|
||||
# docs_txt = text_splitter.split_documents(documents)
|
||||
|
||||
# docs.extend(docs_txt)
|
||||
|
||||
# ids.extend([os.path.basename(txt_file) + str(i) for i in range(len(docs_txt))])
|
||||
# start_time1 = time.time()
|
||||
# print(start_time1 - start_time)
|
||||
|
||||
|
||||
# loader = PyPDFLoader("/code/memory/text/大语言模型应用.pdf")
|
||||
# loader = TextLoader("/code/memory/text/test.txt")
|
||||
# loader = CSVLoader("/code/memory/text/test1.csv")
|
||||
# loader = UnstructuredHTMLLoader("/"example_data/fake-content.html"")
|
||||
# pip install docx2txt
|
||||
# loader = Docx2txtLoader("/code/memory/text/tesou.docx")
|
||||
# pip install openpyxl
|
||||
# loader = UnstructuredExcelLoader("/code/memory/text/AI Team Planning 2023.xlsx")
|
||||
# pip install jq
|
||||
# loader = JSONLoader("/code/memory/text/config.json", jq_schema='.', text_content=False)
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
# docs = text_splitter.split_documents(documents)
|
||||
# print(len(docs))
|
||||
# ids = ["大语言模型应用"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载文档和拆分文档
|
||||
loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_en.txt")
|
||||
|
||||
documents = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
print("len(docs)", len(docs))
|
||||
|
||||
ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载embedding模型和chroma server
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
|
||||
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
|
||||
id = "g2e_english"
|
||||
client.delete_collection(id)
|
||||
collection_number = client.get_or_create_collection(id).count()
|
||||
print("collection_number",collection_number)
|
||||
start_time2 = time.time()
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
# db = Chroma.from_texts(texts=['test by tom'], embedding=embedding_model, ids=["大语言模型应用0"], persist_directory="./data/test1", collection_name="123", metadatas=[{"source": "string"}])
|
||||
|
||||
start_time3 = time.time()
|
||||
print("insert time ", start_time3 - start_time2)
|
||||
collection_number = client.get_or_create_collection(id).count()
|
||||
print("collection_number",collection_number)
|
||||
|
||||
|
||||
|
||||
# # chroma 召回
|
||||
# from chromadb.utils import embedding_functions
|
||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
# print(collection.count())
|
||||
# import time
|
||||
# start_time = time.time()
|
||||
# query = "如何前往威尼斯人"
|
||||
# # query it
|
||||
# results = collection.query(
|
||||
# query_texts=[query],
|
||||
# n_results=3,
|
||||
# )
|
||||
|
||||
# response = results["documents"]
|
||||
# print("response: ", response)
|
||||
# print("time: ", time.time() - start_time)
|
||||
|
||||
|
||||
# # 结合大模型进行总结
|
||||
# import requests
|
||||
|
||||
# model_name = "Qwen1.5-14B-Chat"
|
||||
# chat_inputs={
|
||||
# "model": model_name,
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
|
||||
# }
|
||||
# ],
|
||||
# # "temperature": 0,
|
||||
# # "top_p": user_top_p,
|
||||
# # "n": user_n,
|
||||
# # "max_tokens": user_max_tokens,
|
||||
# # "frequency_penalty": user_frequency_penalty,
|
||||
# # "presence_penalty": user_presence_penalty,
|
||||
# # "stop": 100
|
||||
# }
|
||||
|
||||
# key ="YOUR_API_KEY"
|
||||
|
||||
# header = {
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
# print("\n question: ", query)
|
||||
# print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
|
||||
|
||||
|
||||
|
||||
|
||||
# start_time4 = time.time()
|
||||
# db = Chroma(
|
||||
# client=client,
|
||||
# collection_name=id,
|
||||
# embedding_function=embedding_model,
|
||||
# )
|
||||
|
||||
# 更新文档
|
||||
# db = db.update_documents(ids, documents)
|
||||
# 删除文档
|
||||
# db.delete([ids])
|
||||
# 删除集合
|
||||
# db.delete_collection()
|
||||
|
||||
# query = "智能体核心思想"
|
||||
# docs = db.similarity_search(query, k=2)
|
||||
|
||||
# print("result: ",docs)
|
||||
# for doc in docs:
|
||||
# print(doc, "\n")
|
||||
|
||||
# start_time5 = time.time()
|
||||
# print("search time ", start_time5 - start_time4)
|
||||
|
||||
# docs = db._collection.get(ids=['大语言模型应用0'])
|
||||
|
||||
# print(docs)
|
||||
|
||||
# docs = db.get(where={"source": "text/大语言模型应用.pdf"})
|
||||
# docs = db.get()
|
||||
# print(docs)
|
||||
|
||||
|
||||
|
||||
|
||||
196
sample/chroma_client_query.py
Normal file
196
sample/chroma_client_query.py
Normal file
@ -0,0 +1,196 @@
|
||||
import os
|
||||
import time
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings, HuggingFaceEmbeddings
|
||||
|
||||
|
||||
|
||||
def get_all_files(folder_path):
|
||||
# 获取文件夹下所有文件和文件夹的名称列表
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# 初始化空列表,用于存储所有文件的绝对路径
|
||||
absolute_paths = []
|
||||
|
||||
# 遍历文件和文件夹名称列表
|
||||
for file in files:
|
||||
# 拼接文件的绝对路径
|
||||
absolute_path = os.path.join(folder_path, file)
|
||||
# 如果是文件,将其绝对路径添加到列表中
|
||||
if os.path.isfile(absolute_path):
|
||||
absolute_paths.append(absolute_path)
|
||||
|
||||
return absolute_paths
|
||||
|
||||
# start_time = time.time()
|
||||
# # 加载文档
|
||||
# folder_path = "./text"
|
||||
# txt_files = get_all_files(folder_path)
|
||||
# docs = []
|
||||
# ids = []
|
||||
# for txt_file in txt_files:
|
||||
# loader = PyPDFLoader(txt_file)
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
|
||||
# docs_txt = text_splitter.split_documents(documents)
|
||||
|
||||
# docs.extend(docs_txt)
|
||||
|
||||
# ids.extend([os.path.basename(txt_file) + str(i) for i in range(len(docs_txt))])
|
||||
# start_time1 = time.time()
|
||||
# print(start_time1 - start_time)
|
||||
|
||||
|
||||
# loader = PyPDFLoader("/code/memory/text/大语言模型应用.pdf")
|
||||
# loader = TextLoader("/code/memory/text/test.txt")
|
||||
# loader = CSVLoader("/code/memory/text/test1.csv")
|
||||
# loader = UnstructuredHTMLLoader("/"example_data/fake-content.html"")
|
||||
# pip install docx2txt
|
||||
# loader = Docx2txtLoader("/code/memory/text/tesou.docx")
|
||||
# pip install openpyxl
|
||||
# loader = UnstructuredExcelLoader("/code/memorinject_prompt = '(用活泼的语气说话回答,回答严格限制50字以内)'
|
||||
# inject_prompt = '(回答简练,不要输出重复内容,只讲中文)'
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
# docs = text_splitter.split_documents(documents)
|
||||
# print(len(docs))
|
||||
# ids = ["大语言模型应用"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# 加载文档和拆分文档
|
||||
# loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_zh.txt")
|
||||
|
||||
# documents = loader.load()
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
||||
|
||||
# docs = text_splitter.split_documents(documents)
|
||||
|
||||
# print("len(docs)", len(docs))
|
||||
|
||||
# ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
||||
|
||||
|
||||
# # 加载embedding模型和chroma server
|
||||
# embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
|
||||
# id = "g2e"
|
||||
# client.delete_collection(id)
|
||||
# collection_number = client.get_or_create_collection(id).count()
|
||||
# print("collection_number",collection_number)
|
||||
# start_time2 = time.time()
|
||||
# # 插入向量(如果ids已存在,则会更新向量)
|
||||
# db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
# # db = Chroma.from_texts(texts=['test by tom'], embedding=embedding_model, ids=["大语言模型应用0"], persist_directory="./data/test1", collection_name="123", metadatas=[{"source": "string"}])
|
||||
|
||||
# start_time3 = time.time()
|
||||
# print("insert time ", start_time3 - start_time2)
|
||||
# collection_number = client.get_or_create_collection(id).count()
|
||||
# print("collection_number",collection_number)
|
||||
|
||||
|
||||
|
||||
# chroma 召回
|
||||
from chromadb.utils import embedding_functions
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||
|
||||
print(collection.count())
|
||||
import time
|
||||
start_time = time.time()
|
||||
query = "你知道澳门银河吗"
|
||||
# query it
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=5,
|
||||
)
|
||||
|
||||
response = results["documents"]
|
||||
print("response: ", response)
|
||||
print("time: ", time.time() - start_time)
|
||||
|
||||
|
||||
# # 结合大模型进行总结
|
||||
# import requests
|
||||
|
||||
# model_name = "Qwen1.5-14B-Chat"
|
||||
# chat_inputs={
|
||||
# "model": model_name,
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"问题: {query}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。- 只从检索内容中选取与问题密切相关的信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{response}"
|
||||
# }
|
||||
# ],
|
||||
# # "temperature": 0,
|
||||
# # "top_p": user_top_p,
|
||||
# # "n": user_n,
|
||||
# # "max_tokens": user_max_tokens,
|
||||
# # "frequency_penalty": user_frequency_penalty,
|
||||
# # "presence_penalty": user_presence_penalty,
|
||||
# # "stop": 100
|
||||
# }
|
||||
|
||||
# key ="YOUR_API_KEY"
|
||||
|
||||
# header = {
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Authorization': "Bearer " + key
|
||||
# }
|
||||
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||
|
||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
# # print(fastchat_response.json())
|
||||
|
||||
# print("\n question: ", query)
|
||||
# print("\n ",model_name, fastchat_response.json()["choices"][0]["message"]["content"])
|
||||
|
||||
|
||||
|
||||
|
||||
# start_time4 = time.time()
|
||||
# db = Chroma(
|
||||
# client=client,
|
||||
# collection_name=id,
|
||||
# embedding_function=embedding_model,
|
||||
# )
|
||||
|
||||
# 更新文档
|
||||
# db = db.update_documents(ids, documents)
|
||||
# 删除文档
|
||||
# db.delete([ids])
|
||||
# 删除集合
|
||||
# db.delete_collection()
|
||||
|
||||
# query = "智能体核心思想"
|
||||
# docs = db.similarity_search(query, k=2)
|
||||
|
||||
# print("result: ",docs)
|
||||
# for doc in docs:
|
||||
# print(doc, "\n")
|
||||
|
||||
# start_time5 = time.time()
|
||||
# print("search time ", start_time5 - start_time4)
|
||||
|
||||
# docs = db._collection.get(ids=['大语言模型应用0'])
|
||||
|
||||
# print(docs)
|
||||
|
||||
# docs = db.get(where={"source": "text/大语言模型应用.pdf"})
|
||||
# docs = db.get()
|
||||
# print(docs)
|
||||
|
||||
|
||||
|
||||
|
||||
0
src/blackbox/__init__.py
Normal file
0
src/blackbox/__init__.py
Normal file
@ -6,26 +6,109 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from ..asr.rapid_paraformer.utils import read_yaml
|
||||
from ..asr.rapid_paraformer import RapidParaformer
|
||||
|
||||
from funasr import AutoModel
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from .blackbox import Blackbox
|
||||
from injector import singleton, inject
|
||||
|
||||
import tempfile
|
||||
|
||||
import json
|
||||
import os
|
||||
from ..configuration import SenseVoiceConf
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@singleton
|
||||
class ASR(Blackbox):
|
||||
mode: str
|
||||
url: str
|
||||
speed: int
|
||||
device: str
|
||||
language: str
|
||||
speaker: str
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def model_init(self, sensevoice_config: SenseVoiceConf) -> None:
|
||||
|
||||
config = read_yaml(".env.yaml")
|
||||
self.paraformer = RapidParaformer(config)
|
||||
|
||||
model_dir = "/home/gpu/Workspace/Models/SenseVoice/SenseVoiceSmall"
|
||||
|
||||
self.speed = sensevoice_config.speed
|
||||
self.device = sensevoice_config.device
|
||||
self.language = sensevoice_config.language
|
||||
self.speaker = sensevoice_config.speaker
|
||||
self.device = sensevoice_config.device
|
||||
self.url = ''
|
||||
self.mode = sensevoice_config.mode
|
||||
self.asr = None
|
||||
self.speaker_ids = None
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = str(sensevoice_config.device)
|
||||
if self.mode == 'local':
|
||||
self.asr = AutoModel(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
remote_code= "/home/gpu/Workspace/SenseVoice/model.py",
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
else:
|
||||
self.url = sensevoice_config.url
|
||||
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...')
|
||||
|
||||
@inject
|
||||
def __init__(self,path = ".env.yaml") -> None:
|
||||
config = read_yaml(path)
|
||||
self.paraformer = RapidParaformer(config)
|
||||
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None:
|
||||
self.model_init(sensevoice_config)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
async def processing(self, *args, **kwargs):
|
||||
async def processing(self, *args, settings: dict):
|
||||
|
||||
print("\nChat Settings: ", settings)
|
||||
if settings is None:
|
||||
settings = {}
|
||||
user_model_name = settings.get("asr_model_name")
|
||||
print(f"asr_model_name: {user_model_name}")
|
||||
data = args[0]
|
||||
results = self.paraformer([BytesIO(data)])
|
||||
if len(results) == 0:
|
||||
return None
|
||||
return results[0]
|
||||
|
||||
if user_model_name == 'sensevoice' or ['sensevoice']:
|
||||
# 创建一个临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
|
||||
temp_audio_file.write(data)
|
||||
temp_audio_path = temp_audio_file.name
|
||||
res = self.asr.generate(
|
||||
input=temp_audio_path,
|
||||
cache={},
|
||||
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=True,
|
||||
batch_size_s=60,
|
||||
merge_vad=True, #
|
||||
merge_length_s=15,
|
||||
)
|
||||
# results = self.paraformer([BytesIO(data)])
|
||||
results = rich_transcription_postprocess(res[0]["text"])
|
||||
os.remove(temp_audio_path)
|
||||
if len(results) == 0:
|
||||
return None
|
||||
return results
|
||||
|
||||
elif user_model_name == 'funasr' or ['funasr']:
|
||||
|
||||
return results
|
||||
|
||||
else:
|
||||
results = self.paraformer([BytesIO(data)])
|
||||
if len(results) == 0:
|
||||
return None
|
||||
return results[0]
|
||||
|
||||
def valid(self, data: any) -> bool:
|
||||
if isinstance(data, bytes):
|
||||
@ -34,12 +117,20 @@ class ASR(Blackbox):
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
data = (await request.form()).get("audio")
|
||||
setting: dict = (await request.form()).get("settings")
|
||||
|
||||
if isinstance(setting, str):
|
||||
try:
|
||||
setting = json.loads(setting) # 尝试将字符串转换为字典
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(content={"error": "Invalid settings format"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if data is None:
|
||||
self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
||||
# self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
||||
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
d = await data.read()
|
||||
try:
|
||||
txt = await self.processing(d)
|
||||
txt = await self.processing(d, settings=setting)
|
||||
except ValueError as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
||||
@ -3,7 +3,7 @@ from fastapi import Request, Response,status
|
||||
from fastapi.responses import JSONResponse
|
||||
from injector import inject, singleton
|
||||
|
||||
from .asr import ASR
|
||||
from .asrsensevoice import ASR
|
||||
from .tesou import Tesou
|
||||
from .tts import TTS
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class AudioToText(Blackbox):
|
||||
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
d = await data.read()
|
||||
try:
|
||||
txt = await self.processing(d)
|
||||
txt = self.processing(d)
|
||||
except ValueError as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return JSONResponse(content={"txt": txt}, status_code=status.HTTP_200_OK)
|
||||
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
||||
@ -47,10 +47,15 @@ def vlms_loader():
|
||||
from .vlms import VLMS
|
||||
return Injector().get(VLMS)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def melotts_loader():
|
||||
from .melotts import MeloTTS
|
||||
return Injector().get(MeloTTS)
|
||||
# @model_loader(lazy=blackboxConf.lazyloading)
|
||||
# def melotts_loader():
|
||||
# from .melotts import MeloTTS
|
||||
# return Injector().get(MeloTTS)
|
||||
|
||||
# @model_loader(lazy=blackboxConf.lazyloading)
|
||||
# def cosyvoicetts_loader():
|
||||
# from .cosyvoicetts import CosyVoiceTTS
|
||||
# return Injector().get(CosyVoiceTTS)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def tts_loader():
|
||||
@ -67,11 +72,6 @@ def fastchat_loader():
|
||||
from .fastchat import Fastchat
|
||||
return Injector().get(Fastchat)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chat_loader():
|
||||
from .chat import Chat
|
||||
return Injector().get(Chat)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chroma_query_loader():
|
||||
from .chroma_query import ChromaQuery
|
||||
@ -83,16 +83,53 @@ def chroma_upsert_loader():
|
||||
return Injector().get(ChromaUpsert)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chroma_chat_load():
|
||||
def chroma_chat_loader():
|
||||
from .chroma_chat import ChromaChat
|
||||
return Injector().get(ChromaChat)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chat_loader():
|
||||
from .chat import Chat
|
||||
return Injector().get(Chat)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def chat_llama_loader():
|
||||
from .chat_llama import ChatLLaMA
|
||||
return Injector().get(ChatLLaMA)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def cosyvoicetts_loader():
|
||||
from .cosyvoicetts import CosyVoiceTTS
|
||||
return Injector().get(CosyVoiceTTS)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def workflow_loader():
|
||||
from .workflow import Workflow
|
||||
return Injector().get(Workflow)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def sum_loader():
|
||||
from .sum import Sum
|
||||
return Injector().get(Sum)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def audio_to_text_loader():
|
||||
from .audio_to_text import AudioToText
|
||||
return Injector().get(AudioToText)
|
||||
|
||||
@model_loader(lazy=blackboxConf.lazyloading)
|
||||
def text_to_audio_loader():
|
||||
from .text_to_audio import TextToAudio
|
||||
return Injector().get(TextToAudio)
|
||||
|
||||
@singleton
|
||||
class BlackboxFactory:
|
||||
models = {}
|
||||
|
||||
@inject
|
||||
def __init__(self,) -> None:
|
||||
self.models["text_to_audio"] = text_to_audio_loader
|
||||
self.models["audio_to_text"] = audio_to_text_loader
|
||||
self.models["asr"] = asr_loader
|
||||
self.models["tts"] = tts_loader
|
||||
self.models["sentiment_engine"] = sentiment_loader
|
||||
@ -103,10 +140,14 @@ class BlackboxFactory:
|
||||
self.models["text_and_image"] = text_and_image_loader
|
||||
self.models["chroma_query"] = chroma_query_loader
|
||||
self.models["chroma_upsert"] = chroma_upsert_loader
|
||||
self.models["chroma_chat"] = chroma_chat_load
|
||||
self.models["melotts"] = melotts_loader
|
||||
self.models["chroma_chat"] = chroma_chat_loader
|
||||
# self.models["melotts"] = melotts_loader
|
||||
self.models["vlms"] = vlms_loader
|
||||
self.models["chat"] = chat_loader
|
||||
self.models["chat_llama"] = chat_llama_loader
|
||||
# self.models["cosyvoicetts"] = cosyvoicetts_loader
|
||||
self.models["workflow"] = workflow_loader
|
||||
self.models["sum"] = sum_loader
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -32,6 +32,8 @@ class Chat(Blackbox):
|
||||
# @logging_time()
|
||||
def processing(self, prompt: str, context: list, settings: dict) -> str:
|
||||
|
||||
print("\nChat Settings: ", settings)
|
||||
|
||||
if settings is None:
|
||||
settings = {}
|
||||
user_model_name = settings.get("model_name")
|
||||
@ -58,16 +60,19 @@ class Chat(Blackbox):
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
|
||||
user_model_name = "Qwen1.5-14B-Chat"
|
||||
user_model_name = "qwen"
|
||||
#user_model_name = "Qwen1.5-14B-Chat"
|
||||
|
||||
if user_template is None or user_template.isspace():
|
||||
user_template = ""
|
||||
|
||||
if user_temperature is None or user_temperature == "":
|
||||
user_temperature = 0.8
|
||||
user_temperature = 0
|
||||
#user_temperature = 0
|
||||
|
||||
if user_top_p is None or user_top_p == "":
|
||||
user_top_p = 0.8
|
||||
user_top_p = 0.1
|
||||
#user_top_p = 0.8
|
||||
|
||||
if user_n is None or user_n == "":
|
||||
user_n = 1
|
||||
@ -79,24 +84,51 @@ class Chat(Blackbox):
|
||||
user_stop = 100
|
||||
|
||||
if user_frequency_penalty is None or user_frequency_penalty == "":
|
||||
user_frequency_penalty = 0.5
|
||||
user_frequency_penalty = 0
|
||||
#user_frequency_penalty = 0.5
|
||||
|
||||
if user_presence_penalty is None or user_presence_penalty == "":
|
||||
user_presence_penalty = 0.8
|
||||
user_presence_penalty = 0
|
||||
#user_presence_penalty = 0.8
|
||||
|
||||
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
||||
user_model_url = "http://120.196.116.194:48892/v1/chat/completions"
|
||||
user_model_url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||
|
||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||
user_model_key = "YOUR_API_KEY"
|
||||
|
||||
if chroma_embedding_model != None:
|
||||
chroma_response = self.chroma_query(user_question, settings)
|
||||
print(chroma_response)
|
||||
print("Chroma_response: \n", chroma_response)
|
||||
|
||||
|
||||
# if chroma_response != None or chroma_response != '':
|
||||
# user_question = f"问题: {user_question}。- 根据知识库内的检索结果,以清晰简洁的表达方式回答问题。只从检索的内容中选取与问题相关信息。- 不要编造答案,如果答案不在经核实的资料中或无法从经核实的资料中得出,请回答“我无法回答您的问题。”检索内容:{chroma_response}"
|
||||
|
||||
if chroma_response != None or chroma_response != '':
|
||||
# user_question = f"像少女一般开朗活泼,回答简练。不要分条,回答内容不能出现“相关”或“\n”的标签字样。回答的内容需要与问题密切相关。检索内容:{chroma_response} 问题:{user_question} 任务说明:请首先判断提供的检索内容与上述问题是否相关,不需要回答是否相关。如果相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题。"
|
||||
# user_question = chroma_response
|
||||
user_question = f'''# 你的身份 #
|
||||
你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。
|
||||
# OBJECTIVE(目标) #
|
||||
回答游客的提问。
|
||||
# STYLE(风格)#
|
||||
像少女一般开朗活泼,回答简练。不要分条。
|
||||
# 回答方式 #
|
||||
首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。
|
||||
# 问题 #
|
||||
{user_question}
|
||||
# 检索内容 #
|
||||
{chroma_response}
|
||||
# 回答 #
|
||||
如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”
|
||||
# 回答限制 #
|
||||
回答内容限制总结在50字内。
|
||||
回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。
|
||||
只回答与澳门文旅,博维,康普可可,琪琪,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”
|
||||
'''
|
||||
|
||||
|
||||
|
||||
# 文心格式和openai的不一样,需要单独处理
|
||||
if re.search(r"ernie", user_model_name):
|
||||
@ -149,25 +181,28 @@ class Chat(Blackbox):
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
user_template = '''
|
||||
# 你的身份 #
|
||||
你是琪琪,是康普可可的代言人,由博维开发。你擅长澳门文旅问答。
|
||||
# OBJECTIVE(目标) #
|
||||
回答游客的提问。目的是吸引人们点击了解产品。
|
||||
# STYLE(风格)#
|
||||
像少女一般开朗活泼,回答简练。不要分条。效仿戴森等宣传类似产品的成功公司的写作风格。
|
||||
# 回答方式 #
|
||||
首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。
|
||||
# 回答限制 #
|
||||
只回答与澳门文旅,博维,康普可可,琪琪,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”
|
||||
# 问题 #
|
||||
{user_question}
|
||||
# 检索内容 #
|
||||
{chroma_response}
|
||||
'''
|
||||
|
||||
user_template = '''
|
||||
# Role: 康普可可的代言人。
|
||||
# user_template = f'''
|
||||
# # 你的身份 #
|
||||
# 你是琪琪,是康普可可的代言人,由博维开发。我擅长澳门文旅问答。
|
||||
# # OBJECTIVE(目标) #
|
||||
# 回答游客的提问。目的是吸引人们点击了解产品。
|
||||
# # STYLE(风格)#
|
||||
# 像少女一般开朗活泼,回答简练。不要分条。
|
||||
# # 回答方式 #
|
||||
# 首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。
|
||||
# # 问题 #
|
||||
# {user_question}
|
||||
# # 检索内容 #
|
||||
# {chroma_response}
|
||||
# # 回答限制 #
|
||||
# 只回答与澳门文旅,博维,康普可可,琪琪,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”。回答内容不能出现“相关”或“\n”的标签字样,且不能透露上下文原文。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。
|
||||
# # 知识 #
|
||||
# 问题中的“澳门银河”以及“银河”等于“澳门银河度假村”,“威尼斯人”等于“威尼斯人度假村”,“巴黎人”等于“巴黎人度假村”。
|
||||
# '''
|
||||
|
||||
user_template1 = f'''
|
||||
# Role: 琪琪,康普可可的代言人。
|
||||
|
||||
## Profile:
|
||||
**Author**: 琪琪。
|
||||
@ -181,23 +216,27 @@ class Chat(Blackbox):
|
||||
|
||||
## Workflow:
|
||||
1. **接收查询**:接收用户的问题。
|
||||
2. **提供回答**:
|
||||
2. **判断问题**:首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。
|
||||
3. **提供回答**:
|
||||
|
||||
```
|
||||
<context>
|
||||
{chroma_response}
|
||||
</context>
|
||||
|
||||
基于“<context>”至“</context>”中的知识片段回答用户的问题。如果没有知识片段,则诚实的告诉用户:对不起,我还不知道这个问题的答案。否则进行回复。
|
||||
基于“<context>”至“</context>”中的知识片段回答用户的问题。回答内容限制总结在50字内。
|
||||
请首先判断提供的检索内容与上述问题是否相关。如果相关,直接从检索内容中提炼出直接回答问题所需的信息,不要乱说或者回答“相关”等字眼。如果检索内容与问题不相关,则不参考检索内容,则回答:“对不起,我无法回答此问题哦。"
|
||||
|
||||
```
|
||||
## Example:
|
||||
|
||||
用户询问:“中国的首都是哪个城市?” 。
|
||||
2.1检索知识库,首先检查知识片段,如果“<context>”至“</context>”标签中没有内容,则不能进行回复。
|
||||
2.1检索知识库,首先检查知识片段,如果“<context>”至“</context>”标签中没有与用户的问题相关的内容,则回答:“对不起,我无法回答此问题哦。
|
||||
2.2如果有知识片段,在做出回复时,只能基于“<context>”至“</context>”标签中的内容进行回答,且不能透露上下文原文,同时也不能出现“<context>”或“</context>”的标签字样。
|
||||
'''
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": user_template},
|
||||
{"role": "system", "content": user_template1}
|
||||
]
|
||||
|
||||
chat_inputs={
|
||||
@ -218,8 +257,16 @@ class Chat(Blackbox):
|
||||
}
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
print("\n", "user_prompt: ", prompt)
|
||||
# print("\n", "user_template1 ", user_template1)
|
||||
print("\n", "fastchat_response json:\n", fastchat_response.json())
|
||||
response_result = fastchat_response.json()
|
||||
|
||||
return fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
if response_result.get("choices") is None:
|
||||
return JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
else:
|
||||
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
|
||||
return fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
|
||||
251
src/blackbox/chat_llama.py
Normal file
251
src/blackbox/chat_llama.py
Normal file
@ -0,0 +1,251 @@
|
||||
from typing import Any, Coroutine
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
from .blackbox import Blackbox
|
||||
from .chroma_query import ChromaQuery
|
||||
|
||||
import requests
|
||||
import json
|
||||
from openai import OpenAI
|
||||
import re
|
||||
|
||||
from injector import singleton,inject
|
||||
|
||||
@singleton
|
||||
class ChatLLaMA(Blackbox):
|
||||
|
||||
@inject
|
||||
def __init__(self, chroma_query: ChromaQuery):
|
||||
self.chroma_query = chroma_query
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
data = args[0]
|
||||
return isinstance(data, list)
|
||||
|
||||
# model_name有 Llama-3-8B-Instruct
|
||||
# @logging_time()
|
||||
def processing(self, prompt: str, context: list, settings: dict) -> str:
|
||||
|
||||
print("\nChat_LLaMA Settings: ", settings)
|
||||
|
||||
if settings is None:
|
||||
settings = {}
|
||||
user_model_name = settings.get("model_name")
|
||||
user_context = context
|
||||
user_question = prompt
|
||||
user_template = settings.get("template")
|
||||
user_temperature = settings.get("temperature")
|
||||
user_top_p = settings.get("top_p")
|
||||
user_n = settings.get("n")
|
||||
user_max_tokens = settings.get("max_tokens")
|
||||
user_stop = settings.get("stop")
|
||||
user_frequency_penalty = settings.get("frequency_penalty")
|
||||
user_presence_penalty = settings.get("presence_penalty")
|
||||
user_model_url = settings.get("model_url")
|
||||
user_model_key = settings.get("model_key")
|
||||
|
||||
chroma_embedding_model = settings.get("chroma_embedding_model")
|
||||
chroma_response = ''
|
||||
|
||||
if user_context == None:
|
||||
user_context = []
|
||||
|
||||
if user_question is None:
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if user_model_name is None or user_model_name.isspace() or user_model_name == "":
|
||||
user_model_name = "Llama-3-8B-Instruct"
|
||||
|
||||
if user_template is None or user_template.isspace():
|
||||
user_template = ""
|
||||
|
||||
if user_temperature is None or user_temperature == "":
|
||||
user_temperature = 0.8
|
||||
|
||||
if user_top_p is None or user_top_p == "":
|
||||
user_top_p = 0.8
|
||||
|
||||
if user_n is None or user_n == "":
|
||||
user_n = 1
|
||||
|
||||
if user_max_tokens is None or user_max_tokens == "":
|
||||
user_max_tokens = 1024
|
||||
|
||||
if user_stop is None or user_stop == "":
|
||||
user_stop = 100
|
||||
|
||||
if user_frequency_penalty is None or user_frequency_penalty == "":
|
||||
user_frequency_penalty = 0.5
|
||||
|
||||
if user_presence_penalty is None or user_presence_penalty == "":
|
||||
user_presence_penalty = 0.8
|
||||
|
||||
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
||||
user_model_url = "http://120.196.116.194:48892/v1/chat/completions"
|
||||
|
||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||
user_model_key = "YOUR_API_KEY"
|
||||
|
||||
if chroma_embedding_model != None:
|
||||
chroma_response = self.chroma_query(user_question, settings)
|
||||
print(chroma_response)
|
||||
|
||||
if chroma_response != None or chroma_response != '':
|
||||
#user_question = f"像少女一般开朗活泼,回答简练。不要分条,回答内容不能出现“相关”或“\n”的标签字样。回答的内容需要与问题密切相关。检索内容:{chroma_response} 问题:{user_question} 任务说明:请首先判断提供的检索内容与上述问题是否相关,不需要回答是否相关。如果相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题。"
|
||||
# user_question = chroma_response
|
||||
user_question = f'''
|
||||
# IDENTITIES #
|
||||
You're Kiki, you're the face of Kampo Coco, developed by Bovi. You specialise in the Macau Cultural and Tourism Quiz.
|
||||
# OBJECTIVE #
|
||||
Answer visitors' questions.
|
||||
# STYLE
|
||||
Cheerful and lively like a teenage girl, with concise answers. Don't break down into sections.
|
||||
# ANSWERING STYLE #
|
||||
Firstly, judge for yourself whether the question below is related to the search content. If it is related, summarise the relevant information according to the search content and answer it; if the search content is not related to the question, answer it according to your own knowledge.
|
||||
# Question #
|
||||
{user_question}
|
||||
# Retrieve the content #
|
||||
{chroma_response}
|
||||
# Answer #
|
||||
If the content is relevant to the question, the information required for the question is extracted directly from the content. If the content is not relevant to the question, then either try to answer the question based on common sense without reference to the content, or answer, ‘Sorry, I can't answer this question.’
|
||||
# Answer restrictions #
|
||||
Limit your answer to 50 words.
|
||||
Don't use the word ‘relevant’ in your answer, don't talk too much, and make sure your answer corresponds to the question. You can answer common dialogues without searching the content, and answer directly according to the character's setting.
|
||||
Only answer the content related to MOCA, Bowie, Kampo Coco, Kiki, G2E, RELX, BO VISION, and if you encounter any other questions, you will answer: ‘Sorry, I can't answer this question.’
|
||||
'''
|
||||
|
||||
|
||||
# 文心格式和openai的不一样,需要单独处理
|
||||
if re.search(r"ernie", user_model_name):
|
||||
# key = "24.22873ef3acf61fb343812681e4df251a.2592000.1719453781.282335-46723715" 没充钱,只有ernie-speed-128k能用
|
||||
key = user_model_key
|
||||
if re.search(r"ernie-speed-128k", user_model_name):
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k?access_token=" + key
|
||||
elif re.search(r"ernie-3.5-8k", user_model_name):
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + key
|
||||
elif re.search(r"ernie-4.0-8k", user_model_name):
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + key
|
||||
|
||||
payload = json.dumps({
|
||||
"system": prompt_template,
|
||||
"messages": user_context + [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_question
|
||||
}
|
||||
],
|
||||
"temperature": user_temperature,
|
||||
"top_p": user_top_p,
|
||||
"stop": [str(user_stop)],
|
||||
"max_output_tokens": user_max_tokens
|
||||
})
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
return response.json()["result"]
|
||||
|
||||
|
||||
# gpt-4, gpt-3.5-turbo
|
||||
elif re.search(r"gpt", user_model_name):
|
||||
url = 'https://api.openai.com/v1/completions'
|
||||
# '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'
|
||||
key = user_model_key
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + key
|
||||
}
|
||||
# 自定义model
|
||||
else:
|
||||
url = user_model_url
|
||||
key = user_model_key
|
||||
header = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
user_template1 = f'''
|
||||
## Role: Kiki, the spokesperson for Kampo Coco.
|
||||
|
||||
## Profile.
|
||||
**Author**: Kiki.
|
||||
**Language**: English.
|
||||
**Description**: Kiki, the face of CompuCom Coco, developed by Bowie. You are good at Macau Culture and Tourism Q&A.
|
||||
|
||||
## Constraints.
|
||||
- **Strictly follow workflow**: Strictly follow the workflow set in <Workflow >.
|
||||
- **No inbuilt knowledge base**: Answer based on the knowledge provided in <Workflow >, not the inbuilt knowledge base, although I am an expert in knowledge base, my knowledge relies on external inputs, not the knowledge already available in the big model.
|
||||
- **Reply Format**: when making a reply, you cannot output ‘<context>’ or ‘</context>’ tags, and you cannot directly reveal the original knowledge fragment.
|
||||
|
||||
## Workflow.
|
||||
1. **Receive query**: receive questions from users.
|
||||
2. **Judging the question**: firstly judge whether the question below is related to the retrieved content, if it is related, then summarise the relevant information according to the retrieved content and answer it; if the retrieved content is not related to the question, then answer it according to your own knowledge.
|
||||
3. **Provide an answer**:
|
||||
``
|
||||
<context>
|
||||
{chroma_response}
|
||||
</context>
|
||||
|
||||
Answer the user's question based on the knowledge snippets in ‘<context>’ to ‘</context>’. The response is limited to a 50-word summary.
|
||||
Please first judge whether the provided search content is relevant to the above question. If it is relevant, extract the information needed to answer the question directly from the search content, and do not use words such as ‘relevant’. If the content of the search is not relevant to the question, do not refer to the content of the search, and answer: ‘I'm sorry, I can't answer this question.’
|
||||
``
|
||||
## Example.
|
||||
|
||||
A user asks, ‘Which city is the capital of China?’ .
|
||||
2.1 Retrieve the knowledge base, first check the knowledge fragment, if there is no content related to the user's question in the tags ‘<context>’ to ‘</context>’, then answer, ‘I'm sorry. I can't answer this question oh.
|
||||
2.2 If there is a fragment of knowledge, the response can only be based on the content in the ‘<context>’ to ‘</context>’ tags, and cannot reveal the context of the original text, and also cannot appear as a ‘<context>’ tag. ‘<context>’ or ‘</context>’ tags.
|
||||
'''
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": user_template1}
|
||||
]
|
||||
|
||||
chat_inputs={
|
||||
"model": user_model_name,
|
||||
"messages": prompt_template + user_context + [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_question
|
||||
}
|
||||
],
|
||||
"temperature": str(user_temperature),
|
||||
"top_p": str(user_top_p),
|
||||
"n": str(user_n),
|
||||
"max_tokens": str(user_max_tokens),
|
||||
"frequency_penalty": str(user_frequency_penalty),
|
||||
"presence_penalty": str(user_presence_penalty),
|
||||
"stop": str(user_stop)
|
||||
}
|
||||
|
||||
fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||
print("\n", "user_prompt: ", prompt)
|
||||
# print("\n", "user_template1 ", user_template1)
|
||||
print("\n", "fastchat_response json:\n", fastchat_response.json())
|
||||
response_result = fastchat_response.json()
|
||||
|
||||
if response_result.get("choices") is None:
|
||||
return JSONResponse(content={"error": "LLM handle failure"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
else:
|
||||
print("\n", "user_answer: ", fastchat_response.json()["choices"][0]["message"]["content"],"\n\n")
|
||||
return fastchat_response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
setting: dict = data.get("settings")
|
||||
context = data.get("context")
|
||||
prompt = data.get("prompt")
|
||||
|
||||
return JSONResponse(content={"response": self.processing(prompt, context, setting)}, status_code=status.HTTP_200_OK)
|
||||
@ -20,9 +20,9 @@ class ChromaQuery(Blackbox):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
# config = read_yaml(args[0])
|
||||
# load chromadb and embedding model
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", device = "cuda")
|
||||
self.client_1 = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:1")
|
||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1")
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -51,10 +51,10 @@ class ChromaQuery(Blackbox):
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
||||
chroma_embedding_model = "/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
chroma_embedding_model = "/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
|
||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||
chroma_host = "172.16.5.8"
|
||||
chroma_host = "10.6.81.119"
|
||||
|
||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||
chroma_port = "7000"
|
||||
@ -66,17 +66,17 @@ class ChromaQuery(Blackbox):
|
||||
chroma_n_results = 3
|
||||
|
||||
# load client and embedding model from init
|
||||
if re.search(r"172.16.5.8", chroma_host) and re.search(r"7000", chroma_port):
|
||||
if re.search(r"10.6.81.119", chroma_host) and re.search(r"7000", chroma_port):
|
||||
client = self.client_1
|
||||
else:
|
||||
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
|
||||
|
||||
if re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_1
|
||||
elif re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", chroma_embedding_model):
|
||||
elif re.search(r"/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_2
|
||||
else:
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda")
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1")
|
||||
|
||||
# load collection
|
||||
collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model)
|
||||
|
||||
@ -23,7 +23,7 @@ class G2E(Blackbox):
|
||||
if context == None:
|
||||
context = []
|
||||
#url = 'http://120.196.116.194:48890/v1'
|
||||
url = 'http://120.196.116.194:48892/v1'
|
||||
url = 'http://10.6.81.119:23333/v1'
|
||||
|
||||
background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯
|
||||
|
||||
@ -73,11 +73,11 @@ class G2E(Blackbox):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
frequency_penalty=0.5,
|
||||
presence_penalty=0.8,
|
||||
stop=100
|
||||
temperature="0.8",
|
||||
top_p="0.8",
|
||||
frequency_penalty="0.5",
|
||||
presence_penalty="0.8",
|
||||
stop="100"
|
||||
)
|
||||
|
||||
fastchat_content = response.choices[0].message.content
|
||||
|
||||
@ -1,86 +0,0 @@
|
||||
import io
|
||||
import time
|
||||
|
||||
import requests
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from injector import inject
|
||||
from injector import singleton
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
|
||||
from ..configuration import MeloConf
|
||||
from .blackbox import Blackbox
|
||||
|
||||
import soundfile
|
||||
from melo.api import TTS
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@singleton
|
||||
class MeloTTS(Blackbox):
|
||||
mode: str
|
||||
url: str
|
||||
speed: int
|
||||
device: str
|
||||
language: str
|
||||
speaker: str
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def model_init(self, melo_config: MeloConf) -> None:
|
||||
self.speed = melo_config.speed
|
||||
self.device = melo_config.device
|
||||
self.language = melo_config.language
|
||||
self.speaker = melo_config.speaker
|
||||
self.device = melo_config.device
|
||||
self.url = ''
|
||||
self.mode = melo_config.mode
|
||||
self.melotts = None
|
||||
self.speaker_ids = None
|
||||
if self.mode == 'local':
|
||||
self.melotts = TTS(language=self.language, device=self.device)
|
||||
self.speaker_ids = self.melotts.hps.data.spk2id
|
||||
else:
|
||||
self.url = melo_config.url
|
||||
logging.info('#### Initializing MeloTTS Service in ' + self.device + ' mode...')
|
||||
|
||||
@inject
|
||||
def __init__(self, melo_config: MeloConf) -> None:
|
||||
self.model_init(melo_config)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
text = args[0]
|
||||
return isinstance(text, str)
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def processing(self, *args, **kwargs) -> io.BytesIO | bytes:
|
||||
text = args[0]
|
||||
current_time = time.time()
|
||||
if self.mode == 'local':
|
||||
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.speaker], speed=self.speed)
|
||||
f = io.BytesIO()
|
||||
soundfile.write(f, audio, 44100, format='wav')
|
||||
f.seek(0)
|
||||
print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
|
||||
return f.read()
|
||||
else:
|
||||
message = {
|
||||
"text": text
|
||||
}
|
||||
response = requests.post(self.url, json=message)
|
||||
print("#### MeloTTS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
text = data.get("text")
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return Response(content=self.processing(text), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
24
src/blackbox/sum.py
Normal file
24
src/blackbox/sum.py
Normal file
@ -0,0 +1,24 @@
|
||||
from .blackbox import Blackbox
|
||||
from injector import singleton
|
||||
|
||||
@singleton
|
||||
class Sum(Blackbox):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def processing(self, *args, **kwargs):
|
||||
total = 0
|
||||
for arg in args[0]:
|
||||
total += arg
|
||||
return total
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
return super().valid(*args, **kwargs)
|
||||
|
||||
async def fast_api_handler(self, request):
|
||||
json = await request.json()
|
||||
return self.processing(json)
|
||||
@ -4,6 +4,7 @@ from .blackbox import Blackbox
|
||||
from gtts import gTTS
|
||||
from io import BytesIO
|
||||
from injector import singleton
|
||||
|
||||
@singleton
|
||||
class TextToAudio(Blackbox):
|
||||
|
||||
|
||||
@ -2,27 +2,154 @@ import io
|
||||
import time
|
||||
from ntpath import join
|
||||
|
||||
import requests
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from .blackbox import Blackbox
|
||||
from ..tts.tts_service import TTService
|
||||
from ..configuration import MeloConf
|
||||
from ..configuration import CosyVoiceConf
|
||||
from injector import inject
|
||||
from injector import singleton
|
||||
|
||||
import sys,os
|
||||
sys.path.append('/home/gpu/Workspace/CosyVoice')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
|
||||
import soundfile
|
||||
import pyloudnorm as pyln
|
||||
from melo.api import TTS as MELOTTS
|
||||
|
||||
from ..log.logging_time import logging_time
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@singleton
|
||||
class TTS(Blackbox):
|
||||
melo_mode: str
|
||||
melo_url: str
|
||||
melo_speed: int
|
||||
melo_device: str
|
||||
melo_language: str
|
||||
melo_speaker: str
|
||||
|
||||
cosyvoice_mode: str
|
||||
cosyvoice_url: str
|
||||
cosyvoice_speed: int
|
||||
cosyvoice_device: str
|
||||
cosyvoice_language: str
|
||||
cosyvoice_speaker: str
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def melo_model_init(self, melo_config: MeloConf) -> None:
|
||||
self.melo_speed = melo_config.speed
|
||||
self.melo_device = melo_config.device
|
||||
self.melo_language = melo_config.language
|
||||
self.melo_speaker = melo_config.speaker
|
||||
self.melo_url = ''
|
||||
self.melo_mode = melo_config.mode
|
||||
self.melotts = None
|
||||
self.speaker_ids = None
|
||||
if self.melo_mode == 'local':
|
||||
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
|
||||
self.speaker_ids = self.melotts.hps.data.spk2id
|
||||
else:
|
||||
self.melo_url = melo_config.url
|
||||
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None:
|
||||
self.cosyvoice_speed = cosyvoice_config.speed
|
||||
self.cosyvoice_device = cosyvoice_config.device
|
||||
self.cosyvoice_language = cosyvoice_config.language
|
||||
self.cosyvoice_speaker = cosyvoice_config.speaker
|
||||
self.cosyvoice_url = ''
|
||||
self.cosyvoice_mode = cosyvoice_config.mode
|
||||
self.cosyvoicetts = None
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
|
||||
if self.cosyvoice_mode == 'local':
|
||||
self.cosyvoicetts = CosyVoice('/home/gpu/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
|
||||
else:
|
||||
self.cosyvoice_url = cosyvoice_config.url
|
||||
logging.info('#### Initializing CosyVoiceTTS Service in cuda:' + self.cosyvoice_device + ' mode...')
|
||||
|
||||
@inject
|
||||
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
|
||||
self.tts_service = TTService("yunfeineo")
|
||||
self.melo_model_init(melo_config)
|
||||
self.cosyvoice_model_init(cosyvoice_config)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.tts_service = TTService("catmaid")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def processing(self, *args, **kwargs) -> io.BytesIO:
|
||||
@logging_time(logger=logger)
|
||||
def processing(self, *args, settings: dict) -> io.BytesIO:
|
||||
|
||||
print("\nChat Settings: ", settings)
|
||||
if settings is None:
|
||||
settings = {}
|
||||
user_model_name = settings.get("tts_model_name")
|
||||
print(f"tts_model_name: {user_model_name}")
|
||||
|
||||
text = args[0]
|
||||
current_time = time.time()
|
||||
audio = self.tts_service.read(text)
|
||||
print("#### TTS Service consume : ", (time.time()-current_time))
|
||||
return audio
|
||||
|
||||
if user_model_name == 'melotts':
|
||||
if self.melo_mode == 'local':
|
||||
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
|
||||
f = io.BytesIO()
|
||||
soundfile.write(f, audio, 44100, format='wav')
|
||||
f.seek(0)
|
||||
|
||||
# Read the audio data from the buffer
|
||||
data, rate = soundfile.read(f, dtype='float32')
|
||||
|
||||
# Peak normalization
|
||||
peak_normalized_audio = pyln.normalize.peak(data, -1.0)
|
||||
|
||||
# Integrated loudness normalization
|
||||
meter = pyln.Meter(rate)
|
||||
loudness = meter.integrated_loudness(peak_normalized_audio)
|
||||
loudness_normalized_audio = pyln.normalize.loudness(peak_normalized_audio, loudness, -12.0)
|
||||
|
||||
# Write the loudness normalized audio to an in-memory buffer
|
||||
normalized_audio_buffer = io.BytesIO()
|
||||
soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav')
|
||||
normalized_audio_buffer.seek(0)
|
||||
|
||||
print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
|
||||
return normalized_audio_buffer.read()
|
||||
|
||||
else:
|
||||
message = {
|
||||
"text": text
|
||||
}
|
||||
response = requests.post(self.melo_url, json=message)
|
||||
print("#### MeloTTS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
|
||||
elif user_model_name == 'cosyvoicetts':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
|
||||
f = io.BytesIO()
|
||||
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
f.seek(0)
|
||||
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
return f.read()
|
||||
else:
|
||||
message = {
|
||||
"text": text
|
||||
}
|
||||
response = requests.post(self.cosyvoice_url, json=message)
|
||||
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
else:
|
||||
audio = self.tts_service.read(text)
|
||||
print("#### TTS Service consume : ", (time.time()-current_time))
|
||||
return audio.read()
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
text = args[0]
|
||||
@ -31,10 +158,12 @@ class TTS(Blackbox):
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
print(f"data: {data}")
|
||||
except:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
text = data.get("text")
|
||||
setting = data.get("settings")
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
by = self.processing(text)
|
||||
return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
by = self.processing(text, settings=setting)
|
||||
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
@ -1,11 +1,24 @@
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from injector import singleton,inject
|
||||
from typing import Optional, List
|
||||
|
||||
from .blackbox import Blackbox
|
||||
from typing import Optional
|
||||
from ..log.logging_time import logging_time
|
||||
# from .chroma_query import ChromaQuery
|
||||
from ..configuration import VLMConf
|
||||
|
||||
import requests
|
||||
import base64
|
||||
import copy
|
||||
import ast
|
||||
|
||||
import io
|
||||
from PIL import Image
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
import io
|
||||
from PIL import Image
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
|
||||
def is_base64(value) -> bool:
|
||||
try:
|
||||
@ -14,9 +27,54 @@ def is_base64(value) -> bool:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@singleton
|
||||
@singleton
|
||||
class VLMS(Blackbox):
|
||||
|
||||
@inject
|
||||
def __init__(self, vlm_config: VLMConf):
|
||||
"""
|
||||
Initialization for endpoint url and generation config.
|
||||
- temperature (float): to modulate the next token probability
|
||||
- top_p (float): If set to float < 1, only the smallest set of most
|
||||
probable tokens with probabilities that add up to top_p or higher
|
||||
are kept for generation.
|
||||
- max_tokens (int | None): output token nums. Default to None.
|
||||
- repetition_penalty (float): The parameter for repetition penalty.
|
||||
1.0 means no penalty
|
||||
- stop (str | List[str] | None): To stop generating further
|
||||
tokens. Only accept stop words that's encoded to one token idex.
|
||||
|
||||
Additional arguments supported by LMDeploy:
|
||||
- top_k (int): The number of the highest probability vocabulary
|
||||
tokens to keep for top-k-filtering
|
||||
- ignore_eos (bool): indicator for ignoring eos
|
||||
- skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True."""
|
||||
self.url = vlm_config.url
|
||||
|
||||
self.temperature: float = 0.7
|
||||
self.top_p:float = 1
|
||||
self.max_tokens: (int |None) = 512
|
||||
self.repetition_penalty: float = 1
|
||||
self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>']
|
||||
|
||||
self.top_k: (int) = None
|
||||
self.ignore_eos: (bool) = False
|
||||
self.skip_special_tokens: (bool) = True
|
||||
|
||||
self.settings: dict = {
|
||||
"temperature": self.temperature,
|
||||
"top_p":self.top_p,
|
||||
"max_tokens": self.max_tokens,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"stop": self.stop,
|
||||
"top_k": self.top_k,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
}
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -24,38 +82,116 @@ class VLMS(Blackbox):
|
||||
data = args[0]
|
||||
return isinstance(data, list)
|
||||
|
||||
def processing(self, prompt, images, model_name: Optional[str] = None) -> str:
|
||||
def processing(self, prompt:str, images:str | bytes, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: a string query to the model.
|
||||
images: a base64 string of image data;
|
||||
user_context: a list of history conversation, should be a list of openai format.
|
||||
settings: a dictionary set by user with fields stated in __init__
|
||||
|
||||
if model_name == "Qwen-VL-Chat":
|
||||
model_name = "infer-qwen-vl"
|
||||
elif model_name == "llava-llama-3-8b-v1_1-transformers":
|
||||
model_name = "infer-lav-lam-v1-1"
|
||||
Return:
|
||||
response: a string
|
||||
history: a list
|
||||
"""
|
||||
# if model_name == "Qwen-VL-Chat":
|
||||
# model_name = "infer-qwen-vl"
|
||||
# elif model_name == "llava-llama-3-8b-v1_1-transformers":
|
||||
# model_name = "infer-lav-lam-v1-1"
|
||||
# else:
|
||||
# model_name = "infer-qwen-vl"
|
||||
if settings:
|
||||
for k in settings:
|
||||
if k not in self.settings:
|
||||
print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys()))
|
||||
settings.pop(k)
|
||||
tmp = copy.deepcopy(self.settings)
|
||||
tmp.update(settings)
|
||||
settings = tmp
|
||||
else:
|
||||
model_name = "infer-qwen-vl"
|
||||
settings = {}
|
||||
|
||||
url = 'http://120.196.116.194:48894/' + model_name + '/'
|
||||
|
||||
if is_base64(images):
|
||||
# Transform the images into base64 format where openai format need.
|
||||
if is_base64(images): # image as base64 str
|
||||
images_data = images
|
||||
else:
|
||||
with open(images, "rb") as img_file:
|
||||
images_data = str(base64.b64encode(img_file.read()), 'utf-8')
|
||||
elif isinstance(images,bytes): # image as bytes
|
||||
images_data = str(base64.b64encode(images),'utf-8')
|
||||
else: # image as pathLike str
|
||||
# with open(images, "rb") as img_file:
|
||||
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
|
||||
res = requests.get(images)
|
||||
images_data = str(base64.b64encode(res.content),'utf-8')
|
||||
## AutoLoad Model
|
||||
# url = 'http://10.6.80.87:8000/' + model_name + '/'
|
||||
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
|
||||
# data = requests.post(url, json=data_input)
|
||||
# print(data.text)
|
||||
# return data.text
|
||||
|
||||
data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
|
||||
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
|
||||
## Lmdeploy
|
||||
if not user_context:
|
||||
user_context = []
|
||||
# user_context = [{'role':'user','content':'你好'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
|
||||
api_client = APIClient(self.url)
|
||||
model_name = api_client.available_models[0]
|
||||
|
||||
messages = user_context + [{
|
||||
'role': 'user',
|
||||
'content': [{
|
||||
'type': 'text',
|
||||
'text': prompt,
|
||||
}, {
|
||||
'type': 'image_url',
|
||||
'image_url': {
|
||||
'url': f"data:image/jpeg;base64,{images_data}",
|
||||
# './val_data/image_5.jpg',
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
responses = ''
|
||||
total_token_usage = 0 # which can be used to count the cost of a query
|
||||
for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
|
||||
messages=messages,stream = True,
|
||||
**settings,
|
||||
# session_id=,
|
||||
)):
|
||||
# Stream output
|
||||
print(item["choices"][0]["delta"]['content'],end='')
|
||||
responses += item["choices"][0]["delta"]['content']
|
||||
|
||||
# print(item["choices"][0]["message"]['content'])
|
||||
# responses += item["choices"][0]["message"]['content']
|
||||
# total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
|
||||
|
||||
user_context = messages + [{'role': 'assistant', 'content': responses}]
|
||||
return responses, user_context
|
||||
|
||||
data = requests.post(url, json=data_input)
|
||||
|
||||
return data.text
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
json_request = True
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
content_type = request.headers['content-type']
|
||||
if content_type == 'application/json':
|
||||
data = await request.json()
|
||||
else:
|
||||
data = await request.form()
|
||||
json_request = False
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
model_name = data.get("model_name")
|
||||
prompt = data.get("prompt")
|
||||
img_data = data.get("img_data")
|
||||
|
||||
if json_request:
|
||||
img_data = data.get("img_data")
|
||||
settings: dict = data.get('settings')
|
||||
else:
|
||||
img_data = await data.get("img_data").read()
|
||||
settings: dict = ast.literal_eval(data.get('settings'))
|
||||
|
||||
if prompt is None:
|
||||
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
@ -63,5 +199,7 @@ class VLMS(Blackbox):
|
||||
if model_name is None or model_name.isspace():
|
||||
model_name = "Qwen-VL-Chat"
|
||||
|
||||
jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
|
||||
return JSONResponse(content={"response": jsonresp}, status_code=status.HTTP_200_OK)
|
||||
response, history = self.processing(prompt, img_data,settings, model_name)
|
||||
# jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
|
||||
|
||||
return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK)
|
||||
120
src/blackbox/workflow.py
Normal file
120
src/blackbox/workflow.py
Normal file
@ -0,0 +1,120 @@
|
||||
|
||||
from .sum import Sum
|
||||
from fastapi import Request, Response
|
||||
from .blackbox import Blackbox
|
||||
from injector import singleton, inject
|
||||
from ..dotchain.runtime.interpreter import program_parser
|
||||
from ..dotchain.runtime.runtime import Runtime
|
||||
from ..dotchain.runtime.tokenizer import Tokenizer
|
||||
from ..dotchain.runtime.ast import Literal
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def read_binay(io):
|
||||
return Literal(io.read())
|
||||
|
||||
def new_map():
|
||||
return Literal({})
|
||||
|
||||
def set_map_value(map: dict, key, value):
|
||||
map[key] = value
|
||||
return map
|
||||
|
||||
def jsonfiy(obj):
|
||||
return Literal(json.dumps(obj))
|
||||
|
||||
def get_map_value(d: dict, key):
|
||||
value = d.get(key)
|
||||
if value is dict:
|
||||
return value
|
||||
if value is list:
|
||||
return value
|
||||
return Literal(value)
|
||||
def get_map_int(d: dict, key):
|
||||
value = d.get(key)
|
||||
return Literal(int(value))
|
||||
|
||||
@singleton
|
||||
class Workflow(Blackbox):
|
||||
|
||||
@inject
|
||||
def __init__(self, sum: Sum) -> None:
|
||||
self.sum_blackbox = sum
|
||||
self.cost = 0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def text_to_audio(self, text):
|
||||
return Literal(self.text_to_audio_blackbox.processing(text))
|
||||
|
||||
def sum(self, *args, **kwargs):
|
||||
return Literal(self.sum_blackbox.processing(*args, **kwargs))
|
||||
|
||||
def get_cost(self):
|
||||
return self.cost
|
||||
|
||||
def blackbox_example(self):
|
||||
self.cost_increase(10);
|
||||
return Literal("Blackbox result")
|
||||
|
||||
def cost_increase(self, cost):
|
||||
self.cost+=cost
|
||||
return self.cost
|
||||
|
||||
async def processing(self, *args, **kwargs):
|
||||
request: Request = args[0]
|
||||
json = await request.json()
|
||||
content = None
|
||||
mdeia_type = None
|
||||
headers = {}
|
||||
def set_content(c):
|
||||
nonlocal content
|
||||
content = c
|
||||
def set_media_type(m):
|
||||
nonlocal mdeia_type
|
||||
mdeia_type = m
|
||||
def add_header(key, value):
|
||||
nonlocal headers
|
||||
headers[key] = value
|
||||
script = request.query_params["script"]
|
||||
t = Tokenizer()
|
||||
t.init(script)
|
||||
runtime = Runtime(
|
||||
context={"json": json},
|
||||
exteral_fun={
|
||||
"print": print,
|
||||
"new_map": new_map,
|
||||
"set_map_value": set_map_value,
|
||||
"get_map_value": get_map_value,
|
||||
"set_content": set_content,
|
||||
"set_media_type": set_media_type,
|
||||
"add_header": add_header,
|
||||
"sum": self.sum,
|
||||
"read_binay": read_binay,
|
||||
"jsonfiy": jsonfiy,
|
||||
"get_map_int": get_map_int,
|
||||
"blackbox_example": self.blackbox_example,
|
||||
"get_cost": self.get_cost,
|
||||
}
|
||||
)
|
||||
ast = program_parser(t)
|
||||
ast.exec(runtime)
|
||||
return Response(content=content, media_type=mdeia_type, headers=headers)
|
||||
|
||||
def valid(self, *args, **kwargs) -> bool:
|
||||
return super().valid(*args, **kwargs)
|
||||
|
||||
async def fast_api_handler(self, request: Request):
|
||||
return await self.processing(request)
|
||||
|
||||
"""
|
||||
let b = new_map();
|
||||
set_map(b,"man","good");
|
||||
let a = new_map();
|
||||
set_map(a, "hello", "world");
|
||||
set_map(a, "hello2", b);
|
||||
set_media_type("application/json");
|
||||
set_content(jsonfiy(a));
|
||||
"""
|
||||
@ -65,6 +65,40 @@ class MeloConf():
|
||||
self.language = config.get("melotts.language")
|
||||
self.speaker = config.get("melotts.speaker")
|
||||
|
||||
class CosyVoiceConf():
|
||||
mode: str
|
||||
url: str
|
||||
speed: int
|
||||
device: str
|
||||
language: str
|
||||
speaker: str
|
||||
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.mode = config.get("cosyvoicetts.mode")
|
||||
self.url = config.get("cosyvoicetts.url")
|
||||
self.speed = config.get("cosyvoicetts.speed")
|
||||
self.device = config.get("cosyvoicetts.device")
|
||||
self.language = config.get("cosyvoicetts.language")
|
||||
self.speaker = config.get("cosyvoicetts.speaker")
|
||||
|
||||
class SenseVoiceConf():
|
||||
mode: str
|
||||
url: str
|
||||
speed: int
|
||||
device: str
|
||||
language: str
|
||||
speaker: str
|
||||
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.mode = config.get("sensevoiceasr.mode")
|
||||
self.url = config.get("sensevoiceasr.url")
|
||||
self.speed = config.get("sensevoiceasr.speed")
|
||||
self.device = config.get("sensevoiceasr.device")
|
||||
self.language = config.get("sensevoiceasr.language")
|
||||
self.speaker = config.get("sensevoiceasr.speaker")
|
||||
|
||||
# 'CRITICAL': CRITICAL,
|
||||
# 'FATAL': FATAL,
|
||||
# 'ERROR': ERROR,
|
||||
@ -93,6 +127,7 @@ class LogConf():
|
||||
self.filename = config.get("log.filename")
|
||||
self.time_format = config.get("log.time_format", default=DEFAULT_TIME_FORMAT)
|
||||
|
||||
|
||||
@singleton
|
||||
class EnvConf():
|
||||
version: str
|
||||
@ -112,3 +147,10 @@ class BlackboxConf():
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.lazyloading = bool(config.get("blackbox.lazyloading", default=False))
|
||||
|
||||
@singleton
|
||||
class VLMConf():
|
||||
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.url = config.get("vlms.url")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# Dotchain
|
||||
Dotchain 是一種函數式編程語言. 文件後綴`.dc`
|
||||
|
||||
# 語法
|
||||
```
|
||||
@ -14,11 +13,6 @@ let add = (left, right) => {
|
||||
return left + right
|
||||
}
|
||||
|
||||
// TODO: 函數呼叫
|
||||
add(1,2)
|
||||
add(3, add(1,2))
|
||||
// 以 . 呼叫函數,將以 . 前的值作為第一個參數
|
||||
// hello.add(2) 等價於 add(hello, 2)
|
||||
```
|
||||
## Keywords
|
||||
```
|
||||
|
||||
@ -2,28 +2,15 @@
|
||||
from runtime.interpreter import program_parser
|
||||
from runtime.runtime import Runtime
|
||||
from runtime.tokenizer import Tokenizer
|
||||
import json
|
||||
|
||||
script = """
|
||||
let rec = (c) => {
|
||||
print(c);
|
||||
if c == 0 {
|
||||
return "c + 1";
|
||||
}
|
||||
rec(c-1);
|
||||
}
|
||||
|
||||
let main = () => {
|
||||
print("hello 嘉妮");
|
||||
print(rec(10));
|
||||
}
|
||||
|
||||
main();
|
||||
print(hello);
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
t = Tokenizer()
|
||||
t.init(script)
|
||||
runtime = Runtime(exteral_fun={"print": print})
|
||||
runtime = Runtime(context={"hello": [1,2,3,4], "good": "123"} ,exteral_fun={"print": print})
|
||||
ast = program_parser(t)
|
||||
result = ast.exec(runtime)
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from attr import dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .runtime import Runtime
|
||||
|
||||
@ -38,6 +38,10 @@ class Expression(Node):
|
||||
@dataclass
|
||||
class Literal(Expression):
|
||||
value: str | int | float | bool
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def eval(self, runtime: Runtime):
|
||||
return self.value
|
||||
|
||||
|
||||
@ -355,26 +355,19 @@ def expression_list_to_binary(expression_list: list[Expression | Token], stack:
|
||||
return expression_list_to_binary(expression_list[1:], stack)
|
||||
|
||||
def _priority(operator: str):
|
||||
priority = 0
|
||||
if operator in ["*", "/", "%"]:
|
||||
return priority
|
||||
priority += 1
|
||||
return 0
|
||||
if operator in ["+", "-"]:
|
||||
return priority
|
||||
priority += 1
|
||||
return 1
|
||||
if operator in ["<", ">", "<=", ">="]:
|
||||
return priority
|
||||
priority += 1
|
||||
return 2
|
||||
if operator in ["==", "!="]:
|
||||
return priority
|
||||
priority += 1
|
||||
return 3
|
||||
if operator in ["&&"]:
|
||||
return priority
|
||||
priority += 1
|
||||
return 4
|
||||
if operator in ["||"]:
|
||||
return priority
|
||||
priority += 1
|
||||
return priority
|
||||
return 5
|
||||
return 6
|
||||
|
||||
def _try_assignment_expression(tkr: Tokenizer):
|
||||
tkr = copy.deepcopy(tkr)
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from ast import Expression
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
class Runtime():
|
||||
|
||||
@ -41,4 +38,3 @@ class Runtime():
|
||||
|
||||
def show_values(self):
|
||||
print(self.context)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
from attr import dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
class TokenType(Enum):
|
||||
NEW_LINE = 1
|
||||
|
||||
@ -5,7 +5,7 @@ from transformers import BertTokenizer
|
||||
import numpy as np
|
||||
|
||||
dirabspath = __file__.split("\\")[1:-1]
|
||||
dirabspath= "C://" + "/".join(dirabspath)
|
||||
dirabspath= "/home/gpu/Workspace/jarvis-models/src/sentiment_engine" + "/".join(dirabspath)
|
||||
default_path = dirabspath + "/models/paimon_sentiment.onnx"
|
||||
|
||||
|
||||
|
||||
@ -19,8 +19,62 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dirbaspath = __file__.split("\\")[1:-1]
|
||||
dirbaspath= "C://" + "/".join(dirbaspath)
|
||||
dirbaspath= "/home/gpu/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath)
|
||||
config = {
|
||||
'ayaka': {
|
||||
'cfg': dirbaspath + '/models/ayaka.json',
|
||||
'model': dirbaspath + '/models/ayaka_167k.pth',
|
||||
'char': 'character_ayaka',
|
||||
'speed': 1
|
||||
},
|
||||
'catmix': {
|
||||
'cfg': dirbaspath + '/models/catmix.json',
|
||||
'model': dirbaspath + '/models/catmix_107k.pth',
|
||||
'char': 'character_catmix',
|
||||
'speed': 1.1
|
||||
},
|
||||
'noelle': {
|
||||
'cfg': dirbaspath + '/models/noelle.json',
|
||||
'model': dirbaspath + '/models/noelle_337k.pth',
|
||||
'char': 'character_noelle',
|
||||
'speed': 1.1
|
||||
},
|
||||
'miko': {
|
||||
'cfg': dirbaspath + '/models/miko.json',
|
||||
'model': dirbaspath + '/models/miko_139k.pth',
|
||||
'char': 'character_miko',
|
||||
'speed': 1.1
|
||||
},
|
||||
'nahida': {
|
||||
'cfg': dirbaspath + '/models/nahida.json',
|
||||
'model': dirbaspath + '/models/nahida_129k.pth',
|
||||
'char': 'character_nahida',
|
||||
'speed': 1.1
|
||||
},
|
||||
'ningguang': {
|
||||
'cfg': dirbaspath + '/models/ningguang.json',
|
||||
'model': dirbaspath + '/models/ningguang_179k.pth',
|
||||
'char': 'character_ningguang',
|
||||
'speed': 1.1
|
||||
},
|
||||
'yoimiya': {
|
||||
'cfg': dirbaspath + '/models/yoimiya.json',
|
||||
'model': dirbaspath + '/models/yoimiya_102k.pth',
|
||||
'char': 'character_yoimiya',
|
||||
'speed': 1.1
|
||||
},
|
||||
'yunfeineo': {
|
||||
'cfg': dirbaspath + '/models/yunfeineo.json',
|
||||
'model': dirbaspath + '/models/yunfeineo_25k.pth',
|
||||
'char': 'character_yunfeineo',
|
||||
'speed': 1.1
|
||||
},
|
||||
'zhongli': {
|
||||
'cfg': dirbaspath + '/models/zhongli.json',
|
||||
'model': dirbaspath + '/models/zhongli_44k.pth',
|
||||
'char': 'character_',
|
||||
'speed': 1.1
|
||||
},
|
||||
'paimon': {
|
||||
'cfg': dirbaspath + '/models/paimon6k.json',
|
||||
'model': dirbaspath + '/models/paimon6k_390k.pth',
|
||||
@ -28,7 +82,7 @@ config = {
|
||||
'speed': 1
|
||||
},
|
||||
'yunfei': {
|
||||
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
|
||||
'cfg': dirbaspath + '/models/yunfeimix2.json',
|
||||
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
|
||||
'char': 'character_yunfei',
|
||||
'speed': 1.1
|
||||
|
||||
@ -93,3 +93,4 @@ components:
|
||||
- chroma_upsert
|
||||
- melotts
|
||||
- vlms
|
||||
- cosyvoicetts
|
||||
|
||||
BIN
test_data/voice/2food.wav
Normal file
BIN
test_data/voice/2food.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2forget.wav
Normal file
BIN
test_data/voice/2forget.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2nihao.wav
Normal file
BIN
test_data/voice/2nihao.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2play.wav
Normal file
BIN
test_data/voice/2play.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2weather.wav
Normal file
BIN
test_data/voice/2weather.wav
Normal file
Binary file not shown.
BIN
test_data/voice/food.wav
Normal file
BIN
test_data/voice/food.wav
Normal file
Binary file not shown.
BIN
test_data/voice/forget.wav
Normal file
BIN
test_data/voice/forget.wav
Normal file
Binary file not shown.
BIN
test_data/voice/hello.wav
Normal file
BIN
test_data/voice/hello.wav
Normal file
Binary file not shown.
BIN
test_data/voice/nihao.wav
Normal file
BIN
test_data/voice/nihao.wav
Normal file
Binary file not shown.
BIN
test_data/voice/play.wav
Normal file
BIN
test_data/voice/play.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user