Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created July 3, 2019 14:54
Show Gist options
  • Save chenyaofo/1e8467caeeeda1182c17ca5978618185 to your computer and use it in GitHub Desktop.
Save chenyaofo/1e8467caeeeda1182c17ca5978618185 to your computer and use it in GitHub Desktop.
The core implementation of "The Shallow End: Empowering Shallower Deep-Convolutional Networks through Auxiliary Outputs"
import torch
import typing
import functools
import torch.nn as nn
def intermediate_output_hook(module, input, output, intermediate_output_store: list):
intermediate_output_store.append(output)
def _check_entrypoints(backbone, entrypoints):
complete_entrypoints = set([name for name, _ in backbone.named_modules()])
expected_entrypoints = set(entrypoints)
if not expected_entrypoints.issubset(complete_entrypoints):
raise ValueError("The entrypoints({}) do not exist in backbone.".format(
",".join(expected_entrypoints - complete_entrypoints)
))
class AuxNetHead(nn.Module):
def __init__(self, in_features, out_features):
super(AuxNetHead, self).__init__()
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_features, out_features)
def forward(self, input):
out = self.adaptive_pool(input)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
class AuxNet(nn.Module):
def __init__(self, backbone: nn.Module,
entrypoints_with_channels: typing.Dict[str, int],
num_classes, Classifier=AuxNetHead):
super(AuxNet, self).__init__()
self.backbone = backbone
self.entrypoints_with_channels = entrypoints_with_channels
self.num_classes = num_classes
self.Classifier = Classifier
self._aux_classifiers = nn.ModuleList()
self._intermediate_outputs = []
self._hooks = []
_check_entrypoints(self.backbone, self.entrypoints_with_channels.keys())
self._register_intermediate_hooks()
self._create_aux_classifiers()
def _register_intermediate_hooks(self):
for name, module in self.backbone.named_modules():
if name in self.entrypoints_with_channels.keys():
self._hooks.append(
module.register_forward_hook(
functools.partial(intermediate_output_hook,
intermediate_output_store=self._intermediate_outputs)
)
)
def _create_aux_classifiers(self):
for name, channels in self.entrypoints_with_channels.items():
self._aux_classifiers.append(
self.Classifier(channels, self.num_classes)
)
def _clean_intermediate_outputs(self):
self._intermediate_outputs.clear()
def _remove_hooks(self):
for hook in self._hooks:
hook.remove()
def forward(self, *args, **kwargs):
self._clean_intermediate_outputs()
output = self.backbone(*args, **kwargs)
aux_outputs = [classifier(feature) for classifier, feature in
zip(self._aux_classifiers, self._intermediate_outputs)]
return [*aux_outputs, output]
def autoconfig_auxnet(backbone: nn.Module,
entrypoints: typing.Iterable[str],
Classifier=AuxNetHead,
test_size=(1, 3, 224, 224)):
training = backbone.training
backbone.eval()
#
intermediate_outputs = []
hooks = []
_check_entrypoints(backbone, entrypoints)
for name, module in backbone.named_modules():
if name in entrypoints:
hooks.append(
module.register_forward_hook(
functools.partial(
intermediate_output_hook,
intermediate_output_store=intermediate_outputs
)
)
)
output = backbone(torch.rand(test_size))
_, num_classes = output.shape
entrypoints_with_channels = {
entrypoint: feature.shape[1] for entrypoint, feature in zip(entrypoints, intermediate_outputs)
}
for hook in hooks:
hook.remove()
#
backbone.train(mode=training)
return AuxNet(backbone, entrypoints_with_channels, num_classes, Classifier=Classifier)
class AuxCriterion(object):
def __init__(self, criterion):
self.criterion = criterion
def __call__(self, inputs, targets):
return AuxLoss(self.criterion(input, targets) for input in inputs)
class AuxLoss(object):
def __init__(self, losses):
self.losses = losses
def backward(self):
for i in reversed(range(len(self.losses))):
self.losses[i].backward(retain_graph=i != 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment