Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created May 22, 2019 05:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/1b5c33da7369dd8165b43969da8e4aad to your computer and use it in GitHub Desktop.
Save xmodar/1b5c33da7369dd8165b43969da8e4aad to your computer and use it in GitHub Desktop.
Change the weights of a conv2d in pytorch to incorporate the mean and std and allow the input range to be in [0, 1]
from torch.nn import functional as F
from torchvision.transforms.functional import normalize
def denormalize_conv2d(weight, bias, mean, std):
weight, bias = weight.data, bias.data
std = torch.as_tensor(std).data.view(1, -1, 1, 1)
mean = torch.as_tensor(mean).data.view(1, -1, 1, 1)
w = weight / std
b = bias - (w * mean).flatten(1).sum(1)
return w, b
mean = [0.4915, 0.4823, 0.4468]
std = [0.2470, 0.2435, 0.2616]
c = nn.Conv2d(3, 16, 4)
c.bias.data += torch.randn(c.bias.data.shape)
x = torch.rand(1, 3, 10, 10)
o = nn.functional.conv2d(normalize(x[0], mean, std).unsqueeze(0), c.weight.data, c.bias.data)
z = nn.functional.conv2d(x, *denormalize_conv2d(c.weight, c.bias, mean, std))
(o - z).abs().max()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment