diff --git a/main.py b/main.py index 3193fdb..94bbe68 100644 --- a/main.py +++ b/main.py @@ -46,4 +46,4 @@ async def workflows(script: Annotated[str, Form()], request: Request=None): ast.exec(dsl_runtime) if __name__ == "__main__": - uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") + uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") diff --git a/src/blackbox/audio_chat.py b/src/blackbox/audio_chat.py new file mode 100644 index 0000000..6528a7f --- /dev/null +++ b/src/blackbox/audio_chat.py @@ -0,0 +1,36 @@ +from fastapi import Request, Response,status +from fastapi.responses import JSONResponse + +from .blackbox import Blackbox + +class AudioChat(Blackbox): + + def __init__(self, asr, gpt, tts): + self.asr = asr + self.gpt = gpt + self.tts = tts + + def __call__(self, *args, **kwargs): + return self.processing(*args, **kwargs) + + def valid(self, *args, **kwargs) -> bool : + data = args[0] + if isinstance(data, bytes): + return True + return False + + async def processing(self, *args, **kwargs): + data = args[0] + text = await self.asr(data) + # TODO: ID + text = self.gpt("123", " " + text) + audio = self.tts(text) + return audio + + async def fast_api_handler(self, request: Request) -> Response: + data = (await request.form()).get("audio") + if data is None: + return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST) + d = await data.read() + by = await self.processing(d) + return Response(content=by.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"}) \ No newline at end of file diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 973b8a3..408de8a 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -1,3 +1,4 @@ +from .audio_chat import AudioChat from .sum import SUM from .sentiment import Sentiment from .tts import TTS @@ -19,6 +20,7 @@ class BlackboxFactory: self.audio_to_text = AudioToText() self.text_to_audio = TextToAudio() self.tesou = Tesou() + self.audio_chat = AudioChat(self.asr, self.tesou, self.tts) def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -40,4 +42,6 @@ class BlackboxFactory: return self.sum if blackbox_name == "tesou": return self.tesou + if blackbox_name == "audio_chat": + return self.audio_chat raise ValueError("Invalid blockbox type") \ No newline at end of file diff --git a/src/blackbox/tesou.py b/src/blackbox/tesou.py index 6b95d3f..ae547fc 100755 --- a/src/blackbox/tesou.py +++ b/src/blackbox/tesou.py @@ -17,8 +17,7 @@ class Tesou(Blackbox): # 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}] def processing(self, id, prompt) -> str: - url = 'http://120.196.116.194:48891/' - + url = 'http://120.196.116.194:48891/chat/' message = { "user_id": id, "prompt": prompt,