From d493bd221f61949735fe7c3433618331f639505d Mon Sep 17 00:00:00 2001 From: superobk Date: Fri, 26 Apr 2024 10:51:10 +0800 Subject: [PATCH] feat: configuration --- README.md | 41 ++++++++++++++++++++++++++++ src/blackbox/blackbox_factory.py | 4 ++- src/blackbox/tesou.py | 10 +++++-- src/configuration.py | 47 ++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 src/configuration.py diff --git a/README.md b/README.md index a798594..c0dcd9d 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,49 @@ | python | uvicorn | https://www.uvicorn.org/ | pip install "uvicorn[standard]" | | python | SpeechRecognition | https://pypi.org/project/SpeechRecognition/ | pip install SpeechRecognition | | python | gtts | https://pypi.org/project/gTTS/ | pip install gTTS | +| python | PyYAML | https://pypi.org/project/PyYAML/ | pip install PyYAML | +| python | injector | https://github.com/python-injector/injector | pip install injector | + ## Start Dev rh ```bash uvicorn main:app --reload ``` + +## Configuration +```yaml +tesou: + url: http://120.196.116.194:48891/chat/ + +TokenIDConverter: + token_path: src/asr/resources/models/token_list.pkl + unk_symbol: + +CharTokenizer: + symbol_value: + space_symbol: + remove_non_linguistic_symbols: false + +WavFrontend: + cmvn_file: src/asr/resources/models/am.mvn + frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + filter_length_max: -.inf + dither: 0.0 + +Model: + model_path: src/asr/resources/models/model.onnx + use_cuda: false + CUDAExecutionProvider: + device_id: 0 + arena_extend_strategy: kNextPowerOfTwo + cudnn_conv_algo_search: EXHAUSTIVE + do_copy_in_default_stream: true + batch_size: 3 +``` diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 49d931d..d42a76e 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -9,16 +9,18 @@ from .tesou import Tesou from .fastchat import Fastchat from .g2e import G2E from .text_and_image import TextAndImage +from injector import Injector class BlackboxFactory: def __init__(self) -> None: + injector = Injector() self.tts = TTS() self.asr = ASR(".env.yaml") self.sentiment = Sentiment() self.audio_to_text = AudioToText() self.text_to_audio = TextToAudio() - self.tesou = Tesou() + self.tesou = injector.get(Tesou) self.fastchat = Fastchat() self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) self.g2e = G2E() diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index 723b811..a262c21 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -2,11 +2,18 @@ from typing import Any, Coroutine from fastapi import Request, Response, status from fastapi.responses import JSONResponse +from injector import inject +from ..configuration import TesouConf from .blackbox import Blackbox import requests class Tesou(Blackbox): + url: str + + @inject + def __init__(self, tesou_config: TesouConf): + self.url = tesou_config.url def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -17,12 +24,11 @@ class Tesou(Blackbox): # 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}] def processing(self, id, prompt) -> str: - url = 'http://120.196.116.194:48891/chat/' message = { "user_id": id, "prompt": prompt, } - response = requests.post(url, json=message) + response = requests.post(self.url, json=message) return response.json() async def fast_api_handler(self, request: Request) -> Response: diff --git a/src/configuration.py b/src/configuration.py new file mode 100644 index 0000000..8c3e6fb --- /dev/null +++ b/src/configuration.py @@ -0,0 +1,47 @@ + +from dataclasses import dataclass +from injector import Injector, inject +import yaml +import sys + +class Configuration(): + + @inject + def __init__(self) -> None: + config_file_path = "" + try: + config_file_path = sys.argv[1] + except: + config_file_path = ".env.yaml" + with open(config_file_path) as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + self.cfg = cfg + + def getDict(self): + return self.cfg + + """ + # yaml 檔中的路徑 get("aaa.bbb.ccc") + aaa: + bbb: + ccc: "hello world" + """ + def get(self, path: str | list[str], cfg: dict = None): + if isinstance(path, str): + if cfg is None: + cfg = self.cfg + return self.get(path.split("."), cfg) + lenght = len(path) + if lenght == 0 or not isinstance(cfg, dict): + return None + if lenght == 1: + return cfg.get(path[0]) + return self.get(path[1:], cfg.get(path[0])) + + +class TesouConf(): + url: str + + @inject + def __init__(self,config: Configuration) -> None: + self.url = config.get("tesou.url")