Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Last active June 30, 2019 14:40
Show Gist options
  • Save chenyaofo/3b319e233b45f4167ef999d8fbcbdf70 to your computer and use it in GitHub Desktop.
Save chenyaofo/3b319e233b45f4167ef999d8fbcbdf70 to your computer and use it in GitHub Desktop.
The API comparison between pytorch and mxnet
import numpy
import torch
import torch.nn
import mxnet
def test_forward_template(module_name, data, pt_module, mx_module, is_train=True):
print(f"Start to test forward process {module_name}, mode = {'train' if is_train else 'validating'}")
if not is_train:
pt_module.eval()
pt_output = pt_module(torch.from_numpy(data)).detach().numpy()
if module_name == "bn":
mx_output = mx_module.bind(mxnet.cpu(),
{
"data": mxnet.nd.array(data),
"module_gamma": mxnet.nd.array(
pt_module.weight.data.numpy()
),
"module_beta": mxnet.nd.array(
pt_module.bias.data.numpy()
),
},
aux_states={
"module_moving_mean": mxnet.nd.array(
pt_module.running_mean.data.numpy()
),
"module_moving_var": mxnet.nd.array(
pt_module.running_var.data.numpy()
),
}
).forward(is_train=is_train)[0].asnumpy()
else:
mx_output = mx_module.bind(mxnet.cpu(),
{
"data": mxnet.nd.array(data),
"module_weight": mxnet.nd.array(
pt_module.weight.data.numpy()
),
"module_gamma": mxnet.nd.array(
pt_module.weight.data.numpy()
),
}).forward(is_train=is_train)[0].asnumpy()
print("The absolute max diff is {}, the relative max diff is {}".format(
numpy.abs(pt_output - mx_output).max(), numpy.abs((pt_output - mx_output) / pt_output).max()
))
# def test_backward_template(module_name, pt_output, mx_output, is_train):
# if is_train:
# pass
# else:
# print(f"Start to test backward process {module_name}, mode = {'train' if is_train else 'validating'}")
# pt_output.sum().backward()
def test_conv2d():
data = numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)
pt_conv2d = torch.nn.Conv2d(3, 64, 3, padding=1, bias=False)
mx_conv2d = mxnet.sym.Convolution(data=mxnet.symbol.Variable(name="data"),
num_filter=64, kernel=(3, 3), num_group=1, stride=(1, 1), pad=(1, 1),
no_bias=True, name="module")
test_forward_template("conv2d", data, pt_conv2d, mx_conv2d)
def test_linear():
data = numpy.random.rand(1, 1024).astype(numpy.float32)
pt_linear = torch.nn.Linear(1024, 512, bias=False)
mx_linear = mxnet.sym.FullyConnected(data=mxnet.symbol.Variable(name="data"),
num_hidden=512,
no_bias=True, name="module")
test_forward_template("linear", data, pt_linear, mx_linear)
def test_batch_normalization_2d():
data = numpy.random.rand(10, 64, 56, 56).astype(numpy.float32)
pt_bn = torch.nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True)
mx_bn = mxnet.sym.BatchNorm(data=mxnet.symbol.Variable(name="data"),
axis=1,
eps=1e-5,
momentum=0.9,
fix_gamma=False,
name="module")
test_forward_template("bn", data, pt_bn, mx_bn, is_train=True)
test_forward_template("bn", data, pt_bn, mx_bn, is_train=False)
def test_leakyrelu():
data = numpy.random.rand(1, 1024).astype(numpy.float32)
pt_prelu = torch.nn.PReLU()
mx_prelu = mxnet.sym.LeakyReLU(data=mxnet.symbol.Variable(name="data"),
act_type="prelu", name="module")
test_forward_template("prelu", data, pt_prelu, mx_prelu)
if __name__ == '__main__':
test_conv2d()
test_linear()
test_batch_normalization_2d()
test_leakyrelu()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment