From 229501515aa837c30168865ec9be14f112c5104f Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 27 Mar 2024 16:10:46 +0800 Subject: [PATCH] ftemp --- README.md | 2 +- main.py | 7 +++++-- src/asr/asr.py | 4 +++- src/blackbox/blackbox_factory.py | 5 ++--- src/blackbox/tts.py | 8 ++------ tts/tts_service.py | 35 +++++++++++++++++++++++++------- 6 files changed, 41 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 706cd9b..a798594 100644 --- a/README.md +++ b/README.md @@ -13,4 +13,4 @@ Dev rh ```bash uvicorn main:app --reload -``` \ No newline at end of file +``` diff --git a/main.py b/main.py index f9ee155..e8e2764 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse from src.blackbox.blackbox_factory import BlackboxFactory +import uvicorn app = FastAPI() @@ -11,7 +12,6 @@ blackbox_factory = BlackboxFactory() @app.post("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): - print(blackbox_name) if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) try: @@ -22,4 +22,7 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No @app.post("/workflows") async def workflows(reqest: Request): - print("workflows") \ No newline at end of file + print("workflows") + +if __name__ == "__main__": + uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") diff --git a/src/asr/asr.py b/src/asr/asr.py index f75db00..fa879f8 100644 --- a/src/asr/asr.py +++ b/src/asr/asr.py @@ -13,7 +13,9 @@ class ASR(Blackbox): def __init__(self, *args, **kwargs) -> None: config = read_yaml(args[0]) self.paraformer = RapidParaformer(config) - super().__init__(config) + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) async def processing(self, *args, **kwargs): data = args[0] diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 60d7db6..ff0b4ea 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -12,8 +12,8 @@ class BlackboxFactory: def __init__(self) -> None: self.tts = TTS() - #self.asr = ASR("./.env.yaml") - #self.sentiment = Sentiment() + self.asr = ASR(".env.yaml") + self.sentiment = Sentiment() #self.sum = SUM() #self.calculator = Calculator() #self.audio_to_text = AudioToText() @@ -24,7 +24,6 @@ class BlackboxFactory: return self.processing(*args, **kwargs) def create_blackbox(self, blackbox_name: str) -> Blackbox: - return self.tts if blackbox_name == "audio_to_text": return self.audio_to_text if blackbox_name == "text_to_audio": diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 6a11657..eada390 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -1,4 +1,5 @@ import io +from ntpath import join from fastapi import Request, Response, status from fastapi.responses import JSONResponse @@ -8,12 +9,7 @@ from tts.tts_service import TTService class TTS(Blackbox): def __init__(self, *args, **kwargs) -> None: - config = { - 'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1], - 'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1], - 'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2] - } - self.tts_service = TTService(*config['catmaid']) + self.tts_service = TTService("catmaid") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/tts/tts_service.py b/tts/tts_service.py index 0011315..45eff7f 100644 --- a/tts/tts_service.py +++ b/tts/tts_service.py @@ -1,10 +1,8 @@ import io import sys -import time sys.path.append('tts/vits') -import numpy as np import soundfile import os os.environ["PYTORCH_JIT"] = "0" @@ -21,20 +19,43 @@ import logging logging.getLogger().setLevel(logging.INFO) logging.basicConfig(level=logging.INFO) +dirbaspath = __file__.split("\\")[1:-1] +dirbaspath= "C://" + "/".join(dirbaspath) +config = { + 'paimon': { + 'cfg': dirbaspath + '/models/paimon6k.json', + 'model': dirbaspath + '/models/paimon6k_390k.pth', + 'char': 'character_paimon', + 'speed': 1 + }, + 'yunfei': { + 'cfg': dirbaspath + '/tts/models/yunfeimix2.json', + 'model': dirbaspath + '/models/yunfeimix2_53k.pth', + 'char': 'character_yunfei', + 'speed': 1.1 + }, + 'catmaid': { + 'cfg': dirbaspath + '/models/catmix.json', + 'model': dirbaspath + '/models/catmix_107k.pth', + 'char': 'character_catmaid', + 'speed': 1.2 + }, +} class TTService(): - def __init__(self, cfg, model, char, speed): - logging.info('Initializing TTS Service for %s...' % char) - self.hps = utils.get_hparams_from_file(cfg) - self.speed = speed + def __init__(self, model_name="catmaid"): + cfg = config[model_name] + logging.info('Initializing TTS Service for %s...' % cfg["char"]) + self.hps = utils.get_hparams_from_file(cfg["cfg"]) + self.speed = cfg["speed"] self.net_g = SynthesizerTrn( len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, **self.hps.model).cpu() _ = self.net_g.eval() - _ = utils.load_checkpoint(model, self.net_g, None) + _ = utils.load_checkpoint(cfg["model"], self.net_g, None) def get_text(self, text, hps): text_norm = text_to_sequence(text, hps.data.text_cleaners)