From f14c36d77ab6368a592f572e21aaf50f391c6ff9 Mon Sep 17 00:00:00 2001 From: superobk Date: Wed, 20 Mar 2024 09:42:08 +0800 Subject: [PATCH] feat: sentiment engine --- sentiment_engine/sentiment_engine.py | 29 ++++++++++++++++++++++++++ src/blackbox/blackbox_factory.py | 4 ++++ src/blackbox/sentiment.py | 31 ++++++++++++++++++++++++++++ tts/tts_service.py | 4 ---- 4 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 sentiment_engine/sentiment_engine.py create mode 100644 src/blackbox/sentiment.py diff --git a/sentiment_engine/sentiment_engine.py b/sentiment_engine/sentiment_engine.py new file mode 100644 index 0000000..acb93d3 --- /dev/null +++ b/sentiment_engine/sentiment_engine.py @@ -0,0 +1,29 @@ +import logging + +import onnxruntime +from transformers import BertTokenizer +import numpy as np + + +class SentimentEngine(): + + def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"): + logging.info('Initializing Sentiment Engine...') + onnx_model_path = model_path + self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) + self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') + + def infer(self, text): + tokens = self.tokenizer(text, return_tensors="np") + input_dict = { + "input_ids": tokens["input_ids"], + "attention_mask": tokens["attention_mask"], + } + # Convert input_ids and attention_mask to int64 + input_dict["input_ids"] = input_dict["input_ids"].astype(np.int64) + input_dict["attention_mask"] = input_dict["attention_mask"].astype(np.int64) + logits = self.ort_session.run(["logits"], input_dict)[0] + probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) + predicted = np.argmax(probabilities, axis=1)[0] + logging.info(f'Sentiment Engine Infer: {predicted}') + return predicted diff --git a/src/blackbox/blackbox_factory.py b/src/blackbox/blackbox_factory.py index 31a1e7b..4262b47 100644 --- a/src/blackbox/blackbox_factory.py +++ b/src/blackbox/blackbox_factory.py @@ -1,3 +1,4 @@ +from .sentiment import Sentiment from .tts import TTS from ..asr.asr import ASR from .audio_to_text import AudioToText @@ -11,6 +12,7 @@ class BlackboxFactory: def __init__(self) -> None: self.tts = TTS() self.asr = ASR("./.env.yaml") + self.sentiment = Sentiment() def create_blackbox(self, blackbox_name: str, blackbox_config: dict) -> Blackbox: if blackbox_name == "audio_to_text": @@ -23,4 +25,6 @@ class BlackboxFactory: return self.asr if blackbox_name == "tts": return self.tts + if blackbox_name == "sentiment_engine": + return self.sentiment raise ValueError("Invalid blockbox type") \ No newline at end of file diff --git a/src/blackbox/sentiment.py b/src/blackbox/sentiment.py new file mode 100644 index 0000000..4f12156 --- /dev/null +++ b/src/blackbox/sentiment.py @@ -0,0 +1,31 @@ +from typing import Any, Coroutine + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse + +from sentiment_engine.sentiment_engine import SentimentEngine +from .blackbox import Blackbox + + +class Sentiment(Blackbox): + + def __init__(self) -> None: + self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx') + + def valid(self, data: any) -> bool: + return isinstance(data, str) + + def processing(self, text: any) -> int: + return int(self.engine.infer(text)) + + async def fast_api_handler(self, request) -> Response: + try: + data = await request.json() + except: + return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) + text = data.get("text") + if text is None: + return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST) + sentiment = self.processing(text) + return JSONResponse(content={"sentiment": sentiment }, status_code=status.HTTP_200_OK) + \ No newline at end of file diff --git a/tts/tts_service.py b/tts/tts_service.py index ea06b26..0011315 100644 --- a/tts/tts_service.py +++ b/tts/tts_service.py @@ -22,10 +22,6 @@ logging.getLogger().setLevel(logging.INFO) logging.basicConfig(level=logging.INFO) -from pydub import AudioSegment - - - class TTService(): def __init__(self, cfg, model, char, speed):