This commit is contained in:
superobk
2024-03-27 15:06:30 +08:00
parent ec10081276
commit 0549a033e1
3 changed files with 11 additions and 10 deletions

View File

@ -11,10 +11,11 @@ blackbox_factory = BlackboxFactory()
@app.post("/") @app.post("/")
async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None): async def blackbox(blackbox_name: Union[str, None] = None, request: Request = None):
print(blackbox_name)
if not blackbox_name: if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
try: try:
box = blackbox_factory.create_blackbox(blackbox_name, {}) box = blackbox_factory.create_blackbox(blackbox_name)
except ValueError: except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request) return await box.fast_api_handler(request)

View File

@ -12,18 +12,19 @@ class BlackboxFactory:
def __init__(self) -> None: def __init__(self) -> None:
self.tts = TTS() self.tts = TTS()
self.asr = ASR("./.env.yaml") #self.asr = ASR("./.env.yaml")
self.sentiment = Sentiment() #self.sentiment = Sentiment()
self.sum = SUM() #self.sum = SUM()
self.calculator = Calculator() #self.calculator = Calculator()
self.audio_to_text = AudioToText() #self.audio_to_text = AudioToText()
self.text_to_audio = TextToAudio() #self.text_to_audio = TextToAudio()
self.tesou = Tesou() #self.tesou = Tesou()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
def create_blackbox(self, blackbox_name: str) -> Blackbox: def create_blackbox(self, blackbox_name: str) -> Blackbox:
return self.tts
if blackbox_name == "audio_to_text": if blackbox_name == "audio_to_text":
return self.audio_to_text return self.audio_to_text
if blackbox_name == "text_to_audio": if blackbox_name == "text_to_audio":

View File

@ -14,8 +14,7 @@ class TTS(Blackbox):
'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2] 'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2]
} }
self.tts_service = TTService(*config['catmaid']) self.tts_service = TTService(*config['catmaid'])
super().__init__(config)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)