add blackbox fastchat

This commit is contained in:
ACBBZ
2024-04-10 09:31:16 +00:00
107 changed files with 29145 additions and 50 deletions

3
.gitignore vendored
View File

@ -163,4 +163,5 @@ cython_debug/
.DS_Store .DS_Store
playground.py playground.py
.env* .env*
models models
.idea/

View File

@ -13,4 +13,4 @@
Dev rh Dev rh
```bash ```bash
uvicorn main:app --reload uvicorn main:app --reload
``` ```

5
cuda.py Normal file
View File

@ -0,0 +1,5 @@
import torch
print("Torch version:",torch.__version__)
print("Is CUDA enabled?",torch.cuda.is_available())

45
main.py
View File

@ -1,12 +1,26 @@
from typing import Union from typing import Annotated, Union
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status, Form
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.dotchain.runtime.interpreter import program_parser
from src.dotchain.runtime.tokenizer import Tokenizer
from src.dotchain.runtime.runtime import Runtime
from src.blackbox.blackbox_factory import BlackboxFactory from src.blackbox.blackbox_factory import BlackboxFactory
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI() app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
blackbox_factory = BlackboxFactory() blackbox_factory = BlackboxFactory()
@app.post("/") @app.post("/")
@ -14,11 +28,32 @@ async def blackbox(blackbox_name: Union[str, None] = None, request: Request = No
if not blackbox_name: if not blackbox_name:
return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "blackbox_name is required"}, status_code=status.HTTP_400_BAD_REQUEST)
try: try:
box = blackbox_factory.create_blackbox(blackbox_name, {}) box = blackbox_factory.create_blackbox(blackbox_name)
except ValueError: except ValueError:
return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST) return await JSONResponse(content={"error": "value error"}, status_code=status.HTTP_400_BAD_REQUEST)
return await box.fast_api_handler(request) return await box.fast_api_handler(request)
def read_form_image(request: Request):
async def inner(field: str):
print(field)
return "image"
return inner
def read_form_text(request: Request):
def inner(field: str):
print(field)
return "text"
return inner
@app.post("/workflows") @app.post("/workflows")
async def workflows(reqest: Request): async def workflows(script: Annotated[str, Form()], request: Request=None):
print("workflows") dsl_runtime = Runtime(exteral_fun={"print": print,
'read_form_image': read_form_image(request),
"read_form_text": read_form_text(request)})
t = Tokenizer()
t.init(script)
ast = program_parser(t)
ast.exec(dsl_runtime)
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info")

View File

@ -4,16 +4,18 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from .rapid_paraformer.utils import read_yaml from ..asr.rapid_paraformer.utils import read_yaml
from .rapid_paraformer import RapidParaformer from ..asr.rapid_paraformer import RapidParaformer
from ..blackbox.blackbox import Blackbox from .blackbox import Blackbox
class ASR(Blackbox): class ASR(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
config = read_yaml(args[0]) config = read_yaml(args[0])
self.paraformer = RapidParaformer(config) self.paraformer = RapidParaformer(config)
super().__init__(config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
async def processing(self, *args, **kwargs): async def processing(self, *args, **kwargs):
data = args[0] data = args[0]
@ -36,4 +38,4 @@ class ASR(Blackbox):
txt = await self.processing(d) txt = await self.processing(d)
except ValueError as e: except ValueError as e:
return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": str(e)}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(content={"txt": txt}, status_code=status.HTTP_200_OK) return JSONResponse(content={"text": txt}, status_code=status.HTTP_200_OK)

View File

@ -0,0 +1,36 @@
from fastapi import Request, Response,status
from fastapi.responses import JSONResponse
from .blackbox import Blackbox
class AudioChat(Blackbox):
def __init__(self, asr, gpt, tts):
self.asr = asr
self.gpt = gpt
self.tts = tts
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool :
data = args[0]
if isinstance(data, bytes):
return True
return False
async def processing(self, *args, **kwargs):
data = args[0]
text = await self.asr(data)
# TODO: ID
text = self.gpt("123", " " + text)
audio = self.tts(text)
return audio
async def fast_api_handler(self, request: Request) -> Response:
data = (await request.form()).get("audio")
if data is None:
return JSONResponse(content={"error": "data is required"}, status_code=status.HTTP_400_BAD_REQUEST)
d = await data.read()
by = await self.processing(d)
return Response(content=by.read(), media_type="audio/x-wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})

View File

@ -1,7 +1,8 @@
from .audio_chat import AudioChat
from .sum import SUM from .sum import SUM
from .sentiment import Sentiment from .sentiment import Sentiment
from .tts import TTS from .tts import TTS
from ..asr.asr import ASR from .asr import ASR
from .audio_to_text import AudioToText from .audio_to_text import AudioToText
from .blackbox import Blackbox from .blackbox import Blackbox
from .calculator import Calculator from .calculator import Calculator
@ -13,14 +14,18 @@ class BlackboxFactor:
def __init__(self) -> None: def __init__(self) -> None:
self.tts = TTS() self.tts = TTS()
self.asr = ASR("./.env.yaml") self.asr = ASR(".env.yaml")
self.sentiment = Sentiment() self.sentiment = Sentiment()
self.sum = SUM() self.sum = SUM()
self.calculator = Calculator() self.calculator = Calculator()
self.audio_to_text = AudioToText() self.audio_to_text = AudioToText()
self.text_to_audio = TextToAudio() self.text_to_audio = TextToAudio()
self.tesou = Tesou() self.tesou = Tesou()
<<<<<<< HEAD
self.fastchat = Fastchat() self.fastchat = Fastchat()
=======
self.audio_chat = AudioChat(self.asr, self.tesou, self.tts)
>>>>>>> refs/remotes/origin/main
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -42,6 +47,11 @@ class BlackboxFactor:
return self.sum return self.sum
if blackbox_name == "tesou": if blackbox_name == "tesou":
return self.tesou return self.tesou
<<<<<<< HEAD
if blackbox_name == "fastchat": if blackbox_name == "fastchat":
return self.fastchat return self.fastchat
=======
if blackbox_name == "audio_chat":
return self.audio_chat
>>>>>>> refs/remotes/origin/main
raise ValueError("Invalid blockbox type") raise ValueError("Invalid blockbox type")

View File

@ -3,14 +3,14 @@ from typing import Any, Coroutine
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sentiment_engine.sentiment_engine import SentimentEngine from ..sentiment_engine.sentiment_engine import SentimentEngine
from .blackbox import Blackbox from .blackbox import Blackbox
class Sentiment(Blackbox): class Sentiment(Blackbox):
def __init__(self) -> None: def __init__(self) -> None:
self.engine = SentimentEngine('resources/sentiment_engine/models/paimon_sentiment.onnx') self.engine = SentimentEngine()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)

View File

@ -17,23 +17,21 @@ class Tesou(Blackbox):
# 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}] # 用户输入的数据格式为:[{"id": "123", "prompt": "叉烧饭,帮我查询叉烧饭的介绍"}]
def processing(self, id, prompt) -> str: def processing(self, id, prompt) -> str:
url = 'http://120.196.116.194:48891/' url = 'http://120.196.116.194:48891/chat/'
message = { message = {
"user_id": id, "user_id": id,
"prompt": prompt, "prompt": prompt,
} }
response = requests.post(url, json=message) response = requests.post(url, json=message)
return response return response.json()
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
try: try:
data = await request.json() data = await request.json()
except: except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_id = data.get("id") user_id = data.get("user_id")
user_prompt = data.get("prompt") user_prompt = data.get("prompt")
if user_prompt is None: if user_prompt is None:
return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK) return JSONResponse(content={"Response": self.processing(user_id, user_prompt)}, status_code=status.HTTP_200_OK)

View File

@ -1,27 +1,25 @@
import io import io
import time
from ntpath import join
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from .blackbox import Blackbox from .blackbox import Blackbox
from tts.tts_service import TTService from ..tts.tts_service import TTService
class TTS(Blackbox): class TTS(Blackbox):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
config = { self.tts_service = TTService("catmaid")
'paimon': ['resources/tts/models/paimon6k.json', 'resources/tts/models/paimon6k_390k.pth', 'character_paimon', 1],
'yunfei': ['resources/tts/models/yunfeimix2.json', 'resources/tts/models/yunfeimix2_53k.pth', 'character_yunfei', 1.1],
'catmaid': ['resources/tts/models/catmix.json', 'resources/tts/models/catmix_107k.pth', 'character_catmaid', 1.2]
}
self.tts_service = TTService(*config['catmaid'])
super().__init__(config)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
def processing(self, *args, **kwargs) -> io.BytesIO: def processing(self, *args, **kwargs) -> io.BytesIO:
text = args[0] text = args[0]
current_time = time.time()
audio = self.tts_service.read(text) audio = self.tts_service.read(text)
print("#### TTS Service consume : ", (time.time()-current_time))
return audio return audio
def valid(self, *args, **kwargs) -> bool: def valid(self, *args, **kwargs) -> bool:

30
src/dotchain/README.md Normal file
View File

@ -0,0 +1,30 @@
# Dotchain
Dotchain 是一種函數式編程語言. 文件後綴`.dc`
# 語法
```
// 註解
// 變量宣告
let hello = 123
// 函數宣告
let add = (left, right) => {
// 返回值
return left + right
}
// TODO 函數呼叫
add(1,2)
add(3, add(1,2))
// 以 . 呼叫函數,將以 . 前的值作為第一個參數
// hello.add(2) 等價於 add(hello, 2)
```
## Keywords
```
let while if else true false
```
```bash
python -m unittest
```

16
src/dotchain/main.dc Normal file
View File

@ -0,0 +1,16 @@
// 註解
// 變量宣告
let hello = 123;
// 函數宣告
let add = (left, right) => {
// 返回值
return left + right;
}
// TODO 函數呼叫
add(1,2);
add(3, add(1,2));
// 以 . 呼叫函數,將以 . 前的值作為第一個參數
// hello.add(2) == add(hello, 2);

29
src/dotchain/main.py Normal file
View File

@ -0,0 +1,29 @@
from runtime.interpreter import program_parser
from runtime.runtime import Runtime
from runtime.tokenizer import Tokenizer
import json
script = """
let rec = (c) => {
print(c);
if c == 0 {
return "c + 1";
}
rec(c-1);
}
let main = () => {
print("hello 嘉妮");
print(rec(10));
}
main();
"""
if __name__ == "__main__":
t = Tokenizer()
t.init(script)
runtime = Runtime(exteral_fun={"print": print})
ast = program_parser(t)
result = ast.exec(runtime)

View File

384
src/dotchain/runtime/ast.py Normal file
View File

@ -0,0 +1,384 @@
from abc import ABC, abstractmethod
from attr import dataclass
from .runtime import Runtime
@dataclass
class ReturnValue():
value: any
class Node(ABC):
def type(self):
return self.__class__.__name__
@dataclass
class Statement(Node, ABC):
@abstractmethod
def exec(self, runtime: Runtime):
print(self)
pass
@abstractmethod
def dict(self):
pass
@dataclass
class Expression(Node):
@abstractmethod
def eval(self, runtime: Runtime):
pass
@abstractmethod
def dict(self):
pass
@dataclass
class Literal(Expression):
value: str | int | float | bool
def eval(self, runtime: Runtime):
return self.value
def dict(self) -> dict:
return {
"type": "Literal",
"value": self.value
}
@dataclass
class StringLiteral(Literal):
value: str
def dict(self) -> dict:
return {
"type": "StringLiteral",
"value": self.value
}
@dataclass
class IntLiteral(Literal):
value: int
def dict(self):
return {
"type": "IntLiteral",
"value": self.value
}
@dataclass
class FloatLiteral(Literal):
value: float
def dict(self):
return {
"type": "FloatLiteral",
"value": self.value
}
@dataclass
class BoolLiteral(Literal):
value: bool
def dict(self):
return {
"type": "FloatLiteral",
"value": self.value
}
@dataclass
class UnaryExpression(Expression):
operator: str
expression: Expression
def eval(self, runtime: Runtime):
if self.operator == "-":
return -self.expression.eval(runtime)
if self.operator == "!":
return not self.expression.eval(runtime)
return self.expression.eval(runtime)
def dict(self):
return {
"type": "UnaryExpression",
"operator": self.operator,
"argument": self.expression.dict()
}
@dataclass
class Program(Statement):
body: list[Statement]
def exec(self, runtime: Runtime):
index = 0
while index < len(self.body):
statement = self.body[index]
result = statement.exec(runtime)
if isinstance(result, ReturnValue):
return result
index += 1
def dict(self):
return {
"type": self.type(),
"body": [statement.dict() for statement in self.body]
}
@dataclass
class Identifier(Expression):
name: str
def eval(self,runtime: Runtime):
return runtime.deep_get_value(self.name)
def dict(self):
return {
"type": self.type(),
"name": self.name
}
@dataclass
class Block(Statement):
body: list[Statement]
def exec(self, runtime: Runtime):
index = 0
while index < len(self.body):
statement = self.body[index]
result = statement.exec(runtime)
if isinstance(result, ReturnValue):
return result
if isinstance(result, BreakStatement):
return result
index += 1
def dict(self):
return {
"type": "Block",
"body": [statement.dict() for statement in self.body]
}
@dataclass
class WhileStatement(Statement):
test: Expression
body: Block
def exec(self, runtime: Runtime):
while self.test.eval(runtime):
while_runtime = Runtime(parent=runtime,name="while")
result = self.body.exec(while_runtime)
if isinstance(result, ReturnValue):
return result
if isinstance(result, BreakStatement):
return result
def dict(self):
return {
"type": "WhileStatement",
"test": self.test.dict(),
"body": self.body.dict()
}
@dataclass
class BreakStatement(Statement):
def exec(self, _: Runtime):
return self
def dict(self):
return {
"type": "BreakStatement"
}
@dataclass
class ReturnStatement(Statement):
value: Expression
def exec(self, runtime: Runtime):
return ReturnValue(self.value.eval(runtime))
def dict(self):
return {
"type": "ReturnStatement",
"value": self.value.dict()
}
@dataclass
class IfStatement(Statement):
test: Expression
consequent: Block
alternate: Block
def exec(self, runtime: Runtime):
if_runtime = Runtime(parent=runtime)
if self.test.eval(runtime):
return self.consequent.exec(if_runtime)
else:
return self.alternate.exec(if_runtime)
def dict(self):
return {
"type": "IfStatement",
"test": self.test.dict(),
"consequent": self.consequent.dict(),
"alternate": self.alternate.dict()
}
@dataclass
class VariableDeclaration(Statement):
id: Identifier
value: Expression
value_type: str = "any"
def exec(self, runtime: Runtime):
runtime.declare(self.id.name, self.value.eval(runtime))
def dict(self):
return {
"type": "VariableDeclaration",
"id": self.id.dict(),
"value": self.value.dict()
}
@dataclass
class Assignment(Statement):
id: Identifier
value: Expression
def exec(self, runtime: Runtime):
runtime.assign(self.id.name, self.value.eval(runtime))
def dict(self):
return {
"type": "Assignment",
"id": self.id.dict(),
"value": self.value.dict()
}
@dataclass
class Argument(Expression):
id: Identifier
value: Expression
def dict(self):
return {
"type": "Argument",
"id": self.id.dict(),
"value": self.value.dict()
}
@dataclass
class BinaryExpression(Expression):
left: Expression
operator: str
right: Expression
def eval(self, runtime: Runtime):
left = self.left.eval(runtime)
right = self.right.eval(runtime)
if self.operator == "+":
return left + right
if self.operator == "-":
return left - right
if self.operator == "*":
return left * right
if self.operator == "/":
return left / right
if self.operator == "%":
return left % right
if self.operator == "<":
return left < right
if self.operator == ">":
return left > right
if self.operator == "<=":
return left <= right
if self.operator == ">=":
return left >= right
if self.operator == "==":
return left == right
if self.operator == "!=":
return left != right
if self.operator == "&&":
return left and right
if self.operator == "||":
return left or right
return None
def dict(self):
return {
"type": "BinaryExpression",
"left": self.left.dict(),
"operator": self.operator,
"right": self.right.dict()
}
@dataclass
class CallExpression(Expression):
callee: Identifier
arguments: list[Expression]
def exec(self, runtime: Runtime, args: list=None):
if args == None:
args = []
for index, argument in enumerate(self.arguments):
args.append(argument.eval(runtime))
if runtime.has_value(self.callee.name):
fun:FunEnv = runtime.get_value(self.callee.name)
return fun.exec(args)
if runtime.parent is not None:
return self.exec(runtime.parent,args)
if self.callee.name in runtime.exteral_fun:
return runtime.exteral_fun[self.callee.name](*args)
def eval(self, runtime):
result = self.exec(runtime)
if result is not None:
return result.value
def dict(self):
return {
"type": "CallExpression",
"callee": self.callee.dict(),
"arguments": [argument.dict() for argument in self.arguments]
}
@dataclass
class Fun(Statement):
params: list[Identifier]
body: Block
def exec(self, runtime: Runtime):
return self.body.exec(runtime)
def eval(self, runtime: Runtime):
return FunEnv(runtime, self)
def dict(self):
return {
"type": "Fun",
"params": [param.dict() for param in self.params],
"body": self.body.dict()
}
class EmptyStatement(Statement):
def exec(self, _: Runtime):
return None
def eval(self, _: Runtime):
return None
def dict(self):
return {
"type": "EmptyStatement"
}
class FunEnv():
def __init__(self, parent: Runtime, body: Fun):
self.parent = parent
self.body = body
def exec(self, args: list):
fun_runtime = Runtime(parent=self.parent)
for index, param in enumerate(self.body.params):
fun_runtime.declare(param.name, args[index])
return self.body.exec(fun_runtime)

View File

@ -0,0 +1,420 @@
from ast import Expression
import copy
from .ast import Assignment, BinaryExpression, Block, BoolLiteral, BreakStatement, CallExpression, EmptyStatement, FloatLiteral, Fun, Identifier, IfStatement, IntLiteral, Program, ReturnStatement, Statement, StringLiteral, UnaryExpression, VariableDeclaration, WhileStatement
from .tokenizer import Token, TokenType, Tokenizer
unary_prev_statement = [
TokenType.COMMENTS,
TokenType.LEFT_PAREN,
TokenType.COMMA,
TokenType.LEFT_BRACE,
TokenType.RIGHT_BRACE,
TokenType.SEMICOLON,
TokenType.LET,
TokenType.RETURN,
TokenType.IF,
TokenType.ELSE,
TokenType.WHILE,
TokenType.FOR,
TokenType.LOGICAL_OPERATOR,
TokenType.NOT,
TokenType.ASSIGNMENT,
TokenType.MULTIPLICATIVE_OPERATOR,
TokenType.ADDITIVE_OPERATOR,
TokenType.ARROW,
]
unary_end_statement = [
TokenType.MULTIPLICATIVE_OPERATOR,
TokenType.ADDITIVE_OPERATOR,
TokenType.LOGICAL_OPERATOR,
]
end_statement = [
TokenType.SEMICOLON,
TokenType.COMMA,
TokenType.ARROW,
TokenType.RETURN,
TokenType.LET,
TokenType.IF,
TokenType.ELSE,
TokenType.WHILE,
TokenType.FOR,
TokenType.ASSIGNMENT,
TokenType.RIGHT_BRACE,
TokenType.LEFT_BRACE,
]
def program_parser(tkr: Tokenizer):
statements = list[Statement]()
count = 0
while True:
if tkr.token() is None:
break
if tkr.token().type == TokenType.SEMICOLON:
tkr.next()
continue
statement = statement_parser(tkr)
statements.append(statement)
count += 1
return Program(statements)
def if_parser(tkr: Tokenizer):
tkr.eat(TokenType.IF)
condition = ExpressionParser(tkr).parse()
block = block_statement(tkr)
if tkr.type_is(TokenType.ELSE):
tkr.eat(TokenType.ELSE)
if tkr.type_is(TokenType.IF):
print("else if")
return IfStatement(condition, block, Block([if_parser(tkr)]))
return IfStatement(condition, block, block_statement(tkr))
return IfStatement(condition, block, Block([]))
def while_parser(tkr: Tokenizer):
tkr.eat(TokenType.WHILE)
condition = ExpressionParser(tkr).parse()
block = block_statement(tkr)
return WhileStatement(condition, block)
def identifier(tkr: Tokenizer):
token = tkr.token()
if token.type != TokenType.IDENTIFIER:
raise Exception("Invalid identifier", token)
tkr.next()
return Identifier(token.value)
def block_statement(tkr: Tokenizer):
tkr.eat(TokenType.LEFT_BRACE)
statements = list[Statement]()
while True:
if tkr.token() is None:
raise Exception("Invalid block expression", tkr.token())
if tkr.tokenType() == TokenType.RIGHT_BRACE:
tkr.eat(TokenType.RIGHT_BRACE)
break
if tkr.tokenType() == TokenType.SEMICOLON:
tkr.next()
continue
statements.append(statement_parser(tkr))
return Block(statements)
def return_parser(tkr: Tokenizer):
tkr.eat(TokenType.RETURN)
return ReturnStatement(ExpressionParser(tkr).parse())
def statement_parser(tkr: Tokenizer):
token = tkr.token()
if token is None:
return EmptyStatement()
if token.type == TokenType.SEMICOLON:
tkr.next()
return EmptyStatement()
if token.type == TokenType.LET:
return let_expression_parser(tkr)
if _try_assignment_expression(tkr):
return assignment_parser(tkr)
if token.type == TokenType.IF:
return if_parser(tkr)
if token.type == TokenType.WHILE:
return while_parser(tkr)
if token.type == TokenType.RETURN:
return return_parser(tkr)
if token.type == TokenType.BREAK:
tkr.eat(TokenType.BREAK)
return BreakStatement()
return ExpressionParser(tkr).parse()
def assignment_parser(tkr: Tokenizer):
id = identifier(tkr)
tkr.eat(TokenType.ASSIGNMENT)
return Assignment(id, ExpressionParser(tkr).parse())
def let_expression_parser(tkr: Tokenizer):
tkr.eat(TokenType.LET)
token = tkr.token()
if token.type != TokenType.IDENTIFIER:
raise Exception("Invalid let statement", token)
id = identifier(tkr)
token = tkr.token()
if token is None:
raise Exception("Invalid let statement", token)
if token.type != TokenType.ASSIGNMENT:
raise Exception("Invalid let statement", token.type)
tkr.next()
ast = ExpressionParser(tkr).parse()
return VariableDeclaration(id, ast)
class ExpressionParser:
def __init__(self, tkr: Tokenizer):
self.stack = list[Expression | Token]()
self.operator_stack = list[Token]()
self.tkr = tkr
def parse(self, unary = False):
while not self.is_end():
token = self.tkr.token()
if unary and not self.is_unary() and token.type in unary_end_statement:
break
if self.is_unary():
self.push_stack(self.unary_expression_parser())
elif self._try_fun_expression():
return self.fun_expression()
# -(hello x 123) // !(true and false)
elif unary and token.type == TokenType.LEFT_PAREN:
self.tkr.next()
self.push_stack(ExpressionParser(self.tkr).parse())
elif self._is_operator(token) or token.type in [TokenType.LEFT_PAREN, TokenType.RIGHT_PAREN ]:
self.push_operator_stack(token)
self.tkr.next()
else:
self.push_stack(self.expression_parser())
self.pop_all()
return self.expression()
def expression(self):
if len(self.stack) == 0:
return EmptyStatement()
if len(self.stack) == 1:
return self.stack[0]
return expression_list_to_binary(self.stack)
def expression_parser(self):
token = self.tkr.token()
if token is None:
return EmptyStatement()
expression = None
if token.type == TokenType.INT:
self.tkr.eat(TokenType.INT)
expression = IntLiteral(int(token.value))
elif token.type == TokenType.FLOAT:
self.tkr.eat(TokenType.FLOAT)
expression = FloatLiteral(float(token.value))
elif token.type == TokenType.STRING:
self.tkr.eat(TokenType.STRING)
expression = StringLiteral(token.value[1:-1])
elif token.type == TokenType.BOOL:
self.tkr.eat(TokenType.BOOL)
expression = BoolLiteral(token.value == "true")
elif token.type == TokenType.IDENTIFIER:
expression = self.identifier_or_fun_call_parser()
return expression
def _try_fun_expression(self):
return _try_fun_expression(self.tkr)
def fun_expression(self):
tkr = self.tkr
tkr.next()
args = list[Identifier]()
token_type = tkr.tokenType()
while token_type != TokenType.RIGHT_PAREN:
args.append(Identifier(tkr.token().value))
tkr.next()
token_type = tkr.tokenType()
if token_type == TokenType.RIGHT_PAREN:
break
tkr.next()
token_type = tkr.tokenType()
token_type = tkr.next_token_type()
if token_type != TokenType.ARROW:
raise Exception("Invalid fun_expression", tkr.token())
tkr.next()
return Fun(args, block_statement(tkr))
def push_stack(self, expression: Expression | Token):
self.stack.append(expression)
def _pop_by_right_paren(self):
token = self.operator_stack.pop()
if token.type != TokenType.LEFT_PAREN:
self.push_stack(token)
self._pop_by_right_paren()
def pop(self):
self.push_stack(self.operator_stack.pop())
def pop_all(self):
while len(self.operator_stack) > 0:
self.pop()
def push_operator_stack(self, token: Token):
if len(self.operator_stack) == 0:
self.operator_stack.append(token)
return
if token.type == TokenType.LEFT_PAREN:
self.operator_stack.append(token)
return
if token.type == TokenType.RIGHT_PAREN:
self._pop_by_right_paren()
return
top_operator = self.operator_stack[-1]
if top_operator.type == TokenType.LEFT_PAREN:
self.operator_stack.append(token)
return
# priority is in descending order
if self._priority(token) >= self._priority(top_operator):
self.pop()
self.push_operator_stack(token)
return
self.operator_stack.append(token)
def unary_expression_parser(self):
token = self.tkr.token()
self.tkr.next()
return UnaryExpression(token.value, ExpressionParser(self.tkr).parse(True))
def identifier_or_fun_call_parser(self):
id = self.identifier()
tokenType = self.tkr.tokenType()
if tokenType == TokenType.LEFT_PAREN:
return self.fun_call_parser(id)
return id
def fun_call_parser(self, id: Identifier):
self.tkr.eat(TokenType.LEFT_PAREN)
args = list[Expression]()
while self.tkr.tokenType() != TokenType.RIGHT_PAREN:
args.append(ExpressionParser(self.tkr).parse())
if self.tkr.tokenType() == TokenType.COMMA:
self.tkr.eat(TokenType.COMMA)
self.tkr.eat(TokenType.RIGHT_PAREN)
return CallExpression(id, args)
def identifier(self):
return identifier(self.tkr)
def is_unary(self):
token = self.tkr.token()
if not self.unary_operator(token):
return False
if token.type == TokenType.NOT:
return True
prev_token = self.tkr.get_prev()
if prev_token is None:
return True
if prev_token.type == TokenType.LEFT_PAREN:
return True
if prev_token.type in unary_prev_statement:
return True
return False
def unary_operator(self, token: Token):
if token is None:
return False
return token.value in ["+", "-", "!"]
def _has_brackets(self):
return TokenType.LEFT_PAREN in map(lambda x: x.type, self.operator_stack)
def is_end(self):
token = self.tkr.token()
if token is None:
return True
if token.type == TokenType.SEMICOLON:
return True
if not self._has_brackets() and token.type == TokenType.RIGHT_PAREN:
return True
if token.type in end_statement:
return True
return False
def _is_operator(self, token: Token):
if token is None:
return False
return token.type in [TokenType.ADDITIVE_OPERATOR, TokenType.MULTIPLICATIVE_OPERATOR, TokenType.LOGICAL_OPERATOR, TokenType.NOT]
def _debug_print_tokens(self):
print("operator stack:----")
for token in self.operator_stack:
print(token)
def _debug_print_stack(self):
print("stack:----")
for expression in self.stack:
print(expression)
def _priority(self, token: Token):
return _priority(token.value)
def expression_list_to_binary(expression_list: list[Expression | Token], stack: list = None):
if stack is None:
stack = list()
if len(expression_list) == 0:
return stack[0]
top = expression_list[0]
if isinstance(top, Token):
right = stack.pop()
left = stack.pop()
return expression_list_to_binary(expression_list[1:], stack + [BinaryExpression(left, top.value, right)])
else:
stack.append(top)
return expression_list_to_binary(expression_list[1:], stack)
def _priority(operator: str):
priority = 0
if operator in ["*", "/", "%"]:
return priority
priority += 1
if operator in ["+", "-"]:
return priority
priority += 1
if operator in ["<", ">", "<=", ">="]:
return priority
priority += 1
if operator in ["==", "!="]:
return priority
priority += 1
if operator in ["&&"]:
return priority
priority += 1
if operator in ["||"]:
return priority
priority += 1
return priority
def _try_assignment_expression(tkr: Tokenizer):
tkr = copy.deepcopy(tkr)
token = tkr.token()
if token is None:
return False
if token.type != TokenType.IDENTIFIER:
return False
tkr.next()
token = tkr.token()
if token is None:
return False
if token.type != TokenType.ASSIGNMENT:
return False
return True
def _try_fun_expression(_tkr: Tokenizer):
tkr = copy.deepcopy(_tkr)
token = tkr.token()
if token is None:
return False
if token.type != TokenType.LEFT_PAREN:
return False
tkr.next()
token_type = tkr.tokenType()
while token_type != TokenType.RIGHT_PAREN:
if token_type == TokenType.IDENTIFIER:
tkr.next()
token_type = tkr.tokenType()
if token_type == TokenType.RIGHT_PAREN:
break
if token_type != TokenType.COMMA:
return False
tkr.next()
token_type = tkr.tokenType()
if token_type == TokenType.RIGHT_PAREN:
return False
else:
return False
token_type = tkr.next_token_type()
if token_type != TokenType.ARROW:
return False
return True

View File

@ -0,0 +1,44 @@
from ast import Expression
from attr import dataclass
class Runtime():
def __init__(self, context=None, parent=None, exteral_fun=None, name=None) -> None:
self.name = name
self.parent = parent
self.context = context if context is not None else dict()
self.exteral_fun = exteral_fun if exteral_fun is not None else dict()
def has_value(self, identifier: str) -> bool:
return identifier in self.context
def get_value(self, identifier: str):
return self.context.get(identifier)
def deep_get_value(self, id: str):
if self.has_value(id):
return self.get_value(id)
if self.parent is not None:
return self.parent.deep_get_value(id)
return None
def set_value(self, identifier: str, value):
self.context[identifier] = value
def declare(self, identifier: str, value):
if self.has_value(identifier):
raise Exception(f"Variable {identifier} is already declared")
self.set_value(identifier, value)
def assign(self, identifier: str, value):
if self.has_value(identifier):
self.set_value(identifier, value)
elif self.parent is not None:
self.parent.assign(identifier, value)
else:
raise Exception(f"Variable {identifier} is not declared")
def show_values(self):
print(self.context)

View File

View File

@ -0,0 +1,153 @@
import unittest
from runtime.ast import BoolLiteral, CallExpression, FloatLiteral, Identifier, IntLiteral, UnaryExpression
from runtime.interpreter import ExpressionParser, _priority, _try_fun_expression
from runtime.tokenizer import TokenType, Tokenizer,Token
class TestExpressionParser(unittest.TestCase):
def test__try_fun_expression(self):
t = Tokenizer()
t.init("()")
self.assertFalse(_try_fun_expression(t))
t.init("() =>")
self.assertTrue(_try_fun_expression(t))
t.init("(a) =>")
self.assertTrue(_try_fun_expression(t))
t.init("(a,) =>")
self.assertFalse(_try_fun_expression(t))
t.init("(a,b,c,d) =>;")
self.assertTrue(_try_fun_expression(t))
t.init("(a,b,c,true) =>;")
self.assertFalse(_try_fun_expression(t))
t.init("(a,b,c,1.23) =>;")
self.assertFalse(_try_fun_expression(t))
def test_is_unary(self):
t = Tokenizer()
t.init("!")
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init("+")
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init("--123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init("+-123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init(")-123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertFalse(pred)
t.init("=> - 123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init(", - 123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init("* - 123")
t.next()
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertTrue(pred)
t.init("* - 123")
parser = ExpressionParser(t)
pred = parser.is_unary()
self.assertFalse(pred)
def test_expression_parser(self):
t = Tokenizer()
t.init("a")
parser = ExpressionParser(t)
expression = parser.expression_parser()
self.assertIsInstance(expression, Identifier)
t.init("true")
parser = ExpressionParser(t)
expression = parser.expression_parser()
self.assertIsInstance(expression, BoolLiteral)
self.assertEqual(expression.value, True)
t.init("false")
parser = ExpressionParser(t)
expression = parser.expression_parser()
self.assertIsInstance(expression, BoolLiteral)
self.assertEqual(expression.value, False)
t.init("12341")
parser = ExpressionParser(t)
expression = parser.expression_parser()
self.assertEqual(expression.value, 12341)
self.assertIsInstance(expression, IntLiteral)
t.init("12341.42")
parser = ExpressionParser(t)
expression = parser.expression_parser()
self.assertEqual(expression.value, 12341.42)
self.assertIsInstance(expression, FloatLiteral)
t.init("hello")
parser = ExpressionParser(t)
expression: Identifier = parser.expression_parser()
self.assertIsInstance(expression, Identifier)
self.assertEqual(expression.name, "hello")
t.init("print()")
parser = ExpressionParser(t)
expression: CallExpression = parser.expression_parser()
self.assertIsInstance(expression, CallExpression)
self.assertEqual(expression.callee.name, "print")
t.init("print(1,2,3,hello)")
parser = ExpressionParser(t)
expression: CallExpression = parser.expression_parser()
self.assertIsInstance(expression, CallExpression)
self.assertEqual(expression.callee.name, "print")
self.assertEqual(len(expression.arguments), 4)
def test_binary_expression(self):
t = Tokenizer()
def test__priority(self):
self.assertEqual(_priority("*"), 0)
self.assertEqual(_priority("/"), 0)
self.assertEqual(_priority("%"), 0)
self.assertEqual(_priority("+"), 1)
self.assertEqual(_priority("-"), 1)
self.assertEqual(_priority(">"), 2)
self.assertEqual(_priority("<"), 2)
self.assertEqual(_priority(">="), 2)
self.assertEqual(_priority("<="), 2)
self.assertEqual(_priority("=="), 3)
self.assertEqual(_priority("!="), 3)
self.assertEqual(_priority("&&"), 4)
self.assertEqual(_priority("||"), 5)

View File

@ -0,0 +1,7 @@
import unittest
class TestRuntime(unittest.TestCase):
def test_eval(self):
self.assertTrue(True)

View File

@ -0,0 +1,151 @@
import unittest
from runtime.tokenizer import TokenType, Tokenizer,Token
class TestTokenizer(unittest.TestCase):
def test_init(self):
t = Tokenizer()
self.assertEqual(t.script, "")
self.assertEqual(t.cursor, 0)
self.assertEqual(t.col, 0)
self.assertEqual(t.row, 0)
def test_tokenizer(self):
t = Tokenizer()
t.init("a")
self.assertEqual(t.token().value, "a")
self.assertEqual(t.token().type, TokenType.IDENTIFIER)
t.init("12341")
self.assertEqual(t.token().value, "12341")
self.assertEqual(t.token().type, TokenType.INT)
t.init("12341.1234124")
self.assertEqual(t.token().value, "12341.1234124")
self.assertEqual(t.token().type, TokenType.FLOAT)
t.init("false")
self.assertEqual(t.token().value, "false")
self.assertEqual(t.token().type, TokenType.BOOL)
t.init("\"false\"")
self.assertEqual(t.token().value, "\"false\"")
self.assertEqual(t.token().type, TokenType.STRING)
t.init("helloworld")
self.assertEqual(t.token().value, "helloworld")
self.assertEqual(t.token().type, TokenType.IDENTIFIER)
t.init("!")
self.assertEqual(t.token().value, "!")
self.assertEqual(t.token().type, TokenType.NOT)
t.init("==")
self.assertEqual(t.token().value, "==")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("!=")
self.assertEqual(t.token().value, "!=")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("<=")
self.assertEqual(t.token().value, "<=")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init(">=")
self.assertEqual(t.token().value, ">=")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("<")
self.assertEqual(t.token().value, "<")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init(">")
self.assertEqual(t.token().value, ">")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("&&")
self.assertEqual(t.token().value, "&&")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("||")
self.assertEqual(t.token().value, "||")
self.assertEqual(t.token().type, TokenType.LOGICAL_OPERATOR)
t.init("=")
self.assertEqual(t.token().value, "=")
self.assertEqual(t.token().type, TokenType.ASSIGNMENT)
t.init("+")
self.assertEqual(t.token().value, "+")
self.assertEqual(t.token().type, TokenType.ADDITIVE_OPERATOR)
t.init("-")
self.assertEqual(t.token().value, "-")
self.assertEqual(t.token().type, TokenType.ADDITIVE_OPERATOR)
t.init("*")
self.assertEqual(t.token().value, "*")
self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR)
t.init("/")
self.assertEqual(t.token().value, "/")
self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR)
t.init("%")
self.assertEqual(t.token().value, "%")
self.assertEqual(t.token().type, TokenType.MULTIPLICATIVE_OPERATOR)
t.init("(")
self.assertEqual(t.token().value, "(")
self.assertEqual(t.token().type, TokenType.LEFT_PAREN)
t.init(")")
self.assertEqual(t.token().value, ")")
self.assertEqual(t.token().type, TokenType.RIGHT_PAREN)
t.init("{")
self.assertEqual(t.token().value, "{")
self.assertEqual(t.token().type, TokenType.LEFT_BRACE)
t.init("}")
self.assertEqual(t.token().value, "}")
self.assertEqual(t.token().type, TokenType.RIGHT_BRACE)
def test_init(self):
t = Tokenizer()
script = "a + 9 * ( 3 - 1 ) * 3 + 10 / 2;"
t.init(script)
self.assertEqual(t.script, script)
self.assertEqual(len(t.tokens), 16)
self.assertEqual(t.get_prev(), None)
self.assertEqual(t.token().value, "a")
self.assertEqual(t.get_next().value, "+")
self.assertEqual(t.next().value, "+")
self.assertEqual(t.next().value, "9")
self.assertEqual(t.next().value, "*")
t.prev()
self.assertEqual(t.token().value, "9")
t.prev()
self.assertEqual(t.token().value, "+")
script = "a + 9"
t.init(script)
self.assertEqual(t.token().type, TokenType.IDENTIFIER)
self.assertEqual(t.next().type, TokenType.ADDITIVE_OPERATOR)
self.assertEqual(t.next().type, TokenType.INT)
self.assertEqual(t.next(), None)
self.assertEqual(t._current_token_index, 3)
self.assertEqual(t.next(), None)
self.assertEqual(t.next(), None)
self.assertEqual(t._current_token_index, 3)
self.assertEqual(t.next(), None)
t.prev()
self.assertEqual(t.token().value, "9")
t.prev()
self.assertEqual(t.token().value, "+")
t.prev()
self.assertEqual(t.token().value, "a")
t.prev()
self.assertEqual(t.token().value, "a")

View File

@ -0,0 +1,259 @@
import re
from enum import Enum
from attr import dataclass
class TokenType(Enum):
NEW_LINE = 1
SPACE = 2
COMMENTS = 3
LEFT_PAREN = 4
RIGHT_PAREN = 5
COMMA = 6
LEFT_BRACE = 7
RIGHT_BRACE = 8
SEMICOLON = 9
LET = 10
RETURN = 11
IF = 12
ELSE = 13
WHILE = 14
FOR = 15
FLOAT = 18
INT = 19
IDENTIFIER = 20
LOGICAL_OPERATOR = 21
NOT = 22
ASSIGNMENT = 23
MULTIPLICATIVE_OPERATOR = 24
ADDITIVE_OPERATOR = 25
STRING = 26
ARROW = 27
BOOL = 28
BREAK = 29
TYPE_DEFINITION = 30
COLON = 31
specs = (
(re.compile(r"^\n"),TokenType.NEW_LINE),
# Space:
(re.compile(r"^\s"),TokenType.SPACE),
# Comments:
(re.compile(r"^//.*"), TokenType.COMMENTS),
# Symbols:
(re.compile(r"^\("), TokenType.LEFT_PAREN),
(re.compile(r"^\)"), TokenType.RIGHT_PAREN),
(re.compile(r"^\,"), TokenType.COMMA),
(re.compile(r"^\{"), TokenType.LEFT_BRACE),
(re.compile(r"^\}"), TokenType.RIGHT_BRACE),
(re.compile(r"^;"), TokenType.SEMICOLON),
(re.compile(r"^:"), TokenType.COLON),
(re.compile(r"^=>"), TokenType.ARROW),
# Keywords:
(re.compile(r"^\blet\b"), TokenType.LET),
(re.compile(r"^\breturn\b"), TokenType.RETURN),
(re.compile(r"^\bif\b"), TokenType.IF),
(re.compile(r"^\belse\b"), TokenType.ELSE),
(re.compile(r"^\bwhile\b"), TokenType.WHILE),
(re.compile(r"^\bfor\b"), TokenType.FOR),
(re.compile(r"^\bbreak\b"), TokenType.BREAK),
(re.compile(r"^\btrue\b"), TokenType.BOOL),
(re.compile(r"^\bfalse\b"), TokenType.BOOL),
# Type definition:
(re.compile(r"^\bstring\b"), TokenType.TYPE_DEFINITION),
(re.compile(r"^\bint\b"), TokenType.TYPE_DEFINITION),
(re.compile(r"^\bfloat\b"), TokenType.TYPE_DEFINITION),
(re.compile(r"^\bbool\b"), TokenType.TYPE_DEFINITION),
(re.compile(r"^\bany\b"), TokenType.TYPE_DEFINITION),
# Floats:
(re.compile(r"^[0-9]+\.[0-9]+"), TokenType.FLOAT),
# Ints:
(re.compile(r"^[0-9]+"), TokenType.INT),
# Identifiers:
(re.compile(r"^\w+"), TokenType.IDENTIFIER),
# Logical operators:
(re.compile(r"^&&"), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^\|\|"), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^=="), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^!="), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^<="), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^>="), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^<"), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^>"), TokenType.LOGICAL_OPERATOR),
(re.compile(r"^!"), TokenType.NOT),
# Assignment:
(re.compile(r"^="), TokenType.ASSIGNMENT),
# Math operators: +, -, *, /:
(re.compile(r"^[*/%]"), TokenType.MULTIPLICATIVE_OPERATOR),
(re.compile(r"^[+-]"), TokenType.ADDITIVE_OPERATOR),
# Double-quoted strings
# TODO: escape character \" and
(re.compile(r"^\"[^\"]*\""), TokenType.STRING),
)
@dataclass
class Token:
type: TokenType
value: str
row: int
col: int
col_end: int
cursor: int
def __str__(self) -> str:
return f"Token({self.type}, {self.value}, row={self.row}, col={self.col}, col_end={self.col_end}, cursor={self.cursor})"
class Tokenizer:
def __init__(self):
self._current_token = None
self.script = ""
self.cursor = 0
self.col = 0
self.row = 0
self._current_token_index = 0
self.tokens = list[Token]()
self.checkpoint = list[int]()
def init(self, script: str):
self.checkpoint = list[int]()
self.tokens = list[Token]()
self._current_token_index = 0
self._current_token = None
self.script = script
self.cursor = 0
self.col = 0
self.row = 0
self._get_next_token()
while self._current_token is not None:
self.tokens.append(self._current_token)
self._get_next_token()
def checkpoint_push(self):
self.checkpoint.append(self._current_token_index)
def checkpoint_pop(self):
self._current_token_index = self.checkpoint.pop()
def next(self):
if self._current_token_index < len(self.tokens):
self._current_token_index += 1
return self.token()
def next_token_type(self):
if self._current_token_index < len(self.tokens):
self._current_token_index += 1
return self.tokenType()
def prev(self):
if self._current_token_index > 0:
self._current_token_index -= 1
return self.token()
def get_prev(self):
if self._current_token_index == 0:
return None
return self.tokens[self._current_token_index - 1]
def get_next(self):
if self._current_token_index >= len(self.tokens):
return None
return self.tokens[self._current_token_index + 1]
def token(self):
if self._current_token_index >= len(self.tokens):
return None
return self.tokens[self._current_token_index]
def tokenType(self):
if self._current_token_index >= len(self.tokens):
return None
return self.tokens[self._current_token_index].type
def _get_next_token(self):
if self._is_eof():
self._current_token = None
return None
_string = self.script[self.cursor:]
for spec in specs:
tokenValue, offset = self.match(spec[0], _string)
if tokenValue == None:
continue
if spec[1] == TokenType.NEW_LINE:
self.row += 1
self.col = 0
return self._get_next_token()
if spec[1] == TokenType.COMMENTS:
return self._get_next_token()
if spec[1] == TokenType.SPACE:
self.col += offset
return self._get_next_token()
if spec[1] == None:
return self._get_next_token()
self._current_token = Token(spec[1],tokenValue, self.cursor, self.row, self.col, self.col + offset)
self.col += offset
return self.get_current_token()
raise Exception("Unknown token: " + _string[0])
def _is_eof(self):
return self.cursor == len(self.script)
def has_more_tokens(self):
return self.cursor < len(self.script)
def get_current_token(self):
return self._current_token
def match(self, reg: re, _script):
matched = reg.search(_script)
if matched == None:
return None,0
self.cursor = self.cursor + matched.span(0)[1]
return matched[0], matched.span(0)[1]
def eat(self, value: str | TokenType):
if isinstance(value, str):
return self.eat_value(value)
if isinstance(value, TokenType):
return self.eat_token_type(value)
def eat_value(self, value: str):
token = self.token()
if token is None:
raise Exception(f"Expected {value} but got None")
if token.value != value:
raise Exception(f"Expected {value} but got {token.value}")
self.next()
return token
def eat_token_type(self,tokenType: TokenType):
token = self.token()
if token is None:
raise Exception(f"Expected {tokenType} but got None")
if token.type != tokenType:
raise Exception(f"Expected {tokenType} but got {token.type}")
self.next()
return token
def type_is(self, tokenType: TokenType):
if self.token() is None:
return False
return self.token().type == tokenType
def the_rest(self):
return self.tokens[self._current_token_index:]

View File

@ -4,12 +4,17 @@ import onnxruntime
from transformers import BertTokenizer from transformers import BertTokenizer
import numpy as np import numpy as np
dirabspath = __file__.split("\\")[1:-1]
dirabspath= "C://" + "/".join(dirabspath)
default_path = dirabspath + "/models/paimon_sentiment.onnx"
class SentimentEngine(): class SentimentEngine():
def __init__(self, model_path="resources/sentiment_engine/models/paimon_sentiment.onnx"): def __init__(self):
logging.info('Initializing Sentiment Engine...') logging.info('Initializing Sentiment Engine...')
onnx_model_path = model_path onnx_model_path = default_path
self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) self.ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

View File

@ -1,40 +1,61 @@
import io import io
import sys import sys
import time
sys.path.append('tts/vits') sys.path.append('src/tts/vits')
import numpy as np
import soundfile import soundfile
import os import os
os.environ["PYTORCH_JIT"] = "0" os.environ["PYTORCH_JIT"] = "0"
import torch import torch
import tts.vits.commons as commons import src.tts.vits.commons as commons
import tts.vits.utils as utils import src.tts.vits.utils as utils
from tts.vits.models import SynthesizerTrn from src.tts.vits.models import SynthesizerTrn
from tts.vits.text.symbols import symbols from src.tts.vits.text.symbols import symbols
from tts.vits.text import text_to_sequence from src.tts.vits.text import text_to_sequence
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
dirbaspath = __file__.split("\\")[1:-1]
dirbaspath= "C://" + "/".join(dirbaspath)
config = {
'paimon': {
'cfg': dirbaspath + '/models/paimon6k.json',
'model': dirbaspath + '/models/paimon6k_390k.pth',
'char': 'character_paimon',
'speed': 1
},
'yunfei': {
'cfg': dirbaspath + '/tts/models/yunfeimix2.json',
'model': dirbaspath + '/models/yunfeimix2_53k.pth',
'char': 'character_yunfei',
'speed': 1.1
},
'catmaid': {
'cfg': dirbaspath + '/models/catmix.json',
'model': dirbaspath + '/models/catmix_107k.pth',
'char': 'character_catmaid',
'speed': 1.2
},
}
class TTService(): class TTService():
def __init__(self, cfg, model, char, speed): def __init__(self, model_name="catmaid"):
logging.info('Initializing TTS Service for %s...' % char) cfg = config[model_name]
self.hps = utils.get_hparams_from_file(cfg) logging.info('Initializing TTS Service for %s...' % cfg["char"])
self.speed = speed self.hps = utils.get_hparams_from_file(cfg["cfg"])
self.speed = cfg["speed"]
self.net_g = SynthesizerTrn( self.net_g = SynthesizerTrn(
len(symbols), len(symbols),
self.hps.data.filter_length // 2 + 1, self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length, self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model).cpu() **self.hps.model).cuda()
_ = self.net_g.eval() _ = self.net_g.eval()
_ = utils.load_checkpoint(model, self.net_g, None) _ = utils.load_checkpoint(cfg["model"], self.net_g, None)
def get_text(self, text, hps): def get_text(self, text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners) text_norm = text_to_sequence(text, hps.data.text_cleaners)
@ -48,8 +69,8 @@ class TTService():
stn_tst = self.get_text(text, self.hps) stn_tst = self.get_text(text, self.hps)
with torch.no_grad(): with torch.no_grad():
x_tst = stn_tst.cpu().unsqueeze(0) x_tst = stn_tst.cuda().unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cpu() x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
# tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed) # tp = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][ audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.2, length_scale=self.speed)[0][
0, 0].data.cpu().float().numpy() 0, 0].data.cpu().float().numpy()

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import torch import torch
from .monotonic_align.core import maximum_path_c from .core import maximum_path_c
def maximum_path(neg_cent, mask): def maximum_path(neg_cent, mask):

View File

Before

Width:  |  Height:  |  Size: 63 KiB

After

Width:  |  Height:  |  Size: 63 KiB

View File

Before

Width:  |  Height:  |  Size: 35 KiB

After

Width:  |  Height:  |  Size: 35 KiB

View File

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB

Some files were not shown because too many files have changed in this diff Show More