This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# helpers | |
def make_unit_length(x, epsilon=1e-6): | |
norm = x.norm(p=2, dim=-1, keepdim=True) | |
return x.div(norm + epsilon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CTC vanilla and CTC via crossentropy are equal, and their gradients as well. In this reformulation it's easier to experiment with modifications of CTC. | |
# References on CTC regularization: | |
# "A Novel Re-weighting Method for Connectionist Temporal Classification", Li et al, https://arxiv.org/abs/1904.10619 | |
# "Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets", Feng et al, https://www.hindawi.com/journals/complexity/2019/9345861/ | |
# "Improved training for online end-to-end speech recognition systems", Kim et al, https://arxiv.org/abs/1711.02212 | |
import torch | |
import torch.nn.functional as F | |
## generate example data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An implementation of "Machine Learning on Sequential Data Using a Recurrent Weighted Average" using pytorch | |
# https://arxiv.org/pdf/1703.01253.pdf | |
# | |
# | |
# This is a RNN (recurrent neural network) type that uses a weighted average of values seen in the past, rather | |
# than a separate running state. | |
# | |
# Check the test code at the bottom for an example of usage, where you can compare it's performance | |
# against LSTM and GRU, at a classification task from the paper. It handily beats both the LSTM and | |
# GRU :) |