feat: configuration

This commit is contained in:
superobk
2024-04-26 10:51:10 +08:00
parent f8f75e11ff
commit d493bd221f
4 changed files with 99 additions and 3 deletions

View File

@ -9,16 +9,18 @@ from .tesou import Tesou
from .fastchat import Fastchat
from .g2e import G2E
from .text_and_image import TextAndImage
from injector import Injector
class BlackboxFactory:
def __init__(self) -> None:
injector = Injector()
self.tts = TTS()
self.asr = ASR(".env.yaml")
self.sentiment = Sentiment()
self.audio_to_text = AudioToText()
self.text_to_audio = TextToAudio()
self.tesou = Tesou()
self.tesou = injector.get(Tesou)
self.fastchat = Fastchat()
self.audio_chat = AudioChat(self.asr, self.tesou, self.tts)
self.g2e = G2E()

View File

@ -2,11 +2,18 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from injector import inject
from ..configuration import TesouConf
from .blackbox import Blackbox
import requests
class Tesou(Blackbox):
url: str
@inject
def __init__(self, tesou_config: TesouConf):
self.url = tesou_config.url
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
@ -17,12 +24,11 @@ class Tesou(Blackbox):
# 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}]
def processing(self, id, prompt) -> str:
url = 'http://120.196.116.194:48891/chat/'
message = {
"user_id": id,
"prompt": prompt,
}
response = requests.post(url, json=message)
response = requests.post(self.url, json=message)
return response.json()
async def fast_api_handler(self, request: Request) -> Response:

47
src/configuration.py Normal file
View File

@ -0,0 +1,47 @@
from dataclasses import dataclass
from injector import Injector, inject
import yaml
import sys
class Configuration():
@inject
def __init__(self) -> None:
config_file_path = ""
try:
config_file_path = sys.argv[1]
except:
config_file_path = ".env.yaml"
with open(config_file_path) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
self.cfg = cfg
def getDict(self):
return self.cfg
"""
# yaml 檔中的路徑 get("aaa.bbb.ccc")
aaa:
bbb:
ccc: "hello world"
"""
def get(self, path: str | list[str], cfg: dict = None):
if isinstance(path, str):
if cfg is None:
cfg = self.cfg
return self.get(path.split("."), cfg)
lenght = len(path)
if lenght == 0 or not isinstance(cfg, dict):
return None
if lenght == 1:
return cfg.get(path[0])
return self.get(path[1:], cfg.get(path[0]))
class TesouConf():
url: str
@inject
def __init__(self,config: Configuration) -> None:
self.url = config.get("tesou.url")