Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active January 9, 2019 23:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/7c65caf9b0cf9c476f754d4af84d6e3b to your computer and use it in GitHub Desktop.
Save vfdev-5/7c65caf9b0cf9c476f754d4af84d6e3b to your computer and use it in GitHub Desktop.
PyTorch model's implementations compare by forward time
# Implementation from https://github.com/davidcpage/cifar10-fast
# Adapted to python 3.5
# TorchGraph(
# (prep_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (prep_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (prep_relu): ReLU(inplace)
# (layer1_conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer1_relu): ReLU(inplace)
# (layer1_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (layer1_residual_in): Identity()
# (layer1_residual_res1_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer1_residual_res1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer1_residual_res1_relu): ReLU(inplace)
# (layer1_residual_res2_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer1_residual_res2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer1_residual_res2_relu): ReLU(inplace)
# (layer1_residual_add): Add()
# (layer2_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer2_relu): ReLU(inplace)
# (layer2_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (layer3_conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer3_relu): ReLU(inplace)
# (layer3_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (layer3_residual_in): Identity()
# (layer3_residual_res1_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer3_residual_res1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer3_residual_res1_relu): ReLU(inplace)
# (layer3_residual_res2_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (layer3_residual_res2_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (layer3_residual_res2_relu): ReLU(inplace)
# (layer3_residual_add): Add()
# (classifier_pool): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
# (classifier_flatten): Flatten()
# (classifier_linear): Linear(in_features=512, out_features=10, bias=False)
# (classifier_logits): Mul()
# )
from collections import namedtuple, OrderedDict
import time
import torch
from torch import nn
import numpy as np
torch.backends.cudnn.benchmark = True
device = "cuda"
class Identity(nn.Module):
def forward(self, x): return x
class Mul(nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def __call__(self, x):
return x*self.weight
class Flatten(nn.Module):
def forward(self, x): return x.view(x.size(0), x.size(1))
class Add(nn.Module):
def forward(self, x, y): return x + y
class Concat(nn.Module):
def forward(self, *xs): return torch.cat(xs, 1)
class Correct(nn.Module):
def forward(self, classifier, target):
return classifier.max(dim = 1)[1] == target
def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False):
m = nn.BatchNorm2d(num_channels)
if bn_bias_init is not None:
m.bias.data.fill_(bn_bias_init)
if bn_bias_freeze:
m.bias.requires_grad = False
if bn_weight_init is not None:
m.weight.data.fill_(bn_weight_init)
if bn_weight_freeze:
m.weight.requires_grad = False
return m
#####################
## dict utils
#####################
union = lambda *dicts: OrderedDict([
(k, v) for d in dicts for (k, v) in d.items()
])
def path_iter(nested_dict, pfx=()):
for name, val in nested_dict.items():
if isinstance(val, dict): yield from path_iter(val, (*pfx, name))
else: yield ((*pfx, name), val)
#####################
## graph building
#####################
sep='_'
RelativePath = namedtuple('RelativePath', ('parts'))
rel_path = lambda *parts: RelativePath(parts)
def build_graph(net):
net = OrderedDict(path_iter(net))
default_inputs = [[('input',)]]+[[k] for k in net.keys()]
with_default_inputs = lambda vals: (val if isinstance(val, tuple) else (val, default_inputs[idx]) for idx,val in enumerate(vals))
parts = lambda path, pfx: tuple(pfx) + path.parts if isinstance(path, RelativePath) else (path,) if isinstance(path, str) else path
return OrderedDict([
(sep.join((*pfx, name)), (val, [sep.join(parts(x, pfx)) for x in inputs])) for (*pfx, name), (val, inputs) in zip(net.keys(), with_default_inputs(net.values()))
])
class TorchGraph(nn.Module):
def __init__(self, net):
self.graph = build_graph(net)
# print("Graph:", self.graph)
super().__init__()
for n, (v, _) in self.graph.items():
setattr(self, n, v)
def forward(self, inputs):
# print("Forward: ")
self.cache = OrderedDict(inputs)
for n, (_, i) in self.graph.items():
# print(n, _, i)
self.cache[n] = getattr(self, n)(*[self.cache[x] for x in i])
return self.cache
def half(self):
for module in self.children():
if type(module) is not nn.BatchNorm2d:
module.half()
return self
def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw):
return OrderedDict([
('conv', nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False)),
('bn', batch_norm(c_out, bn_weight_init=bn_weight_init, **kw)),
('relu', nn.ReLU(True))
])
def residual(c, **kw):
return OrderedDict([
('in', Identity()),
('res1', conv_bn(c, c, **kw)),
('res2', conv_bn(c, c, **kw)),
('add', (Add(), [rel_path('in'), rel_path('res2', 'relu')])),
])
def basic_net(channels, weight, pool, **kw):
return OrderedDict([
('prep', conv_bn(3, channels['prep'], **kw)),
('layer1', OrderedDict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool)),
('layer2', OrderedDict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool)),
('layer3', OrderedDict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool)),
('classifier', OrderedDict([
('pool', nn.MaxPool2d(4)),
('flatten', Flatten()),
('linear', nn.Linear(channels['layer3'], 10, bias=False)),
('logits', Mul(weight)),
]))
])
def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3'), **kw):
channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
n = basic_net(channels, weight, pool, **kw)
for layer in res_layers:
n[layer]['residual'] = residual(channels[layer], **kw)
for layer in extra_layers:
n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw)
return n
torch.manual_seed(12)
model = TorchGraph(net()).to(device)
print("Model")
print("---")
print(model)
print("---")
batch_size = 512
batch = {
'input': torch.ones((batch_size, 3, 32, 32)).cuda(),
}
model.train(True)
t0 = time.time()
n_runs = 1000
for _ in range(n_runs):
model(batch)
elapsed = time.time() - t0
print("Forward pass (f32) time: {}".format(elapsed / n_runs))
torch.manual_seed(12)
model = TorchGraph(net()).to(device).half()
batch_size = 512
batch = {
'input': torch.ones((batch_size, 3, 32, 32)).cuda().half(),
}
model.train(True)
t0 = time.time()
n_runs = 1000
for _ in range(n_runs):
model(batch)
elapsed = time.time() - t0
print("Forward pass (f16) time: {}".format(elapsed / n_runs))
# Same model implementation with nn.Sequential
# Resnet(
# (prep): Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (layer1): Sequential(
# (0): Sequential(
# (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (2): IdentityResidualBlock(
# (conv1): Sequential(
# (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (conv2): Sequential(
# (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# )
# )
# (layer2): Sequential(
# (0): Sequential(
# (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
# (layer3): Sequential(
# (0): Sequential(
# (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (2): IdentityResidualBlock(
# (conv1): Sequential(
# (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# (conv2): Sequential(
# (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace)
# )
# )
# )
# (classifier): Sequential(
# (0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
# (1): Flatten()
# (2): Linear(in_features=512, out_features=10, bias=False)
# )
# )
import time
import torch
import torch.nn as nn
device = "cuda"
def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False):
m = nn.BatchNorm2d(num_channels)
if bn_bias_init is not None:
m.bias.data.fill_(bn_bias_init)
if bn_bias_freeze:
m.bias.requires_grad = False
if bn_weight_init is not None:
m.weight.data.fill_(bn_weight_init)
if bn_weight_freeze:
m.weight.requires_grad = False
return m
def conv_bn(in_channels, out_channels, bn_kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
batch_norm(out_channels, **bn_kwargs),
nn.ReLU(inplace=True)
)
class Resnet(nn.Module):
def __init__(self, bn_kwargs=None, final_weight=0.125):
super(Resnet, self).__init__()
bn_kwargs = {} if bn_kwargs is None else bn_kwargs
self.prep = conv_bn(3, 64, bn_kwargs)
self.layer1 = nn.Sequential(
conv_bn(64, 128, bn_kwargs),
nn.MaxPool2d(kernel_size=2),
IdentityResidualBlock(128, 128, bn_kwargs)
)
self.layer2 = nn.Sequential(
conv_bn(128, 256, bn_kwargs),
nn.MaxPool2d(kernel_size=2)
)
self.layer3 = nn.Sequential(
conv_bn(256, 512, bn_kwargs),
nn.MaxPool2d(kernel_size=2),
IdentityResidualBlock(512, 512, bn_kwargs)
)
self.classifier = nn.Sequential(
nn.MaxPool2d(kernel_size=4),
Flatten(),
nn.Linear(512, 10, bias=False)
)
self.final_weight = final_weight
def forward(self, x):
x = self.prep(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.classifier(x)
x = x * self.final_weight
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), x.size(1))
class IdentityResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, bn_kwargs):
super(IdentityResidualBlock, self).__init__()
self.conv1 = conv_bn(in_channels, out_channels, bn_kwargs)
self.conv2 = conv_bn(out_channels, out_channels, bn_kwargs)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
return x + residual
torch.manual_seed(12)
model = Resnet().to(device)
print("Model")
print("---")
print(model)
print("---")
batch_size = 512
batch = torch.ones((batch_size, 3, 32, 32)).cuda()
model.train(True)
t0 = time.time()
n_runs = 1000
for _ in range(n_runs):
model(batch)
elapsed = time.time() - t0
print("Forward pass (f32) time: {}".format(elapsed / n_runs))
def model_to_fp16(model):
for module in model.children():
if type(module) is not nn.BatchNorm2d:
module.half()
return model
torch.manual_seed(12)
model = Resnet().to(device)
model = model_to_fp16(model)
batch_size = 512
batch = torch.ones((batch_size, 3, 32, 32)).cuda().half()
model.train(True)
t0 = time.time()
n_runs = 1000
for _ in range(n_runs):
model(batch)
elapsed = time.time() - t0
print("Forward pass (f16) time: {}".format(elapsed / n_runs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment