From f76231485c5f8013aa94a6e779aa8b5114aa5306 Mon Sep 17 00:00:00 2001 From: Dan Chen Date: Fri, 26 Apr 2024 17:53:28 +0800 Subject: [PATCH] feat: inject --- src/blackbox/asr.py | 4 +-- src/blackbox/audio_chat.py | 8 +++++- src/blackbox/blackbox_factory.py | 48 ++++++++++++++++---------------- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/blackbox/asr.py b/src/blackbox/asr.py index 9c691f3..ad7ca81 100644 --- a/src/blackbox/asr.py +++ b/src/blackbox/asr.py @@ -10,8 +10,8 @@ from .blackbox import Blackbox class ASR(Blackbox): - def __init__(self, *args, **kwargs) -> None: - config = read_yaml(args[0]) + def __init__(self, path = ".env.yaml") -> None: + config = read_yaml(path) self.paraformer = RapidParaformer(config) def __call__(self, *args, **kwargs): diff --git a/src/blackbox/audio_chat.py b/src/blackbox/audio_chat.py index 1ab156b..1131c22 100644 --- a/src/blackbox/audio_chat.py +++ b/src/blackbox/audio_chat.py @@ -1,11 +1,17 @@ from fastapi import Request, Response,status from fastapi.responses import JSONResponse +from injector import inject + +from blackbox.asr import ASR +from blackbox.tesou import Tesou +from blackbox.tts import TTS from .blackbox import Blackbox class AudioChat(Blackbox): - def __init__(self, asr, gpt, tts): + @inject + def __init__(self, asr: ASR, gpt: Tesou, tts: TTS): self.asr = asr self.gpt = gpt self.tts = tts diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index aeea922..d6ac563 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -9,33 +9,33 @@ from .tesou import Tesou from .fastchat import Fastchat from .g2e import G2E from .text_and_image import TextAndImage -from injector import Injector +from injector import Injector, inject class BlackboxFactory: models = {} - - def __init__(self) -> None: - injector = Injector() - self.models["audio_to_text"] = AudioToText() - self.models["text_to_audio"] = TextToAudio() - self.models["asr"] = ASR(".env.yaml") - self.models["tts"] = TTS() - self.models["sentiment_engine"] = Sentiment() - self.models["tesou"] = injector.get(Tesou) - self.models["fastchat"] = Fastchat() - self.models["audio_chat"] = AudioChat(self.models["asr"], self.models["tesou"], self.models["tts"]) - self.models["g2e"] = G2E() - self.models["text_and_image"] = TextAndImage() - # self.tts = TTS() - # self.asr = ASR(".env.yaml") - # self.sentiment = Sentiment() - # self.audio_to_text = AudioToText() - # self.text_to_audio = TextToAudio() - # self.tesou = injector.get(Tesou) - # self.fastchat = Fastchat() - # self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) - # self.g2e = G2E() - # self.text_and_image = TextAndImage() + + @inject + def __init__(self, + audio_to_text: AudioToText, + text_to_audio: TextToAudio, + asr: ASR, + tts: TTS, + sentiment_engine: Sentiment, + tesou: Tesou, + fastchat: Fastchat, + audio_chat: AudioChat, + g2e: G2E, + text_and_image:TextAndImage ) -> None: + self.models["audio_to_text"] = audio_to_text + self.models["text_to_audio"] = text_to_audio + self.models["asr"] = asr + self.models["tts"] = tts + self.models["sentiment_engine"] = sentiment_engine + self.models["tesou"] = tesou + self.models["fastchat"] = fastchat + self.models["audio_chat"] = audio_chat + self.models["g2e"] = g2e + self.models["text_and_image"] = text_and_image def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs)