mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
Merge main into Tom
This commit is contained in:
55
README.md
55
README.md
@ -45,14 +45,54 @@ log:
|
||||
time_format: "%Y-%m-%d %H:%M:%S"
|
||||
filename: "D:/Workspace/Logging/jarvis/jarvis-models.log"
|
||||
|
||||
loki:
|
||||
url: "https://loki.bwgdi.com/loki/api/v1/push"
|
||||
labels:
|
||||
app: jarvis
|
||||
env: dev
|
||||
location: "gdi"
|
||||
layer: models
|
||||
|
||||
melotts:
|
||||
mode: local # or docker
|
||||
url: http://10.6.44.16:18080/convert/tts
|
||||
url: http://10.6.44.141:18080/convert/tts
|
||||
speed: 0.9
|
||||
device: 'cuda'
|
||||
device: 'cuda:0'
|
||||
language: 'ZH'
|
||||
speaker: 'ZH'
|
||||
|
||||
cosyvoicetts:
|
||||
mode: local # or docker
|
||||
url: http://10.6.44.141:18080/convert/tts
|
||||
speed: 0.9
|
||||
device: 'cuda:0'
|
||||
language: '粤语女'
|
||||
speaker: 'ZH'
|
||||
|
||||
sovitstts:
|
||||
mode: docker
|
||||
url: http://10.6.80.90:9880/tts
|
||||
speed: 0.9
|
||||
device: 'cuda:0'
|
||||
language: 'ZH'
|
||||
speaker: 'ZH'
|
||||
text_lang: "yue"
|
||||
ref_audio_path: "output/slicer_opt/Ricky-Wong/Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav"
|
||||
prompt_lang: "yue"
|
||||
prompt_text: "你失敗咗點算啊?你而家安安穩穩,點解要咁樣做呢?"
|
||||
text_split_method: "cut5"
|
||||
batch_size: 1
|
||||
media_type: "wav"
|
||||
streaming_mode: True
|
||||
|
||||
sensevoiceasr:
|
||||
mode: local # or docker
|
||||
url: http://10.6.44.141:18080/convert/tts
|
||||
speed: 0.9
|
||||
device: 'cuda:0'
|
||||
language: '粤语女'
|
||||
speaker: 'ZH'
|
||||
|
||||
tesou:
|
||||
url: http://120.196.116.194:48891/chat/
|
||||
|
||||
@ -91,5 +131,14 @@ blackbox:
|
||||
lazyloading: true
|
||||
|
||||
vlms:
|
||||
url: http://10.6.80.87:23333
|
||||
urls:
|
||||
qwen_vl: http://10.6.80.87:8000
|
||||
qwen2_vl: http://10.6.80.87:23333
|
||||
qwen2_vl_72b: http://10.6.80.91:23333
|
||||
|
||||
path:
|
||||
chroma_rerank_embedding_model: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI
|
||||
cosyvoice_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/Workspace/CosyVoice
|
||||
cosyvoice_model_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/Voice/CosyVoice/pretrained_models
|
||||
sensevoice_model_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/Voice/SenseVoice/SenseVoiceSmall
|
||||
```
|
||||
|
||||
@ -9,4 +9,5 @@ chromadb==0.5.0
|
||||
langchain==0.1.17
|
||||
langchain-community==0.0.36
|
||||
sentence-transformers==2.7.0
|
||||
openai
|
||||
openai
|
||||
python-logging-loki
|
||||
@ -7,20 +7,23 @@ from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
path = Path("/media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI")
|
||||
|
||||
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
|
||||
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
|
||||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
|
||||
loader = TextLoader("./RAG_boss.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
docs = text_splitter.split_documents(documents)
|
||||
print("len(docs)", len(docs))
|
||||
ids = ["粤语语料"+str(i) for i in range(len(docs))]
|
||||
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name= str(path / "bge-m3"), model_kwargs={"device": "cuda:0"})
|
||||
client = chromadb.HttpClient(host="localhost", port=7000)
|
||||
|
||||
id = "kiki"
|
||||
id = "boss2"
|
||||
# client.delete_collection(id)
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
@ -28,13 +31,13 @@ db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, c
|
||||
|
||||
|
||||
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name= str(path / "bge-m3"), device = "cuda:0")
|
||||
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
client = chromadb.HttpClient(host='localhost', port=7000)
|
||||
|
||||
collection = client.get_collection(id, embedding_function=embedding_model)
|
||||
|
||||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
|
||||
reranker_model = CrossEncoder(str(path / "bge-reranker-v2-m3"), max_length=512, device = "cuda:0")
|
||||
|
||||
# while True:
|
||||
# usr_question = input("\n 请输入问题: ")
|
||||
|
||||
60
server.py
60
server.py
@ -4,8 +4,12 @@ from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from src.blackbox.blackbox_factory import BlackboxFactory
|
||||
from src.log.loki_config import LokiLogger
|
||||
import logging
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from injector import Injector
|
||||
import yaml
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -20,16 +24,64 @@ app.add_middleware(
|
||||
injector = Injector()
|
||||
blackbox_factory = injector.get(BlackboxFactory)
|
||||
|
||||
with open(".env.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
logger = LokiLogger(
|
||||
# url=config["loki"]["url"],
|
||||
# username=config["loki"]["username"],
|
||||
# password=config["loki"]["password"],
|
||||
tags=config["loki"]["labels"],
|
||||
logger_name=__name__,
|
||||
level=logging.DEBUG
|
||||
).get_logger()
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def catch_all_exceptions(request: Request, exc: Exception):
|
||||
"""
|
||||
捕获所有在 ASGI 应用中未被处理的异常,并记录到 Loki。
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception in ASGI application for path: {request.url.path}, method: {request.method} - {exc}",
|
||||
exc_info=True, # 必须为 True,才能获取完整的堆栈信息
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"error_type": type(exc).__name__,
|
||||
"error_message": str(exc)
|
||||
}
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": "Internal Server Error", "detail": "An unexpected error occurred."}
|
||||
)
|
||||
|
||||
@app.post("/")
|
||||
@app.get("/")
|
||||
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
|
||||
if not blackbox_name:
|
||||
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
try:
|
||||
box = blackbox_factory.get_blackbox(blackbox_name)
|
||||
except ValueError:
|
||||
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return await box.fast_api_handler(request)
|
||||
except ValueError as e:
|
||||
logger.error(f"获取 blackbox 失败: {blackbox_name} - {e}", exc_info=True)
|
||||
return JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
try:
|
||||
response = await box.fast_api_handler(request)
|
||||
# 检查响应的状态码,如果 >= 400,则认为是错误响应并记录
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
decoded_content = response.body.decode('utf-8', errors='ignore')
|
||||
logger.error(f"Blackbox 返回错误响应: URL={request.url}, Method={request.method}, Status={response.status_code}, Content={decoded_content}")
|
||||
except Exception as body_exc:
|
||||
logger.error(f"Blackbox {blackbox_name} 返回错误响应: URL={request.url}, Method={request.method}, Status={response.status_code}, 无法解析响应内容: {body_exc}")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Blackbox 内部抛出异常: URL={request.url}, Method={request.method}, 异常报错: {type(e).__name__}: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": "Internal Server Error", "detail": "An unexpected error occurred."}
|
||||
)
|
||||
|
||||
# @app.get("/audio/{filename}")
|
||||
# async def serve_audio(filename: str):
|
||||
|
||||
@ -12,6 +12,9 @@ from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from .blackbox import Blackbox
|
||||
from injector import singleton, inject
|
||||
|
||||
from pathlib import Path
|
||||
from ..configuration import PathConf
|
||||
|
||||
import tempfile
|
||||
|
||||
import json
|
||||
@ -44,12 +47,13 @@ class ASR(Blackbox):
|
||||
speaker: str
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def model_init(self, sensevoice_config: SenseVoiceConf) -> None:
|
||||
def model_init(self, sensevoice_config: SenseVoiceConf, path: PathConf) -> None:
|
||||
|
||||
config = read_yaml(".env.yaml")
|
||||
self.paraformer = RapidParaformer(config)
|
||||
sense_model_path = Path(path.sensevoice_model_path)
|
||||
|
||||
model_dir = "/model/Voice/SenseVoice/SenseVoiceSmall"
|
||||
model_dir = str(sense_model_path)
|
||||
|
||||
self.speed = sensevoice_config.speed
|
||||
self.device = sensevoice_config.device
|
||||
@ -65,7 +69,7 @@ class ASR(Blackbox):
|
||||
self.asr = AutoModel(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
remote_code= "/Workspace/SenseVoice/model.py",
|
||||
remote_code= "../SenseVoice/model.py",
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
device=self.device,
|
||||
@ -76,8 +80,8 @@ class ASR(Blackbox):
|
||||
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...')
|
||||
|
||||
@inject
|
||||
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None:
|
||||
self.model_init(sensevoice_config)
|
||||
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict, path: PathConf) -> None:
|
||||
self.model_init(sensevoice_config, path)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -59,12 +59,13 @@ class Chat(Blackbox):
|
||||
chroma_embedding_model = settings.get("chroma_embedding_model")
|
||||
chroma_collection_id = settings.get("chroma_collection_id")
|
||||
chroma_response = ''
|
||||
system_prompt = settings.get('system_prompt')
|
||||
system_prompt = settings.get('system_prompt',"")
|
||||
user_prompt_template = settings.get('user_prompt_template')
|
||||
user_stream = settings.get('stream')
|
||||
user_websearch = settings.get('websearch')
|
||||
user_thinking = settings.get('thinking', False)
|
||||
|
||||
llm_model = "vllm"
|
||||
llm_model = "llm"
|
||||
|
||||
if user_context == None:
|
||||
user_context = []
|
||||
@ -106,9 +107,9 @@ class Chat(Blackbox):
|
||||
|
||||
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
||||
if llm_model != "vllm":
|
||||
user_model_url = "http://10.6.80.75:23333/v1/chat/completions"
|
||||
user_model_url = "http://localhost:23333/v1/chat/completions"
|
||||
else:
|
||||
user_model_url = "http://10.6.80.94:8000/v1/completions"
|
||||
user_model_url = "http://localhost:8000/v1/completions"
|
||||
|
||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||
if llm_model != "vllm":
|
||||
@ -120,16 +121,16 @@ class Chat(Blackbox):
|
||||
chroma_response = self.chroma_query(user_question, settings)
|
||||
print("1.Chroma_response: \n", chroma_response)
|
||||
|
||||
if chroma_response:
|
||||
if chroma_response and isinstance(chroma_response, str):
|
||||
if chroma_collection_id == 'boss':
|
||||
user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE(目标) #\n回答游客的提问。\n# STYLE(风格)#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口,澳门博维集团董事长,周家俊,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”"
|
||||
elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki':
|
||||
user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE(目标) #\n回答游客的提问。\n# STYLE(风格)#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅,博维,康普可可,琪琪,G2E,RELX,BO VISION相关内容,若遇到其他提问则回答:“对不起,我无法回答此问题哦。”"
|
||||
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
|
||||
user_question = user_prompt_template + "问题: " + user_question + "。检索内容: " + chroma_response + "。"
|
||||
user_question = user_prompt_template + "\n\n问题: " + user_question + "\n\n检索内容: " + chroma_response
|
||||
else:
|
||||
if llm_model != "vllm":
|
||||
user_question = user_prompt_template + "问题: " + user_question + "。"
|
||||
user_question = user_prompt_template + "\n\n问题: " + user_question
|
||||
else:
|
||||
user_question = user_question
|
||||
|
||||
@ -249,6 +250,7 @@ class Chat(Blackbox):
|
||||
"presence_penalty": str(user_presence_penalty),
|
||||
"stop": str(user_stop),
|
||||
"stream": user_stream,
|
||||
"chat_template_kwargs": {"enable_thinking": user_thinking},
|
||||
}
|
||||
else:
|
||||
chat_inputs={
|
||||
@ -267,6 +269,7 @@ class Chat(Blackbox):
|
||||
"presence_penalty":float( user_presence_penalty),
|
||||
# "stop": user_stop,
|
||||
"stream": user_stream,
|
||||
"chat_template_kwargs": {"enable_thinking": user_thinking},
|
||||
}
|
||||
|
||||
# # 获取当前时间戳
|
||||
@ -308,6 +311,7 @@ class Chat(Blackbox):
|
||||
except json.JSONDecodeError:
|
||||
# print("---- Error in JSON parsing ----") # 打印错误信息
|
||||
continue # 继续处理下一个chunk,直到解析成功
|
||||
# yield "[Response END]"
|
||||
|
||||
else:
|
||||
print("*"*90)
|
||||
@ -343,4 +347,4 @@ class Chat(Blackbox):
|
||||
return EventSourceResponse(self.processing(prompt, context, setting))
|
||||
else:
|
||||
response_content = "".join(self.processing(prompt, context, setting))
|
||||
return JSONResponse(content={"response": response_content}, status_code=status.HTTP_200_OK)
|
||||
return JSONResponse(content={"response": response_content}, status_code=status.HTTP_200_OK)
|
||||
|
||||
@ -12,22 +12,29 @@ from ..log.logging_time import logging_time
|
||||
import re
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
from pathlib import Path
|
||||
from ..configuration import Configuration
|
||||
from ..configuration import PathConf
|
||||
|
||||
logger = logging.getLogger
|
||||
DEFAULT_COLLECTION_ID = "123"
|
||||
|
||||
from injector import singleton
|
||||
@singleton
|
||||
class ChromaQuery(Blackbox):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
# config = read_yaml(args[0])
|
||||
# load chromadb and embedding model
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
|
||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0")
|
||||
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
path = PathConf(Configuration())
|
||||
self.model_path = Path(path.chroma_rerank_embedding_model)
|
||||
|
||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0")
|
||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0")
|
||||
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0")
|
||||
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
|
||||
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda")
|
||||
self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
@ -57,10 +64,10 @@ class ChromaQuery(Blackbox):
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
||||
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5")
|
||||
|
||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||
chroma_host = "10.6.44.141"
|
||||
chroma_host = "localhost"
|
||||
|
||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||
chroma_port = "7000"
|
||||
@ -72,7 +79,7 @@ class ChromaQuery(Blackbox):
|
||||
chroma_n_results = 10
|
||||
|
||||
# load client and embedding model from init
|
||||
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port):
|
||||
if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
|
||||
client = self.client_1
|
||||
else:
|
||||
try:
|
||||
@ -80,11 +87,11 @@ class ChromaQuery(Blackbox):
|
||||
except:
|
||||
return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_1
|
||||
elif re.search(r"/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model):
|
||||
elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_2
|
||||
elif re.search(r"/Workspace/Models/BAAI/bge-m3", chroma_embedding_model):
|
||||
elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_3
|
||||
else:
|
||||
try:
|
||||
@ -123,7 +130,7 @@ class ChromaQuery(Blackbox):
|
||||
final_result = str(results["documents"])
|
||||
|
||||
if chroma_reranker_model:
|
||||
if re.search(r"/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_reranker_model):
|
||||
if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model):
|
||||
reranker_model = self.reranker_model_1
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -8,7 +8,7 @@ import requests
|
||||
import json
|
||||
|
||||
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader, UnstructuredPDFLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
||||
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||
@ -21,6 +21,10 @@ import logging
|
||||
from ..log.logging_time import logging_time
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from ..configuration import Configuration
|
||||
from ..configuration import PathConf
|
||||
|
||||
logger = logging.getLogger
|
||||
DEFAULT_COLLECTION_ID = "123"
|
||||
|
||||
@ -31,9 +35,13 @@ class ChromaUpsert(Blackbox):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
# config = read_yaml(args[0])
|
||||
# load embedding model
|
||||
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
|
||||
|
||||
path = PathConf(Configuration())
|
||||
self.model_path = Path(path.chroma_rerank_embedding_model)
|
||||
|
||||
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name=str(self.model_path / "bge-large-zh-v1.5"), model_kwargs={"device": "cuda"})
|
||||
# load chroma db
|
||||
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
@ -51,45 +59,44 @@ class ChromaUpsert(Blackbox):
|
||||
# # chroma_query settings
|
||||
if "settings" in settings:
|
||||
chroma_embedding_model = settings["settings"].get("chroma_embedding_model")
|
||||
chroma_host = settings["settings"].get("chroma_host")
|
||||
chroma_port = settings["settings"].get("chroma_port")
|
||||
chroma_collection_id = settings["settings"].get("chroma_collection_id")
|
||||
chroma_host = settings["settings"].get("chroma_host", "localhost")
|
||||
chroma_port = settings["settings"].get("chroma_port", "7000")
|
||||
chroma_collection_id = settings["settings"].get("chroma_collection_id", DEFAULT_COLLECTION_ID)
|
||||
user_chunk_size = settings["settings"].get("chunk_size", 256)
|
||||
user_chunk_overlap = settings["settings"].get("chunk_overlap", 10)
|
||||
user_separators = settings["settings"].get("separators", ["\n\n"])
|
||||
else:
|
||||
chroma_embedding_model = settings.get("chroma_embedding_model")
|
||||
chroma_host = settings.get("chroma_host")
|
||||
chroma_port = settings.get("chroma_port")
|
||||
chroma_collection_id = settings.get("chroma_collection_id")
|
||||
chroma_host = settings.get("chroma_host", "localhost")
|
||||
chroma_port = settings.get("chroma_port", "7000")
|
||||
chroma_collection_id = settings.get("chroma_collection_id", DEFAULT_COLLECTION_ID)
|
||||
user_chunk_size = settings.get("chunk_size", 256)
|
||||
user_chunk_overlap = settings.get("chunk_overlap", 10)
|
||||
user_separators = settings.get("separators", ["\n\n"])
|
||||
|
||||
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
||||
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||
|
||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||
chroma_host = "10.6.82.192"
|
||||
|
||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||
chroma_port = "8000"
|
||||
|
||||
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
|
||||
chroma_collection_id = "g2e"
|
||||
chroma_embedding_model = model_name=str(self.model_path / "bge-large-zh-v1.5")
|
||||
|
||||
# load client and embedding model from init
|
||||
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port):
|
||||
if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
|
||||
client = self.client_1
|
||||
else:
|
||||
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
|
||||
print(f"chroma_embedding_model: {chroma_embedding_model}")
|
||||
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||
if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
|
||||
embedding_model = self.embedding_model_1
|
||||
else:
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0")
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, model_kwargs={"device": "cuda"})
|
||||
|
||||
response_file =''
|
||||
response_string = ''
|
||||
if file is not None:
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=user_chunk_size, chunk_overlap=user_chunk_overlap, separators=user_separators)
|
||||
|
||||
file_type = file.split(".")[-1]
|
||||
print("file_type: ",file_type)
|
||||
if file_type == "pdf":
|
||||
loader = PyPDFLoader(file)
|
||||
loader = UnstructuredPDFLoader(file)
|
||||
elif file_type == "txt":
|
||||
loader = TextLoader(file)
|
||||
elif file_type == "csv":
|
||||
@ -102,9 +109,10 @@ class ChromaUpsert(Blackbox):
|
||||
loader = Docx2txtLoader(file)
|
||||
elif file_type == "xlsx":
|
||||
loader = UnstructuredExcelLoader(file)
|
||||
elif file_type == "md":
|
||||
loader = UnstructuredMarkdownLoader(file, mode="single", strategy="fast")
|
||||
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
||||
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
@ -148,13 +156,13 @@ class ChromaUpsert(Blackbox):
|
||||
if user_text is not None and user_text_ids is None:
|
||||
return JSONResponse(content={"error": "text_ids is required when text is provided"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if user_file is not None:
|
||||
if user_file is not None and user_file.size != 0:
|
||||
pdf_bytes = await user_file.read()
|
||||
|
||||
custom_filename = user_file.filename
|
||||
# 获取系统的临时目录路径
|
||||
safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename))
|
||||
|
||||
print("file_path", safe_filename)
|
||||
with open(safe_filename, "wb") as f:
|
||||
f.write(pdf_bytes)
|
||||
else:
|
||||
|
||||
@ -10,14 +10,15 @@ from ..tts.tts_service import TTService
|
||||
from ..configuration import MeloConf
|
||||
from ..configuration import CosyVoiceConf
|
||||
from ..configuration import SovitsConf
|
||||
from ..configuration import PathConf
|
||||
|
||||
from injector import inject
|
||||
from injector import singleton
|
||||
|
||||
import sys,os
|
||||
sys.path.append('/Workspace/CosyVoice')
|
||||
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
# from cosyvoice.utils.file_utils import load_wav, speed_change
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
import soundfile as sf
|
||||
import pyloudnorm as pyln
|
||||
@ -33,6 +34,7 @@ import numpy as np
|
||||
|
||||
from pydub import AudioSegment
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
@ -99,7 +101,18 @@ class TTS(Blackbox):
|
||||
print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None:
|
||||
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf, path: PathConf) -> None:
|
||||
|
||||
cosy_path = Path(path.cosyvoice_path)
|
||||
cosy_model_path = Path(path.cosyvoice_model_path)
|
||||
|
||||
sys.path.append(str(cosy_path))
|
||||
sys.path.append(str(cosy_path / "third_party/Matcha-TTS"))
|
||||
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav#, speed_change
|
||||
|
||||
|
||||
self.cosyvoice_speed = cosyvoice_config.speed
|
||||
self.cosyvoice_device = cosyvoice_config.device
|
||||
self.cosyvoice_language = cosyvoice_config.language
|
||||
@ -107,12 +120,13 @@ class TTS(Blackbox):
|
||||
self.cosyvoice_url = ''
|
||||
self.cosyvoice_mode = cosyvoice_config.mode
|
||||
self.cosyvoicetts = None
|
||||
self.prompt_speech_16k = None
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
|
||||
if self.cosyvoice_mode == 'local':
|
||||
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
|
||||
|
||||
# self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||
self.cosyvoicetts = CosyVoice2(str (cosy_model_path / "CosyVoice2-0.5B"), load_jit=True, load_trt=False)
|
||||
self.prompt_speech_16k = load_wav("./Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav", 16000)
|
||||
|
||||
else:
|
||||
self.cosyvoice_url = cosyvoice_config.url
|
||||
@ -149,16 +163,31 @@ class TTS(Blackbox):
|
||||
|
||||
|
||||
@inject
|
||||
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None:
|
||||
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict, path: PathConf) -> None:
|
||||
self.tts_service = TTService("yunfeineo")
|
||||
self.melo_model_init(melo_config)
|
||||
self.cosyvoice_model_init(cosyvoice_config)
|
||||
self.cosyvoice_model_init(cosyvoice_config, path)
|
||||
self.sovits_model_init(sovits_config)
|
||||
self.audio_dir = "audio_files" # 存储音频文件的目录
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
def filter_invalid_chars(self,text):
|
||||
"""过滤无效字符(包括字节流)"""
|
||||
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
|
||||
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode('utf-8', errors='ignore')
|
||||
|
||||
for keyword in invalid_keywords:
|
||||
text = text.replace(keyword, "")
|
||||
|
||||
# 移除所有英文字母和符号(保留中文、标点等)
|
||||
text = re.sub(r'[a-zA-Z]', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
@logging_time(logger=logger)
|
||||
def processing(self, *args, settings: dict) -> io.BytesIO:
|
||||
|
||||
@ -213,7 +242,8 @@ class TTS(Blackbox):
|
||||
elif chroma_collection_id == 'boss':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(35616313)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -234,12 +264,44 @@ class TTS(Blackbox):
|
||||
set_all_random_seed(56056558)
|
||||
print("*"*90)
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
|
||||
# audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
|
||||
# for i, j in enumerate(audio):
|
||||
# f = io.BytesIO()
|
||||
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
# f.seek(0)
|
||||
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
# return f.read()
|
||||
# 打印 audio 的长度和内容结构
|
||||
# print(f"Total audio segments: {len(audio)}")
|
||||
# print(f"Audio data structure: {audio}")
|
||||
|
||||
# 创建一个空的列表来存储所有音频段的 NumPy 数组
|
||||
all_audio_data = []
|
||||
|
||||
# 遍历每一段音频并将它们存储到 all_audio_data 列表
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
f.seek(0)
|
||||
# print(f"Processing segment {i + 1}...")
|
||||
|
||||
# 打印每段音频的信息,确保其正确
|
||||
# print(f"Segment {i + 1} shape: {j['tts_speech'].shape}")
|
||||
|
||||
# 直接将音频数据转换成 NumPy 数组
|
||||
audio_data = j['tts_speech'].cpu().numpy()
|
||||
|
||||
# 将每个段的音频数据添加到 all_audio_data 列表
|
||||
all_audio_data.append(audio_data[0]) # 取音频的第一个通道(假设为单声道)
|
||||
|
||||
# 将所有音频段的 NumPy 数组合并成一个完整的音频数组
|
||||
combined_audio_data = np.concatenate(all_audio_data, axis=0)
|
||||
|
||||
# 将合并后的音频数据写入到 BytesIO 中
|
||||
f = io.BytesIO()
|
||||
sf.write(f, combined_audio_data, 22050, format='wav') # 22050 为采样率,可能需要根据实际情况调整
|
||||
f.seek(0)
|
||||
|
||||
# 返回合并后的音频
|
||||
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
return f.read()
|
||||
return f.read() # 返回最终合并后的音频数据
|
||||
else:
|
||||
message = {
|
||||
"text": text
|
||||
@ -250,7 +312,8 @@ class TTS(Blackbox):
|
||||
elif chroma_collection_id == 'boss':
|
||||
if self.cosyvoice_mode == 'local':
|
||||
set_all_random_seed(35616313)
|
||||
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
@ -266,6 +329,7 @@ class TTS(Blackbox):
|
||||
return response.content
|
||||
|
||||
elif user_model_name == 'sovitstts':
|
||||
# text = self.filter_invalid_chars(text)
|
||||
if chroma_collection_id == 'kiki' or chroma_collection_id is None:
|
||||
if self.sovits_mode == 'local':
|
||||
set_all_random_seed(56056558)
|
||||
@ -286,9 +350,9 @@ class TTS(Blackbox):
|
||||
"text_split_method": self.sovits_text_split_method,
|
||||
"batch_size": self.sovits_batch_size,
|
||||
"media_type": self.sovits_media_type,
|
||||
"streaming_mode": self.sovits_streaming_mode
|
||||
"streaming_mode": user_stream
|
||||
}
|
||||
if user_stream:
|
||||
if user_stream == True or str(user_stream).lower() == "true":
|
||||
response = requests.get(self.sovits_url, params=message, stream=True)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response
|
||||
@ -299,23 +363,37 @@ class TTS(Blackbox):
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
|
||||
|
||||
# elif chroma_collection_id == 'boss':
|
||||
# if self.cosyvoice_mode == 'local':
|
||||
# set_all_random_seed(35616313)
|
||||
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
|
||||
# for i, j in enumerate(audio):
|
||||
# f = io.BytesIO()
|
||||
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
# f.seek(0)
|
||||
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
# return f.read()
|
||||
# else:
|
||||
# message = {
|
||||
# "text": text
|
||||
# }
|
||||
# response = requests.post(self.cosyvoice_url, json=message)
|
||||
# print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
|
||||
# return response.content
|
||||
elif chroma_collection_id == 'boss':
|
||||
if self.sovits_mode == 'local':
|
||||
set_all_random_seed(56056558)
|
||||
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
|
||||
for i, j in enumerate(audio):
|
||||
f = io.BytesIO()
|
||||
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||
f.seek(0)
|
||||
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||
return f.read()
|
||||
else:
|
||||
message = {
|
||||
"text": text,
|
||||
"text_lang": self.sovits_text_lang,
|
||||
"ref_audio_path": self.sovits_ref_audio_path,
|
||||
"prompt_lang": self.sovits_prompt_lang,
|
||||
"prompt_text": self.sovits_prompt_text,
|
||||
"text_split_method": self.sovits_text_split_method,
|
||||
"batch_size": self.sovits_batch_size,
|
||||
"media_type": self.sovits_media_type,
|
||||
"streaming_mode": user_stream
|
||||
}
|
||||
if user_stream == True or str(user_stream).lower() == "true":
|
||||
response = requests.get(self.sovits_url, params=message, stream=True)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response
|
||||
else:
|
||||
response = requests.get(self.sovits_url, params=message)
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
return response.content
|
||||
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
|
||||
|
||||
|
||||
elif user_model_name == 'man':
|
||||
@ -359,9 +437,12 @@ class TTS(Blackbox):
|
||||
if text is None:
|
||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
by = self.processing(text, settings=setting)
|
||||
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
|
||||
if user_stream:
|
||||
if user_stream not in (True, "True", "true"):
|
||||
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||
print(f"tts user_stream: {type(user_stream)}")
|
||||
# import pdb; pdb.set_trace()
|
||||
if user_stream in (True, "True", "true"):
|
||||
print(f"tts user_stream22: {user_stream}")
|
||||
if by.status_code == 200:
|
||||
print("*"*90)
|
||||
def audio_stream():
|
||||
@ -405,6 +486,7 @@ class TTS(Blackbox):
|
||||
|
||||
else:
|
||||
wav_filename = os.path.join(self.audio_dir, 'audio.wav')
|
||||
print("8"*90)
|
||||
with open(wav_filename, 'wb') as f:
|
||||
f.write(by)
|
||||
|
||||
|
||||
@ -13,13 +13,19 @@ import requests
|
||||
import base64
|
||||
import copy
|
||||
import ast
|
||||
|
||||
import json
|
||||
|
||||
import random
|
||||
from time import time
|
||||
|
||||
|
||||
import io
|
||||
from PIL import Image
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def is_base64(value) -> bool:
|
||||
try:
|
||||
@ -52,14 +58,14 @@ class VLMS(Blackbox):
|
||||
- skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True."""
|
||||
self.model_dict = vlm_config.urls
|
||||
self.model_url = None
|
||||
self.available_models = {}
|
||||
self.temperature: float = 0.7
|
||||
self.top_p:float = 1
|
||||
self.max_tokens: (int |None) = 512
|
||||
self.repetition_penalty: float = 1
|
||||
self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>']
|
||||
|
||||
self.top_k: (int) = None
|
||||
self.top_k: (int) = 40
|
||||
self.ignore_eos: (bool) = False
|
||||
self.skip_special_tokens: (bool) = True
|
||||
|
||||
@ -74,7 +80,13 @@ class VLMS(Blackbox):
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
}
|
||||
|
||||
|
||||
for model, url in self.model_dict.items():
|
||||
try:
|
||||
response = requests.get(url+'/health',timeout=3)
|
||||
if response.status_code == 200:
|
||||
self.available_models[model] = url
|
||||
except Exception as e:
|
||||
pass
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -82,7 +94,7 @@ class VLMS(Blackbox):
|
||||
data = args[0]
|
||||
return isinstance(data, list)
|
||||
|
||||
def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str:
|
||||
def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, user_context: List[dict] = None) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: a string query to the model.
|
||||
@ -94,22 +106,29 @@ class VLMS(Blackbox):
|
||||
response: a string
|
||||
history: a list
|
||||
"""
|
||||
|
||||
config: dict = {
|
||||
"lmdeploy_infer":True,
|
||||
"system_prompt":"",
|
||||
"vlm_model_name":"",
|
||||
}
|
||||
if settings:
|
||||
for k in settings:
|
||||
for k in list(settings.keys()):
|
||||
if k not in self.settings:
|
||||
print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys()))
|
||||
settings.pop(k)
|
||||
config[k] = settings.pop(k)
|
||||
tmp = copy.deepcopy(self.settings)
|
||||
tmp.update(settings)
|
||||
settings = tmp
|
||||
else:
|
||||
settings = {}
|
||||
|
||||
config['lmdeploy_infer'] = str(config['lmdeploy_infer']).strip().lower() == 'true'
|
||||
|
||||
if not prompt:
|
||||
prompt = '你是一个辅助机器人,请就此图做一个简短的概括性描述,包括图中的主体物品及状态,不超过50字。' if images else '你好'
|
||||
|
||||
# Transform the images into base64 format where openai format need.
|
||||
# Transform the images into base64 format where openai url)
|
||||
|
||||
if images:
|
||||
if is_base64(images): # image as base64 str
|
||||
images_data = images
|
||||
@ -122,43 +141,14 @@ class VLMS(Blackbox):
|
||||
images_data = str(base64.b64encode(res.content),'utf-8')
|
||||
else:
|
||||
images_data = None
|
||||
## AutoLoad Model
|
||||
# url = 'http://10.6.80.87:8000/' + model_name + '/'
|
||||
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
|
||||
# data = requests.post(url, json=data_input)
|
||||
# print(data.text)
|
||||
# return data.text
|
||||
|
||||
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
|
||||
## Lmdeploy
|
||||
# if not user_context:
|
||||
# user_context = []
|
||||
|
||||
## Predefine user_context only for testing
|
||||
# user_context = [{'role':'user','content':'你好,我叫康康,你是谁?'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
|
||||
# user_context = [{
|
||||
# 'role': 'user',
|
||||
# 'content': [{
|
||||
# 'type': 'text',
|
||||
# 'text': '图中有什么,请描述一下',
|
||||
# }, {
|
||||
# 'type': 'image_url',
|
||||
# 'image_url': {
|
||||
# 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
|
||||
# },
|
||||
# }]
|
||||
# },{
|
||||
# 'role': 'assistant',
|
||||
# 'content': '图片中主要展示了一只老虎,它正在绿色的草地上休息。草地上有很多可以让人坐下的地方,而且看起来相当茂盛。背景比较模糊,可能是因为老虎的影响,让整个图片的其他部分都变得不太清晰了。'
|
||||
# }
|
||||
# ]
|
||||
## Predefine user_context only for testing
|
||||
# user_context = [{'role':'user','content':'你好,我叫康康,你是谁?'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
|
||||
|
||||
if not user_context and config['system_prompt']: user_context = [{'role':'system','content': config['system_prompt']}]
|
||||
user_context = self.keep_last_k_images(user_context,k = 2)
|
||||
|
||||
user_context = self.keep_last_k_images(user_context,k = 1)
|
||||
if self.model_url is None: self.model_url = self._get_model_url(model_name)
|
||||
|
||||
api_client = APIClient(self.model_url)
|
||||
# api_client = APIClient("http://10.6.80.91:23333")
|
||||
model_name = api_client.available_models[0]
|
||||
# Reformat input into openai format to request.
|
||||
if images_data:
|
||||
messages = user_context + [{
|
||||
@ -170,8 +160,6 @@ class VLMS(Blackbox):
|
||||
'type': 'image_url',
|
||||
'image_url': { # Image two
|
||||
'url':
|
||||
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
|
||||
# './val_data/image_5.jpg'
|
||||
f"data:image/jpeg;base64,{images_data}",
|
||||
},
|
||||
# },{ # Image one
|
||||
@ -194,40 +182,56 @@ class VLMS(Blackbox):
|
||||
|
||||
responses = ''
|
||||
total_token_usage = 0 # which can be used to count the cost of a query
|
||||
for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
|
||||
model_url = self._get_model_url(config['vlm_model_name'])
|
||||
|
||||
if config['lmdeploy_infer']:
|
||||
api_client = APIClient(model_url)
|
||||
model_name = api_client.available_models[0]
|
||||
for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
|
||||
messages=messages,stream = True,
|
||||
**settings,
|
||||
# session_id=,
|
||||
)):
|
||||
# Stream output
|
||||
print(item["choices"][0]["delta"]['content'],end='\n')
|
||||
yield item["choices"][0]["delta"]['content']
|
||||
responses += item["choices"][0]["delta"]['content']
|
||||
# Stream output
|
||||
yield item["choices"][0]["delta"]['content']
|
||||
responses += item["choices"][0]["delta"]['content']
|
||||
|
||||
# print(item["choices"][0]["message"]['content'])
|
||||
# responses += item["choices"][0]["message"]['content']
|
||||
# total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
|
||||
# print(item["choices"][0]["message"]['content'])
|
||||
# responses += item["choices"][0]["message"]['content']
|
||||
# total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
|
||||
else:
|
||||
api_key = "EMPTY_API_KEY"
|
||||
api_client = OpenAI(api_key=api_key, base_url=model_url+'/v1')
|
||||
model_name = api_client.models.list().data[0].id
|
||||
for item in api_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=True):
|
||||
yield(item.choices[0].delta.content)
|
||||
responses += item.choices[0].delta.content
|
||||
# response = api_client.chat.completions.create(
|
||||
# model=model_name,
|
||||
# messages=messages,
|
||||
# temperature=0.8,
|
||||
# top_p=0.8)
|
||||
# print(response.choices[0].message.content)
|
||||
# return response.choices[0].message.content
|
||||
|
||||
|
||||
user_context = messages + [{'role': 'assistant', 'content': responses}]
|
||||
self.custom_print(user_context)
|
||||
# return responses, user_context
|
||||
# return responses
|
||||
|
||||
def _get_model_url(self,model_name:str | None):
|
||||
available_models = {}
|
||||
for model, url in self.model_dict.items():
|
||||
try:
|
||||
response = requests.get(url,timeout=3)
|
||||
if response.status_code == 200:
|
||||
available_models[model] = url
|
||||
except Exception as e:
|
||||
# print(e)
|
||||
pass
|
||||
if not available_models: print("There are no available running models and please check your endpoint urls.")
|
||||
if model_name and model_name in available_models:
|
||||
return available_models[model_name]
|
||||
if not self.available_models: print("There are no available running models and please check your endpoint urls.")
|
||||
if model_name and model_name in self.available_models:
|
||||
return self.available_models[model_name]
|
||||
else:
|
||||
model = random.choice(list(available_models.keys()))
|
||||
model = random.choice(list(self.available_models.keys()))
|
||||
print(f"No such model {model_name}, using {model} instead.") if model_name else print(f"Using random model {model}.")
|
||||
return available_models[model]
|
||||
return self.available_models[model]
|
||||
|
||||
def _into_openai_format(self, context:List[list]) -> List[dict]:
|
||||
"""
|
||||
@ -298,7 +302,6 @@ class VLMS(Blackbox):
|
||||
result.append(item)
|
||||
return result[::-1]
|
||||
|
||||
|
||||
def custom_print(self, user_context: list):
|
||||
result = []
|
||||
for item in user_context:
|
||||
@ -315,18 +318,23 @@ class VLMS(Blackbox):
|
||||
## TODO: add support for multiple images and support image in form-data format
|
||||
json_request = True
|
||||
try:
|
||||
content_type = request.headers['content-type']
|
||||
content_type = request.headers.get('content-type', '')
|
||||
if content_type == 'application/json':
|
||||
data = await request.json()
|
||||
else:
|
||||
elif 'multipart/form-data' in content_type:
|
||||
data = await request.form()
|
||||
json_request = False
|
||||
json_request = False
|
||||
else:
|
||||
body = await request.body()
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
model_name = data.get("model_name")
|
||||
|
||||
prompt = data.get("prompt")
|
||||
settings: dict = data.get('settings')
|
||||
|
||||
context = data.get("context")
|
||||
if not context:
|
||||
user_context = []
|
||||
@ -337,22 +345,20 @@ class VLMS(Blackbox):
|
||||
else:
|
||||
return JSONResponse(content={"error": "context format error, should be in format of list or Openai_format"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if json_request:
|
||||
img_data = data.get("img_data")
|
||||
if json_request or 'multipart/form-data' not in content_type:
|
||||
img_data = data.get("img_data")
|
||||
else:
|
||||
img_data = await data.get("img_data").read()
|
||||
if settings: settings = ast.literal_eval(settings)
|
||||
|
||||
if prompt is None:
|
||||
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# if model_name is None or model_name.isspace():
|
||||
# model_name = "Qwen-VL-Chat"
|
||||
# response,_ = self.processing(prompt, img_data,settings, model_name,user_context=user_context)
|
||||
|
||||
# return StreamingResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
|
||||
return EventSourceResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
|
||||
|
||||
# HTTP JsonResponse
|
||||
response, history = self.processing(prompt, img_data,settings, model_name,user_context=user_context)
|
||||
# return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)
|
||||
streaming_output = str(settings.get('stream',False)).strip().lower() == 'true' if settings else False
|
||||
if streaming_output:
|
||||
# return StreamingResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
|
||||
return EventSourceResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
|
||||
else:
|
||||
# HTTP JsonResponse
|
||||
output = self.processing(prompt, img_data,settings, user_context=user_context)
|
||||
response = ''.join([res for res in output])
|
||||
return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)
|
||||
@ -180,3 +180,17 @@ class VLMConf():
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.urls = config.get("vlms.urls")
|
||||
|
||||
|
||||
@singleton
|
||||
class PathConf():
|
||||
chroma_rerank_embedding_model: str
|
||||
cosyvoice_path: str
|
||||
cosyvoice_model_path: str
|
||||
|
||||
@inject
|
||||
def __init__(self, config: Configuration) -> None:
|
||||
self.chroma_rerank_embedding_model = config.get("path.chroma_rerank_embedding_model")
|
||||
self.cosyvoice_path = config.get("path.cosyvoice_path")
|
||||
self.cosyvoice_model_path = config.get("path.cosyvoice_model_path")
|
||||
self.sensevoice_model_path = config.get("path.sensevoice_model_path")
|
||||
102
src/log/loki_config.py
Normal file
102
src/log/loki_config.py
Normal file
@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import logging_loki
|
||||
import os
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from typing import Dict, Optional
|
||||
class LokiLogger:
|
||||
"""
|
||||
一个用于配置和获取Loki日志记录器的类。
|
||||
可以在其他文件里import后直接使用。
|
||||
|
||||
用法示例:
|
||||
from loki_config import LokiLogger
|
||||
|
||||
# 获取日志记录器实例 (使用默认配置或环境变量)
|
||||
loki_logger_instance = LokiLogger().get_logger()
|
||||
loki_logger_instance.info("这条日志会推送到Loki!")
|
||||
|
||||
# 或者使用自定义配置
|
||||
my_loki_logger = LokiLogger(
|
||||
url="https://your-custom-loki-url.com/loki/api/v1/push",
|
||||
username="myuser",
|
||||
password="mypassword",
|
||||
tags={"app_name": "my-custom-app", "environment": "prod"},
|
||||
level=logging.DEBUG
|
||||
).get_logger()
|
||||
my_loki_logger.debug("这条调试日志也推送到Loki。")
|
||||
"""
|
||||
def __init__(self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
level: int = logging.INFO,
|
||||
logger_name: str = "app_loki_logger"):
|
||||
"""
|
||||
初始化LokiLogger。
|
||||
|
||||
Args:
|
||||
url (str, optional): Loki的推送URL。默认从环境变量LOKI_URL获取,
|
||||
如果不存在则使用"https://loki.bwgdi.com/loki/api/v1/push"。
|
||||
username (str, optional): Loki的认证用户名。默认从环境变量LOKI_USERNAME获取。
|
||||
password (str, optional): Loki的认证密码。默认从环境变量LOKI_PASSWORD获取。
|
||||
tags (Dict[str, str], optional): 发送到Loki的默认标签。例如:{"application": "my-app"}。
|
||||
如果未提供,则默认为{"application": "default-app", "source": "python-app"}。
|
||||
level (int, optional): 日志级别(如logging.INFO, logging.DEBUG)。默认为logging.INFO。
|
||||
logger_name (str, optional): 日志记录器的名称。默认为"app_loki_logger"。
|
||||
"""
|
||||
# 从环境变量获取配置,如果未通过参数提供
|
||||
self._url = url if url else os.getenv("LOKI_URL", "https://loki.bwgdi.com/loki/api/v1/push")
|
||||
self._username = username if username else os.getenv("LOKI_USERNAME",'admin')
|
||||
self._password = password if password else os.getenv("LOKI_PASSWORD",'admin')
|
||||
self._tags = tags if tags is not None else {"app": "jarvis", "env": "dev", "location": "gdi", "layer": "models"}
|
||||
|
||||
# 获取或创建指定名称的日志记录器
|
||||
self._logger = logging.getLogger(logger_name)
|
||||
self._logger.setLevel(level)
|
||||
|
||||
for handler in self._logger.handlers[:]:
|
||||
self._logger.removeHandler(handler)
|
||||
# 检查是否已存在LokiHandler,避免重复添加导致重复日志
|
||||
# 在多文件或多次初始化的情况下,同一个logger_name可能会获取到同一个logger实例
|
||||
if not any(isinstance(h, logging_loki.LokiHandler) for h in self._logger.handlers):
|
||||
try:
|
||||
auth = None
|
||||
if self._username and self._password:
|
||||
auth = HTTPBasicAuth(self._username, self._password)
|
||||
|
||||
loki_handler = logging_loki.LokiHandler(
|
||||
url=self._url,
|
||||
tags=self._tags,
|
||||
version="1", # 通常Loki API版本为1
|
||||
auth=auth
|
||||
)
|
||||
self._logger.addHandler(loki_handler)
|
||||
# 同时添加一个StreamHandler到控制台,以便在本地调试时也能看到日志输出
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
|
||||
self._logger.addHandler(logging.StreamHandler())
|
||||
self._logger.info(f"LokiLogger: 已成功配置Loki日志处理器,目标地址:{self._url}")
|
||||
except Exception as e:
|
||||
# 如果Loki配置失败,确保仍然有StreamHandler将日志输出到控制台
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
|
||||
self._logger.addHandler(logging.StreamHandler())
|
||||
self._logger.error(f"LokiLogger: 配置LokiHandler失败:{e}。将回退到控制台日志记录。", exc_info=True)
|
||||
else:
|
||||
# 如果LokiHandler已经存在,只确保StreamHandler存在
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
|
||||
self._logger.addHandler(logging.StreamHandler())
|
||||
self._logger.debug(f"LokiLogger: '{logger_name}' 记录器已配置LokiHandler,跳过重新配置。")
|
||||
|
||||
|
||||
def get_logger(self) -> logging.Logger:
|
||||
"""
|
||||
返回已配置的日志记录器实例。
|
||||
"""
|
||||
self._logger.debug("LokiLogger: 获取已配置的日志记录器实例。")
|
||||
return self._logger
|
||||
|
||||
# 可选:如果希望有一个默认的全局LokiLogger实例
|
||||
# 你可以在这里实例化,然后在其他文件直接从这里导入 `loki_logger`
|
||||
# 例如:
|
||||
# DEFAULT_LOKI_LOGGER_INSTANCE = LokiLogger().get_logger()
|
||||
# 然后在其他文件里: `from loki_config import DEFAULT_LOKI_LOGGER_INSTANCE as logger`
|
||||
@ -19,7 +19,7 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dirbaspath = __file__.split("\\")[1:-1]
|
||||
dirbaspath= "/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath)
|
||||
dirbaspath= "./src/tts" + "/".join(dirbaspath)
|
||||
config = {
|
||||
'ayaka': {
|
||||
'cfg': dirbaspath + '/models/ayaka.json',
|
||||
|
||||
Reference in New Issue
Block a user