diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index d42a76e..aeea922 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -12,42 +12,36 @@ from .text_and_image import TextAndImage from injector import Injector class BlackboxFactory: - + models = {} + def __init__(self) -> None: injector = Injector() - self.tts = TTS() - self.asr = ASR(".env.yaml") - self.sentiment = Sentiment() - self.audio_to_text = AudioToText() - self.text_to_audio = TextToAudio() - self.tesou = injector.get(Tesou) - self.fastchat = Fastchat() - self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) - self.g2e = G2E() - self.text_and_image = TextAndImage() + self.models["audio_to_text"] = AudioToText() + self.models["text_to_audio"] = TextToAudio() + self.models["asr"] = ASR(".env.yaml") + self.models["tts"] = TTS() + self.models["sentiment_engine"] = Sentiment() + self.models["tesou"] = injector.get(Tesou) + self.models["fastchat"] = Fastchat() + self.models["audio_chat"] = AudioChat(self.models["asr"], self.models["tesou"], self.models["tts"]) + self.models["g2e"] = G2E() + self.models["text_and_image"] = TextAndImage() + # self.tts = TTS() + # self.asr = ASR(".env.yaml") + # self.sentiment = Sentiment() + # self.audio_to_text = AudioToText() + # self.text_to_audio = TextToAudio() + # self.tesou = injector.get(Tesou) + # self.fastchat = Fastchat() + # self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) + # self.g2e = G2E() + # self.text_and_image = TextAndImage() def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) def create_blackbox(self, blackbox_name: str) -> Blackbox: - if blackbox_name == "audio_to_text": - return self.audio_to_text - if blackbox_name == "text_to_audio": - return self.text_to_audio - if blackbox_name == "asr": - return self.asr - if blackbox_name == "tts": - return self.tts - if blackbox_name == "sentiment_engine": - return self.sentiment - if blackbox_name == "tesou": - return self.tesou - if blackbox_name == "fastchat": - return self.fastchat - if blackbox_name == "audio_chat": - return self.audio_chat - if blackbox_name == "g2e": - return self.g2e - if blackbox_name == 'text_and_image': - return self.text_and_image - raise ValueError("Invalid blockbox type") \ No newline at end of file + model = self.models.get(blackbox_name) + if model is None: + raise ValueError("Invalid blockbox type") + return model \ No newline at end of file