Update blackbox_factory.py

This commit is contained in:
Limbo
2024-05-13 14:43:20 +08:00
committed by GitHub
parent d947c917e5
commit 8d2a4d0673

View File

@ -3,6 +3,7 @@ from .sentiment import Sentiment
from .tts import TTS from .tts import TTS
from .asr import ASR from .asr import ASR
from .audio_to_text import AudioToText from .audio_to_text import AudioToText
from .emotion import Emotion
from .blackbox import Blackbox from .blackbox import Blackbox
from .text_to_audio import TextToAudio from .text_to_audio import TextToAudio
from .tesou import Tesou from .tesou import Tesou
@ -26,6 +27,7 @@ class BlackboxFactory:
asr: ASR, asr: ASR,
tts: TTS, tts: TTS,
sentiment_engine: Sentiment, sentiment_engine: Sentiment,
emotion: Emotion,
tesou: Tesou, tesou: Tesou,
fastchat: Fastchat, fastchat: Fastchat,
audio_chat: AudioChat, audio_chat: AudioChat,
@ -41,6 +43,7 @@ class BlackboxFactory:
self.models["tts"] = tts self.models["tts"] = tts
self.models["sentiment_engine"] = sentiment_engine self.models["sentiment_engine"] = sentiment_engine
self.models["tesou"] = tesou self.models["tesou"] = tesou
self.models["emotion"] = emotion
self.models["fastchat"] = fastchat self.models["fastchat"] = fastchat
self.models["audio_chat"] = audio_chat self.models["audio_chat"] = audio_chat
self.models["g2e"] = g2e self.models["g2e"] = g2e
@ -57,4 +60,4 @@ class BlackboxFactory:
model = self.models.get(blackbox_name) model = self.models.get(blackbox_name)
if model is None: if model is None:
raise ValueError("Invalid blockbox type") raise ValueError("Invalid blockbox type")
return model return model