Merge pull request #1 from BoardWare-Genius/refactor

Refactor
This commit is contained in:
Dan218
2024-03-27 16:44:49 +08:00
committed by GitHub
91 changed files with 61 additions and 34 deletions

View File

@ -4,6 +4,7 @@ from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from src.blackbox.blackbox_factory import BlackboxFactory
import uvicorn
app = FastAPI()
@ -14,7 +15,7 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
try:
box = blackbox_factory.create_blackbox(blackbox_name, {})
box = blackbox_factory.create_blackbox(blackbox_name)
except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request)
@ -22,3 +23,6 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
@app.post("/workflows")
async def workflows(reqest: Request):
print("workflows")
if __name__ == "__main__":
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")

View File

@ -13,7 +13,9 @@ class ASR(Blackbox):
def __init__(self, *args, **kwargs) -> None:
config = read_yaml(args[0])
self.paraformer = RapidParaformer(config)
super().__init__(config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
async def processing(self, *args, **kwargs):
data = args[0]

View File

@ -12,7 +12,7 @@ class BlackboxFactory:
def __init__(self) -> None:
self.tts = TTS()
self.asr = ASR("./.env.yaml")
self.asr = ASR(".env.yaml")
self.sentiment = Sentiment()
self.sum = SUM()
self.calculator = Calculator()

View File

@ -3,14 +3,14 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from sentiment_engine.sentiment_engine import SentimentEngine
from ..sentiment_engine.sentiment_engine import SentimentEngine
from .blackbox import Blackbox
class Sentiment(Blackbox):
def __init__(self) -> None:
self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx')
self.engine = SentimentEngine()
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)

View File

@ -23,16 +23,16 @@ class Tesou(Blackbox):
"user_id": id,
"prompt": prompt,
}
print(message)
response = requests.post(url, json=message)
return response
return response.json()
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_id = data.get("id")
user_id = data.get("user_id")
user_prompt = data.get("prompt")
if user_prompt is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -1,20 +1,15 @@
import io
from ntpath import join
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
from tts.tts_service import TTService
from ..tts.tts_service import TTService
class TTS(Blackbox):
def __init__(self, *args, **kwargs) -> None:
config = {
'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1],
'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1],
'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2]
}
self.tts_service = TTService(*config['catmaid'])
super().__init__(config)
self.tts_service = TTService("catmaid")
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)

View File

@ -4,12 +4,17 @@ import onnxruntime
from transformers import BertTokenizer
import numpy as np
dirabspath = __file__.split("\\")[1:-1]
dirabspath= "C://" + "/".join(dirabspath)
default_path = dirabspath + "/models/paimon_sentiment.onnx"
class SentimentEngine():
def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"):
def __init__(self):
logging.info('Initializing Sentiment Engine...')
onnx_model_path = model_path
onnx_model_path = default_path
self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

View File

@ -1,40 +1,61 @@
import io
import sys
import time
sys.path.append('tts/vits')
sys.path.append('src/tts/vits')
import numpy as np
import soundfile
import os
os.environ["PYTORCH_JIT"] = "0"
import torch
import tts.vits.commons as commons
import tts.vits.utils as utils
import src.tts.vits.commons as commons
import src.tts.vits.utils as utils
from tts.vits.models import SynthesizerTrn
from tts.vits.text.symbols import symbols
from tts.vits.text import text_to_sequence
from src.tts.vits.models import SynthesizerTrn
from src.tts.vits.text.symbols import symbols
from src.tts.vits.text import text_to_sequence
import logging
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "C://" + "/".join(dirbaspath)
config = {
'paimon': {
'cfg': dirbaspath + '/models/paimon6k.json',
'model': dirbaspath + '/models/paimon6k_390k.pth',
'char': 'character_paimon',
'speed': 1
},
'yunfei': {
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
'char': 'character_yunfei',
'speed': 1.1
},
'catmaid': {
'cfg': dirbaspath + '/models/catmix.json',
'model': dirbaspath + '/models/catmix_107k.pth',
'char': 'character_catmaid',
'speed': 1.2
},
}
class TTService():
def __init__(self, cfg, model, char, speed):
logging.info('Initializing TTS Service for %s...' % char)
self.hps = utils.get_hparams_from_file(cfg)
self.speed = speed
def __init__(self, model_name="catmaid"):
cfg = config[model_name]
logging.info('Initializing TTS Service for %s...' % cfg["char"])
self.hps = utils.get_hparams_from_file(cfg["cfg"])
self.speed = cfg["speed"]
self.net_g = SynthesizerTrn(
len(symbols),
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu()
_ = self.net_g.eval()
_ = utils.load_checkpoint(model, self.net_g, None)
_ = utils.load_checkpoint(cfg["model"], self.net_g, None)
def get_text(self, text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)

View File

Before

Width:  |  Height:  |  Size: 63 KiB

After

Width:  |  Height:  |  Size: 63 KiB

View File

Before

Width:  |  Height:  |  Size: 35 KiB

After

Width:  |  Height:  |  Size: 35 KiB

View File

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB