Merge main into Tom

This commit is contained in:
tom
2025-08-18 17:37:07 +08:00
13 changed files with 524 additions and 192 deletions

View File

@ -45,14 +45,54 @@ log:
time_format: "%Y-%m-%d %H:%M:%S" time_format: "%Y-%m-%d %H:%M:%S"
filename: "D:/Workspace/Logging/jarvis/jarvis-models.log" filename: "D:/Workspace/Logging/jarvis/jarvis-models.log"
loki:
url: "https://loki.bwgdi.com/loki/api/v1/push"
labels:
app: jarvis
env: dev
location: "gdi"
layer: models
melotts: melotts:
mode: local # or docker mode: local # or docker
url: http://10.6.44.16:18080/convert/tts url: http://10.6.44.141:18080/convert/tts
speed: 0.9 speed: 0.9
device: 'cuda' device: 'cuda:0'
language: 'ZH' language: 'ZH'
speaker: 'ZH' speaker: 'ZH'
cosyvoicetts:
mode: local # or docker
url: http://10.6.44.141:18080/convert/tts
speed: 0.9
device: 'cuda:0'
language: '粤语女'
speaker: 'ZH'
sovitstts:
mode: docker
url: http://10.6.80.90:9880/tts
speed: 0.9
device: 'cuda:0'
language: 'ZH'
speaker: 'ZH'
text_lang: "yue"
ref_audio_path: "output/slicer_opt/Ricky-Wong/Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav"
prompt_lang: "yue"
prompt_text: "你失敗咗點算啊?你而家安安穩穩,點解要咁樣做呢?"
text_split_method: "cut5"
batch_size: 1
media_type: "wav"
streaming_mode: True
sensevoiceasr:
mode: local # or docker
url: http://10.6.44.141:18080/convert/tts
speed: 0.9
device: 'cuda:0'
language: '粤语女'
speaker: 'ZH'
tesou: tesou:
url: http://120.196.116.194:48891/chat/ url: http://120.196.116.194:48891/chat/
@ -91,5 +131,14 @@ blackbox:
lazyloading: true lazyloading: true
vlms: vlms:
url: http://10.6.80.87:23333 urls:
qwen_vl: http://10.6.80.87:8000
qwen2_vl: http://10.6.80.87:23333
qwen2_vl_72b: http://10.6.80.91:23333
path:
chroma_rerank_embedding_model: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI
cosyvoice_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/Workspace/CosyVoice
cosyvoice_model_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/Voice/CosyVoice/pretrained_models
sensevoice_model_path: /media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/Voice/SenseVoice/SenseVoiceSmall
``` ```

View File

@ -10,3 +10,4 @@ langchain==0.1.17
langchain-community==0.0.36 langchain-community==0.0.36
sentence-transformers==2.7.0 sentence-transformers==2.7.0
openai openai
python-logging-loki

View File

@ -7,20 +7,23 @@ from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import time import time
from pathlib import Path
path = Path("/media/verachen/e0f7a88c-ad43-4736-8829-4d06e5ed8f4f/model/BAAI")
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0 # chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8") # loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt") loader = TextLoader("./RAG_boss.txt")
documents = loader.load() documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n']) text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents) docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs)) print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))] ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"}) embedding_model = SentenceTransformerEmbeddings(model_name= str(path / "bge-m3"), model_kwargs={"device": "cuda:0"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host="localhost", port=7000)
id = "kiki" id = "boss2"
# client.delete_collection(id) # client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量) # 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client) db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
@ -28,13 +31,13 @@ db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, c
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name= str(path / "bge-m3"), device = "cuda:0")
client = chromadb.HttpClient(host='10.6.44.141', port=7000) client = chromadb.HttpClient(host='localhost', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model) collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0") reranker_model = CrossEncoder(str(path / "bge-reranker-v2-m3"), max_length=512, device = "cuda:0")
# while True: # while True:
# usr_question = input("\n 请输入问题: ") # usr_question = input("\n 请输入问题: ")

View File

@ -4,8 +4,12 @@ from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from src.blackbox.blackbox_factory import BlackboxFactory from src.blackbox.blackbox_factory import BlackboxFactory
from src.log.loki_config import LokiLogger
import logging
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from injector import Injector from injector import Injector
import yaml
app = FastAPI() app = FastAPI()
@ -20,16 +24,64 @@ app.add_middleware(
injector = Injector() injector = Injector()
blackbox_factory = injector.get(BlackboxFactory) blackbox_factory = injector.get(BlackboxFactory)
with open(".env.yaml", "r") as f:
config = yaml.safe_load(f)
logger = LokiLogger(
# url=config["loki"]["url"],
# username=config["loki"]["username"],
# password=config["loki"]["password"],
tags=config["loki"]["labels"],
logger_name=__name__,
level=logging.DEBUG
).get_logger()
@app.exception_handler(Exception)
async def catch_all_exceptions(request: Request, exc: Exception):
"""
捕获所有在 ASGI 应用中未被处理的异常,并记录到 Loki。
"""
logger.error(
f"Unhandled exception in ASGI application for path: {request.url.path}, method: {request.method} - {exc}",
exc_info=True, # 必须为 True才能获取完整的堆栈信息
extra={
"path": request.url.path,
"method": request.method,
"error_type": type(exc).__name__,
"error_message": str(exc)
}
)
return JSONResponse(
status_code=500,
content={"message": "Internal Server Error", "detail": "An unexpected error occurred."}
)
@app.post("/") @app.post("/")
@app.get("/") @app.get("/")
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
if not blackbox_name: if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
try: try:
box = blackbox_factory.get_blackbox(blackbox_name) box = blackbox_factory.get_blackbox(blackbox_name)
except ValueError: except ValueError as e:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) logger.error(f"获取 blackbox 失败: {blackbox_name} - {e}", exc_info=True)
return await box.fast_api_handler(request) return JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
try:
response = await box.fast_api_handler(request)
# 检查响应的状态码,如果 >= 400则认为是错误响应并记录
if response.status_code >= 400:
try:
decoded_content = response.body.decode('utf-8', errors='ignore')
logger.error(f"Blackbox 返回错误响应: URL={request.url}, Method={request.method}, Status={response.status_code}, Content={decoded_content}")
except Exception as body_exc:
logger.error(f"Blackbox {blackbox_name} 返回错误响应: URL={request.url}, Method={request.method}, Status={response.status_code}, 无法解析响应内容: {body_exc}")
return response
except Exception as e:
logger.error(f"Blackbox 内部抛出异常: URL={request.url}, Method={request.method}, 异常报错: {type(e).__name__}: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"message": "Internal Server Error", "detail": "An unexpected error occurred."}
)
# @app.get("/audio/{filename}") # @app.get("/audio/{filename}")
# async def serve_audio(filename: str): # async def serve_audio(filename: str):

View File

@ -12,6 +12,9 @@ from funasr.utils.postprocess_utils import rich_transcription_postprocess
from .blackbox import Blackbox from .blackbox import Blackbox
from injector import singleton, inject from injector import singleton, inject
from pathlib import Path
from ..configuration import PathConf
import tempfile import tempfile
import json import json
@ -44,12 +47,13 @@ class ASR(Blackbox):
speaker: str speaker: str
@logging_time(logger=logger) @logging_time(logger=logger)
def model_init(self, sensevoice_config: SenseVoiceConf) -> None: def model_init(self, sensevoice_config: SenseVoiceConf, path: PathConf) -> None:
config = read_yaml(".env.yaml") config = read_yaml(".env.yaml")
self.paraformer = RapidParaformer(config) self.paraformer = RapidParaformer(config)
sense_model_path = Path(path.sensevoice_model_path)
model_dir = "/model/Voice/SenseVoice/SenseVoiceSmall" model_dir = str(sense_model_path)
self.speed = sensevoice_config.speed self.speed = sensevoice_config.speed
self.device = sensevoice_config.device self.device = sensevoice_config.device
@ -65,7 +69,7 @@ class ASR(Blackbox):
self.asr = AutoModel( self.asr = AutoModel(
model=model_dir, model=model_dir,
trust_remote_code=True, trust_remote_code=True,
remote_code= "/Workspace/SenseVoice/model.py", remote_code= "../SenseVoice/model.py",
vad_model="fsmn-vad", vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000}, vad_kwargs={"max_single_segment_time": 30000},
device=self.device, device=self.device,
@ -76,8 +80,8 @@ class ASR(Blackbox):
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...') logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...')
@inject @inject
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None: def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict, path: PathConf) -> None:
self.model_init(sensevoice_config) self.model_init(sensevoice_config, path)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -59,12 +59,13 @@ class Chat(Blackbox):
chroma_embedding_model = settings.get("chroma_embedding_model") chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_collection_id = settings.get("chroma_collection_id") chroma_collection_id = settings.get("chroma_collection_id")
chroma_response = '' chroma_response = ''
system_prompt = settings.get('system_prompt') system_prompt = settings.get('system_prompt',"")
user_prompt_template = settings.get('user_prompt_template') user_prompt_template = settings.get('user_prompt_template')
user_stream = settings.get('stream') user_stream = settings.get('stream')
user_websearch = settings.get('websearch') user_websearch = settings.get('websearch')
user_thinking = settings.get('thinking', False)
llm_model = "vllm" llm_model = "llm"
if user_context == None: if user_context == None:
user_context = [] user_context = []
@ -106,9 +107,9 @@ class Chat(Blackbox):
if user_model_url is None or user_model_url.isspace() or user_model_url == "": if user_model_url is None or user_model_url.isspace() or user_model_url == "":
if llm_model != "vllm": if llm_model != "vllm":
user_model_url = "http://10.6.80.75:23333/v1/chat/completions" user_model_url = "http://localhost:23333/v1/chat/completions"
else: else:
user_model_url = "http://10.6.80.94:8000/v1/completions" user_model_url = "http://localhost:8000/v1/completions"
if user_model_key is None or user_model_key.isspace() or user_model_key == "": if user_model_key is None or user_model_key.isspace() or user_model_key == "":
if llm_model != "vllm": if llm_model != "vllm":
@ -120,16 +121,16 @@ class Chat(Blackbox):
chroma_response = self.chroma_query(user_question, settings) chroma_response = self.chroma_query(user_question, settings)
print("1.Chroma_response: \n", chroma_response) print("1.Chroma_response: \n", chroma_response)
if chroma_response: if chroma_response and isinstance(chroma_response, str):
if chroma_collection_id == 'boss': if chroma_collection_id == 'boss':
user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口澳门博维集团董事长周家俊G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”" user_prompt_template = "# 你的身份 #\n你是周家俊,澳门博维集团董事长。你擅长脑机接口回答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n成熟稳重,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n # 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与脑机接口澳门博维集团董事长周家俊G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki': elif chroma_collection_id == 'g2e' or chroma_collection_id == 'kiki':
user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅博维康普可可琪琪G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”" user_prompt_template = "# 你的身份 #\n你是琪琪,你是康普可可的代言人,由博维开发。你擅长澳门文旅问答。\n# OBJECTIVE目标 #\n回答游客的提问。\n# STYLE风格#\n像少女一般开朗活泼,回答简练。不要分条。\n# 回答方式 #\n首先自行判断下方问题与检索内容是否相关,若相关则根据检索内容总结概括相关信息进行回答;若检索内容与问题无关,则根据自身知识进行回答。\n# 回答 #\n如果检索内容与问题相关,则直接从检索内容中提炼出问题所需的信息。如果检索内容与问题不相关,则不参考检索内容,直接根据常识尝试回答问题,或者则回答:“对不起,我无法回答此问题哦。”\n# 回答限制 #\n回答内容限制总结在50字内。\n回答内容出不要出现“相关”等字眼,不要乱说或者多说,回答的内容需要与问题对应。常见的对话可以不采用检索内容,根据人物设定,直接进行回答。\n只回答与澳门文旅博维康普可可琪琪G2ERELXBO VISION相关内容若遇到其他提问则回答“对不起我无法回答此问题哦。”"
print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}") print(f"user_prompt_template: {type(user_prompt_template)}, user_question: {type(user_question)}, chroma_response: {type(chroma_response)}")
user_question = user_prompt_template + "问题: " + user_question + "检索内容: " + chroma_response + "" user_question = user_prompt_template + "\n\n问题: " + user_question + "\n\n检索内容: " + chroma_response
else: else:
if llm_model != "vllm": if llm_model != "vllm":
user_question = user_prompt_template + "问题: " + user_question + "" user_question = user_prompt_template + "\n\n问题: " + user_question
else: else:
user_question = user_question user_question = user_question
@ -249,6 +250,7 @@ class Chat(Blackbox):
"presence_penalty": str(user_presence_penalty), "presence_penalty": str(user_presence_penalty),
"stop": str(user_stop), "stop": str(user_stop),
"stream": user_stream, "stream": user_stream,
"chat_template_kwargs": {"enable_thinking": user_thinking},
} }
else: else:
chat_inputs={ chat_inputs={
@ -267,6 +269,7 @@ class Chat(Blackbox):
"presence_penalty":float( user_presence_penalty), "presence_penalty":float( user_presence_penalty),
# "stop": user_stop, # "stop": user_stop,
"stream": user_stream, "stream": user_stream,
"chat_template_kwargs": {"enable_thinking": user_thinking},
} }
# # 获取当前时间戳 # # 获取当前时间戳
@ -308,6 +311,7 @@ class Chat(Blackbox):
except json.JSONDecodeError: except json.JSONDecodeError:
# print("---- Error in JSON parsing ----") # 打印错误信息 # print("---- Error in JSON parsing ----") # 打印错误信息
continue # 继续处理下一个chunk直到解析成功 continue # 继续处理下一个chunk直到解析成功
# yield "[Response END]"
else: else:
print("*"*90) print("*"*90)

View File

@ -12,22 +12,29 @@ from ..log.logging_time import logging_time
import re import re
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123" DEFAULT_COLLECTION_ID = "123"
from injector import singleton from injector import singleton
@singleton @singleton
class ChromaQuery(Blackbox): class ChromaQuery(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0]) # config = read_yaml(args[0])
# load chromadb and embedding model # load chromadb and embedding model
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:0") path = PathConf(Configuration())
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0") self.model_path = Path(path.chroma_rerank_embedding_model)
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000)
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-large-zh-v1.5"), device = "cuda:0")
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-small-en-v1.5"), device = "cuda:0")
self.embedding_model_3 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=str(self.model_path / "bge-m3"), device = "cuda:0")
self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000) # self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
self.reranker_model_1 = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda") self.reranker_model_1 = CrossEncoder(str(self.model_path / "bge-reranker-v2-m3"), max_length=512, device = "cuda")
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -57,10 +64,10 @@ class ChromaQuery(Blackbox):
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" chroma_embedding_model = str(self.model_path / "bge-large-zh-v1.5")
if chroma_host is None or chroma_host.isspace() or chroma_host == "": if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.44.141" chroma_host = "localhost"
if chroma_port is None or chroma_port.isspace() or chroma_port == "": if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "7000" chroma_port = "7000"
@ -72,7 +79,7 @@ class ChromaQuery(Blackbox):
chroma_n_results = 10 chroma_n_results = 10
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.44.141", chroma_host) and re.search(r"7000", chroma_port): if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
try: try:
@ -80,11 +87,11 @@ class ChromaQuery(Blackbox):
except: except:
return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "chroma client not found"}, status_code=status.HTTP_400_BAD_REQUEST)
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1 embedding_model = self.embedding_model_1
elif re.search(r"/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model): elif re.search(str(self.model_path / "bge-small-en-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_2 embedding_model = self.embedding_model_2
elif re.search(r"/Workspace/Models/BAAI/bge-m3", chroma_embedding_model): elif re.search(str(self.model_path / "bge-m3"), chroma_embedding_model):
embedding_model = self.embedding_model_3 embedding_model = self.embedding_model_3
else: else:
try: try:
@ -123,7 +130,7 @@ class ChromaQuery(Blackbox):
final_result = str(results["documents"]) final_result = str(results["documents"])
if chroma_reranker_model: if chroma_reranker_model:
if re.search(r"/Workspace/Models/BAAI/bge-reranker-v2-m3", chroma_reranker_model): if re.search(str(self.model_path / "bge-reranker-v2-m3"), chroma_reranker_model):
reranker_model = self.reranker_model_1 reranker_model = self.reranker_model_1
else: else:
try: try:

View File

@ -8,7 +8,7 @@ import requests
import json import json
from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader, UnstructuredHTMLLoader, JSONLoader, Docx2txtLoader, UnstructuredExcelLoader, UnstructuredPDFLoader
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
@ -21,6 +21,10 @@ import logging
from ..log.logging_time import logging_time from ..log.logging_time import logging_time
import re import re
from pathlib import Path
from ..configuration import Configuration
from ..configuration import PathConf
logger = logging.getLogger logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123" DEFAULT_COLLECTION_ID = "123"
@ -31,9 +35,13 @@ class ChromaUpsert(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
# config = read_yaml(args[0]) # config = read_yaml(args[0])
# load embedding model # load embedding model
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name="/Workspace/Models/BAAI/bge-large-zh-v1.5", model_kwargs={"device": "cuda"})
path = PathConf(Configuration())
self.model_path = Path(path.chroma_rerank_embedding_model)
self.embedding_model_1 = SentenceTransformerEmbeddings(model_name=str(self.model_path / "bge-large-zh-v1.5"), model_kwargs={"device": "cuda"})
# load chroma db # load chroma db
self.client_1 = chromadb.HttpClient(host='10.6.44.141', port=7000) self.client_1 = chromadb.HttpClient(host='localhost', port=7000)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -51,45 +59,44 @@ class ChromaUpsert(Blackbox):
# # chroma_query settings # # chroma_query settings
if "settings" in settings: if "settings" in settings:
chroma_embedding_model = settings["settings"].get("chroma_embedding_model") chroma_embedding_model = settings["settings"].get("chroma_embedding_model")
chroma_host = settings["settings"].get("chroma_host") chroma_host = settings["settings"].get("chroma_host", "localhost")
chroma_port = settings["settings"].get("chroma_port") chroma_port = settings["settings"].get("chroma_port", "7000")
chroma_collection_id = settings["settings"].get("chroma_collection_id") chroma_collection_id = settings["settings"].get("chroma_collection_id", DEFAULT_COLLECTION_ID)
user_chunk_size = settings["settings"].get("chunk_size", 256)
user_chunk_overlap = settings["settings"].get("chunk_overlap", 10)
user_separators = settings["settings"].get("separators", ["\n\n"])
else: else:
chroma_embedding_model = settings.get("chroma_embedding_model") chroma_embedding_model = settings.get("chroma_embedding_model")
chroma_host = settings.get("chroma_host") chroma_host = settings.get("chroma_host", "localhost")
chroma_port = settings.get("chroma_port") chroma_port = settings.get("chroma_port", "7000")
chroma_collection_id = settings.get("chroma_collection_id") chroma_collection_id = settings.get("chroma_collection_id", DEFAULT_COLLECTION_ID)
user_chunk_size = settings.get("chunk_size", 256)
user_chunk_overlap = settings.get("chunk_overlap", 10)
user_separators = settings.get("separators", ["\n\n"])
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "": if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
chroma_embedding_model = "/Workspace/Models/BAAI/bge-large-zh-v1.5" chroma_embedding_model = model_name=str(self.model_path / "bge-large-zh-v1.5")
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
chroma_host = "10.6.82.192"
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
chroma_port = "8000"
if chroma_collection_id is None or chroma_collection_id.isspace() or chroma_collection_id == "":
chroma_collection_id = "g2e"
# load client and embedding model from init # load client and embedding model from init
if re.search(r"10.6.82.192", chroma_host) and re.search(r"8000", chroma_port): if re.search(r"localhost", chroma_host) and re.search(r"7000", chroma_port):
client = self.client_1 client = self.client_1
else: else:
client = chromadb.HttpClient(host=chroma_host, port=chroma_port) client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
print(f"chroma_embedding_model: {chroma_embedding_model}") print(f"chroma_embedding_model: {chroma_embedding_model}")
if re.search(r"/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model): if re.search(str(self.model_path / "bge-large-zh-v1.5"), chroma_embedding_model):
embedding_model = self.embedding_model_1 embedding_model = self.embedding_model_1
else: else:
embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, device = "cuda:0") embedding_model = SentenceTransformerEmbeddings(model_name=chroma_embedding_model, model_kwargs={"device": "cuda"})
response_file ='' response_file =''
response_string = '' response_string = ''
if file is not None: if file is not None:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=user_chunk_size, chunk_overlap=user_chunk_overlap, separators=user_separators)
file_type = file.split(".")[-1] file_type = file.split(".")[-1]
print("file_type: ",file_type) print("file_type: ",file_type)
if file_type == "pdf": if file_type == "pdf":
loader = PyPDFLoader(file) loader = UnstructuredPDFLoader(file)
elif file_type == "txt": elif file_type == "txt":
loader = TextLoader(file) loader = TextLoader(file)
elif file_type == "csv": elif file_type == "csv":
@ -102,9 +109,10 @@ class ChromaUpsert(Blackbox):
loader = Docx2txtLoader(file) loader = Docx2txtLoader(file)
elif file_type == "xlsx": elif file_type == "xlsx":
loader = UnstructuredExcelLoader(file) loader = UnstructuredExcelLoader(file)
elif file_type == "md":
loader = UnstructuredMarkdownLoader(file, mode="single", strategy="fast")
documents = loader.load() documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
docs = text_splitter.split_documents(documents) docs = text_splitter.split_documents(documents)
@ -148,13 +156,13 @@ class ChromaUpsert(Blackbox):
if user_text is not None and user_text_ids is None: if user_text is not None and user_text_ids is None:
return JSONResponse(content={"error": "text_ids is required when text is provided"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "text_ids is required when text is provided"}, status_code=status.HTTP_400_BAD_REQUEST)
if user_file is not None: if user_file is not None and user_file.size != 0:
pdf_bytes = await user_file.read() pdf_bytes = await user_file.read()
custom_filename = user_file.filename custom_filename = user_file.filename
# 获取系统的临时目录路径 # 获取系统的临时目录路径
safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename)) safe_filename = os.path.join(tempfile.gettempdir(), os.path.basename(custom_filename))
print("file_path", safe_filename)
with open(safe_filename, "wb") as f: with open(safe_filename, "wb") as f:
f.write(pdf_bytes) f.write(pdf_bytes)
else: else:

View File

@ -10,14 +10,15 @@ from ..tts.tts_service import TTService
from ..configuration import MeloConf from ..configuration import MeloConf
from ..configuration import CosyVoiceConf from ..configuration import CosyVoiceConf
from ..configuration import SovitsConf from ..configuration import SovitsConf
from ..configuration import PathConf
from injector import inject from injector import inject
from injector import singleton from injector import singleton
import sys,os import sys,os
sys.path.append('/Workspace/CosyVoice')
sys.path.append('/Workspace/CosyVoice/third_party/Matcha-TTS') from pathlib import Path
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
# from cosyvoice.utils.file_utils import load_wav, speed_change
import soundfile as sf import soundfile as sf
import pyloudnorm as pyln import pyloudnorm as pyln
@ -33,6 +34,7 @@ import numpy as np
from pydub import AudioSegment from pydub import AudioSegment
import subprocess import subprocess
import re
def set_all_random_seed(seed): def set_all_random_seed(seed):
random.seed(seed) random.seed(seed)
@ -99,7 +101,18 @@ class TTS(Blackbox):
print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...') print('1.#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
@logging_time(logger=logger) @logging_time(logger=logger)
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None: def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf, path: PathConf) -> None:
cosy_path = Path(path.cosyvoice_path)
cosy_model_path = Path(path.cosyvoice_model_path)
sys.path.append(str(cosy_path))
sys.path.append(str(cosy_path / "third_party/Matcha-TTS"))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav#, speed_change
self.cosyvoice_speed = cosyvoice_config.speed self.cosyvoice_speed = cosyvoice_config.speed
self.cosyvoice_device = cosyvoice_config.device self.cosyvoice_device = cosyvoice_config.device
self.cosyvoice_language = cosyvoice_config.language self.cosyvoice_language = cosyvoice_config.language
@ -107,12 +120,13 @@ class TTS(Blackbox):
self.cosyvoice_url = '' self.cosyvoice_url = ''
self.cosyvoice_mode = cosyvoice_config.mode self.cosyvoice_mode = cosyvoice_config.mode
self.cosyvoicetts = None self.cosyvoicetts = None
self.prompt_speech_16k = None
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device) # os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
# self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M') # self.cosyvoicetts = CosyVoice('/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M') # self.cosyvoicetts = CosyVoice('/model/Voice/CosyVoice/pretrained_models/CosyVoice-300M')
# self.cosyvoicetts = CosyVoice2('/model/Voice/CosyVoice/pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False) self.cosyvoicetts = CosyVoice2(str (cosy_model_path / "CosyVoice2-0.5B"), load_jit=True, load_trt=False)
self.prompt_speech_16k = load_wav("./Ricky-Wong-3-Mins.wav_0006003840_0006134080.wav", 16000)
else: else:
self.cosyvoice_url = cosyvoice_config.url self.cosyvoice_url = cosyvoice_config.url
@ -149,16 +163,31 @@ class TTS(Blackbox):
@inject @inject
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict) -> None: def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, sovits_config: SovitsConf, settings: dict, path: PathConf) -> None:
self.tts_service = TTService("yunfeineo") self.tts_service = TTService("yunfeineo")
self.melo_model_init(melo_config) self.melo_model_init(melo_config)
self.cosyvoice_model_init(cosyvoice_config) self.cosyvoice_model_init(cosyvoice_config, path)
self.sovits_model_init(sovits_config) self.sovits_model_init(sovits_config)
self.audio_dir = "audio_files" # 存储音频文件的目录 self.audio_dir = "audio_files" # 存储音频文件的目录
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
def filter_invalid_chars(self,text):
"""过滤无效字符(包括字节流)"""
invalid_keywords = ["data:", "\n", "\r", "\t", " "]
if isinstance(text, bytes):
text = text.decode('utf-8', errors='ignore')
for keyword in invalid_keywords:
text = text.replace(keyword, "")
# 移除所有英文字母和符号(保留中文、标点等)
text = re.sub(r'[a-zA-Z]', '', text)
return text.strip()
@logging_time(logger=logger) @logging_time(logger=logger)
def processing(self, *args, settings: dict) -> io.BytesIO: def processing(self, *args, settings: dict) -> io.BytesIO:
@ -213,7 +242,8 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) # audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -234,12 +264,44 @@ class TTS(Blackbox):
set_all_random_seed(56056558) set_all_random_seed(56056558)
print("*"*90) print("*"*90)
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True) audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
# audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
# for i, j in enumerate(audio):
# f = io.BytesIO()
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
# f.seek(0)
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
# return f.read()
# 打印 audio 的长度和内容结构
# print(f"Total audio segments: {len(audio)}")
# print(f"Audio data structure: {audio}")
# 创建一个空的列表来存储所有音频段的 NumPy 数组
all_audio_data = []
# 遍历每一段音频并将它们存储到 all_audio_data 列表
for i, j in enumerate(audio): for i, j in enumerate(audio):
# print(f"Processing segment {i + 1}...")
# 打印每段音频的信息,确保其正确
# print(f"Segment {i + 1} shape: {j['tts_speech'].shape}")
# 直接将音频数据转换成 NumPy 数组
audio_data = j['tts_speech'].cpu().numpy()
# 将每个段的音频数据添加到 all_audio_data 列表
all_audio_data.append(audio_data[0]) # 取音频的第一个通道(假设为单声道)
# 将所有音频段的 NumPy 数组合并成一个完整的音频数组
combined_audio_data = np.concatenate(all_audio_data, axis=0)
# 将合并后的音频数据写入到 BytesIO 中
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, combined_audio_data, 22050, format='wav') # 22050 为采样率,可能需要根据实际情况调整
f.seek(0) f.seek(0)
# 返回合并后的音频
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read() return f.read() # 返回最终合并后的音频数据
else: else:
message = { message = {
"text": text "text": text
@ -250,7 +312,8 @@ class TTS(Blackbox):
elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
if self.cosyvoice_mode == 'local': if self.cosyvoice_mode == 'local':
set_all_random_seed(35616313) set_all_random_seed(35616313)
audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) # audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False)
audio = self.cosyvoicetts.inference_instruct2(text, '用粤语说这句话', self.prompt_speech_16k, stream=False)
for i, j in enumerate(audio): for i, j in enumerate(audio):
f = io.BytesIO() f = io.BytesIO()
sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
@ -266,6 +329,7 @@ class TTS(Blackbox):
return response.content return response.content
elif user_model_name == 'sovitstts': elif user_model_name == 'sovitstts':
# text = self.filter_invalid_chars(text)
if chroma_collection_id == 'kiki' or chroma_collection_id is None: if chroma_collection_id == 'kiki' or chroma_collection_id is None:
if self.sovits_mode == 'local': if self.sovits_mode == 'local':
set_all_random_seed(56056558) set_all_random_seed(56056558)
@ -286,9 +350,9 @@ class TTS(Blackbox):
"text_split_method": self.sovits_text_split_method, "text_split_method": self.sovits_text_split_method,
"batch_size": self.sovits_batch_size, "batch_size": self.sovits_batch_size,
"media_type": self.sovits_media_type, "media_type": self.sovits_media_type,
"streaming_mode": self.sovits_streaming_mode "streaming_mode": user_stream
} }
if user_stream: if user_stream == True or str(user_stream).lower() == "true":
response = requests.get(self.sovits_url, params=message, stream=True) response = requests.get(self.sovits_url, params=message, stream=True)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response return response
@ -299,23 +363,37 @@ class TTS(Blackbox):
print("#### SoVITS Service consume - docker : ", (time.time()-current_time)) print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
# elif chroma_collection_id == 'boss': elif chroma_collection_id == 'boss':
# if self.cosyvoice_mode == 'local': if self.sovits_mode == 'local':
# set_all_random_seed(35616313) set_all_random_seed(56056558)
# audio = self.cosyvoicetts.inference_sft(text, '中文男', speed=1.5, stream=False) audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language, stream=True)
# for i, j in enumerate(audio): for i, j in enumerate(audio):
# f = io.BytesIO() f = io.BytesIO()
# sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav') sf.write(f, j['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
# f.seek(0) f.seek(0)
# print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time)) print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
# return f.read() return f.read()
# else: else:
# message = { message = {
# "text": text "text": text,
# } "text_lang": self.sovits_text_lang,
# response = requests.post(self.cosyvoice_url, json=message) "ref_audio_path": self.sovits_ref_audio_path,
# print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time)) "prompt_lang": self.sovits_prompt_lang,
# return response.content "prompt_text": self.sovits_prompt_text,
"text_split_method": self.sovits_text_split_method,
"batch_size": self.sovits_batch_size,
"media_type": self.sovits_media_type,
"streaming_mode": user_stream
}
if user_stream == True or str(user_stream).lower() == "true":
response = requests.get(self.sovits_url, params=message, stream=True)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response
else:
response = requests.get(self.sovits_url, params=message)
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
return response.content
print("#### SoVITS Service consume - docker : ", (time.time()-current_time))
elif user_model_name == 'man': elif user_model_name == 'man':
@ -359,9 +437,12 @@ class TTS(Blackbox):
if text is None: if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
by = self.processing(text, settings=setting) by = self.processing(text, settings=setting)
# return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) if user_stream not in (True, "True", "true"):
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
if user_stream: print(f"tts user_stream: {type(user_stream)}")
# import pdb; pdb.set_trace()
if user_stream in (True, "True", "true"):
print(f"tts user_stream22: {user_stream}")
if by.status_code == 200: if by.status_code == 200:
print("*"*90) print("*"*90)
def audio_stream(): def audio_stream():
@ -405,6 +486,7 @@ class TTS(Blackbox):
else: else:
wav_filename = os.path.join(self.audio_dir, 'audio.wav') wav_filename = os.path.join(self.audio_dir, 'audio.wav')
print("8"*90)
with open(wav_filename, 'wb') as f: with open(wav_filename, 'wb') as f:
f.write(by) f.write(by)

View File

@ -13,13 +13,19 @@ import requests
import base64 import base64
import copy import copy
import ast import ast
import json
import random import random
from time import time from time import time
import io import io
from PIL import Image from PIL import Image
from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.serve.openai.api_client import APIClient
from openai import OpenAI
def is_base64(value) -> bool: def is_base64(value) -> bool:
try: try:
@ -52,14 +58,14 @@ class VLMS(Blackbox):
- skip_special_tokens (bool): Whether or not to remove special tokens - skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.""" in the decoding. Default to be True."""
self.model_dict = vlm_config.urls self.model_dict = vlm_config.urls
self.model_url = None self.available_models = {}
self.temperature: float = 0.7 self.temperature: float = 0.7
self.top_p:float = 1 self.top_p:float = 1
self.max_tokens: (int |None) = 512 self.max_tokens: (int |None) = 512
self.repetition_penalty: float = 1 self.repetition_penalty: float = 1
self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>'] self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>']
self.top_k: (int) = None self.top_k: (int) = 40
self.ignore_eos: (bool) = False self.ignore_eos: (bool) = False
self.skip_special_tokens: (bool) = True self.skip_special_tokens: (bool) = True
@ -74,7 +80,13 @@ class VLMS(Blackbox):
"skip_special_tokens": self.skip_special_tokens, "skip_special_tokens": self.skip_special_tokens,
} }
for model, url in self.model_dict.items():
try:
response = requests.get(url+'/health',timeout=3)
if response.status_code == 200:
self.available_models[model] = url
except Exception as e:
pass
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -82,7 +94,7 @@ class VLMS(Blackbox):
data = args[0] data = args[0]
return isinstance(data, list) return isinstance(data, list)
def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str: def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, user_context: List[dict] = None) -> str:
""" """
Args: Args:
prompt: a string query to the model. prompt: a string query to the model.
@ -94,22 +106,29 @@ class VLMS(Blackbox):
response: a string response: a string
history: a list history: a list
""" """
config: dict = {
"lmdeploy_infer":True,
"system_prompt":"",
"vlm_model_name":"",
}
if settings: if settings:
for k in settings: for k in list(settings.keys()):
if k not in self.settings: if k not in self.settings:
print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys())) print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys()))
settings.pop(k) config[k] = settings.pop(k)
tmp = copy.deepcopy(self.settings) tmp = copy.deepcopy(self.settings)
tmp.update(settings) tmp.update(settings)
settings = tmp settings = tmp
else: else:
settings = {} settings = {}
config['lmdeploy_infer'] = str(config['lmdeploy_infer']).strip().lower() == 'true'
if not prompt: if not prompt:
prompt = '你是一个辅助机器人请就此图做一个简短的概括性描述包括图中的主体物品及状态不超过50字。' if images else '你好' prompt = '你是一个辅助机器人请就此图做一个简短的概括性描述包括图中的主体物品及状态不超过50字。' if images else '你好'
# Transform the images into base64 format where openai format need. # Transform the images into base64 format where openai url)
if images: if images:
if is_base64(images): # image as base64 str if is_base64(images): # image as base64 str
images_data = images images_data = images
@ -122,43 +141,14 @@ class VLMS(Blackbox):
images_data = str(base64.b64encode(res.content),'utf-8') images_data = str(base64.b64encode(res.content),'utf-8')
else: else:
images_data = None images_data = None
## AutoLoad Model
# url = 'http://10.6.80.87:8000/' + model_name + '/'
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
# data = requests.post(url, json=data_input)
# print(data.text)
# return data.text
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
## Lmdeploy
# if not user_context:
# user_context = []
## Predefine user_context only for testing ## Predefine user_context only for testing
# user_context = [{'role':'user','content':'你好,我叫康康,你是谁?'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}] # user_context = [{'role':'user','content':'你好,我叫康康,你是谁?'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
# user_context = [{
# 'role': 'user',
# 'content': [{
# 'type': 'text',
# 'text': '图中有什么,请描述一下',
# }, {
# 'type': 'image_url',
# 'image_url': {
# 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
# },
# }]
# },{
# 'role': 'assistant',
# 'content': '图片中主要展示了一只老虎,它正在绿色的草地上休息。草地上有很多可以让人坐下的地方,而且看起来相当茂盛。背景比较模糊,可能是因为老虎的影响,让整个图片的其他部分都变得不太清晰了。'
# }
# ]
user_context = self.keep_last_k_images(user_context,k = 1) if not user_context and config['system_prompt']: user_context = [{'role':'system','content': config['system_prompt']}]
if self.model_url is None: self.model_url = self._get_model_url(model_name) user_context = self.keep_last_k_images(user_context,k = 2)
api_client = APIClient(self.model_url)
# api_client = APIClient("http://10.6.80.91:23333")
model_name = api_client.available_models[0]
# Reformat input into openai format to request. # Reformat input into openai format to request.
if images_data: if images_data:
messages = user_context + [{ messages = user_context + [{
@ -170,8 +160,6 @@ class VLMS(Blackbox):
'type': 'image_url', 'type': 'image_url',
'image_url': { # Image two 'image_url': { # Image two
'url': 'url':
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
# './val_data/image_5.jpg'
f"data:image/jpeg;base64,{images_data}", f"data:image/jpeg;base64,{images_data}",
}, },
# },{ # Image one # },{ # Image one
@ -194,40 +182,56 @@ class VLMS(Blackbox):
responses = '' responses = ''
total_token_usage = 0 # which can be used to count the cost of a query total_token_usage = 0 # which can be used to count the cost of a query
model_url = self._get_model_url(config['vlm_model_name'])
if config['lmdeploy_infer']:
api_client = APIClient(model_url)
model_name = api_client.available_models[0]
for i,item in enumerate(api_client.chat_completions_v1(model=model_name, for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
messages=messages,stream = True, messages=messages,stream = True,
**settings, **settings,
# session_id=, # session_id=,
)): )):
# Stream output # Stream output
print(item["choices"][0]["delta"]['content'],end='\n')
yield item["choices"][0]["delta"]['content'] yield item["choices"][0]["delta"]['content']
responses += item["choices"][0]["delta"]['content'] responses += item["choices"][0]["delta"]['content']
# print(item["choices"][0]["message"]['content']) # print(item["choices"][0]["message"]['content'])
# responses += item["choices"][0]["message"]['content'] # responses += item["choices"][0]["message"]['content']
# total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *} # total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
else:
api_key = "EMPTY_API_KEY"
api_client = OpenAI(api_key=api_key, base_url=model_url+'/v1')
model_name = api_client.models.list().data[0].id
for item in api_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True):
yield(item.choices[0].delta.content)
responses += item.choices[0].delta.content
# response = api_client.chat.completions.create(
# model=model_name,
# messages=messages,
# temperature=0.8,
# top_p=0.8)
# print(response.choices[0].message.content)
# return response.choices[0].message.content
user_context = messages + [{'role': 'assistant', 'content': responses}] user_context = messages + [{'role': 'assistant', 'content': responses}]
self.custom_print(user_context) self.custom_print(user_context)
# return responses, user_context # return responses
def _get_model_url(self,model_name:str | None): def _get_model_url(self,model_name:str | None):
available_models = {} if not self.available_models: print("There are no available running models and please check your endpoint urls.")
for model, url in self.model_dict.items(): if model_name and model_name in self.available_models:
try: return self.available_models[model_name]
response = requests.get(url,timeout=3)
if response.status_code == 200:
available_models[model] = url
except Exception as e:
# print(e)
pass
if not available_models: print("There are no available running models and please check your endpoint urls.")
if model_name and model_name in available_models:
return available_models[model_name]
else: else:
model = random.choice(list(available_models.keys())) model = random.choice(list(self.available_models.keys()))
print(f"No such model {model_name}, using {model} instead.") if model_name else print(f"Using random model {model}.") print(f"No such model {model_name}, using {model} instead.") if model_name else print(f"Using random model {model}.")
return available_models[model] return self.available_models[model]
def _into_openai_format(self, context:List[list]) -> List[dict]: def _into_openai_format(self, context:List[list]) -> List[dict]:
""" """
@ -298,7 +302,6 @@ class VLMS(Blackbox):
result.append(item) result.append(item)
return result[::-1] return result[::-1]
def custom_print(self, user_context: list): def custom_print(self, user_context: list):
result = [] result = []
for item in user_context: for item in user_context:
@ -315,18 +318,23 @@ class VLMS(Blackbox):
## TODO: add support for multiple images and support image in form-data format ## TODO: add support for multiple images and support image in form-data format
json_request = True json_request = True
try: try:
content_type = request.headers['content-type'] content_type = request.headers.get('content-type', '')
if content_type == 'application/json': if content_type == 'application/json':
data = await request.json() data = await request.json()
else: elif 'multipart/form-data' in content_type:
data = await request.form() data = await request.form()
json_request = False json_request = False
else:
body = await request.body()
data = json.loads(body.decode("utf-8"))
except Exception as e: except Exception as e:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
model_name = data.get("model_name")
prompt = data.get("prompt") prompt = data.get("prompt")
settings: dict = data.get('settings') settings: dict = data.get('settings')
context = data.get("context") context = data.get("context")
if not context: if not context:
user_context = [] user_context = []
@ -337,7 +345,7 @@ class VLMS(Blackbox):
else: else:
return JSONResponse(content={"error": "context format error, should be in format of list or Openai_format"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "context format error, should be in format of list or Openai_format"}, status_code=status.HTTP_400_BAD_REQUEST)
if json_request: if json_request or 'multipart/form-data' not in content_type:
img_data = data.get("img_data") img_data = data.get("img_data")
else: else:
img_data = await data.get("img_data").read() img_data = await data.get("img_data").read()
@ -345,14 +353,12 @@ class VLMS(Blackbox):
if prompt is None: if prompt is None:
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
streaming_output = str(settings.get('stream',False)).strip().lower() == 'true' if settings else False
# if model_name is None or model_name.isspace(): if streaming_output:
# model_name = "Qwen-VL-Chat" # return StreamingResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
# response,_ = self.processing(prompt, img_data,settings, model_name,user_context=user_context) return EventSourceResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
else:
# return StreamingResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
return EventSourceResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
# HTTP JsonResponse # HTTP JsonResponse
response, history = self.processing(prompt, img_data,settings, model_name,user_context=user_context) output = self.processing(prompt, img_data,settings, user_context=user_context)
# return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK) response = ''.join([res for res in output])
return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)

View File

@ -180,3 +180,17 @@ class VLMConf():
@inject @inject
def __init__(self, config: Configuration) -> None: def __init__(self, config: Configuration) -> None:
self.urls = config.get("vlms.urls") self.urls = config.get("vlms.urls")
@singleton
class PathConf():
chroma_rerank_embedding_model: str
cosyvoice_path: str
cosyvoice_model_path: str
@inject
def __init__(self, config: Configuration) -> None:
self.chroma_rerank_embedding_model = config.get("path.chroma_rerank_embedding_model")
self.cosyvoice_path = config.get("path.cosyvoice_path")
self.cosyvoice_model_path = config.get("path.cosyvoice_model_path")
self.sensevoice_model_path = config.get("path.sensevoice_model_path")

102
src/log/loki_config.py Normal file
View File

@ -0,0 +1,102 @@
import logging
import logging_loki
import os
from requests.auth import HTTPBasicAuth
from typing import Dict, Optional
class LokiLogger:
"""
一个用于配置和获取Loki日志记录器的类。
可以在其他文件里import后直接使用。
用法示例:
from loki_config import LokiLogger
# 获取日志记录器实例 (使用默认配置或环境变量)
loki_logger_instance = LokiLogger().get_logger()
loki_logger_instance.info("这条日志会推送到Loki")
# 或者使用自定义配置
my_loki_logger = LokiLogger(
url="https://your-custom-loki-url.com/loki/api/v1/push",
username="myuser",
password="mypassword",
tags={"app_name": "my-custom-app", "environment": "prod"},
level=logging.DEBUG
).get_logger()
my_loki_logger.debug("这条调试日志也推送到Loki。")
"""
def __init__(self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
level: int = logging.INFO,
logger_name: str = "app_loki_logger"):
"""
初始化LokiLogger。
Args:
url (str, optional): Loki的推送URL。默认从环境变量LOKI_URL获取
如果不存在则使用"https://loki.bwgdi.com/loki/api/v1/push"
username (str, optional): Loki的认证用户名。默认从环境变量LOKI_USERNAME获取。
password (str, optional): Loki的认证密码。默认从环境变量LOKI_PASSWORD获取。
tags (Dict[str, str], optional): 发送到Loki的默认标签。例如{"application": "my-app"}。
如果未提供,则默认为{"application": "default-app", "source": "python-app"}。
level (int, optional): 日志级别如logging.INFO, logging.DEBUG。默认为logging.INFO。
logger_name (str, optional): 日志记录器的名称。默认为"app_loki_logger"
"""
# 从环境变量获取配置,如果未通过参数提供
self._url = url if url else os.getenv("LOKI_URL", "https://loki.bwgdi.com/loki/api/v1/push")
self._username = username if username else os.getenv("LOKI_USERNAME",'admin')
self._password = password if password else os.getenv("LOKI_PASSWORD",'admin')
self._tags = tags if tags is not None else {"app": "jarvis", "env": "dev", "location": "gdi", "layer": "models"}
# 获取或创建指定名称的日志记录器
self._logger = logging.getLogger(logger_name)
self._logger.setLevel(level)
for handler in self._logger.handlers[:]:
self._logger.removeHandler(handler)
# 检查是否已存在LokiHandler避免重复添加导致重复日志
# 在多文件或多次初始化的情况下同一个logger_name可能会获取到同一个logger实例
if not any(isinstance(h, logging_loki.LokiHandler) for h in self._logger.handlers):
try:
auth = None
if self._username and self._password:
auth = HTTPBasicAuth(self._username, self._password)
loki_handler = logging_loki.LokiHandler(
url=self._url,
tags=self._tags,
version="1", # 通常Loki API版本为1
auth=auth
)
self._logger.addHandler(loki_handler)
# 同时添加一个StreamHandler到控制台以便在本地调试时也能看到日志输出
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
self._logger.addHandler(logging.StreamHandler())
self._logger.info(f"LokiLogger: 已成功配置Loki日志处理器目标地址{self._url}")
except Exception as e:
# 如果Loki配置失败确保仍然有StreamHandler将日志输出到控制台
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
self._logger.addHandler(logging.StreamHandler())
self._logger.error(f"LokiLogger: 配置LokiHandler失败{e}。将回退到控制台日志记录。", exc_info=True)
else:
# 如果LokiHandler已经存在只确保StreamHandler存在
if not any(isinstance(h, logging.StreamHandler) for h in self._logger.handlers):
self._logger.addHandler(logging.StreamHandler())
self._logger.debug(f"LokiLogger: '{logger_name}' 记录器已配置LokiHandler跳过重新配置。")
def get_logger(self) -> logging.Logger:
"""
返回已配置的日志记录器实例。
"""
self._logger.debug("LokiLogger: 获取已配置的日志记录器实例。")
return self._logger
# 可选如果希望有一个默认的全局LokiLogger实例
# 你可以在这里实例化,然后在其他文件直接从这里导入 `loki_logger`
# 例如:
# DEFAULT_LOKI_LOGGER_INSTANCE = LokiLogger().get_logger()
# 然后在其他文件里: `from loki_config import DEFAULT_LOKI_LOGGER_INSTANCE as logger`

View File

@ -19,7 +19,7 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1] dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath) dirbaspath= "./src/tts" + "/".join(dirbaspath)
config = { config = {
'ayaka': { 'ayaka': {
'cfg': dirbaspath + '/models/ayaka.json', 'cfg': dirbaspath + '/models/ayaka.json',