From 0549a033e1518860c8b1d2ef0b8986cba530c0c6 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 27 Mar 2024 15:06:30 +0800 Subject: [PATCH 01/11] temp --- main.py | 3 ++- src/blackbox/blackbox_factory.py | 15 ++++++++------- src/blackbox/tts.py | 3 +-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 01c54cc..f9ee155 100644 --- a/main.py +++ b/main.py @@ -11,10 +11,11 @@ blackbox_factory = BlackboxFactory() @app.post("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): + print(blackbox_name) if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) try: - box = blackbox_factory.create_blackbox(blackbox_name, {}) + box = blackbox_factory.create_blackbox(blackbox_name) except ValueError: return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await box.fast_api_handler(request) diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index f4a6f36..60d7db6 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -12,18 +12,19 @@ class BlackboxFactory: def __init__(self) -> None: self.tts = TTS() - self.asr = ASR("./.env.yaml") - self.sentiment = Sentiment() - self.sum = SUM() - self.calculator = Calculator() - self.audio_to_text = AudioToText() - self.text_to_audio = TextToAudio() - self.tesou = Tesou() + #self.asr = ASR("./.env.yaml") + #self.sentiment = Sentiment() + #self.sum = SUM() + #self.calculator = Calculator() + #self.audio_to_text = AudioToText() + #self.text_to_audio = TextToAudio() + #self.tesou = Tesou() def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) def create_blackbox(self, blackbox_name: str) -> Blackbox: + return self.tts if blackbox_name == "audio_to_text": return self.audio_to_text if blackbox_name == "text_to_audio": diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index f030692..6a11657 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -14,8 +14,7 @@ class TTS(Blackbox): 'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2] } self.tts_service = TTService(*config['catmaid']) - super().__init__(config) - + def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) From 229501515aa837c30168865ec9be14f112c5104f Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 27 Mar 2024 16:10:46 +0800 Subject: [PATCH 02/11] ftemp --- README.md | 2 +- main.py | 7 +++++-- src/asr/asr.py | 4 +++- src/blackbox/blackbox_factory.py | 5 ++--- src/blackbox/tts.py | 8 ++------ tts/tts_service.py | 35 +++++++++++++++++++++++++------- 6 files changed, 41 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 706cd9b..a798594 100644 --- a/README.md +++ b/README.md @@ -13,4 +13,4 @@ Dev rh ```bash uvicorn main:app --reload -``` \ No newline at end of file +``` diff --git a/main.py b/main.py index f9ee155..e8e2764 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse from src.blackbox.blackbox_factory import BlackboxFactory +import uvicorn app = FastAPI() @@ -11,7 +12,6 @@ blackbox_factory = BlackboxFactory() @app.post("/") async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): - print(blackbox_name) if not blackbox_name: return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) try: @@ -22,4 +22,7 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No @app.post("/workflows") async def workflows(reqest: Request): - print("workflows") \ No newline at end of file + print("workflows") + +if __name__ == "__main__": + uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") diff --git a/src/asr/asr.py b/src/asr/asr.py index f75db00..fa879f8 100644 --- a/src/asr/asr.py +++ b/src/asr/asr.py @@ -13,7 +13,9 @@ class ASR(Blackbox): def __init__(self, *args, **kwargs) -> None: config = read_yaml(args[0]) self.paraformer = RapidParaformer(config) - super().__init__(config) + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) async def processing(self, *args, **kwargs): data = args[0] diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 60d7db6..ff0b4ea 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -12,8 +12,8 @@ class BlackboxFactory: def __init__(self) -> None: self.tts = TTS() - #self.asr = ASR("./.env.yaml") - #self.sentiment = Sentiment() + self.asr = ASR(".env.yaml") + self.sentiment = Sentiment() #self.sum = SUM() #self.calculator = Calculator() #self.audio_to_text = AudioToText() @@ -24,7 +24,6 @@ class BlackboxFactory: return self.processing(*args, **kwargs) def create_blackbox(self, blackbox_name: str) -> Blackbox: - return self.tts if blackbox_name == "audio_to_text": return self.audio_to_text if blackbox_name == "text_to_audio": diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index 6a11657..eada390 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -1,4 +1,5 @@ import io +from ntpath import join from fastapi import Request, Response, status from fastapi.responses import JSONResponse @@ -8,12 +9,7 @@ from tts.tts_service import TTService class TTS(Blackbox): def __init__(self, *args, **kwargs) -> None: - config = { - 'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1], - 'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1], - 'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2] - } - self.tts_service = TTService(*config['catmaid']) + self.tts_service = TTService("catmaid") def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/tts/tts_service.py b/tts/tts_service.py index 0011315..45eff7f 100644 --- a/tts/tts_service.py +++ b/tts/tts_service.py @@ -1,10 +1,8 @@ import io import sys -import time sys.path.append('tts/vits') -import numpy as np import soundfile import os os.environ["PYTORCH_JIT"] = "0" @@ -21,20 +19,43 @@ import logging logging.getLogger().setLevel(logging.INFO) logging.basicConfig(level=logging.INFO) +dirbaspath = __file__.split("\\")[1:-1] +dirbaspath= "C://" + "/".join(dirbaspath) +config = { + 'paimon': { + 'cfg': dirbaspath + '/models/paimon6k.json', + 'model': dirbaspath + '/models/paimon6k_390k.pth', + 'char': 'character_paimon', + 'speed': 1 + }, + 'yunfei': { + 'cfg': dirbaspath + '/tts/models/yunfeimix2.json', + 'model': dirbaspath + '/models/yunfeimix2_53k.pth', + 'char': 'character_yunfei', + 'speed': 1.1 + }, + 'catmaid': { + 'cfg': dirbaspath + '/models/catmix.json', + 'model': dirbaspath + '/models/catmix_107k.pth', + 'char': 'character_catmaid', + 'speed': 1.2 + }, +} class TTService(): - def __init__(self, cfg, model, char, speed): - logging.info('Initializing TTS Service for %s...' % char) - self.hps = utils.get_hparams_from_file(cfg) - self.speed = speed + def __init__(self, model_name="catmaid"): + cfg = config[model_name] + logging.info('Initializing TTS Service for %s...' % cfg["char"]) + self.hps = utils.get_hparams_from_file(cfg["cfg"]) + self.speed = cfg["speed"] self.net_g = SynthesizerTrn( len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, **self.hps.model).cpu() _ = self.net_g.eval() - _ = utils.load_checkpoint(model, self.net_g, None) + _ = utils.load_checkpoint(cfg["model"], self.net_g, None) def get_text(self, text, hps): text_norm = text_to_sequence(text, hps.data.text_cleaners) From 2a0c0e047799a2caff459c58981cf24667f69559 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 27 Mar 2024 16:20:12 +0800 Subject: [PATCH 03/11] feat --- src/blackbox/sentiment.py | 4 ++-- src/blackbox/tts.py | 2 +- .../sentiment_engine}/sentiment_engine.py | 9 +++++++-- {tts => src/tts}/tts_service.py | 12 ++++++------ {tts => src/tts}/vits/.dockerignore | 0 {tts => src/tts}/vits/.gitignore | 0 {tts => src/tts}/vits/.vs/ProjectSettings.json | 0 {tts => src/tts}/vits/.vs/VSWorkspaceState.json | 0 {tts => src/tts}/vits/.vs/slnx.sqlite | Bin {tts => src/tts}/vits/.vs/vits/v17/.suo | Bin {tts => src/tts}/vits/Dockerfile | 0 {tts => src/tts}/vits/LICENSE | 0 .../tts}/vits/Libtorch C++ Infer/VITS-LibTorch.cpp | 0 .../tts}/vits/Libtorch C++ Infer/toLibTorch.ipynb | 0 {tts => src/tts}/vits/README.md | 0 {tts => src/tts}/vits/attentions.py | 0 {tts => src/tts}/vits/colab.ipynb | 0 {tts => src/tts}/vits/commons.py | 0 {tts => src/tts}/vits/configs/chinese_base.json | 0 {tts => src/tts}/vits/configs/cjke_base.json | 0 {tts => src/tts}/vits/configs/cjks_base.json | 0 {tts => src/tts}/vits/configs/japanese_base.json | 0 {tts => src/tts}/vits/configs/japanese_base2.json | 0 .../tts}/vits/configs/japanese_ss_base2.json | 0 {tts => src/tts}/vits/configs/korean_base.json | 0 {tts => src/tts}/vits/configs/sanskrit_base.json | 0 .../tts}/vits/configs/shanghainese_base.json | 0 .../tts}/vits/configs/zero_japanese_base2.json | 0 .../tts}/vits/configs/zh_ja_mixture_base.json | 0 {tts => src/tts}/vits/data_utils.py | 0 .../tts}/vits/filelists/cjke_train_filelist.txt | 0 .../vits/filelists/cjke_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/cjke_val_filelist.txt | 0 .../vits/filelists/cjke_val_filelist.txt.cleaned | 0 .../tts}/vits/filelists/cjks_train_filelist.txt | 0 .../vits/filelists/cjks_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/cjks_val_filelist.txt | 0 .../vits/filelists/cjks_val_filelist.txt.cleaned | 0 .../tts}/vits/filelists/fox_train_filelist.txt | 0 .../vits/filelists/fox_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/fox_val_filelist.txt | 0 .../vits/filelists/fox_val_filelist.txt.cleaned | 0 .../tts}/vits/filelists/mix_train_filelist.txt | 0 .../vits/filelists/mix_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/mix_val_filelist.txt | 0 .../vits/filelists/mix_val_filelist.txt.cleaned | 0 .../tts}/vits/filelists/sanskrit_train_filelist.txt | 0 .../filelists/sanskrit_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/sanskrit_val_filelist.txt | 0 .../filelists/sanskrit_val_filelist.txt.cleaned | 0 .../filelists/zaonhe_train_filelist.txt.cleaned | 0 .../vits/filelists/zaonhe_val_filelist.txt.cleaned | 0 .../tts}/vits/filelists/zero_train_filelist.txt | 0 .../vits/filelists/zero_train_filelist.txt.cleaned | 0 .../tts}/vits/filelists/zero_val_filelist.txt | 0 .../vits/filelists/zero_val_filelist.txt.cleaned | 0 {tts => src/tts}/vits/inference.ipynb | 0 {tts => src/tts}/vits/losses.py | 0 {tts => src/tts}/vits/mel_processing.py | 0 {tts => src/tts}/vits/models.py | 0 {tts => src/tts}/vits/modules.py | 0 {tts => src/tts}/vits/monotonic_align/__init__.py | 0 {tts => src/tts}/vits/monotonic_align/core.pyx | 0 {tts => src/tts}/vits/monotonic_align/setup.py | 0 {tts => src/tts}/vits/preprocess.py | 0 {tts => src/tts}/vits/requirements.txt | 0 {tts => src/tts}/vits/resources/fig_1a.png | Bin {tts => src/tts}/vits/resources/fig_1b.png | Bin {tts => src/tts}/vits/resources/training.png | Bin {tts => src/tts}/vits/text/LICENSE | 0 {tts => src/tts}/vits/text/__init__.py | 0 {tts => src/tts}/vits/text/cantonese.py | 0 {tts => src/tts}/vits/text/cleaners.py | 0 {tts => src/tts}/vits/text/english.py | 0 {tts => src/tts}/vits/text/japanese.py | 0 {tts => src/tts}/vits/text/korean.py | 0 {tts => src/tts}/vits/text/mandarin.py | 0 {tts => src/tts}/vits/text/ngu_dialect.py | 0 {tts => src/tts}/vits/text/sanskrit.py | 0 {tts => src/tts}/vits/text/shanghainese.py | 0 {tts => src/tts}/vits/text/symbols.py | 0 {tts => src/tts}/vits/text/thai.py | 0 {tts => src/tts}/vits/train.py | 0 {tts => src/tts}/vits/train_ms.py | 0 {tts => src/tts}/vits/transforms.py | 0 {tts => src/tts}/vits/utils.py | 0 86 files changed, 16 insertions(+), 11 deletions(-) rename {sentiment_engine => src/sentiment_engine}/sentiment_engine.py (84%) rename {tts => src/tts}/tts_service.py (90%) rename {tts => src/tts}/vits/.dockerignore (100%) rename {tts => src/tts}/vits/.gitignore (100%) rename {tts => src/tts}/vits/.vs/ProjectSettings.json (100%) rename {tts => src/tts}/vits/.vs/VSWorkspaceState.json (100%) rename {tts => src/tts}/vits/.vs/slnx.sqlite (100%) rename {tts => src/tts}/vits/.vs/vits/v17/.suo (100%) rename {tts => src/tts}/vits/Dockerfile (100%) rename {tts => src/tts}/vits/LICENSE (100%) rename {tts => src/tts}/vits/Libtorch C++ Infer/VITS-LibTorch.cpp (100%) rename {tts => src/tts}/vits/Libtorch C++ Infer/toLibTorch.ipynb (100%) rename {tts => src/tts}/vits/README.md (100%) rename {tts => src/tts}/vits/attentions.py (100%) rename {tts => src/tts}/vits/colab.ipynb (100%) rename {tts => src/tts}/vits/commons.py (100%) rename {tts => src/tts}/vits/configs/chinese_base.json (100%) rename {tts => src/tts}/vits/configs/cjke_base.json (100%) rename {tts => src/tts}/vits/configs/cjks_base.json (100%) rename {tts => src/tts}/vits/configs/japanese_base.json (100%) rename {tts => src/tts}/vits/configs/japanese_base2.json (100%) rename {tts => src/tts}/vits/configs/japanese_ss_base2.json (100%) rename {tts => src/tts}/vits/configs/korean_base.json (100%) rename {tts => src/tts}/vits/configs/sanskrit_base.json (100%) rename {tts => src/tts}/vits/configs/shanghainese_base.json (100%) rename {tts => src/tts}/vits/configs/zero_japanese_base2.json (100%) rename {tts => src/tts}/vits/configs/zh_ja_mixture_base.json (100%) rename {tts => src/tts}/vits/data_utils.py (100%) rename {tts => src/tts}/vits/filelists/cjke_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/cjke_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/cjke_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/cjke_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/cjks_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/cjks_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/cjks_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/cjks_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/fox_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/fox_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/fox_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/fox_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/mix_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/mix_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/mix_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/mix_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/sanskrit_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/sanskrit_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/sanskrit_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/sanskrit_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/zaonhe_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/zaonhe_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/zero_train_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/zero_train_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/filelists/zero_val_filelist.txt (100%) rename {tts => src/tts}/vits/filelists/zero_val_filelist.txt.cleaned (100%) rename {tts => src/tts}/vits/inference.ipynb (100%) rename {tts => src/tts}/vits/losses.py (100%) rename {tts => src/tts}/vits/mel_processing.py (100%) rename {tts => src/tts}/vits/models.py (100%) rename {tts => src/tts}/vits/modules.py (100%) rename {tts => src/tts}/vits/monotonic_align/__init__.py (100%) rename {tts => src/tts}/vits/monotonic_align/core.pyx (100%) rename {tts => src/tts}/vits/monotonic_align/setup.py (100%) rename {tts => src/tts}/vits/preprocess.py (100%) rename {tts => src/tts}/vits/requirements.txt (100%) rename {tts => src/tts}/vits/resources/fig_1a.png (100%) rename {tts => src/tts}/vits/resources/fig_1b.png (100%) rename {tts => src/tts}/vits/resources/training.png (100%) rename {tts => src/tts}/vits/text/LICENSE (100%) rename {tts => src/tts}/vits/text/__init__.py (100%) rename {tts => src/tts}/vits/text/cantonese.py (100%) rename {tts => src/tts}/vits/text/cleaners.py (100%) rename {tts => src/tts}/vits/text/english.py (100%) rename {tts => src/tts}/vits/text/japanese.py (100%) rename {tts => src/tts}/vits/text/korean.py (100%) rename {tts => src/tts}/vits/text/mandarin.py (100%) rename {tts => src/tts}/vits/text/ngu_dialect.py (100%) rename {tts => src/tts}/vits/text/sanskrit.py (100%) rename {tts => src/tts}/vits/text/shanghainese.py (100%) rename {tts => src/tts}/vits/text/symbols.py (100%) rename {tts => src/tts}/vits/text/thai.py (100%) rename {tts => src/tts}/vits/train.py (100%) rename {tts => src/tts}/vits/train_ms.py (100%) rename {tts => src/tts}/vits/transforms.py (100%) rename {tts => src/tts}/vits/utils.py (100%) diff --git a/src/blackbox/sentiment.py b/src/blackbox/sentiment.py index 0981204..d17af4c 100644 --- a/src/blackbox/sentiment.py +++ b/src/blackbox/sentiment.py @@ -3,14 +3,14 @@ from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse -from sentiment_engine.sentiment_engine import SentimentEngine +from ..sentiment_engine.sentiment_engine import SentimentEngine from .blackbox import Blackbox class Sentiment(Blackbox): def __init__(self) -> None: - self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx') + self.engine = SentimentEngine() def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index eada390..aea74d6 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -4,7 +4,7 @@ from ntpath import join from fastapi import Request, Response, status from fastapi.responses import JSONResponse from .blackbox import Blackbox -from tts.tts_service import TTService +from ..tts.tts_service import TTService class TTS(Blackbox): diff --git a/sentiment_engine/sentiment_engine.py b/src/sentiment_engine/sentiment_engine.py similarity index 84% rename from sentiment_engine/sentiment_engine.py rename to src/sentiment_engine/sentiment_engine.py index acb93d3..bc868fa 100644 --- a/sentiment_engine/sentiment_engine.py +++ b/src/sentiment_engine/sentiment_engine.py @@ -4,12 +4,17 @@ import onnxruntime from transformers import BertTokenizer import numpy as np +dirabspath = __file__.split("\\")[1:-1] +dirabspath= "C://" + "/".join(dirabspath) +default_path = dirabspath + "/models/paimon_sentiment.onnx" + class SentimentEngine(): - def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"): + def __init__(self): + logging.info('Initializing Sentiment Engine...') - onnx_model_path = model_path + onnx_model_path = default_path self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') diff --git a/tts/tts_service.py b/src/tts/tts_service.py similarity index 90% rename from tts/tts_service.py rename to src/tts/tts_service.py index 45eff7f..938df56 100644 --- a/tts/tts_service.py +++ b/src/tts/tts_service.py @@ -1,19 +1,19 @@ import io import sys -sys.path.append('tts/vits') +sys.path.append('src/tts/vits') import soundfile import os os.environ["PYTORCH_JIT"] = "0" import torch -import tts.vits.commons as commons -import tts.vits.utils as utils +import src.tts.vits.commons as commons +import src.tts.vits.utils as utils -from tts.vits.models import SynthesizerTrn -from tts.vits.text.symbols import symbols -from tts.vits.text import text_to_sequence +from src.tts.vits.models import SynthesizerTrn +from src.tts.vits.text.symbols import symbols +from src.tts.vits.text import text_to_sequence import logging logging.getLogger().setLevel(logging.INFO) diff --git a/tts/vits/.dockerignore b/src/tts/vits/.dockerignore similarity index 100% rename from tts/vits/.dockerignore rename to src/tts/vits/.dockerignore diff --git a/tts/vits/.gitignore b/src/tts/vits/.gitignore similarity index 100% rename from tts/vits/.gitignore rename to src/tts/vits/.gitignore diff --git a/tts/vits/.vs/ProjectSettings.json b/src/tts/vits/.vs/ProjectSettings.json similarity index 100% rename from tts/vits/.vs/ProjectSettings.json rename to src/tts/vits/.vs/ProjectSettings.json diff --git a/tts/vits/.vs/VSWorkspaceState.json b/src/tts/vits/.vs/VSWorkspaceState.json similarity index 100% rename from tts/vits/.vs/VSWorkspaceState.json rename to src/tts/vits/.vs/VSWorkspaceState.json diff --git a/tts/vits/.vs/slnx.sqlite b/src/tts/vits/.vs/slnx.sqlite similarity index 100% rename from tts/vits/.vs/slnx.sqlite rename to src/tts/vits/.vs/slnx.sqlite diff --git a/tts/vits/.vs/vits/v17/.suo b/src/tts/vits/.vs/vits/v17/.suo similarity index 100% rename from tts/vits/.vs/vits/v17/.suo rename to src/tts/vits/.vs/vits/v17/.suo diff --git a/tts/vits/Dockerfile b/src/tts/vits/Dockerfile similarity index 100% rename from tts/vits/Dockerfile rename to src/tts/vits/Dockerfile diff --git a/tts/vits/LICENSE b/src/tts/vits/LICENSE similarity index 100% rename from tts/vits/LICENSE rename to src/tts/vits/LICENSE diff --git a/tts/vits/Libtorch C++ Infer/VITS-LibTorch.cpp b/src/tts/vits/Libtorch C++ Infer/VITS-LibTorch.cpp similarity index 100% rename from tts/vits/Libtorch C++ Infer/VITS-LibTorch.cpp rename to src/tts/vits/Libtorch C++ Infer/VITS-LibTorch.cpp diff --git a/tts/vits/Libtorch C++ Infer/toLibTorch.ipynb b/src/tts/vits/Libtorch C++ Infer/toLibTorch.ipynb similarity index 100% rename from tts/vits/Libtorch C++ Infer/toLibTorch.ipynb rename to src/tts/vits/Libtorch C++ Infer/toLibTorch.ipynb diff --git a/tts/vits/README.md b/src/tts/vits/README.md similarity index 100% rename from tts/vits/README.md rename to src/tts/vits/README.md diff --git a/tts/vits/attentions.py b/src/tts/vits/attentions.py similarity index 100% rename from tts/vits/attentions.py rename to src/tts/vits/attentions.py diff --git a/tts/vits/colab.ipynb b/src/tts/vits/colab.ipynb similarity index 100% rename from tts/vits/colab.ipynb rename to src/tts/vits/colab.ipynb diff --git a/tts/vits/commons.py b/src/tts/vits/commons.py similarity index 100% rename from tts/vits/commons.py rename to src/tts/vits/commons.py diff --git a/tts/vits/configs/chinese_base.json b/src/tts/vits/configs/chinese_base.json similarity index 100% rename from tts/vits/configs/chinese_base.json rename to src/tts/vits/configs/chinese_base.json diff --git a/tts/vits/configs/cjke_base.json b/src/tts/vits/configs/cjke_base.json similarity index 100% rename from tts/vits/configs/cjke_base.json rename to src/tts/vits/configs/cjke_base.json diff --git a/tts/vits/configs/cjks_base.json b/src/tts/vits/configs/cjks_base.json similarity index 100% rename from tts/vits/configs/cjks_base.json rename to src/tts/vits/configs/cjks_base.json diff --git a/tts/vits/configs/japanese_base.json b/src/tts/vits/configs/japanese_base.json similarity index 100% rename from tts/vits/configs/japanese_base.json rename to src/tts/vits/configs/japanese_base.json diff --git a/tts/vits/configs/japanese_base2.json b/src/tts/vits/configs/japanese_base2.json similarity index 100% rename from tts/vits/configs/japanese_base2.json rename to src/tts/vits/configs/japanese_base2.json diff --git a/tts/vits/configs/japanese_ss_base2.json b/src/tts/vits/configs/japanese_ss_base2.json similarity index 100% rename from tts/vits/configs/japanese_ss_base2.json rename to src/tts/vits/configs/japanese_ss_base2.json diff --git a/tts/vits/configs/korean_base.json b/src/tts/vits/configs/korean_base.json similarity index 100% rename from tts/vits/configs/korean_base.json rename to src/tts/vits/configs/korean_base.json diff --git a/tts/vits/configs/sanskrit_base.json b/src/tts/vits/configs/sanskrit_base.json similarity index 100% rename from tts/vits/configs/sanskrit_base.json rename to src/tts/vits/configs/sanskrit_base.json diff --git a/tts/vits/configs/shanghainese_base.json b/src/tts/vits/configs/shanghainese_base.json similarity index 100% rename from tts/vits/configs/shanghainese_base.json rename to src/tts/vits/configs/shanghainese_base.json diff --git a/tts/vits/configs/zero_japanese_base2.json b/src/tts/vits/configs/zero_japanese_base2.json similarity index 100% rename from tts/vits/configs/zero_japanese_base2.json rename to src/tts/vits/configs/zero_japanese_base2.json diff --git a/tts/vits/configs/zh_ja_mixture_base.json b/src/tts/vits/configs/zh_ja_mixture_base.json similarity index 100% rename from tts/vits/configs/zh_ja_mixture_base.json rename to src/tts/vits/configs/zh_ja_mixture_base.json diff --git a/tts/vits/data_utils.py b/src/tts/vits/data_utils.py similarity index 100% rename from tts/vits/data_utils.py rename to src/tts/vits/data_utils.py diff --git a/tts/vits/filelists/cjke_train_filelist.txt b/src/tts/vits/filelists/cjke_train_filelist.txt similarity index 100% rename from tts/vits/filelists/cjke_train_filelist.txt rename to src/tts/vits/filelists/cjke_train_filelist.txt diff --git a/tts/vits/filelists/cjke_train_filelist.txt.cleaned b/src/tts/vits/filelists/cjke_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/cjke_train_filelist.txt.cleaned rename to src/tts/vits/filelists/cjke_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/cjke_val_filelist.txt b/src/tts/vits/filelists/cjke_val_filelist.txt similarity index 100% rename from tts/vits/filelists/cjke_val_filelist.txt rename to src/tts/vits/filelists/cjke_val_filelist.txt diff --git a/tts/vits/filelists/cjke_val_filelist.txt.cleaned b/src/tts/vits/filelists/cjke_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/cjke_val_filelist.txt.cleaned rename to src/tts/vits/filelists/cjke_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/cjks_train_filelist.txt b/src/tts/vits/filelists/cjks_train_filelist.txt similarity index 100% rename from tts/vits/filelists/cjks_train_filelist.txt rename to src/tts/vits/filelists/cjks_train_filelist.txt diff --git a/tts/vits/filelists/cjks_train_filelist.txt.cleaned b/src/tts/vits/filelists/cjks_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/cjks_train_filelist.txt.cleaned rename to src/tts/vits/filelists/cjks_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/cjks_val_filelist.txt b/src/tts/vits/filelists/cjks_val_filelist.txt similarity index 100% rename from tts/vits/filelists/cjks_val_filelist.txt rename to src/tts/vits/filelists/cjks_val_filelist.txt diff --git a/tts/vits/filelists/cjks_val_filelist.txt.cleaned b/src/tts/vits/filelists/cjks_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/cjks_val_filelist.txt.cleaned rename to src/tts/vits/filelists/cjks_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/fox_train_filelist.txt b/src/tts/vits/filelists/fox_train_filelist.txt similarity index 100% rename from tts/vits/filelists/fox_train_filelist.txt rename to src/tts/vits/filelists/fox_train_filelist.txt diff --git a/tts/vits/filelists/fox_train_filelist.txt.cleaned b/src/tts/vits/filelists/fox_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/fox_train_filelist.txt.cleaned rename to src/tts/vits/filelists/fox_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/fox_val_filelist.txt b/src/tts/vits/filelists/fox_val_filelist.txt similarity index 100% rename from tts/vits/filelists/fox_val_filelist.txt rename to src/tts/vits/filelists/fox_val_filelist.txt diff --git a/tts/vits/filelists/fox_val_filelist.txt.cleaned b/src/tts/vits/filelists/fox_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/fox_val_filelist.txt.cleaned rename to src/tts/vits/filelists/fox_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/mix_train_filelist.txt b/src/tts/vits/filelists/mix_train_filelist.txt similarity index 100% rename from tts/vits/filelists/mix_train_filelist.txt rename to src/tts/vits/filelists/mix_train_filelist.txt diff --git a/tts/vits/filelists/mix_train_filelist.txt.cleaned b/src/tts/vits/filelists/mix_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/mix_train_filelist.txt.cleaned rename to src/tts/vits/filelists/mix_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/mix_val_filelist.txt b/src/tts/vits/filelists/mix_val_filelist.txt similarity index 100% rename from tts/vits/filelists/mix_val_filelist.txt rename to src/tts/vits/filelists/mix_val_filelist.txt diff --git a/tts/vits/filelists/mix_val_filelist.txt.cleaned b/src/tts/vits/filelists/mix_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/mix_val_filelist.txt.cleaned rename to src/tts/vits/filelists/mix_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/sanskrit_train_filelist.txt b/src/tts/vits/filelists/sanskrit_train_filelist.txt similarity index 100% rename from tts/vits/filelists/sanskrit_train_filelist.txt rename to src/tts/vits/filelists/sanskrit_train_filelist.txt diff --git a/tts/vits/filelists/sanskrit_train_filelist.txt.cleaned b/src/tts/vits/filelists/sanskrit_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/sanskrit_train_filelist.txt.cleaned rename to src/tts/vits/filelists/sanskrit_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/sanskrit_val_filelist.txt b/src/tts/vits/filelists/sanskrit_val_filelist.txt similarity index 100% rename from tts/vits/filelists/sanskrit_val_filelist.txt rename to src/tts/vits/filelists/sanskrit_val_filelist.txt diff --git a/tts/vits/filelists/sanskrit_val_filelist.txt.cleaned b/src/tts/vits/filelists/sanskrit_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/sanskrit_val_filelist.txt.cleaned rename to src/tts/vits/filelists/sanskrit_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/zaonhe_train_filelist.txt.cleaned b/src/tts/vits/filelists/zaonhe_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/zaonhe_train_filelist.txt.cleaned rename to src/tts/vits/filelists/zaonhe_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/zaonhe_val_filelist.txt.cleaned b/src/tts/vits/filelists/zaonhe_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/zaonhe_val_filelist.txt.cleaned rename to src/tts/vits/filelists/zaonhe_val_filelist.txt.cleaned diff --git a/tts/vits/filelists/zero_train_filelist.txt b/src/tts/vits/filelists/zero_train_filelist.txt similarity index 100% rename from tts/vits/filelists/zero_train_filelist.txt rename to src/tts/vits/filelists/zero_train_filelist.txt diff --git a/tts/vits/filelists/zero_train_filelist.txt.cleaned b/src/tts/vits/filelists/zero_train_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/zero_train_filelist.txt.cleaned rename to src/tts/vits/filelists/zero_train_filelist.txt.cleaned diff --git a/tts/vits/filelists/zero_val_filelist.txt b/src/tts/vits/filelists/zero_val_filelist.txt similarity index 100% rename from tts/vits/filelists/zero_val_filelist.txt rename to src/tts/vits/filelists/zero_val_filelist.txt diff --git a/tts/vits/filelists/zero_val_filelist.txt.cleaned b/src/tts/vits/filelists/zero_val_filelist.txt.cleaned similarity index 100% rename from tts/vits/filelists/zero_val_filelist.txt.cleaned rename to src/tts/vits/filelists/zero_val_filelist.txt.cleaned diff --git a/tts/vits/inference.ipynb b/src/tts/vits/inference.ipynb similarity index 100% rename from tts/vits/inference.ipynb rename to src/tts/vits/inference.ipynb diff --git a/tts/vits/losses.py b/src/tts/vits/losses.py similarity index 100% rename from tts/vits/losses.py rename to src/tts/vits/losses.py diff --git a/tts/vits/mel_processing.py b/src/tts/vits/mel_processing.py similarity index 100% rename from tts/vits/mel_processing.py rename to src/tts/vits/mel_processing.py diff --git a/tts/vits/models.py b/src/tts/vits/models.py similarity index 100% rename from tts/vits/models.py rename to src/tts/vits/models.py diff --git a/tts/vits/modules.py b/src/tts/vits/modules.py similarity index 100% rename from tts/vits/modules.py rename to src/tts/vits/modules.py diff --git a/tts/vits/monotonic_align/__init__.py b/src/tts/vits/monotonic_align/__init__.py similarity index 100% rename from tts/vits/monotonic_align/__init__.py rename to src/tts/vits/monotonic_align/__init__.py diff --git a/tts/vits/monotonic_align/core.pyx b/src/tts/vits/monotonic_align/core.pyx similarity index 100% rename from tts/vits/monotonic_align/core.pyx rename to src/tts/vits/monotonic_align/core.pyx diff --git a/tts/vits/monotonic_align/setup.py b/src/tts/vits/monotonic_align/setup.py similarity index 100% rename from tts/vits/monotonic_align/setup.py rename to src/tts/vits/monotonic_align/setup.py diff --git a/tts/vits/preprocess.py b/src/tts/vits/preprocess.py similarity index 100% rename from tts/vits/preprocess.py rename to src/tts/vits/preprocess.py diff --git a/tts/vits/requirements.txt b/src/tts/vits/requirements.txt similarity index 100% rename from tts/vits/requirements.txt rename to src/tts/vits/requirements.txt diff --git a/tts/vits/resources/fig_1a.png b/src/tts/vits/resources/fig_1a.png similarity index 100% rename from tts/vits/resources/fig_1a.png rename to src/tts/vits/resources/fig_1a.png diff --git a/tts/vits/resources/fig_1b.png b/src/tts/vits/resources/fig_1b.png similarity index 100% rename from tts/vits/resources/fig_1b.png rename to src/tts/vits/resources/fig_1b.png diff --git a/tts/vits/resources/training.png b/src/tts/vits/resources/training.png similarity index 100% rename from tts/vits/resources/training.png rename to src/tts/vits/resources/training.png diff --git a/tts/vits/text/LICENSE b/src/tts/vits/text/LICENSE similarity index 100% rename from tts/vits/text/LICENSE rename to src/tts/vits/text/LICENSE diff --git a/tts/vits/text/__init__.py b/src/tts/vits/text/__init__.py similarity index 100% rename from tts/vits/text/__init__.py rename to src/tts/vits/text/__init__.py diff --git a/tts/vits/text/cantonese.py b/src/tts/vits/text/cantonese.py similarity index 100% rename from tts/vits/text/cantonese.py rename to src/tts/vits/text/cantonese.py diff --git a/tts/vits/text/cleaners.py b/src/tts/vits/text/cleaners.py similarity index 100% rename from tts/vits/text/cleaners.py rename to src/tts/vits/text/cleaners.py diff --git a/tts/vits/text/english.py b/src/tts/vits/text/english.py similarity index 100% rename from tts/vits/text/english.py rename to src/tts/vits/text/english.py diff --git a/tts/vits/text/japanese.py b/src/tts/vits/text/japanese.py similarity index 100% rename from tts/vits/text/japanese.py rename to src/tts/vits/text/japanese.py diff --git a/tts/vits/text/korean.py b/src/tts/vits/text/korean.py similarity index 100% rename from tts/vits/text/korean.py rename to src/tts/vits/text/korean.py diff --git a/tts/vits/text/mandarin.py b/src/tts/vits/text/mandarin.py similarity index 100% rename from tts/vits/text/mandarin.py rename to src/tts/vits/text/mandarin.py diff --git a/tts/vits/text/ngu_dialect.py b/src/tts/vits/text/ngu_dialect.py similarity index 100% rename from tts/vits/text/ngu_dialect.py rename to src/tts/vits/text/ngu_dialect.py diff --git a/tts/vits/text/sanskrit.py b/src/tts/vits/text/sanskrit.py similarity index 100% rename from tts/vits/text/sanskrit.py rename to src/tts/vits/text/sanskrit.py diff --git a/tts/vits/text/shanghainese.py b/src/tts/vits/text/shanghainese.py similarity index 100% rename from tts/vits/text/shanghainese.py rename to src/tts/vits/text/shanghainese.py diff --git a/tts/vits/text/symbols.py b/src/tts/vits/text/symbols.py similarity index 100% rename from tts/vits/text/symbols.py rename to src/tts/vits/text/symbols.py diff --git a/tts/vits/text/thai.py b/src/tts/vits/text/thai.py similarity index 100% rename from tts/vits/text/thai.py rename to src/tts/vits/text/thai.py diff --git a/tts/vits/train.py b/src/tts/vits/train.py similarity index 100% rename from tts/vits/train.py rename to src/tts/vits/train.py diff --git a/tts/vits/train_ms.py b/src/tts/vits/train_ms.py similarity index 100% rename from tts/vits/train_ms.py rename to src/tts/vits/train_ms.py diff --git a/tts/vits/transforms.py b/src/tts/vits/transforms.py similarity index 100% rename from tts/vits/transforms.py rename to src/tts/vits/transforms.py diff --git a/tts/vits/utils.py b/src/tts/vits/utils.py similarity index 100% rename from tts/vits/utils.py rename to src/tts/vits/utils.py From 5e807a7e56edd08c5d42f0fd782a4dbb6b024a99 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 27 Mar 2024 16:36:04 +0800 Subject: [PATCH 04/11] feat: tesou --- src/blackbox/blackbox_factory.py | 10 +++++----- src/blackbox/tesou.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index ff0b4ea..d2828ca 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -14,11 +14,11 @@ class BlackboxFactory: self.tts = TTS() self.asr = ASR(".env.yaml") self.sentiment = Sentiment() - #self.sum = SUM() - #self.calculator = Calculator() - #self.audio_to_text = AudioToText() - #self.text_to_audio = TextToAudio() - #self.tesou = Tesou() + self.sum = SUM() + self.calculator = Calculator() + self.audio_to_text = AudioToText() + self.text_to_audio = TextToAudio() + self.tesou = Tesou() def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index a81fe9b..d07d898 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -23,16 +23,16 @@ class Tesou(Blackbox): "user_id": id, "prompt": prompt, } - + print(message) response = requests.post(url, json=message) - return response + return response.json() async def fast_api_handler(self, request: Request) -> Response: try: data = await request.json() except: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) - user_id = data.get("id") + user_id = data.get("user_id") user_prompt = data.get("prompt") if user_prompt is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) From fb941997546186ae182cfc8eec8c24937eca561f Mon Sep 17 00:00:00 2001 From: superobk Date: Thu, 28 Mar 2024 17:31:42 +0800 Subject: [PATCH 05/11] refactor: relocate asr model --- src/{asr => blackbox}/asr.py | 6 +++--- src/blackbox/blackbox_factory.py | 2 +- src/blackbox/tesou.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) rename src/{asr => blackbox}/asr.py (90%) diff --git a/src/asr/asr.py b/src/blackbox/asr.py similarity index 90% rename from src/asr/asr.py rename to src/blackbox/asr.py index fa879f8..7ba88fb 100644 --- a/src/asr/asr.py +++ b/src/blackbox/asr.py @@ -4,9 +4,9 @@ from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse -from .rapid_paraformer.utils import read_yaml -from .rapid_paraformer import RapidParaformer -from ..blackbox.blackbox import Blackbox +from ..asr.rapid_paraformer.utils import read_yaml +from ..asr.rapid_paraformer import RapidParaformer +from .blackbox import Blackbox class ASR(Blackbox): diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index d2828ca..973b8a3 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -1,7 +1,7 @@ from .sum import SUM from .sentiment import Sentiment from .tts import TTS -from ..asr.asr import ASR +from .asr import ASR from .audio_to_text import AudioToText from .blackbox import Blackbox from .calculator import Calculator diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index d07d898..6b95d3f 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -23,7 +23,6 @@ class Tesou(Blackbox): "user_id": id, "prompt": prompt, } - print(message) response = requests.post(url, json=message) return response.json() From 726248246ef5a3296d371abbd778bd17caf0d182 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 3 Apr 2024 14:03:18 +0800 Subject: [PATCH 06/11] Updated main.py and .gitignore --- .gitignore | 3 ++- main.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 869061a..e4c3af3 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ cython_debug/ .DS_Store playground.py .env* -models \ No newline at end of file +models +.idea/ \ No newline at end of file diff --git a/main.py b/main.py index e8e2764..8233a70 100644 --- a/main.py +++ b/main.py @@ -25,4 +25,4 @@ async def workflows(reqest: Request): print("workflows") if __name__ == "__main__": - uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") + uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") From 703bdaec5af2b92e2ba9df09d28473b698dc89b2 Mon Sep 17 00:00:00 2001 From: superobk Date: Mon, 8 Apr 2024 10:29:50 +0800 Subject: [PATCH 07/11] feat: dotchain lang --- main.py | 29 +- src/dotchain/README.md | 30 ++ src/dotchain/main.dc | 16 + src/dotchain/main.py | 29 ++ src/dotchain/runtime/__init__.py | 0 src/dotchain/runtime/ast.py | 384 ++++++++++++++++ src/dotchain/runtime/interpreter.py | 420 ++++++++++++++++++ src/dotchain/runtime/runtime.py | 44 ++ src/dotchain/runtime/tests/__init__.py | 0 .../runtime/tests/test_expression_parser.py | 153 +++++++ src/dotchain/runtime/tests/test_runtime.py | 7 + src/dotchain/runtime/tests/test_tokenizer.py | 151 +++++++ src/dotchain/runtime/tokenizer.py | 259 +++++++++++ 13 files changed, 1518 insertions(+), 4 deletions(-) create mode 100644 src/dotchain/README.md create mode 100644 src/dotchain/main.dc create mode 100644 src/dotchain/main.py create mode 100644 src/dotchain/runtime/__init__.py create mode 100644 src/dotchain/runtime/ast.py create mode 100644 src/dotchain/runtime/interpreter.py create mode 100644 src/dotchain/runtime/runtime.py create mode 100644 src/dotchain/runtime/tests/__init__.py create mode 100644 src/dotchain/runtime/tests/test_expression_parser.py create mode 100644 src/dotchain/runtime/tests/test_runtime.py create mode 100644 src/dotchain/runtime/tests/test_tokenizer.py create mode 100644 src/dotchain/runtime/tokenizer.py diff --git a/main.py b/main.py index e8e2764..3193fdb 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,11 @@ -from typing import Union +from typing import Annotated, Union -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, Request, status, Form from fastapi.responses import JSONResponse +from src.dotchain.runtime.interpreter import program_parser +from src.dotchain.runtime.tokenizer import Tokenizer +from src.dotchain.runtime.runtime import Runtime from src.blackbox.blackbox_factory import BlackboxFactory import uvicorn @@ -20,9 +23,27 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await box.fast_api_handler(request) +def read_form_image(request: Request): + async def inner(field: str): + print(field) + return "image" + return inner + +def read_form_text(request: Request): + def inner(field: str): + print(field) + return "text" + return inner + @app.post("/workflows") -async def workflows(reqest: Request): - print("workflows") +async def workflows(script: Annotated[str, Form()], request: Request=None): + dsl_runtime = Runtime(exteral_fun={"print": print, + 'read_form_image': read_form_image(request), + "read_form_text": read_form_text(request)}) + t = Tokenizer() + t.init(script) + ast = program_parser(t) + ast.exec(dsl_runtime) if __name__ == "__main__": uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") diff --git a/src/dotchain/README.md b/src/dotchain/README.md new file mode 100644 index 0000000..ed2b399 --- /dev/null +++ b/src/dotchain/README.md @@ -0,0 +1,30 @@ +# Dotchain +Dotchain 是一種函數式編程語言. 文件後綴`.dc` + +# 語法 +``` +// 註解 + +// 變量宣告 +let hello = 123 + +// 函數宣告 +let add = (left, right) => { + // 返回值 + return left + right +} + +// TODO: 函數呼叫 +add(1,2) +add(3, add(1,2)) +// 以 . 呼叫函數,將以 . 前的值作為第一個參數 +// hello.add(2) 等價於 add(hello, 2) +``` +## Keywords +``` +let while if else true false +``` + +```bash +python -m unittest +``` \ No newline at end of file diff --git a/src/dotchain/main.dc b/src/dotchain/main.dc new file mode 100644 index 0000000..5100dd0 --- /dev/null +++ b/src/dotchain/main.dc @@ -0,0 +1,16 @@ +// 註解 + +// 變量宣告 +let hello = 123; + +// 函數宣告 +let add = (left, right) => { + // 返回值 + return left + right; +} + +// TODO 函數呼叫 +add(1,2); +add(3, add(1,2)); +// 以 . 呼叫函數,將以 . 前的值作為第一個參數 +// hello.add(2) == add(hello, 2); \ No newline at end of file diff --git a/src/dotchain/main.py b/src/dotchain/main.py new file mode 100644 index 0000000..552f201 --- /dev/null +++ b/src/dotchain/main.py @@ -0,0 +1,29 @@ + +from runtime.interpreter import program_parser +from runtime.runtime import Runtime +from runtime.tokenizer import Tokenizer +import json + +script = """ +let rec = (c) => { + print(c); + if c == 0 { + return "c + 1"; + } + rec(c-1); +} + +let main = () => { + print("hello 嘉妮"); + print(rec(10)); +} + +main(); +""" + +if __name__ == "__main__": + t = Tokenizer() + t.init(script) + runtime = Runtime(exteral_fun={"print": print}) + ast = program_parser(t) + result = ast.exec(runtime) \ No newline at end of file diff --git a/src/dotchain/runtime/__init__.py b/src/dotchain/runtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dotchain/runtime/ast.py b/src/dotchain/runtime/ast.py new file mode 100644 index 0000000..c8006d5 --- /dev/null +++ b/src/dotchain/runtime/ast.py @@ -0,0 +1,384 @@ +from abc import ABC, abstractmethod + +from attr import dataclass + +from .runtime import Runtime + +@dataclass +class ReturnValue(): + value: any + +class Node(ABC): + def type(self): + return self.__class__.__name__ + +@dataclass +class Statement(Node, ABC): + + @abstractmethod + def exec(self, runtime: Runtime): + print(self) + pass + + @abstractmethod + def dict(self): + pass + +@dataclass +class Expression(Node): + + @abstractmethod + def eval(self, runtime: Runtime): + pass + + @abstractmethod + def dict(self): + pass + +@dataclass +class Literal(Expression): + value: str | int | float | bool + def eval(self, runtime: Runtime): + return self.value + + def dict(self) -> dict: + return { + "type": "Literal", + "value": self.value + } + +@dataclass +class StringLiteral(Literal): + value: str + + def dict(self) -> dict: + return { + "type": "StringLiteral", + "value": self.value + } + +@dataclass +class IntLiteral(Literal): + value: int + + def dict(self): + return { + "type": "IntLiteral", + "value": self.value + } + +@dataclass +class FloatLiteral(Literal): + value: float + + def dict(self): + return { + "type": "FloatLiteral", + "value": self.value + } + +@dataclass +class BoolLiteral(Literal): + value: bool + + def dict(self): + return { + "type": "FloatLiteral", + "value": self.value + } + +@dataclass +class UnaryExpression(Expression): + operator: str + expression: Expression + def eval(self, runtime: Runtime): + if self.operator == "-": + return -self.expression.eval(runtime) + if self.operator == "!": + return not self.expression.eval(runtime) + return self.expression.eval(runtime) + + def dict(self): + return { + "type": "UnaryExpression", + "operator": self.operator, + "argument": self.expression.dict() + } + +@dataclass +class Program(Statement): + body: list[Statement] + + def exec(self, runtime: Runtime): + index = 0 + while index < len(self.body): + statement = self.body[index] + result = statement.exec(runtime) + if isinstance(result, ReturnValue): + return result + index += 1 + + def dict(self): + return { + "type": self.type(), + "body": [statement.dict() for statement in self.body] + } + +@dataclass +class Identifier(Expression): + name: str + def eval(self,runtime: Runtime): + return runtime.deep_get_value(self.name) + + def dict(self): + return { + "type": self.type(), + "name": self.name + } + +@dataclass +class Block(Statement): + body: list[Statement] + def exec(self, runtime: Runtime): + index = 0 + while index < len(self.body): + statement = self.body[index] + result = statement.exec(runtime) + if isinstance(result, ReturnValue): + return result + if isinstance(result, BreakStatement): + return result + index += 1 + + def dict(self): + return { + "type": "Block", + "body": [statement.dict() for statement in self.body] + } + +@dataclass +class WhileStatement(Statement): + test: Expression + body: Block + + def exec(self, runtime: Runtime): + while self.test.eval(runtime): + while_runtime = Runtime(parent=runtime,name="while") + result = self.body.exec(while_runtime) + if isinstance(result, ReturnValue): + return result + if isinstance(result, BreakStatement): + return result + + def dict(self): + return { + "type": "WhileStatement", + "test": self.test.dict(), + "body": self.body.dict() + } + +@dataclass +class BreakStatement(Statement): + + def exec(self, _: Runtime): + return self + + def dict(self): + return { + "type": "BreakStatement" + } + +@dataclass +class ReturnStatement(Statement): + value: Expression + + def exec(self, runtime: Runtime): + return ReturnValue(self.value.eval(runtime)) + + def dict(self): + return { + "type": "ReturnStatement", + "value": self.value.dict() + } + +@dataclass +class IfStatement(Statement): + test: Expression + consequent: Block + alternate: Block + + def exec(self, runtime: Runtime): + if_runtime = Runtime(parent=runtime) + if self.test.eval(runtime): + return self.consequent.exec(if_runtime) + else: + return self.alternate.exec(if_runtime) + + def dict(self): + return { + "type": "IfStatement", + "test": self.test.dict(), + "consequent": self.consequent.dict(), + "alternate": self.alternate.dict() + } + +@dataclass +class VariableDeclaration(Statement): + id: Identifier + value: Expression + value_type: str = "any" + def exec(self, runtime: Runtime): + runtime.declare(self.id.name, self.value.eval(runtime)) + + def dict(self): + return { + "type": "VariableDeclaration", + "id": self.id.dict(), + "value": self.value.dict() + } + +@dataclass +class Assignment(Statement): + id: Identifier + value: Expression + + def exec(self, runtime: Runtime): + runtime.assign(self.id.name, self.value.eval(runtime)) + + def dict(self): + return { + "type": "Assignment", + "id": self.id.dict(), + "value": self.value.dict() + } + +@dataclass +class Argument(Expression): + id: Identifier + value: Expression + + def dict(self): + return { + "type": "Argument", + "id": self.id.dict(), + "value": self.value.dict() + } + +@dataclass +class BinaryExpression(Expression): + left: Expression + operator: str + right: Expression + + def eval(self, runtime: Runtime): + left = self.left.eval(runtime) + right = self.right.eval(runtime) + if self.operator == "+": + return left + right + if self.operator == "-": + return left - right + if self.operator == "*": + return left * right + if self.operator == "/": + return left / right + if self.operator == "%": + return left % right + if self.operator == "<": + return left < right + if self.operator == ">": + return left > right + if self.operator == "<=": + return left <= right + if self.operator == ">=": + return left >= right + if self.operator == "==": + return left == right + if self.operator == "!=": + return left != right + if self.operator == "&&": + return left and right + if self.operator == "||": + return left or right + return None + + def dict(self): + return { + "type": "BinaryExpression", + "left": self.left.dict(), + "operator": self.operator, + "right": self.right.dict() + } + +@dataclass +class CallExpression(Expression): + callee: Identifier + arguments: list[Expression] + def exec(self, runtime: Runtime, args: list=None): + if args == None: + args = [] + for index, argument in enumerate(self.arguments): + args.append(argument.eval(runtime)) + if runtime.has_value(self.callee.name): + fun:FunEnv = runtime.get_value(self.callee.name) + return fun.exec(args) + if runtime.parent is not None: + return self.exec(runtime.parent,args) + if self.callee.name in runtime.exteral_fun: + return runtime.exteral_fun[self.callee.name](*args) + + + def eval(self, runtime): + result = self.exec(runtime) + if result is not None: + return result.value + + def dict(self): + return { + "type": "CallExpression", + "callee": self.callee.dict(), + "arguments": [argument.dict() for argument in self.arguments] + } + +@dataclass +class Fun(Statement): + params: list[Identifier] + body: Block + + def exec(self, runtime: Runtime): + return self.body.exec(runtime) + + def eval(self, runtime: Runtime): + return FunEnv(runtime, self) + + def dict(self): + return { + "type": "Fun", + "params": [param.dict() for param in self.params], + "body": self.body.dict() + } + +class EmptyStatement(Statement): + + def exec(self, _: Runtime): + return None + + def eval(self, _: Runtime): + return None + + def dict(self): + return { + "type": "EmptyStatement" + } + + +class FunEnv(): + + def __init__(self, parent: Runtime, body: Fun): + self.parent = parent + self.body = body + + def exec(self, args: list): + fun_runtime = Runtime(parent=self.parent) + for index, param in enumerate(self.body.params): + fun_runtime.declare(param.name, args[index]) + return self.body.exec(fun_runtime) \ No newline at end of file diff --git a/src/dotchain/runtime/interpreter.py b/src/dotchain/runtime/interpreter.py new file mode 100644 index 0000000..bf33cb8 --- /dev/null +++ b/src/dotchain/runtime/interpreter.py @@ -0,0 +1,420 @@ +from ast import Expression +import copy +from .ast import Assignment, BinaryExpression, Block, BoolLiteral, BreakStatement, CallExpression, EmptyStatement, FloatLiteral, Fun, Identifier, IfStatement, IntLiteral, Program, ReturnStatement, Statement, StringLiteral, UnaryExpression, VariableDeclaration, WhileStatement +from .tokenizer import Token, TokenType, Tokenizer + +unary_prev_statement = [ + TokenType.COMMENTS, + TokenType.LEFT_PAREN, + TokenType.COMMA, + TokenType.LEFT_BRACE, + TokenType.RIGHT_BRACE, + TokenType.SEMICOLON, + TokenType.LET, + TokenType.RETURN, + TokenType.IF, + TokenType.ELSE, + TokenType.WHILE, + TokenType.FOR, + TokenType.LOGICAL_OPERATOR, + TokenType.NOT, + TokenType.ASSIGNMENT, + TokenType.MULTIPLICATIVE_OPERATOR, + TokenType.ADDITIVE_OPERATOR, + TokenType.ARROW, +] + +unary_end_statement = [ + TokenType.MULTIPLICATIVE_OPERATOR, + TokenType.ADDITIVE_OPERATOR, + TokenType.LOGICAL_OPERATOR, +] + +end_statement = [ + TokenType.SEMICOLON, + TokenType.COMMA, + TokenType.ARROW, + TokenType.RETURN, + TokenType.LET, + TokenType.IF, + TokenType.ELSE, + TokenType.WHILE, + TokenType.FOR, + TokenType.ASSIGNMENT, + TokenType.RIGHT_BRACE, + TokenType.LEFT_BRACE, +] + +def program_parser(tkr: Tokenizer): + statements = list[Statement]() + count = 0 + while True: + if tkr.token() is None: + break + if tkr.token().type == TokenType.SEMICOLON: + tkr.next() + continue + statement = statement_parser(tkr) + statements.append(statement) + count += 1 + return Program(statements) + +def if_parser(tkr: Tokenizer): + tkr.eat(TokenType.IF) + condition = ExpressionParser(tkr).parse() + block = block_statement(tkr) + if tkr.type_is(TokenType.ELSE): + tkr.eat(TokenType.ELSE) + if tkr.type_is(TokenType.IF): + print("else if") + return IfStatement(condition, block, Block([if_parser(tkr)])) + return IfStatement(condition, block, block_statement(tkr)) + return IfStatement(condition, block, Block([])) + +def while_parser(tkr: Tokenizer): + tkr.eat(TokenType.WHILE) + condition = ExpressionParser(tkr).parse() + block = block_statement(tkr) + return WhileStatement(condition, block) + + +def identifier(tkr: Tokenizer): + token = tkr.token() + if token.type != TokenType.IDENTIFIER: + raise Exception("Invalid identifier", token) + tkr.next() + return Identifier(token.value) + +def block_statement(tkr: Tokenizer): + tkr.eat(TokenType.LEFT_BRACE) + statements = list[Statement]() + while True: + if tkr.token() is None: + raise Exception("Invalid block expression", tkr.token()) + if tkr.tokenType() == TokenType.RIGHT_BRACE: + tkr.eat(TokenType.RIGHT_BRACE) + break + if tkr.tokenType() == TokenType.SEMICOLON: + tkr.next() + continue + statements.append(statement_parser(tkr)) + return Block(statements) + + +def return_parser(tkr: Tokenizer): + tkr.eat(TokenType.RETURN) + return ReturnStatement(ExpressionParser(tkr).parse()) + +def statement_parser(tkr: Tokenizer): + token = tkr.token() + if token is None: + return EmptyStatement() + if token.type == TokenType.SEMICOLON: + tkr.next() + return EmptyStatement() + if token.type == TokenType.LET: + return let_expression_parser(tkr) + if _try_assignment_expression(tkr): + return assignment_parser(tkr) + if token.type == TokenType.IF: + return if_parser(tkr) + if token.type == TokenType.WHILE: + return while_parser(tkr) + if token.type == TokenType.RETURN: + return return_parser(tkr) + if token.type == TokenType.BREAK: + tkr.eat(TokenType.BREAK) + return BreakStatement() + return ExpressionParser(tkr).parse() + +def assignment_parser(tkr: Tokenizer): + id = identifier(tkr) + tkr.eat(TokenType.ASSIGNMENT) + return Assignment(id, ExpressionParser(tkr).parse()) + +def let_expression_parser(tkr: Tokenizer): + tkr.eat(TokenType.LET) + token = tkr.token() + if token.type != TokenType.IDENTIFIER: + raise Exception("Invalid let statement", token) + id = identifier(tkr) + token = tkr.token() + if token is None: + raise Exception("Invalid let statement", token) + if token.type != TokenType.ASSIGNMENT: + raise Exception("Invalid let statement", token.type) + tkr.next() + ast = ExpressionParser(tkr).parse() + return VariableDeclaration(id, ast) + +class ExpressionParser: + + def __init__(self, tkr: Tokenizer): + self.stack = list[Expression | Token]() + self.operator_stack = list[Token]() + self.tkr = tkr + + def parse(self, unary = False): + while not self.is_end(): + token = self.tkr.token() + if unary and not self.is_unary() and token.type in unary_end_statement: + break + if self.is_unary(): + self.push_stack(self.unary_expression_parser()) + elif self._try_fun_expression(): + return self.fun_expression() + # -(hello x 123) // !(true and false) + elif unary and token.type == TokenType.LEFT_PAREN: + self.tkr.next() + self.push_stack(ExpressionParser(self.tkr).parse()) + elif self._is_operator(token) or token.type in [TokenType.LEFT_PAREN, TokenType.RIGHT_PAREN ]: + self.push_operator_stack(token) + self.tkr.next() + else: + self.push_stack(self.expression_parser()) + self.pop_all() + return self.expression() + + def expression(self): + if len(self.stack) == 0: + return EmptyStatement() + if len(self.stack) == 1: + return self.stack[0] + return expression_list_to_binary(self.stack) + + def expression_parser(self): + token = self.tkr.token() + if token is None: + return EmptyStatement() + expression = None + if token.type == TokenType.INT: + self.tkr.eat(TokenType.INT) + expression = IntLiteral(int(token.value)) + elif token.type == TokenType.FLOAT: + self.tkr.eat(TokenType.FLOAT) + expression = FloatLiteral(float(token.value)) + elif token.type == TokenType.STRING: + self.tkr.eat(TokenType.STRING) + expression = StringLiteral(token.value[1:-1]) + elif token.type == TokenType.BOOL: + self.tkr.eat(TokenType.BOOL) + expression = BoolLiteral(token.value == "true") + elif token.type == TokenType.IDENTIFIER: + expression = self.identifier_or_fun_call_parser() + return expression + + def _try_fun_expression(self): + return _try_fun_expression(self.tkr) + + def fun_expression(self): + tkr = self.tkr + tkr.next() + args = list[Identifier]() + token_type = tkr.tokenType() + while token_type != TokenType.RIGHT_PAREN: + args.append(Identifier(tkr.token().value)) + tkr.next() + token_type = tkr.tokenType() + if token_type == TokenType.RIGHT_PAREN: + break + tkr.next() + token_type = tkr.tokenType() + token_type = tkr.next_token_type() + if token_type != TokenType.ARROW: + raise Exception("Invalid fun_expression", tkr.token()) + tkr.next() + return Fun(args, block_statement(tkr)) + + def push_stack(self, expression: Expression | Token): + self.stack.append(expression) + + def _pop_by_right_paren(self): + token = self.operator_stack.pop() + if token.type != TokenType.LEFT_PAREN: + self.push_stack(token) + self._pop_by_right_paren() + + def pop(self): + self.push_stack(self.operator_stack.pop()) + + def pop_all(self): + while len(self.operator_stack) > 0: + self.pop() + + def push_operator_stack(self, token: Token): + if len(self.operator_stack) == 0: + self.operator_stack.append(token) + return + if token.type == TokenType.LEFT_PAREN: + self.operator_stack.append(token) + return + if token.type == TokenType.RIGHT_PAREN: + self._pop_by_right_paren() + return + top_operator = self.operator_stack[-1] + if top_operator.type == TokenType.LEFT_PAREN: + self.operator_stack.append(token) + return + # priority is in descending order + if self._priority(token) >= self._priority(top_operator): + self.pop() + self.push_operator_stack(token) + return + self.operator_stack.append(token) + + def unary_expression_parser(self): + token = self.tkr.token() + self.tkr.next() + return UnaryExpression(token.value, ExpressionParser(self.tkr).parse(True)) + + def identifier_or_fun_call_parser(self): + id = self.identifier() + tokenType = self.tkr.tokenType() + if tokenType == TokenType.LEFT_PAREN: + return self.fun_call_parser(id) + return id + + def fun_call_parser(self, id: Identifier): + self.tkr.eat(TokenType.LEFT_PAREN) + args = list[Expression]() + while self.tkr.tokenType() != TokenType.RIGHT_PAREN: + args.append(ExpressionParser(self.tkr).parse()) + if self.tkr.tokenType() == TokenType.COMMA: + self.tkr.eat(TokenType.COMMA) + self.tkr.eat(TokenType.RIGHT_PAREN) + return CallExpression(id, args) + + def identifier(self): + return identifier(self.tkr) + + def is_unary(self): + token = self.tkr.token() + if not self.unary_operator(token): + return False + if token.type == TokenType.NOT: + return True + prev_token = self.tkr.get_prev() + if prev_token is None: + return True + if prev_token.type == TokenType.LEFT_PAREN: + return True + if prev_token.type in unary_prev_statement: + return True + return False + + def unary_operator(self, token: Token): + if token is None: + return False + return token.value in ["+", "-", "!"] + + def _has_brackets(self): + return TokenType.LEFT_PAREN in map(lambda x: x.type, self.operator_stack) + + def is_end(self): + token = self.tkr.token() + if token is None: + return True + if token.type == TokenType.SEMICOLON: + return True + if not self._has_brackets() and token.type == TokenType.RIGHT_PAREN: + return True + if token.type in end_statement: + return True + return False + + def _is_operator(self, token: Token): + if token is None: + return False + return token.type in [TokenType.ADDITIVE_OPERATOR, TokenType.MULTIPLICATIVE_OPERATOR, TokenType.LOGICAL_OPERATOR, TokenType.NOT] + + def _debug_print_tokens(self): + print("operator stack:----") + for token in self.operator_stack: + print(token) + + def _debug_print_stack(self): + print("stack:----") + for expression in self.stack: + print(expression) + + def _priority(self, token: Token): + return _priority(token.value) + +def expression_list_to_binary(expression_list: list[Expression | Token], stack: list = None): + if stack is None: + stack = list() + if len(expression_list) == 0: + return stack[0] + top = expression_list[0] + if isinstance(top, Token): + right = stack.pop() + left = stack.pop() + return expression_list_to_binary(expression_list[1:], stack + [BinaryExpression(left, top.value, right)]) + else: + stack.append(top) + return expression_list_to_binary(expression_list[1:], stack) + +def _priority(operator: str): + priority = 0 + if operator in ["*", "/", "%"]: + return priority + priority += 1 + if operator in ["+", "-"]: + return priority + priority += 1 + if operator in ["<", ">", "<=", ">="]: + return priority + priority += 1 + if operator in ["==", "!="]: + return priority + priority += 1 + if operator in ["&&"]: + return priority + priority += 1 + if operator in ["||"]: + return priority + priority += 1 + return priority + +def _try_assignment_expression(tkr: Tokenizer): + tkr = copy.deepcopy(tkr) + token = tkr.token() + if token is None: + return False + if token.type != TokenType.IDENTIFIER: + return False + tkr.next() + token = tkr.token() + if token is None: + return False + if token.type != TokenType.ASSIGNMENT: + return False + return True + +def _try_fun_expression(_tkr: Tokenizer): + tkr = copy.deepcopy(_tkr) + token = tkr.token() + if token is None: + return False + if token.type != TokenType.LEFT_PAREN: + return False + tkr.next() + token_type = tkr.tokenType() + while token_type != TokenType.RIGHT_PAREN: + if token_type == TokenType.IDENTIFIER: + tkr.next() + token_type = tkr.tokenType() + if token_type == TokenType.RIGHT_PAREN: + break + if token_type != TokenType.COMMA: + return False + tkr.next() + token_type = tkr.tokenType() + if token_type == TokenType.RIGHT_PAREN: + return False + else: + return False + token_type = tkr.next_token_type() + if token_type != TokenType.ARROW: + return False + return True \ No newline at end of file diff --git a/src/dotchain/runtime/runtime.py b/src/dotchain/runtime/runtime.py new file mode 100644 index 0000000..65fd683 --- /dev/null +++ b/src/dotchain/runtime/runtime.py @@ -0,0 +1,44 @@ +from ast import Expression + +from attr import dataclass + +class Runtime(): + + def __init__(self, context=None, parent=None, exteral_fun=None, name=None) -> None: + self.name = name + self.parent = parent + self.context = context if context is not None else dict() + self.exteral_fun = exteral_fun if exteral_fun is not None else dict() + + def has_value(self, identifier: str) -> bool: + return identifier in self.context + + def get_value(self, identifier: str): + return self.context.get(identifier) + + def deep_get_value(self, id: str): + if self.has_value(id): + return self.get_value(id) + if self.parent is not None: + return self.parent.deep_get_value(id) + return None + + def set_value(self, identifier: str, value): + self.context[identifier] = value + + def declare(self, identifier: str, value): + if self.has_value(identifier): + raise Exception(f"Variable {identifier} is already declared") + self.set_value(identifier, value) + + def assign(self, identifier: str, value): + if self.has_value(identifier): + self.set_value(identifier, value) + elif self.parent is not None: + self.parent.assign(identifier, value) + else: + raise Exception(f"Variable {identifier} is not declared") + + def show_values(self): + print(self.context) + diff --git a/src/dotchain/runtime/tests/__init__.py b/src/dotchain/runtime/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dotchain/runtime/tests/test_expression_parser.py b/src/dotchain/runtime/tests/test_expression_parser.py new file mode 100644 index 0000000..5f606ec --- /dev/null +++ b/src/dotchain/runtime/tests/test_expression_parser.py @@ -0,0 +1,153 @@ + +import unittest +from runtime.ast import BoolLiteral, CallExpression, FloatLiteral, Identifier, IntLiteral, UnaryExpression +from runtime.interpreter import ExpressionParser, _priority, _try_fun_expression +from runtime.tokenizer import TokenType, Tokenizer,Token + + + +class TestExpressionParser(unittest.TestCase): + + def test__try_fun_expression(self): + t = Tokenizer() + t.init("()") + self.assertFalse(_try_fun_expression(t)) + + t.init("() =>") + self.assertTrue(_try_fun_expression(t)) + + t.init("(a) =>") + self.assertTrue(_try_fun_expression(t)) + + t.init("(a,) =>") + self.assertFalse(_try_fun_expression(t)) + + t.init("(a,b,c,d) =>;") + self.assertTrue(_try_fun_expression(t)) + + t.init("(a,b,c,true) =>;") + self.assertFalse(_try_fun_expression(t)) + + t.init("(a,b,c,1.23) =>;") + self.assertFalse(_try_fun_expression(t)) + + def test_is_unary(self): + t = Tokenizer() + t.init("!") + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init("+") + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init("--123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init("+-123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init(")-123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertFalse(pred) + + t.init("=> - 123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init(", - 123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init("* - 123") + t.next() + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertTrue(pred) + + t.init("* - 123") + parser = ExpressionParser(t) + pred = parser.is_unary() + self.assertFalse(pred) + + def test_expression_parser(self): + t = Tokenizer() + t.init("a") + parser = ExpressionParser(t) + expression = parser.expression_parser() + self.assertIsInstance(expression, Identifier) + + t.init("true") + parser = ExpressionParser(t) + expression = parser.expression_parser() + self.assertIsInstance(expression, BoolLiteral) + self.assertEqual(expression.value, True) + + t.init("false") + parser = ExpressionParser(t) + expression = parser.expression_parser() + self.assertIsInstance(expression, BoolLiteral) + self.assertEqual(expression.value, False) + + t.init("12341") + parser = ExpressionParser(t) + expression = parser.expression_parser() + self.assertEqual(expression.value, 12341) + self.assertIsInstance(expression, IntLiteral) + + t.init("12341.42") + parser = ExpressionParser(t) + expression = parser.expression_parser() + self.assertEqual(expression.value, 12341.42) + self.assertIsInstance(expression, FloatLiteral) + + t.init("hello") + parser = ExpressionParser(t) + expression: Identifier = parser.expression_parser() + self.assertIsInstance(expression, Identifier) + self.assertEqual(expression.name, "hello") + + t.init("print()") + parser = ExpressionParser(t) + expression: CallExpression = parser.expression_parser() + self.assertIsInstance(expression, CallExpression) + self.assertEqual(expression.callee.name, "print") + + t.init("print(1,2,3,hello)") + parser = ExpressionParser(t) + expression: CallExpression = parser.expression_parser() + self.assertIsInstance(expression, CallExpression) + self.assertEqual(expression.callee.name, "print") + self.assertEqual(len(expression.arguments), 4) + + def test_binary_expression(self): + t = Tokenizer() + + def test__priority(self): + self.assertEqual(_priority("*"), 0) + self.assertEqual(_priority("/"), 0) + self.assertEqual(_priority("%"), 0) + self.assertEqual(_priority("+"), 1) + self.assertEqual(_priority("-"), 1) + self.assertEqual(_priority(">"), 2) + self.assertEqual(_priority("<"), 2) + self.assertEqual(_priority(">="), 2) + self.assertEqual(_priority("<="), 2) + self.assertEqual(_priority("=="), 3) + self.assertEqual(_priority("!="), 3) + self.assertEqual(_priority("&&"), 4) + self.assertEqual(_priority("||"), 5) \ No newline at end of file diff --git a/src/dotchain/runtime/tests/test_runtime.py b/src/dotchain/runtime/tests/test_runtime.py new file mode 100644 index 0000000..698db5d --- /dev/null +++ b/src/dotchain/runtime/tests/test_runtime.py @@ -0,0 +1,7 @@ + +import unittest + +class TestRuntime(unittest.TestCase): + + def test_eval(self): + self.assertTrue(True) \ No newline at end of file diff --git a/src/dotchain/runtime/tests/test_tokenizer.py b/src/dotchain/runtime/tests/test_tokenizer.py new file mode 100644 index 0000000..63188bf --- /dev/null +++ b/src/dotchain/runtime/tests/test_tokenizer.py @@ -0,0 +1,151 @@ + +import unittest +from runtime.tokenizer import TokenType, Tokenizer,Token + +class TestTokenizer(unittest.TestCase): + + def test_init(self): + t = Tokenizer() + self.assertEqual(t.script, "") + self.assertEqual(t.cursor, 0) + self.assertEqual(t.col, 0) + self.assertEqual(t.row, 0) + + def test_tokenizer(self): + t = Tokenizer() + t.init("a") + self.assertEqual(t.token().value, "a") + self.assertEqual(t.token().type, TokenType.IDENTIFIER) + + t.init("12341") + self.assertEqual(t.token().value, "12341") + self.assertEqual(t.token().type, TokenType.INT) + + t.init("12341.1234124") + self.assertEqual(t.token().value, "12341.1234124") + self.assertEqual(t.token().type, TokenType.FLOAT) + + t.init("false") + self.assertEqual(t.token().value, "false") + self.assertEqual(t.token().type, TokenType.BOOL) + + t.init("\"false\"") + self.assertEqual(t.token().value, "\"false\"") + self.assertEqual(t.token().type, TokenType.STRING) + + t.init("helloworld") + self.assertEqual(t.token().value, "helloworld") + self.assertEqual(t.token().type, TokenType.IDENTIFIER) + + t.init("!") + self.assertEqual(t.token().value, "!") + self.assertEqual(t.token().type, TokenType.NOT) + + t.init("==") + self.assertEqual(t.token().value, "==") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("!=") + self.assertEqual(t.token().value, "!=") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("<=") + self.assertEqual(t.token().value, "<=") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init(">=") + self.assertEqual(t.token().value, ">=") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("<") + self.assertEqual(t.token().value, "<") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init(">") + self.assertEqual(t.token().value, ">") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("&&") + self.assertEqual(t.token().value, "&&") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("||") + self.assertEqual(t.token().value, "||") + self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR) + + t.init("=") + self.assertEqual(t.token().value, "=") + self.assertEqual(t.token().type, TokenType.ASSIGNMENT) + + t.init("+") + self.assertEqual(t.token().value, "+") + self.assertEqual(t.token().type, TokenType.ADDITIVE_OPERATOR) + + t.init("-") + self.assertEqual(t.token().value, "-") + self.assertEqual(t.token().type, TokenType.ADDITIVE_OPERATOR) + + t.init("*") + self.assertEqual(t.token().value, "*") + self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR) + + t.init("/") + self.assertEqual(t.token().value, "/") + self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR) + + t.init("%") + self.assertEqual(t.token().value, "%") + self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR) + + t.init("(") + self.assertEqual(t.token().value, "(") + self.assertEqual(t.token().type, TokenType.LEFT_PAREN) + + t.init(")") + self.assertEqual(t.token().value, ")") + self.assertEqual(t.token().type, TokenType.RIGHT_PAREN) + + t.init("{") + self.assertEqual(t.token().value, "{") + self.assertEqual(t.token().type, TokenType.LEFT_BRACE) + + t.init("}") + self.assertEqual(t.token().value, "}") + self.assertEqual(t.token().type, TokenType.RIGHT_BRACE) + + def test_init(self): + t = Tokenizer() + script = "a + 9 * ( 3 - 1 ) * 3 + 10 / 2;" + t.init(script) + self.assertEqual(t.script, script) + self.assertEqual(len(t.tokens), 16) + self.assertEqual(t.get_prev(), None) + self.assertEqual(t.token().value, "a") + self.assertEqual(t.get_next().value, "+") + self.assertEqual(t.next().value, "+") + self.assertEqual(t.next().value, "9") + self.assertEqual(t.next().value, "*") + t.prev() + self.assertEqual(t.token().value, "9") + t.prev() + self.assertEqual(t.token().value, "+") + + script = "a + 9" + t.init(script) + self.assertEqual(t.token().type, TokenType.IDENTIFIER) + self.assertEqual(t.next().type, TokenType.ADDITIVE_OPERATOR) + self.assertEqual(t.next().type, TokenType.INT) + self.assertEqual(t.next(), None) + self.assertEqual(t._current_token_index, 3) + self.assertEqual(t.next(), None) + self.assertEqual(t.next(), None) + self.assertEqual(t._current_token_index, 3) + self.assertEqual(t.next(), None) + t.prev() + self.assertEqual(t.token().value, "9") + t.prev() + self.assertEqual(t.token().value, "+") + t.prev() + self.assertEqual(t.token().value, "a") + t.prev() + self.assertEqual(t.token().value, "a") \ No newline at end of file diff --git a/src/dotchain/runtime/tokenizer.py b/src/dotchain/runtime/tokenizer.py new file mode 100644 index 0000000..45235af --- /dev/null +++ b/src/dotchain/runtime/tokenizer.py @@ -0,0 +1,259 @@ +import re +from enum import Enum + +from attr import dataclass + +class TokenType(Enum): + NEW_LINE = 1 + SPACE = 2 + COMMENTS = 3 + LEFT_PAREN = 4 + RIGHT_PAREN = 5 + COMMA = 6 + LEFT_BRACE = 7 + RIGHT_BRACE = 8 + SEMICOLON = 9 + LET = 10 + RETURN = 11 + IF = 12 + ELSE = 13 + WHILE = 14 + FOR = 15 + FLOAT = 18 + INT = 19 + IDENTIFIER = 20 + LOGICAL_OPERATOR = 21 + NOT = 22 + ASSIGNMENT = 23 + MULTIPLICATIVE_OPERATOR = 24 + ADDITIVE_OPERATOR = 25 + STRING = 26 + ARROW = 27 + BOOL = 28 + BREAK = 29 + TYPE_DEFINITION = 30 + COLON = 31 + +specs = ( + (re.compile(r"^\n"),TokenType.NEW_LINE), + # Space: + (re.compile(r"^\s"),TokenType.SPACE), + # Comments: + (re.compile(r"^//.*"), TokenType.COMMENTS), + + # Symbols: + (re.compile(r"^\("), TokenType.LEFT_PAREN), + (re.compile(r"^\)"), TokenType.RIGHT_PAREN), + (re.compile(r"^\,"), TokenType.COMMA), + (re.compile(r"^\{"), TokenType.LEFT_BRACE), + (re.compile(r"^\}"), TokenType.RIGHT_BRACE), + (re.compile(r"^;"), TokenType.SEMICOLON), + (re.compile(r"^:"), TokenType.COLON), + (re.compile(r"^=>"), TokenType.ARROW), + + # Keywords: + (re.compile(r"^\blet\b"), TokenType.LET), + (re.compile(r"^\breturn\b"), TokenType.RETURN), + (re.compile(r"^\bif\b"), TokenType.IF), + (re.compile(r"^\belse\b"), TokenType.ELSE), + (re.compile(r"^\bwhile\b"), TokenType.WHILE), + (re.compile(r"^\bfor\b"), TokenType.FOR), + (re.compile(r"^\bbreak\b"), TokenType.BREAK), + + (re.compile(r"^\btrue\b"), TokenType.BOOL), + (re.compile(r"^\bfalse\b"), TokenType.BOOL), + + # Type definition: + (re.compile(r"^\bstring\b"), TokenType.TYPE_DEFINITION), + (re.compile(r"^\bint\b"), TokenType.TYPE_DEFINITION), + (re.compile(r"^\bfloat\b"), TokenType.TYPE_DEFINITION), + (re.compile(r"^\bbool\b"), TokenType.TYPE_DEFINITION), + (re.compile(r"^\bany\b"), TokenType.TYPE_DEFINITION), + + # Floats: + (re.compile(r"^[0-9]+\.[0-9]+"), TokenType.FLOAT), + + # Ints: + (re.compile(r"^[0-9]+"), TokenType.INT), + + # Identifiers: + (re.compile(r"^\w+"), TokenType.IDENTIFIER), + + + # Logical operators: + (re.compile(r"^&&"), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^\|\|"), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^=="), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^!="), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^<="), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^>="), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^<"), TokenType.LOGICAL_OPERATOR), + (re.compile(r"^>"), TokenType.LOGICAL_OPERATOR), + + (re.compile(r"^!"), TokenType.NOT), + + # Assignment: + (re.compile(r"^="), TokenType.ASSIGNMENT), + + # Math operators: +, -, *, /: + (re.compile(r"^[*/%]"), TokenType.MULTIPLICATIVE_OPERATOR), + (re.compile(r"^[+-]"), TokenType.ADDITIVE_OPERATOR), + + # Double-quoted strings + # TODO: escape character \" and + (re.compile(r"^\"[^\"]*\""), TokenType.STRING), +) + +@dataclass +class Token: + type: TokenType + value: str + row: int + col: int + col_end: int + cursor: int + + def __str__(self) -> str: + return f"Token({self.type}, {self.value}, row={self.row}, col={self.col}, col_end={self.col_end}, cursor={self.cursor})" + + +class Tokenizer: + + def __init__(self): + self._current_token = None + self.script = "" + self.cursor = 0 + self.col = 0 + self.row = 0 + self._current_token_index = 0 + self.tokens = list[Token]() + self.checkpoint = list[int]() + + def init(self, script: str): + self.checkpoint = list[int]() + self.tokens = list[Token]() + self._current_token_index = 0 + self._current_token = None + self.script = script + self.cursor = 0 + self.col = 0 + self.row = 0 + self._get_next_token() + while self._current_token is not None: + self.tokens.append(self._current_token) + self._get_next_token() + + def checkpoint_push(self): + self.checkpoint.append(self._current_token_index) + + def checkpoint_pop(self): + self._current_token_index = self.checkpoint.pop() + + def next(self): + if self._current_token_index < len(self.tokens): + self._current_token_index += 1 + return self.token() + + def next_token_type(self): + if self._current_token_index < len(self.tokens): + self._current_token_index += 1 + return self.tokenType() + + def prev(self): + if self._current_token_index > 0: + self._current_token_index -= 1 + return self.token() + + def get_prev(self): + if self._current_token_index == 0: + return None + return self.tokens[self._current_token_index - 1] + + def get_next(self): + if self._current_token_index >= len(self.tokens): + return None + return self.tokens[self._current_token_index + 1] + + def token(self): + if self._current_token_index >= len(self.tokens): + return None + return self.tokens[self._current_token_index] + + def tokenType(self): + if self._current_token_index >= len(self.tokens): + return None + return self.tokens[self._current_token_index].type + + + def _get_next_token(self): + if self._is_eof(): + self._current_token = None + return None + _string = self.script[self.cursor:] + for spec in specs: + tokenValue, offset = self.match(spec[0], _string) + if tokenValue == None: + continue + if spec[1] == TokenType.NEW_LINE: + self.row += 1 + self.col = 0 + return self._get_next_token() + if spec[1] == TokenType.COMMENTS: + return self._get_next_token() + if spec[1] == TokenType.SPACE: + self.col += offset + return self._get_next_token() + if spec[1] == None: + return self._get_next_token() + self._current_token = Token(spec[1],tokenValue, self.cursor, self.row, self.col, self.col + offset) + self.col += offset + return self.get_current_token() + raise Exception("Unknown token: " + _string[0]) + + def _is_eof(self): + return self.cursor == len(self.script) + + def has_more_tokens(self): + return self.cursor < len(self.script) + + def get_current_token(self): + return self._current_token + + def match(self, reg: re, _script): + matched = reg.search(_script) + if matched == None: + return None,0 + self.cursor = self.cursor + matched.span(0)[1] + return matched[0], matched.span(0)[1] + + def eat(self, value: str | TokenType): + if isinstance(value, str): + return self.eat_value(value) + if isinstance(value, TokenType): + return self.eat_token_type(value) + + def eat_value(self, value: str): + token = self.token() + if token is None: + raise Exception(f"Expected {value} but got None") + if token.value != value: + raise Exception(f"Expected {value} but got {token.value}") + self.next() + return token + + def eat_token_type(self,tokenType: TokenType): + token = self.token() + if token is None: + raise Exception(f"Expected {tokenType} but got None") + if token.type != tokenType: + raise Exception(f"Expected {tokenType} but got {token.type}") + self.next() + return token + + def type_is(self, tokenType: TokenType): + if self.token() is None: + return False + return self.token().type == tokenType + + def the_rest(self): + return self.tokens[self._current_token_index:] \ No newline at end of file From 41621b44f798065982db7dbf0f67a5f6d2a8c264 Mon Sep 17 00:00:00 2001 From: superobk Date: Mon, 8 Apr 2024 17:21:44 +0800 Subject: [PATCH 08/11] feat: audio chat --- main.py | 2 +- src/blackbox/audio_chat.py | 36 ++++++++++++++++++++++++++++++++ src/blackbox/blackbox_factory.py | 4 ++++ src/blackbox/tesou.py | 3 +-- 4 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 src/blackbox/audio_chat.py diff --git a/main.py b/main.py index 3193fdb..94bbe68 100644 --- a/main.py +++ b/main.py @@ -46,4 +46,4 @@ async def workflows(script: Annotated[str, Form()], request: Request=None): ast.exec(dsl_runtime) if __name__ == "__main__": - uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") + uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") diff --git a/src/blackbox/audio_chat.py b/src/blackbox/audio_chat.py new file mode 100644 index 0000000..6528a7f --- /dev/null +++ b/src/blackbox/audio_chat.py @@ -0,0 +1,36 @@ +from fastapi import Request, Response,status +from fastapi.responses import JSONResponse + +from .blackbox import Blackbox + +class AudioChat(Blackbox): + + def __init__(self, asr, gpt, tts): + self.asr = asr + self.gpt = gpt + self.tts = tts + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool : + data = args[0] + if isinstance(data, bytes): + return True + return False + + async def processing(self, *args, **kwargs): + data = args[0] + text = await self.asr(data) + # TODO: ID + text = self.gpt("123", " " + text) + audio = self.tts(text) + return audio + + async def fast_api_handler(self, request: Request) -> Response: + data = (await request.form()).get("audio") + if data is None: + return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST) + d = await data.read() + by = await self.processing(d) + return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 973b8a3..408de8a 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -1,3 +1,4 @@ +from .audio_chat import AudioChat from .sum import SUM from .sentiment import Sentiment from .tts import TTS @@ -19,6 +20,7 @@ class BlackboxFactory: self.audio_to_text = AudioToText() self.text_to_audio = TextToAudio() self.tesou = Tesou() + self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -40,4 +42,6 @@ class BlackboxFactory: return self.sum if blackbox_name == "tesou": return self.tesou + if blackbox_name == "audio_chat": + return self.audio_chat raise ValueError("Invalid blockbox type") \ No newline at end of file diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index 6b95d3f..ae547fc 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -17,8 +17,7 @@ class Tesou(Blackbox): # 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}] def processing(self, id, prompt) -> str: - url = 'http://120.196.116.194:48891/' - + url = 'http://120.196.116.194:48891/chat/' message = { "user_id": id, "prompt": prompt, From c104ea52b532aa210d6c0249d85ca29f3172e583 Mon Sep 17 00:00:00 2001 From: superobk Date: Tue, 9 Apr 2024 12:00:08 +0800 Subject: [PATCH 09/11] code updated --- main.py | 10 ++++++++++ src/blackbox/asr.py | 2 +- src/blackbox/audio_chat.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 94bbe68..2cb5207 100644 --- a/main.py +++ b/main.py @@ -8,9 +8,19 @@ from src.dotchain.runtime.tokenizer import Tokenizer from src.dotchain.runtime.runtime import Runtime from src.blackbox.blackbox_factory import BlackboxFactory import uvicorn +from fastapi.middleware.cors import CORSMiddleware app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + blackbox_factory = BlackboxFactory() @app.post("/") diff --git a/src/blackbox/asr.py b/src/blackbox/asr.py index 7ba88fb..9c691f3 100644 --- a/src/blackbox/asr.py +++ b/src/blackbox/asr.py @@ -38,4 +38,4 @@ class ASR(Blackbox): txt = await self.processing(d) except ValueError as e: return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) - return JSONResponse(content={"txt": txt}, status_code=status.HTTP_200_OK) \ No newline at end of file + return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK) \ No newline at end of file diff --git a/src/blackbox/audio_chat.py b/src/blackbox/audio_chat.py index 6528a7f..1ab156b 100644 --- a/src/blackbox/audio_chat.py +++ b/src/blackbox/audio_chat.py @@ -33,4 +33,4 @@ class AudioChat(Blackbox): return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST) d = await data.read() by = await self.processing(d) - return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file + return Response(content=by.read(), media_type="audio/x-wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file From bab055e7d60ea412545248be0ec1b842bf7e0344 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 10 Apr 2024 10:13:36 +0800 Subject: [PATCH 10/11] TTS & cuda processing updated --- cuda.py | 5 +++++ src/blackbox/tts.py | 3 +++ src/tts/tts_service.py | 6 +++--- src/tts/vits/monotonic_align/__init__.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 cuda.py diff --git a/cuda.py b/cuda.py new file mode 100644 index 0000000..159370c --- /dev/null +++ b/cuda.py @@ -0,0 +1,5 @@ +import torch + +print("Torch version:",torch.__version__) + +print("Is CUDA enabled?",torch.cuda.is_available()) \ No newline at end of file diff --git a/src/blackbox/tts.py b/src/blackbox/tts.py index aea74d6..669f273 100644 --- a/src/blackbox/tts.py +++ b/src/blackbox/tts.py @@ -1,4 +1,5 @@ import io +import time from ntpath import join from fastapi import Request, Response, status @@ -16,7 +17,9 @@ class TTS(Blackbox): def processing(self, *args, **kwargs) -> io.BytesIO: text = args[0] + current_time = time.time() audio = self.tts_service.read(text) + print("#### TTS Service consume : ", (time.time()-current_time)) return audio def valid(self, *args, **kwargs) -> bool: diff --git a/src/tts/tts_service.py b/src/tts/tts_service.py index 938df56..e140ff4 100644 --- a/src/tts/tts_service.py +++ b/src/tts/tts_service.py @@ -53,7 +53,7 @@ class TTService(): len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, - **self.hps.model).cpu() + **self.hps.model).cuda() _ = self.net_g.eval() _ = utils.load_checkpoint(cfg["model"], self.net_g, None) @@ -69,8 +69,8 @@ class TTService(): stn_tst = self.get_text(text, self.hps) with torch.no_grad(): - x_tst = stn_tst.cpu().unsqueeze(0) - x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu() + x_tst = stn_tst.cuda().unsqueeze(0) + x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() # tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed) audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][ 0, 0].data.cpu().float().numpy() diff --git a/src/tts/vits/monotonic_align/__init__.py b/src/tts/vits/monotonic_align/__init__.py index 3d7009c..9293c5a 100644 --- a/src/tts/vits/monotonic_align/__init__.py +++ b/src/tts/vits/monotonic_align/__init__.py @@ -1,6 +1,6 @@ import numpy as np import torch -from .monotonic_align.core import maximum_path_c +from .core import maximum_path_c def maximum_path(neg_cent, mask): From 005acc087405c36031045fd9f587ef060851c770 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 10 Apr 2024 10:29:41 +0800 Subject: [PATCH 11/11] feat: tts enable gpu --- src/blackbox/tesou.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index ae547fc..723b811 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -31,7 +31,7 @@ class Tesou(Blackbox): except: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) user_id = data.get("user_id") - user_prompt = data.get("prompt") + user_prompt = data.get("prompt") if user_prompt is None: return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK) \ No newline at end of file