Merge pull request #1 from BoardWare-Genius/refactor

Refactor
This commit is contained in:
Dan218
2024-03-27 16:44:49 +08:00
committed by GitHub
91 changed files with 61 additions and 34 deletions

View File

@ -13,4 +13,4 @@
Dev rh Dev rh
```bash ```bash
uvicorn main:app --reload uvicorn main:app --reload
``` ```

View File

@ -4,6 +4,7 @@ from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.blackbox.blackbox_factory import BlackboxFactory from src.blackbox.blackbox_factory import BlackboxFactory
import uvicorn
app = FastAPI() app = FastAPI()
@ -14,11 +15,14 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
if not blackbox_name: if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
try: try:
box = blackbox_factory.create_blackbox(blackbox_name, {}) box = blackbox_factory.create_blackbox(blackbox_name)
except ValueError: except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request) return await box.fast_api_handler(request)
@app.post("/workflows") @app.post("/workflows")
async def workflows(reqest: Request): async def workflows(reqest: Request):
print("workflows") print("workflows")
if __name__ == "__main__":
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")

View File

@ -13,7 +13,9 @@ class ASR(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
config = read_yaml(args[0]) config = read_yaml(args[0])
self.paraformer = RapidParaformer(config) self.paraformer = RapidParaformer(config)
super().__init__(config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
async def processing(self, *args, **kwargs): async def processing(self, *args, **kwargs):
data = args[0] data = args[0]

View File

@ -12,7 +12,7 @@ class BlackboxFactory:
def __init__(self) -> None: def __init__(self) -> None:
self.tts = TTS() self.tts = TTS()
self.asr = ASR("./.env.yaml") self.asr = ASR(".env.yaml")
self.sentiment = Sentiment() self.sentiment = Sentiment()
self.sum = SUM() self.sum = SUM()
self.calculator = Calculator() self.calculator = Calculator()

View File

@ -3,14 +3,14 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sentiment_engine.sentiment_engine import SentimentEngine from ..sentiment_engine.sentiment_engine import SentimentEngine
from .blackbox import Blackbox from .blackbox import Blackbox
class Sentiment(Blackbox): class Sentiment(Blackbox):
def __init__(self) -> None: def __init__(self) -> None:
self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx') self.engine = SentimentEngine()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -23,16 +23,16 @@ class Tesou(Blackbox):
"user_id": id, "user_id": id,
"prompt": prompt, "prompt": prompt,
} }
print(message)
response = requests.post(url, json=message) response = requests.post(url, json=message)
return response return response.json()
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
try: try:
data = await request.json() data = await request.json()
except: except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_id = data.get("id") user_id = data.get("user_id")
user_prompt = data.get("prompt") user_prompt = data.get("prompt")
if user_prompt is None: if user_prompt is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -1,21 +1,16 @@
import io import io
from ntpath import join
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from .blackbox import Blackbox from .blackbox import Blackbox
from tts.tts_service import TTService from ..tts.tts_service import TTService
class TTS(Blackbox): class TTS(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
config = { self.tts_service = TTService("catmaid")
'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1],
'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1],
'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2]
}
self.tts_service = TTService(*config['catmaid'])
super().__init__(config)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -4,12 +4,17 @@ import onnxruntime
from transformers import BertTokenizer from transformers import BertTokenizer
import numpy as np import numpy as np
dirabspath = __file__.split("\\")[1:-1]
dirabspath= "C://" + "/".join(dirabspath)
default_path = dirabspath + "/models/paimon_sentiment.onnx"
class SentimentEngine(): class SentimentEngine():
def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"): def __init__(self):
logging.info('Initializing Sentiment Engine...') logging.info('Initializing Sentiment Engine...')
onnx_model_path = model_path onnx_model_path = default_path
self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

View File

@ -1,40 +1,61 @@
import io import io
import sys import sys
import time
sys.path.append('tts/vits') sys.path.append('src/tts/vits')
import numpy as np
import soundfile import soundfile
import os import os
os.environ["PYTORCH_JIT"] = "0" os.environ["PYTORCH_JIT"] = "0"
import torch import torch
import tts.vits.commons as commons import src.tts.vits.commons as commons
import tts.vits.utils as utils import src.tts.vits.utils as utils
from tts.vits.models import SynthesizerTrn from src.tts.vits.models import SynthesizerTrn
from tts.vits.text.symbols import symbols from src.tts.vits.text.symbols import symbols
from tts.vits.text import text_to_sequence from src.tts.vits.text import text_to_sequence
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "C://" + "/".join(dirbaspath)
config = {
'paimon': {
'cfg': dirbaspath + '/models/paimon6k.json',
'model': dirbaspath + '/models/paimon6k_390k.pth',
'char': 'character_paimon',
'speed': 1
},
'yunfei': {
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
'char': 'character_yunfei',
'speed': 1.1
},
'catmaid': {
'cfg': dirbaspath + '/models/catmix.json',
'model': dirbaspath + '/models/catmix_107k.pth',
'char': 'character_catmaid',
'speed': 1.2
},
}
class TTService(): class TTService():
def __init__(self, cfg, model, char, speed): def __init__(self, model_name="catmaid"):
logging.info('Initializing TTS Service for %s...' % char) cfg = config[model_name]
self.hps = utils.get_hparams_from_file(cfg) logging.info('Initializing TTS Service for %s...' % cfg["char"])
self.speed = speed self.hps = utils.get_hparams_from_file(cfg["cfg"])
self.speed = cfg["speed"]
self.net_g = SynthesizerTrn( self.net_g = SynthesizerTrn(
len(symbols), len(symbols),
self.hps.data.filter_length // 2 + 1, self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length, self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu() **self.hps.model).cpu()
_ = self.net_g.eval() _ = self.net_g.eval()
_ = utils.load_checkpoint(model, self.net_g, None) _ = utils.load_checkpoint(cfg["model"], self.net_g, None)
def get_text(self, text, hps): def get_text(self, text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners) text_norm = text_to_sequence(text, hps.data.text_cleaners)

View File

Before

Width:  |  Height:  |  Size: 63 KiB

After

Width:  |  Height:  |  Size: 63 KiB

View File

Before

Width:  |  Height:  |  Size: 35 KiB

After

Width:  |  Height:  |  Size: 35 KiB

View File

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB