mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
add asr and tts with settings
This commit is contained in:
5
main.py
5
main.py
@ -4,6 +4,9 @@ from injector import Injector,inject
|
|||||||
from src.log.handler import LogHandler
|
from src.log.handler import LogHandler
|
||||||
from src.configuration import EnvConf, LogConf, singleton
|
from src.configuration import EnvConf, LogConf, singleton
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class Main():
|
class Main():
|
||||||
|
|
||||||
@ -14,7 +17,7 @@ class Main():
|
|||||||
def run(self):
|
def run(self):
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.info("jarvis-models start", extra={"version": "0.0.1"})
|
logger.info("jarvis-models start", extra={"version": "0.0.1"})
|
||||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, log_level="info")
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, log_level="info",reload = True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
injector = Injector()
|
injector = Injector()
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def get_all_files(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
# 加载文档和拆分文档
|
# 加载文档和拆分文档
|
||||||
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_zh.txt")
|
loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_zh.txt")
|
||||||
|
|
||||||
documents = loader.load()
|
documents = loader.load()
|
||||||
|
|
||||||
@ -84,11 +84,11 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
|||||||
|
|
||||||
|
|
||||||
# 加载embedding模型和chroma server
|
# 加载embedding模型和chroma server
|
||||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
|
|
||||||
id = "g2e"
|
id = "g2e"
|
||||||
client.delete_collection(id)
|
#client.delete_collection(id)
|
||||||
collection_number = client.get_or_create_collection(id).count()
|
collection_number = client.get_or_create_collection(id).count()
|
||||||
print("collection_number",collection_number)
|
print("collection_number",collection_number)
|
||||||
start_time2 = time.time()
|
start_time2 = time.time()
|
||||||
@ -106,8 +106,8 @@ print("collection_number",collection_number)
|
|||||||
|
|
||||||
# # chroma 召回
|
# # chroma 召回
|
||||||
# from chromadb.utils import embedding_functions
|
# from chromadb.utils import embedding_functions
|
||||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||||
# client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||||
|
|
||||||
# print(collection.count())
|
# print(collection.count())
|
||||||
@ -152,7 +152,7 @@ print("collection_number",collection_number)
|
|||||||
# 'Content-Type': 'application/json',
|
# 'Content-Type': 'application/json',
|
||||||
# 'Authorization': "Bearer " + key
|
# 'Authorization': "Bearer " + key
|
||||||
# }
|
# }
|
||||||
# url = "http://172.16.5.8:23333/v1/chat/completions"
|
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||||
|
|
||||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||||
# # print(fastchat_response.json())
|
# # print(fastchat_response.json())
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def get_all_files(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
# 加载文档和拆分文档
|
# 加载文档和拆分文档
|
||||||
loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_en.txt")
|
loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_en.txt")
|
||||||
|
|
||||||
documents = loader.load()
|
documents = loader.load()
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ ids = ["20240521_store"+str(i) for i in range(len(docs))]
|
|||||||
|
|
||||||
|
|
||||||
# 加载embedding模型和chroma server
|
# 加载embedding模型和chroma server
|
||||||
embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5', model_kwargs={"device": "cuda"})
|
embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5', model_kwargs={"device": "cuda"})
|
||||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
|
|
||||||
id = "g2e_english"
|
id = "g2e_english"
|
||||||
client.delete_collection(id)
|
client.delete_collection(id)
|
||||||
@ -106,8 +106,8 @@ print("collection_number",collection_number)
|
|||||||
|
|
||||||
# # chroma 召回
|
# # chroma 召回
|
||||||
# from chromadb.utils import embedding_functions
|
# from chromadb.utils import embedding_functions
|
||||||
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
# embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||||
# client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
# collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||||
|
|
||||||
# print(collection.count())
|
# print(collection.count())
|
||||||
@ -152,7 +152,7 @@ print("collection_number",collection_number)
|
|||||||
# 'Content-Type': 'application/json',
|
# 'Content-Type': 'application/json',
|
||||||
# 'Authorization': "Bearer " + key
|
# 'Authorization': "Bearer " + key
|
||||||
# }
|
# }
|
||||||
# url = "http://172.16.5.8:23333/v1/chat/completions"
|
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||||
|
|
||||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||||
# # print(fastchat_response.json())
|
# # print(fastchat_response.json())
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def get_all_files(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
# 加载文档和拆分文档
|
# 加载文档和拆分文档
|
||||||
# loader = TextLoader("/home/administrator/Workspace/jarvis-models/sample/RAG_zh.txt")
|
# loader = TextLoader("/home/gpu/Workspace/jarvis-models/sample/RAG_zh.txt")
|
||||||
|
|
||||||
# documents = loader.load()
|
# documents = loader.load()
|
||||||
|
|
||||||
@ -80,8 +80,8 @@ def get_all_files(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
# # 加载embedding模型和chroma server
|
# # 加载embedding模型和chroma server
|
||||||
# embedding_model = SentenceTransformerEmbeddings(model_name='/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
# embedding_model = SentenceTransformerEmbeddings(model_name='/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5', model_kwargs={"device": "cuda"})
|
||||||
# client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
# client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
|
|
||||||
# id = "g2e"
|
# id = "g2e"
|
||||||
# client.delete_collection(id)
|
# client.delete_collection(id)
|
||||||
@ -102,8 +102,8 @@ def get_all_files(folder_path):
|
|||||||
|
|
||||||
# chroma 召回
|
# chroma 召回
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
||||||
client = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
client = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
collection = client.get_collection("g2e", embedding_function=embedding_model)
|
||||||
|
|
||||||
print(collection.count())
|
print(collection.count())
|
||||||
@ -148,7 +148,7 @@ print("time: ", time.time() - start_time)
|
|||||||
# 'Content-Type': 'application/json',
|
# 'Content-Type': 'application/json',
|
||||||
# 'Authorization': "Bearer " + key
|
# 'Authorization': "Bearer " + key
|
||||||
# }
|
# }
|
||||||
# url = "http://172.16.5.8:23333/v1/chat/completions"
|
# url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||||
|
|
||||||
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
# fastchat_response = requests.post(url, json=chat_inputs, headers=header)
|
||||||
# # print(fastchat_response.json())
|
# # print(fastchat_response.json())
|
||||||
|
|||||||
@ -6,26 +6,109 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from ..asr.rapid_paraformer.utils import read_yaml
|
from ..asr.rapid_paraformer.utils import read_yaml
|
||||||
from ..asr.rapid_paraformer import RapidParaformer
|
from ..asr.rapid_paraformer import RapidParaformer
|
||||||
|
|
||||||
|
from funasr import AutoModel
|
||||||
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||||
from .blackbox import Blackbox
|
from .blackbox import Blackbox
|
||||||
from injector import singleton, inject
|
from injector import singleton, inject
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from ..configuration import SenseVoiceConf
|
||||||
|
|
||||||
|
from ..log.logging_time import logging_time
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class ASR(Blackbox):
|
class ASR(Blackbox):
|
||||||
|
mode: str
|
||||||
|
url: str
|
||||||
|
speed: int
|
||||||
|
device: str
|
||||||
|
language: str
|
||||||
|
speaker: str
|
||||||
|
|
||||||
|
@logging_time(logger=logger)
|
||||||
|
def model_init(self, sensevoice_config: SenseVoiceConf) -> None:
|
||||||
|
|
||||||
|
config = read_yaml(".env.yaml")
|
||||||
|
self.paraformer = RapidParaformer(config)
|
||||||
|
|
||||||
|
model_dir = "/home/gpu/Workspace/Models/SenseVoice/SenseVoiceSmall"
|
||||||
|
|
||||||
|
self.speed = sensevoice_config.speed
|
||||||
|
self.device = sensevoice_config.device
|
||||||
|
self.language = sensevoice_config.language
|
||||||
|
self.speaker = sensevoice_config.speaker
|
||||||
|
self.device = sensevoice_config.device
|
||||||
|
self.url = ''
|
||||||
|
self.mode = sensevoice_config.mode
|
||||||
|
self.asr = None
|
||||||
|
self.speaker_ids = None
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = str(sensevoice_config.device)
|
||||||
|
if self.mode == 'local':
|
||||||
|
self.asr = AutoModel(
|
||||||
|
model=model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
remote_code= "/home/gpu/Workspace/SenseVoice/model.py",
|
||||||
|
vad_model="fsmn-vad",
|
||||||
|
vad_kwargs={"max_single_segment_time": 30000},
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.url = sensevoice_config.url
|
||||||
|
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + sensevoice_config.device + ' mode...')
|
||||||
|
|
||||||
@inject
|
@inject
|
||||||
def __init__(self,path = ".env.yaml") -> None:
|
def __init__(self, sensevoice_config: SenseVoiceConf, settings: dict) -> None:
|
||||||
config = read_yaml(path)
|
self.model_init(sensevoice_config)
|
||||||
self.paraformer = RapidParaformer(config)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.processing(*args, **kwargs)
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
async def processing(self, *args, **kwargs):
|
async def processing(self, *args, settings: dict):
|
||||||
|
|
||||||
|
print("\nChat Settings: ", settings)
|
||||||
|
if settings is None:
|
||||||
|
settings = {}
|
||||||
|
user_model_name = settings.get("asr_model_name")
|
||||||
|
print(f"asr_model_name: {user_model_name}")
|
||||||
data = args[0]
|
data = args[0]
|
||||||
results = self.paraformer([BytesIO(data)])
|
|
||||||
if len(results) == 0:
|
if user_model_name == 'sensevoice' or ['sensevoice']:
|
||||||
return None
|
# 创建一个临时文件
|
||||||
return results[0]
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
|
||||||
|
temp_audio_file.write(data)
|
||||||
|
temp_audio_path = temp_audio_file.name
|
||||||
|
res = self.asr.generate(
|
||||||
|
input=temp_audio_path,
|
||||||
|
cache={},
|
||||||
|
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||||
|
use_itn=True,
|
||||||
|
batch_size_s=60,
|
||||||
|
merge_vad=True, #
|
||||||
|
merge_length_s=15,
|
||||||
|
)
|
||||||
|
# results = self.paraformer([BytesIO(data)])
|
||||||
|
results = rich_transcription_postprocess(res[0]["text"])
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
return results
|
||||||
|
|
||||||
|
elif user_model_name == 'funasr' or ['funasr']:
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
else:
|
||||||
|
results = self.paraformer([BytesIO(data)])
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
return results[0]
|
||||||
|
|
||||||
def valid(self, data: any) -> bool:
|
def valid(self, data: any) -> bool:
|
||||||
if isinstance(data, bytes):
|
if isinstance(data, bytes):
|
||||||
@ -34,12 +117,20 @@ class ASR(Blackbox):
|
|||||||
|
|
||||||
async def fast_api_handler(self, request: Request) -> Response:
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
data = (await request.form()).get("audio")
|
data = (await request.form()).get("audio")
|
||||||
|
setting: dict = (await request.form()).get("settings")
|
||||||
|
|
||||||
|
if isinstance(setting, str):
|
||||||
|
try:
|
||||||
|
setting = json.loads(setting) # 尝试将字符串转换为字典
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return JSONResponse(content={"error": "Invalid settings format"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
# self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
||||||
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
d = await data.read()
|
d = await data.read()
|
||||||
try:
|
try:
|
||||||
txt = await self.processing(d)
|
txt = await self.processing(d, settings=setting)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
||||||
45
src/blackbox/asr_bak.py
Normal file
45
src/blackbox/asr_bak.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
|
from fastapi import Request, Response, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from ..asr.rapid_paraformer.utils import read_yaml
|
||||||
|
from ..asr.rapid_paraformer import RapidParaformer
|
||||||
|
from .blackbox import Blackbox
|
||||||
|
from injector import singleton, inject
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class ASR(Blackbox):
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self,path = ".env.yaml") -> None:
|
||||||
|
config = read_yaml(path)
|
||||||
|
self.paraformer = RapidParaformer(config)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
|
async def processing(self, *args, **kwargs):
|
||||||
|
data = args[0]
|
||||||
|
results = self.paraformer([BytesIO(data)])
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
def valid(self, data: any) -> bool:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
|
data = (await request.form()).get("audio")
|
||||||
|
if data is None:
|
||||||
|
# self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
||||||
|
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
d = await data.read()
|
||||||
|
try:
|
||||||
|
txt = await self.processing(d)
|
||||||
|
except ValueError as e:
|
||||||
|
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
||||||
103
src/blackbox/asrsensevoice.py
Normal file
103
src/blackbox/asrsensevoice.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
|
from fastapi import Request, Response, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from funasr import AutoModel
|
||||||
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||||
|
from .blackbox import Blackbox
|
||||||
|
from injector import singleton, inject
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ..configuration import SenseVoiceConf
|
||||||
|
|
||||||
|
from ..log.logging_time import logging_time
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class ASR(Blackbox):
|
||||||
|
mode: str
|
||||||
|
url: str
|
||||||
|
speed: int
|
||||||
|
device: str
|
||||||
|
language: str
|
||||||
|
speaker: str
|
||||||
|
|
||||||
|
@logging_time(logger=logger)
|
||||||
|
def model_init(self, sensevoice_config: SenseVoiceConf) -> None:
|
||||||
|
|
||||||
|
model_dir = "/home/gpu/Workspace/Models/SenseVoice/SenseVoiceSmall"
|
||||||
|
|
||||||
|
self.speed = sensevoice_config.speed
|
||||||
|
self.device = sensevoice_config.device
|
||||||
|
self.language = sensevoice_config.language
|
||||||
|
self.speaker = sensevoice_config.speaker
|
||||||
|
self.device = sensevoice_config.device
|
||||||
|
self.url = ''
|
||||||
|
self.mode = sensevoice_config.mode
|
||||||
|
self.asr = None
|
||||||
|
self.speaker_ids = None
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(sensevoice_config.device)
|
||||||
|
if self.mode == 'local':
|
||||||
|
self.asr = AutoModel(
|
||||||
|
model=model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
remote_code= "/home/gpu/Workspace/SenseVoice/model.py",
|
||||||
|
vad_model="fsmn-vad",
|
||||||
|
vad_kwargs={"max_single_segment_time": 30000},
|
||||||
|
device="cuda:0",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.url = sensevoice_config.url
|
||||||
|
logging.info('#### Initializing SenseVoiceASR Service in cuda:' + str(sensevoice_config.device) + ' mode...')
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, sensevoice_config: SenseVoiceConf) -> None:
|
||||||
|
self.model_init(sensevoice_config)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
|
async def processing(self, *args, **kwargs):
|
||||||
|
data = args[0]
|
||||||
|
# 创建一个临时文件
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
|
||||||
|
temp_audio_file.write(data)
|
||||||
|
temp_audio_path = temp_audio_file.name
|
||||||
|
res = self.asr.generate(
|
||||||
|
input=temp_audio_path,
|
||||||
|
cache={},
|
||||||
|
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||||
|
use_itn=True,
|
||||||
|
batch_size_s=60,
|
||||||
|
merge_vad=True, #
|
||||||
|
merge_length_s=15,
|
||||||
|
)
|
||||||
|
# results = self.paraformer([BytesIO(data)])
|
||||||
|
results = rich_transcription_postprocess(res[0]["text"])
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
return results
|
||||||
|
|
||||||
|
def valid(self, data: any) -> bool:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
|
data = (await request.form()).get("audio")
|
||||||
|
if data is None:
|
||||||
|
# self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
|
||||||
|
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
d = await data.read()
|
||||||
|
try:
|
||||||
|
txt = await self.processing(d)
|
||||||
|
except ValueError as e:
|
||||||
|
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)
|
||||||
@ -3,7 +3,7 @@ from fastapi import Request, Response,status
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
|
|
||||||
from .asr import ASR
|
from .asrsensevoice import ASR
|
||||||
from .tesou import Tesou
|
from .tesou import Tesou
|
||||||
from .tts import TTS
|
from .tts import TTS
|
||||||
|
|
||||||
|
|||||||
@ -47,15 +47,15 @@ def vlms_loader():
|
|||||||
from .vlms import VLMS
|
from .vlms import VLMS
|
||||||
return Injector().get(VLMS)
|
return Injector().get(VLMS)
|
||||||
|
|
||||||
@model_loader(lazy=blackboxConf.lazyloading)
|
# @model_loader(lazy=blackboxConf.lazyloading)
|
||||||
def melotts_loader():
|
# def melotts_loader():
|
||||||
from .melotts import MeloTTS
|
# from .melotts import MeloTTS
|
||||||
return Injector().get(MeloTTS)
|
# return Injector().get(MeloTTS)
|
||||||
|
|
||||||
@model_loader(lazy=blackboxConf.lazyloading)
|
# @model_loader(lazy=blackboxConf.lazyloading)
|
||||||
def cosyvoicetts_loader():
|
# def cosyvoicetts_loader():
|
||||||
from .cosyvoicetts import CosyVoiceTTS
|
# from .cosyvoicetts import CosyVoiceTTS
|
||||||
return Injector().get(CosyVoiceTTS)
|
# return Injector().get(CosyVoiceTTS)
|
||||||
|
|
||||||
@model_loader(lazy=blackboxConf.lazyloading)
|
@model_loader(lazy=blackboxConf.lazyloading)
|
||||||
def tts_loader():
|
def tts_loader():
|
||||||
@ -97,11 +97,6 @@ def chat_llama_loader():
|
|||||||
from .chat_llama import ChatLLaMA
|
from .chat_llama import ChatLLaMA
|
||||||
return Injector().get(ChatLLaMA)
|
return Injector().get(ChatLLaMA)
|
||||||
|
|
||||||
@model_loader(lazy=blackboxConf.lazyloading)
|
|
||||||
def cosyvoicetts_loader():
|
|
||||||
from .cosyvoicetts import CosyVoiceTTS
|
|
||||||
return Injector().get(CosyVoiceTTS)
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class BlackboxFactory:
|
class BlackboxFactory:
|
||||||
models = {}
|
models = {}
|
||||||
@ -119,11 +114,11 @@ class BlackboxFactory:
|
|||||||
self.models["chroma_query"] = chroma_query_loader
|
self.models["chroma_query"] = chroma_query_loader
|
||||||
self.models["chroma_upsert"] = chroma_upsert_loader
|
self.models["chroma_upsert"] = chroma_upsert_loader
|
||||||
self.models["chroma_chat"] = chroma_chat_loader
|
self.models["chroma_chat"] = chroma_chat_loader
|
||||||
self.models["melotts"] = melotts_loader
|
# self.models["melotts"] = melotts_loader
|
||||||
self.models["vlms"] = vlms_loader
|
self.models["vlms"] = vlms_loader
|
||||||
self.models["chat"] = chat_loader
|
self.models["chat"] = chat_loader
|
||||||
self.models["chat_llama"] = chat_llama_loader
|
self.models["chat_llama"] = chat_llama_loader
|
||||||
self.models["cosyvoicetts"] = cosyvoicetts_loader
|
# self.models["cosyvoicetts"] = cosyvoicetts_loader
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.processing(*args, **kwargs)
|
return self.processing(*args, **kwargs)
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class Chat(Blackbox):
|
|||||||
#user_presence_penalty = 0.8
|
#user_presence_penalty = 0.8
|
||||||
|
|
||||||
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
if user_model_url is None or user_model_url.isspace() or user_model_url == "":
|
||||||
user_model_url = "http://172.16.5.8:23333/v1/chat/completions"
|
user_model_url = "http://10.6.81.119:23333/v1/chat/completions"
|
||||||
|
|
||||||
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
if user_model_key is None or user_model_key.isspace() or user_model_key == "":
|
||||||
user_model_key = "YOUR_API_KEY"
|
user_model_key = "YOUR_API_KEY"
|
||||||
|
|||||||
@ -20,9 +20,9 @@ class ChromaQuery(Blackbox):
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
# config = read_yaml(args[0])
|
# config = read_yaml(args[0])
|
||||||
# load chromadb and embedding model
|
# load chromadb and embedding model
|
||||||
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda")
|
self.embedding_model_1 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", device = "cuda:1")
|
||||||
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", device = "cuda")
|
self.embedding_model_2 = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", device = "cuda:1")
|
||||||
self.client_1 = chromadb.HttpClient(host='172.16.5.8', port=7000)
|
self.client_1 = chromadb.HttpClient(host='10.6.81.119', port=7000)
|
||||||
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
# self.client_2 = chromadb.HttpClient(host='10.6.82.192', port=8000)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@ -51,10 +51,10 @@ class ChromaQuery(Blackbox):
|
|||||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
if chroma_embedding_model is None or chroma_embedding_model.isspace() or chroma_embedding_model == "":
|
||||||
chroma_embedding_model = "/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
chroma_embedding_model = "/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5"
|
||||||
|
|
||||||
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
if chroma_host is None or chroma_host.isspace() or chroma_host == "":
|
||||||
chroma_host = "172.16.5.8"
|
chroma_host = "10.6.81.119"
|
||||||
|
|
||||||
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
if chroma_port is None or chroma_port.isspace() or chroma_port == "":
|
||||||
chroma_port = "7000"
|
chroma_port = "7000"
|
||||||
@ -66,17 +66,17 @@ class ChromaQuery(Blackbox):
|
|||||||
chroma_n_results = 3
|
chroma_n_results = 3
|
||||||
|
|
||||||
# load client and embedding model from init
|
# load client and embedding model from init
|
||||||
if re.search(r"172.16.5.8", chroma_host) and re.search(r"7000", chroma_port):
|
if re.search(r"10.6.81.119", chroma_host) and re.search(r"7000", chroma_port):
|
||||||
client = self.client_1
|
client = self.client_1
|
||||||
else:
|
else:
|
||||||
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
|
client = chromadb.HttpClient(host=chroma_host, port=chroma_port)
|
||||||
|
|
||||||
if re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
if re.search(r"/home/gpu/Workspace/Models/BAAI/bge-large-zh-v1.5", chroma_embedding_model):
|
||||||
embedding_model = self.embedding_model_1
|
embedding_model = self.embedding_model_1
|
||||||
elif re.search(r"/home/administrator/Workspace/Models/BAAI/bge-large-en-v1.5", chroma_embedding_model):
|
elif re.search(r"/home/gpu/Workspace/Models/BAAI/bge-small-en-v1.5", chroma_embedding_model):
|
||||||
embedding_model = self.embedding_model_2
|
embedding_model = self.embedding_model_2
|
||||||
else:
|
else:
|
||||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda")
|
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=chroma_embedding_model, device = "cuda:1")
|
||||||
|
|
||||||
# load collection
|
# load collection
|
||||||
collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model)
|
collection = client.get_collection(chroma_collection_id, embedding_function=embedding_model)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class G2E(Blackbox):
|
|||||||
if context == None:
|
if context == None:
|
||||||
context = []
|
context = []
|
||||||
#url = 'http://120.196.116.194:48890/v1'
|
#url = 'http://120.196.116.194:48890/v1'
|
||||||
url = 'http://120.196.116.194:48892/v1'
|
url = 'http://10.6.81.119:23333/v1'
|
||||||
|
|
||||||
background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯
|
background_prompt = '''KOMBUKIKI是一款茶饮料,目标受众 年龄:20-35岁 性别:女性 地点:一线城市、二线城市 职业:精英中产、都市白领 收入水平:中高收入,有一定消费能力 兴趣和爱好:注重健康,有运动习惯
|
||||||
|
|
||||||
@ -73,11 +73,11 @@ class G2E(Blackbox):
|
|||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.8,
|
temperature="0.8",
|
||||||
top_p=0.8,
|
top_p="0.8",
|
||||||
frequency_penalty=0.5,
|
frequency_penalty="0.5",
|
||||||
presence_penalty=0.8,
|
presence_penalty="0.8",
|
||||||
stop=100
|
stop="100"
|
||||||
)
|
)
|
||||||
|
|
||||||
fastchat_content = response.choices[0].message.content
|
fastchat_content = response.choices[0].message.content
|
||||||
|
|||||||
@ -2,27 +2,154 @@ import io
|
|||||||
import time
|
import time
|
||||||
from ntpath import join
|
from ntpath import join
|
||||||
|
|
||||||
|
import requests
|
||||||
from fastapi import Request, Response, status
|
from fastapi import Request, Response, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from .blackbox import Blackbox
|
from .blackbox import Blackbox
|
||||||
from ..tts.tts_service import TTService
|
from ..tts.tts_service import TTService
|
||||||
|
from ..configuration import MeloConf
|
||||||
|
from ..configuration import CosyVoiceConf
|
||||||
|
from injector import inject
|
||||||
from injector import singleton
|
from injector import singleton
|
||||||
|
|
||||||
|
import sys,os
|
||||||
|
sys.path.append('/home/gpu/Workspace/CosyVoice')
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||||
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
import pyloudnorm as pyln
|
||||||
|
from melo.api import TTS as MELOTTS
|
||||||
|
|
||||||
|
from ..log.logging_time import logging_time
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class TTS(Blackbox):
|
class TTS(Blackbox):
|
||||||
|
melo_mode: str
|
||||||
|
melo_url: str
|
||||||
|
melo_speed: int
|
||||||
|
melo_device: str
|
||||||
|
melo_language: str
|
||||||
|
melo_speaker: str
|
||||||
|
|
||||||
|
cosyvoice_mode: str
|
||||||
|
cosyvoice_url: str
|
||||||
|
cosyvoice_speed: int
|
||||||
|
cosyvoice_device: str
|
||||||
|
cosyvoice_language: str
|
||||||
|
cosyvoice_speaker: str
|
||||||
|
|
||||||
|
@logging_time(logger=logger)
|
||||||
|
def melo_model_init(self, melo_config: MeloConf) -> None:
|
||||||
|
self.melo_speed = melo_config.speed
|
||||||
|
self.melo_device = melo_config.device
|
||||||
|
self.melo_language = melo_config.language
|
||||||
|
self.melo_speaker = melo_config.speaker
|
||||||
|
self.melo_url = ''
|
||||||
|
self.melo_mode = melo_config.mode
|
||||||
|
self.melotts = None
|
||||||
|
self.speaker_ids = None
|
||||||
|
if self.melo_mode == 'local':
|
||||||
|
self.melotts = MELOTTS(language=self.melo_language, device=self.melo_device)
|
||||||
|
self.speaker_ids = self.melotts.hps.data.spk2id
|
||||||
|
else:
|
||||||
|
self.melo_url = melo_config.url
|
||||||
|
logging.info('#### Initializing MeloTTS Service in ' + self.melo_device + ' mode...')
|
||||||
|
|
||||||
|
@logging_time(logger=logger)
|
||||||
|
def cosyvoice_model_init(self, cosyvoice_config: CosyVoiceConf) -> None:
|
||||||
|
self.cosyvoice_speed = cosyvoice_config.speed
|
||||||
|
self.cosyvoice_device = cosyvoice_config.device
|
||||||
|
self.cosyvoice_language = cosyvoice_config.language
|
||||||
|
self.cosyvoice_speaker = cosyvoice_config.speaker
|
||||||
|
self.cosyvoice_url = ''
|
||||||
|
self.cosyvoice_mode = cosyvoice_config.mode
|
||||||
|
self.cosyvoicetts = None
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
|
||||||
|
if self.cosyvoice_mode == 'local':
|
||||||
|
self.cosyvoicetts = CosyVoice('/home/gpu/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.cosyvoice_url = cosyvoice_config.url
|
||||||
|
logging.info('#### Initializing CosyVoiceTTS Service in cuda:' + self.cosyvoice_device + ' mode...')
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, melo_config: MeloConf, cosyvoice_config: CosyVoiceConf, settings: dict) -> None:
|
||||||
|
self.tts_service = TTService("yunfeineo")
|
||||||
|
self.melo_model_init(melo_config)
|
||||||
|
self.cosyvoice_model_init(cosyvoice_config)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
self.tts_service = TTService("catmaid")
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.processing(*args, **kwargs)
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
def processing(self, *args, **kwargs) -> io.BytesIO:
|
@logging_time(logger=logger)
|
||||||
|
def processing(self, *args, settings: dict) -> io.BytesIO:
|
||||||
|
|
||||||
|
print("\nChat Settings: ", settings)
|
||||||
|
if settings is None:
|
||||||
|
settings = {}
|
||||||
|
user_model_name = settings.get("tts_model_name")
|
||||||
|
print(f"tts_model_name: {user_model_name}")
|
||||||
|
|
||||||
text = args[0]
|
text = args[0]
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
audio = self.tts_service.read(text)
|
|
||||||
print("#### TTS Service consume : ", (time.time()-current_time))
|
if user_model_name == 'melotts':
|
||||||
return audio
|
if self.melo_mode == 'local':
|
||||||
|
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.melo_speaker], speed=self.melo_speed)
|
||||||
|
f = io.BytesIO()
|
||||||
|
soundfile.write(f, audio, 44100, format='wav')
|
||||||
|
f.seek(0)
|
||||||
|
|
||||||
|
# Read the audio data from the buffer
|
||||||
|
data, rate = soundfile.read(f, dtype='float32')
|
||||||
|
|
||||||
|
# Peak normalization
|
||||||
|
peak_normalized_audio = pyln.normalize.peak(data, -1.0)
|
||||||
|
|
||||||
|
# Integrated loudness normalization
|
||||||
|
meter = pyln.Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(peak_normalized_audio)
|
||||||
|
loudness_normalized_audio = pyln.normalize.loudness(peak_normalized_audio, loudness, -12.0)
|
||||||
|
|
||||||
|
# Write the loudness normalized audio to an in-memory buffer
|
||||||
|
normalized_audio_buffer = io.BytesIO()
|
||||||
|
soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav')
|
||||||
|
normalized_audio_buffer.seek(0)
|
||||||
|
|
||||||
|
print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
|
||||||
|
return normalized_audio_buffer.read()
|
||||||
|
|
||||||
|
else:
|
||||||
|
message = {
|
||||||
|
"text": text
|
||||||
|
}
|
||||||
|
response = requests.post(self.melo_url, json=message)
|
||||||
|
print("#### MeloTTS Service consume - docker : ", (time.time()-current_time))
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
elif user_model_name == 'cosyvoicetts':
|
||||||
|
if self.cosyvoice_mode == 'local':
|
||||||
|
audio = self.cosyvoicetts.inference_sft(text, self.cosyvoice_language)
|
||||||
|
f = io.BytesIO()
|
||||||
|
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
|
||||||
|
f.seek(0)
|
||||||
|
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
|
||||||
|
return f.read()
|
||||||
|
else:
|
||||||
|
message = {
|
||||||
|
"text": text
|
||||||
|
}
|
||||||
|
response = requests.post(self.cosyvoice_url, json=message)
|
||||||
|
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
|
||||||
|
return response.content
|
||||||
|
else:
|
||||||
|
audio = self.tts_service.read(text)
|
||||||
|
print("#### TTS Service consume : ", (time.time()-current_time))
|
||||||
|
return audio.read()
|
||||||
|
|
||||||
def valid(self, *args, **kwargs) -> bool:
|
def valid(self, *args, **kwargs) -> bool:
|
||||||
text = args[0]
|
text = args[0]
|
||||||
@ -31,10 +158,12 @@ class TTS(Blackbox):
|
|||||||
async def fast_api_handler(self, request: Request) -> Response:
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
print(f"data: {data}")
|
||||||
except:
|
except:
|
||||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
text = data.get("text")
|
text = data.get("text")
|
||||||
|
setting = data.get("settings")
|
||||||
if text is None:
|
if text is None:
|
||||||
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
by = self.processing(text)
|
by = self.processing(text, settings=setting)
|
||||||
return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
return Response(content=by, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||||
40
src/blackbox/tts_bak.py
Normal file
40
src/blackbox/tts_bak.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import io
|
||||||
|
import time
|
||||||
|
from ntpath import join
|
||||||
|
|
||||||
|
from fastapi import Request, Response, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from .blackbox import Blackbox
|
||||||
|
from ..tts.tts_service import TTService
|
||||||
|
from injector import singleton
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class TTS(Blackbox):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
self.tts_service = TTService("yunfeineo")
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
|
def processing(self, *args, **kwargs) -> io.BytesIO:
|
||||||
|
text = args[0]
|
||||||
|
current_time = time.time()
|
||||||
|
audio = self.tts_service.read(text)
|
||||||
|
print("#### TTS Service consume : ", (time.time()-current_time))
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def valid(self, *args, **kwargs) -> bool:
|
||||||
|
text = args[0]
|
||||||
|
return isinstance(text, str)
|
||||||
|
|
||||||
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except:
|
||||||
|
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
text = data.get("text")
|
||||||
|
if text is None:
|
||||||
|
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
by = self.processing(text)
|
||||||
|
return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})
|
||||||
@ -1,11 +1,19 @@
|
|||||||
from fastapi import Request, Response, status
|
from fastapi import Request, Response, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from .blackbox import Blackbox
|
from injector import singleton,inject
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .blackbox import Blackbox
|
||||||
|
from ..log.logging_time import logging_time
|
||||||
|
from .chroma_query import ChromaQuery
|
||||||
|
from ..configuration import VLMConf
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
from lmdeploy.serve.openai.api_client import APIClient
|
||||||
|
|
||||||
def is_base64(value) -> bool:
|
def is_base64(value) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -14,9 +22,16 @@ def is_base64(value) -> bool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@singleton
|
||||||
class VLMS(Blackbox):
|
class VLMS(Blackbox):
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, vlm_config: VLMConf):
|
||||||
|
# Chroma database initially set up for RAG for vision model.
|
||||||
|
# It could be expended to history store.
|
||||||
|
# self.chroma_query = chroma_query
|
||||||
|
self.url = vlm_config.url
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.processing(*args, **kwargs)
|
return self.processing(*args, **kwargs)
|
||||||
|
|
||||||
@ -26,6 +41,7 @@ class VLMS(Blackbox):
|
|||||||
|
|
||||||
def processing(self, prompt, images, model_name: Optional[str] = None) -> str:
|
def processing(self, prompt, images, model_name: Optional[str] = None) -> str:
|
||||||
|
|
||||||
|
# Current only Qwen-vl model
|
||||||
if model_name == "Qwen-VL-Chat":
|
if model_name == "Qwen-VL-Chat":
|
||||||
model_name = "infer-qwen-vl"
|
model_name = "infer-qwen-vl"
|
||||||
elif model_name == "llava-llama-3-8b-v1_1-transformers":
|
elif model_name == "llava-llama-3-8b-v1_1-transformers":
|
||||||
@ -33,19 +49,61 @@ class VLMS(Blackbox):
|
|||||||
else:
|
else:
|
||||||
model_name = "infer-qwen-vl"
|
model_name = "infer-qwen-vl"
|
||||||
|
|
||||||
url = 'http://120.196.116.194:48894/' + model_name + '/'
|
|
||||||
|
## AutoLoad Model
|
||||||
|
# url = 'http://10.6.80.87:8000/' + model_name + '/'
|
||||||
|
|
||||||
if is_base64(images):
|
if is_base64(images):
|
||||||
images_data = images
|
images_data = images
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
# print("{}Type of image data in form {}".format('#'*20,type(images)))
|
||||||
|
# print("{}Type of image data in form {}".format('#'*20,type(images.file)))
|
||||||
|
# byte_stream = io.BytesIO(images.read())
|
||||||
|
# print("{}Type of image data in form {}".format('#'*20,type(byte_stream)))
|
||||||
|
# roiImg = Image.open(byte_stream)
|
||||||
|
# print("{}Successful {}".format('#'*20,type(roiImg)))
|
||||||
|
# return str(type(byte_stream))
|
||||||
|
# images_data = base64.b64encode(byte_stream)
|
||||||
|
|
||||||
|
|
||||||
with open(images, "rb") as img_file:
|
with open(images, "rb") as img_file:
|
||||||
images_data = str(base64.b64encode(img_file.read()), 'utf-8')
|
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
|
||||||
|
images_data = base64.b64encode(img_file.read())
|
||||||
|
|
||||||
data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
|
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
|
||||||
|
|
||||||
data = requests.post(url, json=data_input)
|
# data = requests.post(url, json=data_input)
|
||||||
|
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
|
||||||
|
|
||||||
return data.text
|
## Lmdeploy
|
||||||
|
api_client = APIClient(self.url)
|
||||||
|
model_name = api_client.available_models[0]
|
||||||
|
messages = [{
|
||||||
|
'role': 'user',
|
||||||
|
'content': [{
|
||||||
|
'type': 'text',
|
||||||
|
'text': prompt,
|
||||||
|
}, {
|
||||||
|
'type': 'image_url',
|
||||||
|
'image_url': {
|
||||||
|
'url': f"data:image/jpeg;base64,{images_data}",
|
||||||
|
# './val_data/image_5.jpg',
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
responses = ''
|
||||||
|
for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
|
||||||
|
messages=messages#,stream = True
|
||||||
|
)):
|
||||||
|
# print(item["choices"][0]["message"]['content'])
|
||||||
|
responses += item["choices"][0]["message"]['content']
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
# return data.text
|
||||||
|
|
||||||
async def fast_api_handler(self, request: Request) -> Response:
|
async def fast_api_handler(self, request: Request) -> Response:
|
||||||
try:
|
try:
|
||||||
@ -63,5 +121,6 @@ class VLMS(Blackbox):
|
|||||||
if model_name is None or model_name.isspace():
|
if model_name is None or model_name.isspace():
|
||||||
model_name = "Qwen-VL-Chat"
|
model_name = "Qwen-VL-Chat"
|
||||||
|
|
||||||
jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
|
# jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
|
||||||
return JSONResponse(content={"response": jsonresp}, status_code=status.HTTP_200_OK)
|
|
||||||
|
return JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}, status_code=status.HTTP_200_OK)
|
||||||
@ -82,6 +82,23 @@ class CosyVoiceConf():
|
|||||||
self.language = config.get("cosyvoicetts.language")
|
self.language = config.get("cosyvoicetts.language")
|
||||||
self.speaker = config.get("cosyvoicetts.speaker")
|
self.speaker = config.get("cosyvoicetts.speaker")
|
||||||
|
|
||||||
|
class SenseVoiceConf():
|
||||||
|
mode: str
|
||||||
|
url: str
|
||||||
|
speed: int
|
||||||
|
device: str
|
||||||
|
language: str
|
||||||
|
speaker: str
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, config: Configuration) -> None:
|
||||||
|
self.mode = config.get("sensevoiceasr.mode")
|
||||||
|
self.url = config.get("sensevoiceasr.url")
|
||||||
|
self.speed = config.get("sensevoiceasr.speed")
|
||||||
|
self.device = config.get("sensevoiceasr.device")
|
||||||
|
self.language = config.get("sensevoiceasr.language")
|
||||||
|
self.speaker = config.get("sensevoiceasr.speaker")
|
||||||
|
|
||||||
# 'CRITICAL': CRITICAL,
|
# 'CRITICAL': CRITICAL,
|
||||||
# 'FATAL': FATAL,
|
# 'FATAL': FATAL,
|
||||||
# 'ERROR': ERROR,
|
# 'ERROR': ERROR,
|
||||||
@ -110,6 +127,7 @@ class LogConf():
|
|||||||
self.filename = config.get("log.filename")
|
self.filename = config.get("log.filename")
|
||||||
self.time_format = config.get("log.time_format", default=DEFAULT_TIME_FORMAT)
|
self.time_format = config.get("log.time_format", default=DEFAULT_TIME_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class EnvConf():
|
class EnvConf():
|
||||||
version: str
|
version: str
|
||||||
@ -129,3 +147,10 @@ class BlackboxConf():
|
|||||||
@inject
|
@inject
|
||||||
def __init__(self, config: Configuration) -> None:
|
def __init__(self, config: Configuration) -> None:
|
||||||
self.lazyloading = bool(config.get("blackbox.lazyloading", default=False))
|
self.lazyloading = bool(config.get("blackbox.lazyloading", default=False))
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class VLMConf():
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, config: Configuration) -> None:
|
||||||
|
self.url = config.get("vlms.url")
|
||||||
|
|||||||
@ -21,6 +21,60 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
dirbaspath = __file__.split("\\")[1:-1]
|
dirbaspath = __file__.split("\\")[1:-1]
|
||||||
dirbaspath= "/home/gpu/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath)
|
dirbaspath= "/home/gpu/Workspace/jarvis-models/src/tts" + "/".join(dirbaspath)
|
||||||
config = {
|
config = {
|
||||||
|
'ayaka': {
|
||||||
|
'cfg': dirbaspath + '/models/ayaka.json',
|
||||||
|
'model': dirbaspath + '/models/ayaka_167k.pth',
|
||||||
|
'char': 'character_ayaka',
|
||||||
|
'speed': 1
|
||||||
|
},
|
||||||
|
'catmix': {
|
||||||
|
'cfg': dirbaspath + '/models/catmix.json',
|
||||||
|
'model': dirbaspath + '/models/catmix_107k.pth',
|
||||||
|
'char': 'character_catmix',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'noelle': {
|
||||||
|
'cfg': dirbaspath + '/models/noelle.json',
|
||||||
|
'model': dirbaspath + '/models/noelle_337k.pth',
|
||||||
|
'char': 'character_noelle',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'miko': {
|
||||||
|
'cfg': dirbaspath + '/models/miko.json',
|
||||||
|
'model': dirbaspath + '/models/miko_139k.pth',
|
||||||
|
'char': 'character_miko',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'nahida': {
|
||||||
|
'cfg': dirbaspath + '/models/nahida.json',
|
||||||
|
'model': dirbaspath + '/models/nahida_129k.pth',
|
||||||
|
'char': 'character_nahida',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'ningguang': {
|
||||||
|
'cfg': dirbaspath + '/models/ningguang.json',
|
||||||
|
'model': dirbaspath + '/models/ningguang_179k.pth',
|
||||||
|
'char': 'character_ningguang',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'yoimiya': {
|
||||||
|
'cfg': dirbaspath + '/models/yoimiya.json',
|
||||||
|
'model': dirbaspath + '/models/yoimiya_102k.pth',
|
||||||
|
'char': 'character_yoimiya',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'yunfeineo': {
|
||||||
|
'cfg': dirbaspath + '/models/yunfeineo.json',
|
||||||
|
'model': dirbaspath + '/models/yunfeineo_25k.pth',
|
||||||
|
'char': 'character_yunfeineo',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
|
'zhongli': {
|
||||||
|
'cfg': dirbaspath + '/models/zhongli.json',
|
||||||
|
'model': dirbaspath + '/models/zhongli_44k.pth',
|
||||||
|
'char': 'character_',
|
||||||
|
'speed': 1.1
|
||||||
|
},
|
||||||
'paimon': {
|
'paimon': {
|
||||||
'cfg': dirbaspath + '/models/paimon6k.json',
|
'cfg': dirbaspath + '/models/paimon6k.json',
|
||||||
'model': dirbaspath + '/models/paimon6k_390k.pth',
|
'model': dirbaspath + '/models/paimon6k_390k.pth',
|
||||||
@ -28,7 +82,7 @@ config = {
|
|||||||
'speed': 1
|
'speed': 1
|
||||||
},
|
},
|
||||||
'yunfei': {
|
'yunfei': {
|
||||||
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
|
'cfg': dirbaspath + '/models/yunfeimix2.json',
|
||||||
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
|
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
|
||||||
'char': 'character_yunfei',
|
'char': 'character_yunfei',
|
||||||
'speed': 1.1
|
'speed': 1.1
|
||||||
|
|||||||
BIN
test_data/voice/2food.wav
Normal file
BIN
test_data/voice/2food.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2forget.wav
Normal file
BIN
test_data/voice/2forget.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2nihao.wav
Normal file
BIN
test_data/voice/2nihao.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2play.wav
Normal file
BIN
test_data/voice/2play.wav
Normal file
Binary file not shown.
BIN
test_data/voice/2weather.wav
Normal file
BIN
test_data/voice/2weather.wav
Normal file
Binary file not shown.
BIN
test_data/voice/food.wav
Normal file
BIN
test_data/voice/food.wav
Normal file
Binary file not shown.
BIN
test_data/voice/forget.wav
Normal file
BIN
test_data/voice/forget.wav
Normal file
Binary file not shown.
BIN
test_data/voice/hello.wav
Normal file
BIN
test_data/voice/hello.wav
Normal file
Binary file not shown.
BIN
test_data/voice/nihao.wav
Normal file
BIN
test_data/voice/nihao.wav
Normal file
Binary file not shown.
BIN
test_data/voice/play.wav
Normal file
BIN
test_data/voice/play.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user