Last active
January 9, 2019 23:20
-
-
Save vfdev-5/7c65caf9b0cf9c476f754d4af84d6e3b to your computer and use it in GitHub Desktop.
PyTorch model's implementations compare by forward time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implementation from https://github.com/davidcpage/cifar10-fast | |
# Adapted to python 3.5 | |
# TorchGraph( | |
# (prep_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (prep_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (prep_relu): ReLU(inplace) | |
# (layer1_conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer1_relu): ReLU(inplace) | |
# (layer1_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# (layer1_residual_in): Identity() | |
# (layer1_residual_res1_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer1_residual_res1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer1_residual_res1_relu): ReLU(inplace) | |
# (layer1_residual_res2_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer1_residual_res2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer1_residual_res2_relu): ReLU(inplace) | |
# (layer1_residual_add): Add() | |
# (layer2_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer2_relu): ReLU(inplace) | |
# (layer2_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# (layer3_conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer3_relu): ReLU(inplace) | |
# (layer3_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# (layer3_residual_in): Identity() | |
# (layer3_residual_res1_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer3_residual_res1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer3_residual_res1_relu): ReLU(inplace) | |
# (layer3_residual_res2_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (layer3_residual_res2_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (layer3_residual_res2_relu): ReLU(inplace) | |
# (layer3_residual_add): Add() | |
# (classifier_pool): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False) | |
# (classifier_flatten): Flatten() | |
# (classifier_linear): Linear(in_features=512, out_features=10, bias=False) | |
# (classifier_logits): Mul() | |
# ) | |
from collections import namedtuple, OrderedDict | |
import time | |
import torch | |
from torch import nn | |
import numpy as np | |
torch.backends.cudnn.benchmark = True | |
device = "cuda" | |
class Identity(nn.Module): | |
def forward(self, x): return x | |
class Mul(nn.Module): | |
def __init__(self, weight): | |
super().__init__() | |
self.weight = weight | |
def __call__(self, x): | |
return x*self.weight | |
class Flatten(nn.Module): | |
def forward(self, x): return x.view(x.size(0), x.size(1)) | |
class Add(nn.Module): | |
def forward(self, x, y): return x + y | |
class Concat(nn.Module): | |
def forward(self, *xs): return torch.cat(xs, 1) | |
class Correct(nn.Module): | |
def forward(self, classifier, target): | |
return classifier.max(dim = 1)[1] == target | |
def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False): | |
m = nn.BatchNorm2d(num_channels) | |
if bn_bias_init is not None: | |
m.bias.data.fill_(bn_bias_init) | |
if bn_bias_freeze: | |
m.bias.requires_grad = False | |
if bn_weight_init is not None: | |
m.weight.data.fill_(bn_weight_init) | |
if bn_weight_freeze: | |
m.weight.requires_grad = False | |
return m | |
##################### | |
## dict utils | |
##################### | |
union = lambda *dicts: OrderedDict([ | |
(k, v) for d in dicts for (k, v) in d.items() | |
]) | |
def path_iter(nested_dict, pfx=()): | |
for name, val in nested_dict.items(): | |
if isinstance(val, dict): yield from path_iter(val, (*pfx, name)) | |
else: yield ((*pfx, name), val) | |
##################### | |
## graph building | |
##################### | |
sep='_' | |
RelativePath = namedtuple('RelativePath', ('parts')) | |
rel_path = lambda *parts: RelativePath(parts) | |
def build_graph(net): | |
net = OrderedDict(path_iter(net)) | |
default_inputs = [[('input',)]]+[[k] for k in net.keys()] | |
with_default_inputs = lambda vals: (val if isinstance(val, tuple) else (val, default_inputs[idx]) for idx,val in enumerate(vals)) | |
parts = lambda path, pfx: tuple(pfx) + path.parts if isinstance(path, RelativePath) else (path,) if isinstance(path, str) else path | |
return OrderedDict([ | |
(sep.join((*pfx, name)), (val, [sep.join(parts(x, pfx)) for x in inputs])) for (*pfx, name), (val, inputs) in zip(net.keys(), with_default_inputs(net.values())) | |
]) | |
class TorchGraph(nn.Module): | |
def __init__(self, net): | |
self.graph = build_graph(net) | |
# print("Graph:", self.graph) | |
super().__init__() | |
for n, (v, _) in self.graph.items(): | |
setattr(self, n, v) | |
def forward(self, inputs): | |
# print("Forward: ") | |
self.cache = OrderedDict(inputs) | |
for n, (_, i) in self.graph.items(): | |
# print(n, _, i) | |
self.cache[n] = getattr(self, n)(*[self.cache[x] for x in i]) | |
return self.cache | |
def half(self): | |
for module in self.children(): | |
if type(module) is not nn.BatchNorm2d: | |
module.half() | |
return self | |
def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw): | |
return OrderedDict([ | |
('conv', nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False)), | |
('bn', batch_norm(c_out, bn_weight_init=bn_weight_init, **kw)), | |
('relu', nn.ReLU(True)) | |
]) | |
def residual(c, **kw): | |
return OrderedDict([ | |
('in', Identity()), | |
('res1', conv_bn(c, c, **kw)), | |
('res2', conv_bn(c, c, **kw)), | |
('add', (Add(), [rel_path('in'), rel_path('res2', 'relu')])), | |
]) | |
def basic_net(channels, weight, pool, **kw): | |
return OrderedDict([ | |
('prep', conv_bn(3, channels['prep'], **kw)), | |
('layer1', OrderedDict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool)), | |
('layer2', OrderedDict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool)), | |
('layer3', OrderedDict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool)), | |
('classifier', OrderedDict([ | |
('pool', nn.MaxPool2d(4)), | |
('flatten', Flatten()), | |
('linear', nn.Linear(channels['layer3'], 10, bias=False)), | |
('logits', Mul(weight)), | |
])) | |
]) | |
def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3'), **kw): | |
channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512} | |
n = basic_net(channels, weight, pool, **kw) | |
for layer in res_layers: | |
n[layer]['residual'] = residual(channels[layer], **kw) | |
for layer in extra_layers: | |
n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw) | |
return n | |
torch.manual_seed(12) | |
model = TorchGraph(net()).to(device) | |
print("Model") | |
print("---") | |
print(model) | |
print("---") | |
batch_size = 512 | |
batch = { | |
'input': torch.ones((batch_size, 3, 32, 32)).cuda(), | |
} | |
model.train(True) | |
t0 = time.time() | |
n_runs = 1000 | |
for _ in range(n_runs): | |
model(batch) | |
elapsed = time.time() - t0 | |
print("Forward pass (f32) time: {}".format(elapsed / n_runs)) | |
torch.manual_seed(12) | |
model = TorchGraph(net()).to(device).half() | |
batch_size = 512 | |
batch = { | |
'input': torch.ones((batch_size, 3, 32, 32)).cuda().half(), | |
} | |
model.train(True) | |
t0 = time.time() | |
n_runs = 1000 | |
for _ in range(n_runs): | |
model(batch) | |
elapsed = time.time() - t0 | |
print("Forward pass (f16) time: {}".format(elapsed / n_runs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Same model implementation with nn.Sequential | |
# Resnet( | |
# (prep): Sequential( | |
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (layer1): Sequential( | |
# (0): Sequential( | |
# (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# (2): IdentityResidualBlock( | |
# (conv1): Sequential( | |
# (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (conv2): Sequential( | |
# (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# ) | |
# ) | |
# (layer2): Sequential( | |
# (0): Sequential( | |
# (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# ) | |
# (layer3): Sequential( | |
# (0): Sequential( | |
# (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
# (2): IdentityResidualBlock( | |
# (conv1): Sequential( | |
# (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# (conv2): Sequential( | |
# (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
# (2): ReLU(inplace) | |
# ) | |
# ) | |
# ) | |
# (classifier): Sequential( | |
# (0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False) | |
# (1): Flatten() | |
# (2): Linear(in_features=512, out_features=10, bias=False) | |
# ) | |
# ) | |
import time | |
import torch | |
import torch.nn as nn | |
device = "cuda" | |
def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False): | |
m = nn.BatchNorm2d(num_channels) | |
if bn_bias_init is not None: | |
m.bias.data.fill_(bn_bias_init) | |
if bn_bias_freeze: | |
m.bias.requires_grad = False | |
if bn_weight_init is not None: | |
m.weight.data.fill_(bn_weight_init) | |
if bn_weight_freeze: | |
m.weight.requires_grad = False | |
return m | |
def conv_bn(in_channels, out_channels, bn_kwargs): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), | |
batch_norm(out_channels, **bn_kwargs), | |
nn.ReLU(inplace=True) | |
) | |
class Resnet(nn.Module): | |
def __init__(self, bn_kwargs=None, final_weight=0.125): | |
super(Resnet, self).__init__() | |
bn_kwargs = {} if bn_kwargs is None else bn_kwargs | |
self.prep = conv_bn(3, 64, bn_kwargs) | |
self.layer1 = nn.Sequential( | |
conv_bn(64, 128, bn_kwargs), | |
nn.MaxPool2d(kernel_size=2), | |
IdentityResidualBlock(128, 128, bn_kwargs) | |
) | |
self.layer2 = nn.Sequential( | |
conv_bn(128, 256, bn_kwargs), | |
nn.MaxPool2d(kernel_size=2) | |
) | |
self.layer3 = nn.Sequential( | |
conv_bn(256, 512, bn_kwargs), | |
nn.MaxPool2d(kernel_size=2), | |
IdentityResidualBlock(512, 512, bn_kwargs) | |
) | |
self.classifier = nn.Sequential( | |
nn.MaxPool2d(kernel_size=4), | |
Flatten(), | |
nn.Linear(512, 10, bias=False) | |
) | |
self.final_weight = final_weight | |
def forward(self, x): | |
x = self.prep(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.classifier(x) | |
x = x * self.final_weight | |
return x | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), x.size(1)) | |
class IdentityResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, bn_kwargs): | |
super(IdentityResidualBlock, self).__init__() | |
self.conv1 = conv_bn(in_channels, out_channels, bn_kwargs) | |
self.conv2 = conv_bn(out_channels, out_channels, bn_kwargs) | |
def forward(self, x): | |
residual = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x + residual | |
torch.manual_seed(12) | |
model = Resnet().to(device) | |
print("Model") | |
print("---") | |
print(model) | |
print("---") | |
batch_size = 512 | |
batch = torch.ones((batch_size, 3, 32, 32)).cuda() | |
model.train(True) | |
t0 = time.time() | |
n_runs = 1000 | |
for _ in range(n_runs): | |
model(batch) | |
elapsed = time.time() - t0 | |
print("Forward pass (f32) time: {}".format(elapsed / n_runs)) | |
def model_to_fp16(model): | |
for module in model.children(): | |
if type(module) is not nn.BatchNorm2d: | |
module.half() | |
return model | |
torch.manual_seed(12) | |
model = Resnet().to(device) | |
model = model_to_fp16(model) | |
batch_size = 512 | |
batch = torch.ones((batch_size, 3, 32, 32)).cuda().half() | |
model.train(True) | |
t0 = time.time() | |
n_runs = 1000 | |
for _ in range(n_runs): | |
model(batch) | |
elapsed = time.time() - t0 | |
print("Forward pass (f16) time: {}".format(elapsed / n_runs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment