Created
November 27, 2020 06:41
-
-
Save yanring/0b790ec7fe3ca6ca5443481ac2c8cf47 to your computer and use it in GitHub Desktop.
DropBlock2D
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
from torch import nn | |
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU | |
import torch.nn.functional as F | |
from torch.nn.modules.utils import _pair | |
from torch.autograd import Function | |
class DropBlock2D(nn.Module): | |
r"""Randomly zeroes 2D spatial blocks of the input tensor. | |
As described in the paper | |
`DropBlock: A regularization method for convolutional networks`_ , | |
dropping whole blocks of feature map allows to remove semantic | |
information as compared to regular dropout. | |
Args: | |
drop_prob (float): probability of an element to be dropped. | |
block_size (int): size of the block to drop | |
Shape: | |
- Input: `(N, C, H, W)` | |
- Output: `(N, C, H, W)` | |
.. _DropBlock: A regularization method for convolutional networks: | |
https://arxiv.org/abs/1810.12890 | |
""" | |
def __init__(self, drop_prob, block_size, share_channel=False): | |
super(DropBlock2D, self).__init__() | |
self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32)) | |
self.inited = False | |
self.step_size = 0.0 | |
self.start_step = 0 | |
self.nr_steps = 0 | |
self.block_size = block_size | |
self.share_channel = share_channel | |
self.seed = int(time.time()) | |
self.with_step = True | |
def reset(self): | |
"""stop DropBlock""" | |
self.inited = True | |
self.drop_prob = 0.0 | |
def reset_steps(self, start_step, nr_steps, start_value=1e-6, stop_value=None): | |
self.inited = True | |
stop_value = self.drop_prob.item() if stop_value is None else stop_value | |
self.drop_prob[0] = start_value | |
self.step_size = (stop_value - start_value) / nr_steps | |
def forward(self, x): | |
if self.training and self.with_step: | |
self.step() | |
if not self.training or self.drop_prob == 0.: | |
return x | |
else: | |
#print(self.drop_prob.item()) | |
# get gamma value | |
gamma = self._compute_gamma(x) | |
# sample mask and place on input device | |
torch.manual_seed(self.seed) | |
if self.share_channel: | |
mask = (torch.rand(*x.shape[2:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0).unsqueeze(0) | |
else: | |
mask = (torch.rand(*x.shape[1:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0) | |
# compute block mask | |
block_mask, keeped = self._compute_block_mask(mask) | |
# apply block mask | |
out = x * block_mask.to(x.dtype) | |
# scale output | |
#out = out * (block_mask.numel() / keeped).to(out) | |
#out = out * (block_mask.numel() / keeped.item()) | |
out = out * (block_mask.numel() / keeped) | |
return out | |
def _compute_block_mask(self, mask): | |
block_mask = F.max_pool2d(mask.to(torch.float32), | |
kernel_size=(self.block_size, self.block_size), | |
stride=(1, 1), | |
padding=self.block_size // 2) | |
keeped = block_mask.numel() - block_mask.sum() | |
block_mask = 1 - block_mask | |
return block_mask, keeped | |
def _compute_gamma(self, x): | |
_, c, h, w = x.size() | |
#gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \ | |
# ((w - self.block_size + 1) * (h - self.block_size + 1)) | |
gamma = self.drop_prob / (self.block_size ** 2) * (h * w) / \ | |
((w - self.block_size + 1) * (h - self.block_size + 1)) | |
return gamma | |
def step(self): | |
assert self.inited | |
self.drop_prob += self.step_size | |
self.seed = int(time.time()) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
drop_prob_key = prefix + 'drop_prob' | |
if drop_prob_key not in state_dict: | |
state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32) | |
super(DropBlock2D, self)._load_from_state_dict( | |
state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs) | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
"""overwrite save method""" | |
pass | |
def extra_repr(self): | |
return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size) | |
def reset_dropblock(start_step, nr_steps, start_value, stop_value, m): | |
""" | |
Example: | |
from functools import partial | |
apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1) | |
net.apply(apply_drop_prob) | |
""" | |
if isinstance(m, DropBlock2D): | |
m.reset_steps(start_step, nr_steps, start_value, stop_value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment