mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
8
main.py
8
main.py
@ -4,6 +4,7 @@ from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.blackbox.blackbox_factory import BlackboxFactory
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -14,11 +15,14 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
|
||||
if not blackbox_name:
|
||||
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
try:
|
||||
box = blackbox_factory.create_blackbox(blackbox_name, {})
|
||||
box = blackbox_factory.create_blackbox(blackbox_name)
|
||||
except ValueError:
|
||||
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return await box.fast_api_handler(request)
|
||||
|
||||
@app.post("/workflows")
|
||||
async def workflows(reqest: Request):
|
||||
print("workflows")
|
||||
print("workflows")
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")
|
||||
|
||||
@ -13,7 +13,9 @@ class ASR(Blackbox):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
config = read_yaml(args[0])
|
||||
self.paraformer = RapidParaformer(config)
|
||||
super().__init__(config)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
async def processing(self, *args, **kwargs):
|
||||
data = args[0]
|
||||
|
||||
@ -12,7 +12,7 @@ class BlackboxFactory:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tts = TTS()
|
||||
self.asr = ASR("./.env.yaml")
|
||||
self.asr = ASR(".env.yaml")
|
||||
self.sentiment = Sentiment()
|
||||
self.sum = SUM()
|
||||
self.calculator = Calculator()
|
||||
|
||||
@ -3,14 +3,14 @@ from typing import Any, Coroutine
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from sentiment_engine.sentiment_engine import SentimentEngine
|
||||
from ..sentiment_engine.sentiment_engine import SentimentEngine
|
||||
from .blackbox import Blackbox
|
||||
|
||||
|
||||
class Sentiment(Blackbox):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx')
|
||||
self.engine = SentimentEngine()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
@ -23,16 +23,16 @@ class Tesou(Blackbox):
|
||||
"user_id": id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
print(message)
|
||||
response = requests.post(url, json=message)
|
||||
return response
|
||||
return response.json()
|
||||
|
||||
async def fast_api_handler(self, request: Request) -> Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
user_id = data.get("id")
|
||||
user_id = data.get("user_id")
|
||||
user_prompt = data.get("prompt")
|
||||
if user_prompt is None:
|
||||
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@ -1,21 +1,16 @@
|
||||
import io
|
||||
from ntpath import join
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from .blackbox import Blackbox
|
||||
from tts.tts_service import TTService
|
||||
from ..tts.tts_service import TTService
|
||||
|
||||
class TTS(Blackbox):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
config = {
|
||||
'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1],
|
||||
'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1],
|
||||
'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2]
|
||||
}
|
||||
self.tts_service = TTService(*config['catmaid'])
|
||||
super().__init__(config)
|
||||
|
||||
self.tts_service = TTService("catmaid")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.processing(*args, **kwargs)
|
||||
|
||||
|
||||
@ -4,12 +4,17 @@ import onnxruntime
|
||||
from transformers import BertTokenizer
|
||||
import numpy as np
|
||||
|
||||
dirabspath = __file__.split("\\")[1:-1]
|
||||
dirabspath= "C://" + "/".join(dirabspath)
|
||||
default_path = dirabspath + "/models/paimon_sentiment.onnx"
|
||||
|
||||
|
||||
class SentimentEngine():
|
||||
|
||||
def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"):
|
||||
def __init__(self):
|
||||
|
||||
logging.info('Initializing Sentiment Engine...')
|
||||
onnx_model_path = model_path
|
||||
onnx_model_path = default_path
|
||||
self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
||||
|
||||
@ -1,40 +1,61 @@
|
||||
import io
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.append('tts/vits')
|
||||
sys.path.append('src/tts/vits')
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import os
|
||||
os.environ["PYTORCH_JIT"] = "0"
|
||||
import torch
|
||||
|
||||
import tts.vits.commons as commons
|
||||
import tts.vits.utils as utils
|
||||
import src.tts.vits.commons as commons
|
||||
import src.tts.vits.utils as utils
|
||||
|
||||
from tts.vits.models import SynthesizerTrn
|
||||
from tts.vits.text.symbols import symbols
|
||||
from tts.vits.text import text_to_sequence
|
||||
from src.tts.vits.models import SynthesizerTrn
|
||||
from src.tts.vits.text.symbols import symbols
|
||||
from src.tts.vits.text import text_to_sequence
|
||||
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dirbaspath = __file__.split("\\")[1:-1]
|
||||
dirbaspath= "C://" + "/".join(dirbaspath)
|
||||
config = {
|
||||
'paimon': {
|
||||
'cfg': dirbaspath + '/models/paimon6k.json',
|
||||
'model': dirbaspath + '/models/paimon6k_390k.pth',
|
||||
'char': 'character_paimon',
|
||||
'speed': 1
|
||||
},
|
||||
'yunfei': {
|
||||
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
|
||||
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
|
||||
'char': 'character_yunfei',
|
||||
'speed': 1.1
|
||||
},
|
||||
'catmaid': {
|
||||
'cfg': dirbaspath + '/models/catmix.json',
|
||||
'model': dirbaspath + '/models/catmix_107k.pth',
|
||||
'char': 'character_catmaid',
|
||||
'speed': 1.2
|
||||
},
|
||||
}
|
||||
|
||||
class TTService():
|
||||
|
||||
def __init__(self, cfg, model, char, speed):
|
||||
logging.info('Initializing TTS Service for %s...' % char)
|
||||
self.hps = utils.get_hparams_from_file(cfg)
|
||||
self.speed = speed
|
||||
def __init__(self, model_name="catmaid"):
|
||||
cfg = config[model_name]
|
||||
logging.info('Initializing TTS Service for %s...' % cfg["char"])
|
||||
self.hps = utils.get_hparams_from_file(cfg["cfg"])
|
||||
self.speed = cfg["speed"]
|
||||
self.net_g = SynthesizerTrn(
|
||||
len(symbols),
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
**self.hps.model).cpu()
|
||||
_ = self.net_g.eval()
|
||||
_ = utils.load_checkpoint(model, self.net_g, None)
|
||||
_ = utils.load_checkpoint(cfg["model"], self.net_g, None)
|
||||
|
||||
def get_text(self, text, hps):
|
||||
text_norm = text_to_sequence(text, hps.data.text_cleaners)
|
||||
|
Before Width: | Height: | Size: 63 KiB After Width: | Height: | Size: 63 KiB |
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
Reference in New Issue
Block a user