-
-
Save LukasMosser/19c7b1d8e845182e2dfd5a0257d5fde3 to your computer and use it in GitHub Desktop.
Noise Layer for Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Fails with TypeError: forward() takes exactly 2 arguments (1 given) | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
from noise import WhiteNoise | |
class DCGAN3D_D(nn.Module): | |
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): | |
super(DCGAN3D_D, self).__init__() | |
self.ngpu = ngpu | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
#-- Making the WhiteNoise parallel | |
noise = torch.nn.DataParallel(WhiteNoise, range(self.ngpu)) | |
main = nn.Sequential( | |
noise(), | |
nn.Conv3d(nc, ndf, 4, 2, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
i, csize, cndf = 3, isize / 2, ndf | |
# Extra layers | |
for t in range(n_extra_layers): | |
main.add_module(str(i), | |
nn.Conv3d(cndf, cndf, 3, 1, 1, bias=False)) | |
main.add_module(str(i+1), | |
nn.BatchNorm3d(cndf)) | |
main.add_module(str(i+2), | |
nn.LeakyReLU(0.2, inplace=True)) | |
i += 3 | |
while csize > 4: | |
in_feat = cndf | |
out_feat = cndf * 2 | |
main.add_module(str(i), | |
nn.Conv3d(in_feat, out_feat, 4, 2, 1, bias=False)) | |
main.add_module(str(i+1), | |
nn.BatchNorm3d(out_feat)) | |
main.add_module(str(i+2), | |
nn.LeakyReLU(0.2, inplace=True)) | |
i+=3 | |
cndf = cndf * 2 | |
csize = csize / 2 | |
# state size. K x 4 x 4 x 4 | |
main.add_module(str(i), | |
nn.Conv3d(cndf, 1, 4, 1, 0, bias=False)) | |
main.add_module(str(i+1), nn.Sigmoid()) | |
self.main = main | |
def forward(self, input): | |
gpu_ids = None | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | |
gpu_ids = range(self.ngpu) | |
output = nn.parallel.data_parallel(self.main, input, gpu_ids) | |
return output.view(-1, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
class WhiteNoise(nn.Module): | |
def __init__(self, mean=0.0, std=0.1): | |
super(WhiteNoise, self).__init__() | |
self.mean = mean | |
self.std = std | |
self.noise = nn.Parameter(torch.Tensor()) | |
def updateOutput(input): | |
self.output.resize_as_(input).copy_(input) | |
if self.train == True: | |
self.noise.data.resize_as_(input) | |
self.noise.data.normal_(self.mean, self.std) | |
self.output.data.add_(self.noise) | |
else: | |
if self.mean != 0: | |
self.output.add_(self.mean) | |
def updateGradInput(input, gradOutput): | |
if self.train == True: | |
self.gradInput.data.resize_as_(gradOutput).copy_(gradOutput) | |
else: | |
pass | |
return self.gradInput | |
def __repr__(self): | |
return "" | |
def forward(self, input): | |
return WhiteNoise()(input, self.mean, self.std) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment