init ctc decoder
This commit is contained in:
60
ctc.py
Normal file
60
ctc.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CTC(torch.nn.Module):
|
||||
"""CTC module.
|
||||
|
||||
Args:
|
||||
odim: dimension of outputs
|
||||
encoder_output_size: number of encoder projection units
|
||||
dropout_rate: dropout rate (0.0 ~ 1.0)
|
||||
reduce: reduce the CTC loss into a scalar
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
odim: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
reduce: bool = True,
|
||||
blank_id: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
||||
self.blank_id = blank_id
|
||||
self.ctc_loss = torch.nn.CTCLoss(reduction="none", blank=blank_id)
|
||||
self.reduce = reduce
|
||||
|
||||
def softmax(self, hs_pad):
|
||||
"""softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def log_softmax(self, hs_pad):
|
||||
"""log_softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def argmax(self, hs_pad):
|
||||
"""argmax of frame activations
|
||||
|
||||
Args:
|
||||
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: argmax applied 2d tensor (B, Tmax)
|
||||
"""
|
||||
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
||||
Reference in New Issue
Block a user