Skip to content

Instantly share code, notes, and snippets.

@leboucletoledo
Last active October 23, 2023 18:54
Show Gist options
  • Save leboucletoledo/3d20016b6d8ee79f54505f5b94460f4d to your computer and use it in GitHub Desktop.
Save leboucletoledo/3d20016b6d8ee79f54505f5b94460f4d to your computer and use it in GitHub Desktop.
Emulations of SpatialDropout1D on PyTorch, based on Tensorflow implementation. Tested on PyTorch version < 2.0.
def spatial_dropout1d_v1(x, do_rate=0.1):
'''Emulate SpatialDropout1D from TF using Dropout1d'''
do1d = torch.nn.Dropout1d(p=do_rate)
# convert to [batch, channels, time]
x = do1d(x.permute(0, 2, 1))
# back to [batch, time, channels]
x = x.permute(0, 2, 1)
return x
def spatial_dropout1d_v2(x, do_rate=0.1):
'''Emulate SpatialDropout1D from TF using Dropout2d'''
do2d = torch.nn.Dropout2d(p=do_rate)
x = do2d(x.permute(0, 2, 1)) # convert to [batch, channels, time]
x = x.unsqueeze(2).permute(0, 3, 2, 1)
x = do2d(x)
x = x.permute(0, 3, 2, 1).squeeze(2)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment