Skip to content

Instantly share code, notes, and snippets.

@yanring
Created November 27, 2020 06:41
Show Gist options
  • Save yanring/0b790ec7fe3ca6ca5443481ac2c8cf47 to your computer and use it in GitHub Desktop.
Save yanring/0b790ec7fe3ca6ca5443481ac2c8cf47 to your computer and use it in GitHub Desktop.
DropBlock2D
import time
import torch
from torch import nn
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.autograd import Function
class DropBlock2D(nn.Module):
r"""Randomly zeroes 2D spatial blocks of the input tensor.
As described in the paper
`DropBlock: A regularization method for convolutional networks`_ ,
dropping whole blocks of feature map allows to remove semantic
information as compared to regular dropout.
Args:
drop_prob (float): probability of an element to be dropped.
block_size (int): size of the block to drop
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
.. _DropBlock: A regularization method for convolutional networks:
https://arxiv.org/abs/1810.12890
"""
def __init__(self, drop_prob, block_size, share_channel=False):
super(DropBlock2D, self).__init__()
self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32))
self.inited = False
self.step_size = 0.0
self.start_step = 0
self.nr_steps = 0
self.block_size = block_size
self.share_channel = share_channel
self.seed = int(time.time())
self.with_step = True
def reset(self):
"""stop DropBlock"""
self.inited = True
self.drop_prob = 0.0
def reset_steps(self, start_step, nr_steps, start_value=1e-6, stop_value=None):
self.inited = True
stop_value = self.drop_prob.item() if stop_value is None else stop_value
self.drop_prob[0] = start_value
self.step_size = (stop_value - start_value) / nr_steps
def forward(self, x):
if self.training and self.with_step:
self.step()
if not self.training or self.drop_prob == 0.:
return x
else:
#print(self.drop_prob.item())
# get gamma value
gamma = self._compute_gamma(x)
# sample mask and place on input device
torch.manual_seed(self.seed)
if self.share_channel:
mask = (torch.rand(*x.shape[2:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0).unsqueeze(0)
else:
mask = (torch.rand(*x.shape[1:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0)
# compute block mask
block_mask, keeped = self._compute_block_mask(mask)
# apply block mask
out = x * block_mask.to(x.dtype)
# scale output
#out = out * (block_mask.numel() / keeped).to(out)
#out = out * (block_mask.numel() / keeped.item())
out = out * (block_mask.numel() / keeped)
return out
def _compute_block_mask(self, mask):
block_mask = F.max_pool2d(mask.to(torch.float32),
kernel_size=(self.block_size, self.block_size),
stride=(1, 1),
padding=self.block_size // 2)
keeped = block_mask.numel() - block_mask.sum()
block_mask = 1 - block_mask
return block_mask, keeped
def _compute_gamma(self, x):
_, c, h, w = x.size()
#gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \
# ((w - self.block_size + 1) * (h - self.block_size + 1))
gamma = self.drop_prob / (self.block_size ** 2) * (h * w) / \
((w - self.block_size + 1) * (h - self.block_size + 1))
return gamma
def step(self):
assert self.inited
self.drop_prob += self.step_size
self.seed = int(time.time())
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
drop_prob_key = prefix + 'drop_prob'
if drop_prob_key not in state_dict:
state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32)
super(DropBlock2D, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""overwrite save method"""
pass
def extra_repr(self):
return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size)
def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
"""
Example:
from functools import partial
apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1)
net.apply(apply_drop_prob)
"""
if isinstance(m, DropBlock2D):
m.reset_steps(start_step, nr_steps, start_value, stop_value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment