Files
jarvis-models/src/blackbox/blackbox_factory.py
2024-04-30 15:44:14 +08:00

47 lines
1.5 KiB
Python

from .audio_chat import AudioChat
from .sentiment import Sentiment
from .tts import TTS
from .asr import ASR
from .audio_to_text import AudioToText
from .blackbox import Blackbox
from .text_to_audio import TextToAudio
from .tesou import Tesou
from .fastchat import Fastchat
from .g2e import G2E
from .text_and_image import TextAndImage
from injector import inject
class BlackboxFactory:
models = {}
@inject
def __init__(self,
audio_to_text: AudioToText,
text_to_audio: TextToAudio,
asr: ASR,
tts: TTS,
sentiment_engine: Sentiment,
tesou: Tesou,
fastchat: Fastchat,
audio_chat: AudioChat,
g2e: G2E,
text_and_image:TextAndImage ) -> None:
self.models["audio_to_text"] = audio_to_text
self.models["text_to_audio"] = text_to_audio
self.models["asr"] = asr
self.models["tts"] = tts
self.models["sentiment_engine"] = sentiment_engine
self.models["tesou"] = tesou
self.models["fastchat"] = fastchat
self.models["audio_chat"] = audio_chat
self.models["g2e"] = g2e
self.models["text_and_image"] = text_and_image
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def create_blackbox(self, blackbox_name: str) -> Blackbox:
model = self.models.get(blackbox_name)
if model is None:
raise ValueError("Invalid blockbox type")
return model