import torch import torch.nn.functional as F class CTC(torch.nn.Module): """CTC module. Args: odim: dimension of outputs encoder_output_size: number of encoder projection units dropout_rate: dropout rate (0.0 ~ 1.0) reduce: reduce the CTC loss into a scalar """ def __init__( self, odim: int, encoder_output_size: int, dropout_rate: float = 0.0, reduce: bool = True, blank_id: int = 0, **kwargs, ): super().__init__() eprojs = encoder_output_size self.dropout_rate = dropout_rate self.ctc_lo = torch.nn.Linear(eprojs, odim) self.blank_id = blank_id self.ctc_loss = torch.nn.CTCLoss(reduction="none", blank=blank_id) self.reduce = reduce def softmax(self, hs_pad): """softmax of frame activations Args: Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: torch.Tensor: softmax applied 3d tensor (B, Tmax, odim) """ return F.softmax(self.ctc_lo(hs_pad), dim=2) def log_softmax(self, hs_pad): """log_softmax of frame activations Args: Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) """ return F.log_softmax(self.ctc_lo(hs_pad), dim=2) def argmax(self, hs_pad): """argmax of frame activations Args: torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: torch.Tensor: argmax applied 2d tensor (B, Tmax) """ return torch.argmax(self.ctc_lo(hs_pad), dim=2)