Created
May 22, 2019 05:26
-
-
Save xmodar/1b5c33da7369dd8165b43969da8e4aad to your computer and use it in GitHub Desktop.
Change the weights of a conv2d in pytorch to incorporate the mean and std and allow the input range to be in [0, 1]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn import functional as F | |
from torchvision.transforms.functional import normalize | |
def denormalize_conv2d(weight, bias, mean, std): | |
weight, bias = weight.data, bias.data | |
std = torch.as_tensor(std).data.view(1, -1, 1, 1) | |
mean = torch.as_tensor(mean).data.view(1, -1, 1, 1) | |
w = weight / std | |
b = bias - (w * mean).flatten(1).sum(1) | |
return w, b | |
mean = [0.4915, 0.4823, 0.4468] | |
std = [0.2470, 0.2435, 0.2616] | |
c = nn.Conv2d(3, 16, 4) | |
c.bias.data += torch.randn(c.bias.data.shape) | |
x = torch.rand(1, 3, 10, 10) | |
o = nn.functional.conv2d(normalize(x[0], mean, std).unsqueeze(0), c.weight.data, c.bias.data) | |
z = nn.functional.conv2d(x, *denormalize_conv2d(c.weight, c.bias, mean, std)) | |
(o - z).abs().max() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment