Skip to content

Instantly share code, notes, and snippets.

@SunDoge
Created February 19, 2019 02:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SunDoge/59a8ff336703b45be30b46dc3ee8b4ab to your computer and use it in GitHub Desktop.
Save SunDoge/59a8ff336703b45be30b46dc3ee8b4ab to your computer and use it in GitHub Desktop.
Linear Benchmark between mxnet and pytorch
import torch
from torch import nn as ptnn
import mxnet as mx
from mxnet.gluon import nn as mxnn
from mxnet import nd, initializer
from enum import IntEnum
from time import time
import numpy as np
use_cuda = torch.cuda.is_available()
fmt = ' {:<14} {:<15} {:<12} {:>5}'
mx_ctx = mx.gpu()
class Framework(IntEnum):
PYTORCH = 1
MXNET = 2
def get_mxnet_network():
net = mxnn.HybridSequential()
# net = mxnn.Sequential()
with net.name_scope():
net.add(mxnn.Dense(256, activation="relu"))
net.add(mxnn.Dense(128, activation="relu"))
net.add(mxnn.Dense(2))
net.initialize(init=initializer.Zero(), ctx=mx_ctx)
return net
def pytorch_weights_init(m):
if isinstance(m, ptnn.Linear):
ptnn.init.uniform_(m.weight.data, 0, 0)
ptnn.init.uniform_(m.bias.data, 0, 0)
def get_pytorch_network():
net = ptnn.Sequential()
net.add_module('dense1', ptnn.Linear(512, 256))
net.add_module('relu1', ptnn.ReLU())
net.add_module('dense2', ptnn.Linear(256, 128))
net.add_module('relu2', ptnn.ReLU())
net.add_module('dense3', ptnn.Linear(128, 2))
net.apply(pytorch_weights_init)
return net.cuda()
# Wait for computation to finish to make profiling more accurate
def block(framework):
if framework == Framework.PYTORCH:
if use_cuda:
torch.cuda.synchronize()
elif framework == Framework.MXNET:
mx.nd.waitall()
def bench(net, dtype, framework):
np.random.seed(12)
# Warmup for mxnet shape infer
if framework == Framework.MXNET:
x = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).as_in_context(mx_ctx)
with mx.autograd.record():
y = net(x)
elif framework == Framework.PYTORCH:
x = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).cuda()
y = net(x)
block(framework)
start = time()
if framework == Framework.MXNET:
for i in range(1000):
x = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).as_in_context(mx_ctx)
with mx.autograd.record():
y = net(x)
elif framework == Framework.PYTORCH:
for i in range(1000):
x = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).cuda()
y = net(x)
block(framework)
return time() - start
def report(framework, paradigm, precision, value=None):
t = '%i' % (value * 1000) if value else '---'
print(fmt.format(framework, paradigm, '%i bit' % precision, t))
# Input matrices
inputs = {
'mx_x_32': lambda x: nd.array(x),
'mx_x_16': lambda x: nd.array(x).astype('float16'),
'pt_x_32': lambda x: torch.from_numpy(x),
'pt_x_16': lambda x: torch.from_numpy(x).half(),
}
# mx_x_32 = nd.ones((128, 512, 1), mx_ctx)
# mx_x_16 = mx_x_32.astype('float16')
# pt_x_32 = torch.ones((128, 512, 1)).cuda()
# pt_x_16 = pt_x_32.half()
print()
print(' Device:', 'GPU' if use_cuda else 'CPU')
print('----------------------------------------------------')
print(fmt.format('Framework', 'Paradigm', 'Precision', 'Time'))
print('====================================================')
mx_net = get_mxnet_network()
report('MXNet', 'imperative', 32, bench(mx_net, 'mx_x_32', Framework.MXNET))
mx_net.cast('float16')
report('MXNet', 'imperative', 16, bench(mx_net, 'mx_x_16', Framework.MXNET))
mx_net.cast('float32')
mx_net.hybridize(static_alloc=True)
report('MXNet', 'symbolic', 32, bench(mx_net, 'mx_x_32', Framework.MXNET))
mx_net.cast('float16')
report('MXNet', 'symbolic', 16, bench(mx_net, 'mx_x_16', Framework.MXNET))
pt_net = get_pytorch_network()
report('PyTorch', 'imperative', 32, bench(pt_net, 'pt_x_32', Framework.PYTORCH))
# PyTorch half precision isn't supported on a CPU
pt_16 = bench(pt_net.half(), 'pt_x_16', Framework.PYTORCH) if use_cuda else None
report('PyTorch', 'imperative', 16, pt_16)
print('----------------------------------------------------')
@SunDoge
Copy link
Author

SunDoge commented Feb 19, 2019

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.54                 Driver Version: 396.54                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  TITAN Xp            Off  | 00000000:05:00.0 Off |                  N/A |
| 23%   30C    P0    60W / 250W |      0MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
pytorch                   1.0.1           py3.6_cuda9.0.176_cudnn7.4.2_2
mxnet-cu90                1.5.0b20190218            <pip>
 Device: GPU
----------------------------------------------------
 Framework      Paradigm        Precision     Time
====================================================
 MXNet          imperative      32 bit        1522
 MXNet          imperative      16 bit        1891
 MXNet          symbolic        32 bit        1058
 MXNet          symbolic        16 bit        1954
 PyTorch        imperative      32 bit         995
 PyTorch        imperative      16 bit        2921
----------------------------------------------------

@szha
Copy link

szha commented Feb 19, 2019

The comparison does not take into account that MXNet is making a copy while pytorch is not.

MXNet made the choice of not doing a zero-copy for numpy arrays, but instead making a copy of the numpy data. This means that users are free to change the numpy array after passing it into MXNet. On the other hand, PyTorch chose not to make a copy, by keeping the array alive through incrementing the reference count and then reuse the data pointer.

This also explains why pytorch fp16 is this much worse than fp32 in your results (.half() has to make a copy). If you exclude the cost of all copying, you'd get something like this (code [1]):

 Device: GPU
----------------------------------------------------
 Framework      Paradigm        Precision     Time
====================================================
 MXNet          imperative      32 bit         532
 MXNet          imperative      16 bit         447
 MXNet          symbolic        32 bit         108
 MXNet          symbolic        16 bit         127
 PyTorch        imperative      32 bit         229
 PyTorch        imperative      16 bit         249
----------------------------------------------------

If you keep the copy from CPU to GPU, you get something like this (code [2]):

 Device: GPU
----------------------------------------------------
 Framework      Paradigm        Precision     Time
====================================================
 MXNet          imperative      32 bit         599
 MXNet          imperative      16 bit         680
 MXNet          symbolic        32 bit         217
 MXNet          symbolic        16 bit         229
 PyTorch        imperative      32 bit         316
 PyTorch        imperative      16 bit         313
----------------------------------------------------

[1]

def bench(net, dtype, framework):
    np.random.seed(12)

    # Warmup for mxnet shape infer
    if framework == Framework.MXNET:
        xm = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).as_in_context(mx_ctx)
        with mx.autograd.record():
            y = net(xm)
    elif framework == Framework.PYTORCH:
        xp = inputs[dtype](np.random.rand(128, 512).astype(np.float32)).cuda()
        y = net(xp)

    block(framework)

    start = time()
    if framework == Framework.MXNET:
        for i in range(1000):
            with mx.autograd.record():
                y = net(xm)
    elif framework == Framework.PYTORCH:
        for i in range(1000):
            y = net(xp)
    block(framework)
    return time() - start

[2]

def bench(net, dtype, framework):
    np.random.seed(12)

    # Warmup for mxnet shape infer
    if framework == Framework.MXNET:
        xm = inputs[dtype](np.random.rand(128, 512).astype(np.float32))
        with mx.autograd.record():
            y = net(xm.as_in_context(mx_ctx))
    elif framework == Framework.PYTORCH:
        xp = inputs[dtype](np.random.rand(128, 512).astype(np.float32))
        y = net(xp.cuda())

    block(framework)

    start = time()
    if framework == Framework.MXNET:
        for i in range(1000):
            with mx.autograd.record():
                y = net(xm.as_in_context(mx_ctx))
    elif framework == Framework.PYTORCH:
        for i in range(1000):
            y = net(xp.cuda())
    block(framework)
    return time() - start

@SunDoge
Copy link
Author

SunDoge commented Feb 19, 2019

Hi, @szha!
Thanks for you reply. That make sense.
However, the problem is that when training a model, we load immutable data from files all the time. Thus, in real life, the copy will become a problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment