Last active
October 23, 2023 18:54
-
-
Save leboucletoledo/3d20016b6d8ee79f54505f5b94460f4d to your computer and use it in GitHub Desktop.
Emulations of SpatialDropout1D on PyTorch, based on Tensorflow implementation. Tested on PyTorch version < 2.0.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def spatial_dropout1d_v1(x, do_rate=0.1): | |
'''Emulate SpatialDropout1D from TF using Dropout1d''' | |
do1d = torch.nn.Dropout1d(p=do_rate) | |
# convert to [batch, channels, time] | |
x = do1d(x.permute(0, 2, 1)) | |
# back to [batch, time, channels] | |
x = x.permute(0, 2, 1) | |
return x | |
def spatial_dropout1d_v2(x, do_rate=0.1): | |
'''Emulate SpatialDropout1D from TF using Dropout2d''' | |
do2d = torch.nn.Dropout2d(p=do_rate) | |
x = do2d(x.permute(0, 2, 1)) # convert to [batch, channels, time] | |
x = x.unsqueeze(2).permute(0, 3, 2, 1) | |
x = do2d(x) | |
x = x.permute(0, 3, 2, 1).squeeze(2) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment