feat: asr support base64 input

This commit is contained in:
0Xiao0
2024-11-15 06:16:39 +00:00
parent 8fe010bbbe
commit 902c4d6073

View File

@ -17,7 +17,7 @@ import tempfile
import json import json
import os import os
import opencc import opencc
import base64
from ..configuration import SenseVoiceConf from ..configuration import SenseVoiceConf
@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
converter = opencc.OpenCC('s2t.json') converter = opencc.OpenCC('s2t.json')
def is_base64(value) -> bool:
try:
base64.b64decode(base64.b64decode(value)) == value.encode()
return True
except Exception:
return False
@singleton @singleton
class ASR(Blackbox): class ASR(Blackbox):
mode: str mode: str
@ -43,7 +49,7 @@ class ASR(Blackbox):
config = read_yaml(".env.yaml") config = read_yaml(".env.yaml")
self.paraformer = RapidParaformer(config) self.paraformer = RapidParaformer(config)
model_dir = "/Workspace/Models/SenseVoice/SenseVoiceSmall" model_dir = "/model/Voice/SenseVoice/SenseVoiceSmall"
self.speed = sensevoice_config.speed self.speed = sensevoice_config.speed
self.device = sensevoice_config.device self.device = sensevoice_config.device
@ -123,8 +129,23 @@ class ASR(Blackbox):
return False return False
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
data = (await request.form()).get("audio")
setting: dict = (await request.form()).get("settings") try:
content_type = request.headers['content-type']
if content_type == 'application/json':
data = await request.json()
setting: dict = (await request.json()).get("settings")
else:
data = await request.form()
setting: dict = (await request.form()).get("settings")
except Exception as e:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
data = data.get("audio")
print(f'type: {type(data)}')
print('='*90)
print(data)
print('='*90)
if isinstance(setting, str): if isinstance(setting, str):
try: try:
@ -135,7 +156,15 @@ class ASR(Blackbox):
if data is None: if data is None:
# self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr") # self.logger.warn("asr bag request","type", "fast_api_handler", "api", "asr")
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
d = await data.read()
# 如果数据是Base64编码的字符串需要先解码为字节数据
if isinstance(data, str):
try:
d = base64.b64decode(data) # Base64解码为字节数据
except base64.binascii.Error:
return JSONResponse(content={"error": "Invalid base64 encoding"}, status_code=status.HTTP_400_BAD_REQUEST)
else:
d = await data.read()
try: try:
txt = await self.processing(d, settings=setting) txt = await self.processing(d, settings=setting)
except ValueError as e: except ValueError as e: