Last active
August 20, 2019 11:50
-
-
Save berak/43ad415d66b2cdcf5be717a62308ae7d to your computer and use it in GitHub Desktop.
torch dnn batchnorm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# colab - install latest | |
# !pip install opencv-python==4.1.0.25 | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import numpy as np | |
import cv2 | |
def convnet(in_channels, out_channels=64, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, | |
out_channels=out_channels, padding=1), | |
# torch.nn.BatchNorm2d(out_channels), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=2) | |
) | |
return block | |
# input | |
torch.manual_seed(7973847) | |
x = torch.randn((1, 3, 28, 28), dtype=torch.float32) | |
# torch | |
tn = convnet(3) | |
y1 = tn.forward(x).detach().numpy() | |
outname = "min_proto" | |
torch.onnx.export(tn, x, outname+".onnx", verbose=False, input_names=["input1"], output_names=["output1"]) | |
# opencv | |
on = cv2.dnn.readNet(outname+".onnx") | |
on.setInput(x.numpy()) | |
y2 = on.forward() | |
print(on.getLayerNames()) | |
print("diff", y1.shape[1]*y1.shape[2]*y1.shape[3], np.sum(np.abs(y1)), np.sum(np.abs(y2)), ":", np.sum(np.abs(y1 - y2))) | |
""" | |
# colab | |
from google.colab.patches import cv2_imshow | |
def show(arr): | |
r1 = cv2.resize(arr,(128,128)) | |
cv2_imshow(127+r1*64) | |
show(y1[0,1]) | |
show(y2[0,1]) | |
show(y1[0,1]-y2[0,1]) | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff 12544 7313.5576 7313.5576 : 0.00060441886 # w/o bn | |
diff 12544 6494.3984 3735.4143 : 2794.75 # using bn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment