Created
October 6, 2019 15:11
Revisions
-
albusdemens renamed this gist
Oct 6, 2019 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
albusdemens created this gist
Oct 6, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,804 @@ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch implementation of the StyleGAN Generator\n", "*by Piotr Bialecki and Thomas Viehmann*\n", "\n", "We implement the generator of *T. Karras et al., A Style-Based Generator Architecture for Generative Adversarial Networks* in PyTorch. \n", "\n", "StyleGAN's photorealistic faces are an intriguing GAN output.\n", "While diving deep into the architecture we found that a good way to do this, is to provide a simple notebook which recreates the StyleGAN for use with the pretrained weights.\n", "\n", "Also we can always learn a few tricks when doing something as this, let's take a look.\n", "\n", "[Karras et al. provide a reference implementation and links to weights, paper, and video](https://github.com/NVlabs/stylegan)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.onnx\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from collections import OrderedDict\n", "import pickle\n", "\n", "import numpy as np\n", "\n", "import IPython\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Anything new in the linear layer?\n", "\n", "Did you just yawn, when you saw linear layer? Not so fast!\n", "\n", "There is a trick that is maybe not as well known (I should be grateful if anyone pointed out a reference): While it is very common to use targeted initialization such as the method of K. He (`torch.nn.kaiming_normal_` and `torch.nn.kaiming_normal_`), if we actually don't scale the parameter but multiply the factor separately, the same factor also scales the gradients and thus, for methods like stochastic gradient descent (SGD) the updates. (For optimizers that remove scaling such as Adam, one would expect the effect to cancel with the modification in Adam's scaling.)\n", "\n", "Vaguely connected, [H Zhang et al. *Fixup Initialization: Residual Learning Without Normalization*](https://openreview.net/forum?id=H1gsz30cKX) suggest to have (but trainable) scalar multipliers.\n", "\n", "So here is the linear layer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class MyLinear(nn.Module):\n", " \"\"\"Linear layer with equalized learning rate and custom learning rate multiplier.\"\"\"\n", " def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):\n", " super().__init__()\n", " he_std = gain * input_size**(-0.5) # He init\n", " # Equalized learning rate and custom learning rate multiplier.\n", " if use_wscale:\n", " init_std = 1.0 / lrmul\n", " self.w_mul = he_std * lrmul\n", " else:\n", " init_std = he_std / lrmul\n", " self.w_mul = lrmul\n", " self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)\n", " if bias:\n", " self.bias = torch.nn.Parameter(torch.zeros(output_size))\n", " self.b_mul = lrmul\n", " else:\n", " self.bias = None\n", "\n", " def forward(self, x):\n", " bias = self.bias\n", " if bias is not None:\n", " bias = bias * self.b_mul\n", " return F.linear(x, self.weight * self.w_mul, bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convolution Layer\n", "\n", "The convolution layer uses the same trick as the linear layer.\n", "\n", "As the architecture of StyleGAN prescribes that during upscaling, blurring occurs between bias addition and (I didn't look into detail why it doesn't seem to commute), we need to provide a mechanism to have an intermediate step between them.\n", "For larger resolutions, the authors also propose a fused convolution / upscaling which is *not* equivalent to the two separate operations. (The \"averaging\" of the weight isn't actually an average, but an addition, that would seem to effectively multiply the weight by four, but I don't know how compatible it would be even without this effect.)\n", "\n", "If we don't have the two, we use the regular convolution with bias." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class MyConv2d(nn.Module):\n", " \"\"\"Conv layer with equalized learning rate and custom learning rate multiplier.\"\"\"\n", " def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,\n", " intermediate=None, upscale=False):\n", " super().__init__()\n", " if upscale:\n", " self.upscale = Upscale2d()\n", " else:\n", " self.upscale = None\n", " he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init\n", " self.kernel_size = kernel_size\n", " if use_wscale:\n", " init_std = 1.0 / lrmul\n", " self.w_mul = he_std * lrmul\n", " else:\n", " init_std = he_std / lrmul\n", " self.w_mul = lrmul\n", " self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)\n", " if bias:\n", " self.bias = torch.nn.Parameter(torch.zeros(output_channels))\n", " self.b_mul = lrmul\n", " else:\n", " self.bias = None\n", " self.intermediate = intermediate\n", "\n", " def forward(self, x):\n", " bias = self.bias\n", " if bias is not None:\n", " bias = bias * self.b_mul\n", " \n", " have_convolution = False\n", " if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:\n", " # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way\n", " # this really needs to be cleaned up and go into the conv...\n", " w = self.weight * self.w_mul\n", " w = w.permute(1, 0, 2, 3)\n", " # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!\n", " w = F.pad(w, (1,1,1,1))\n", " w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]\n", " x = F.conv_transpose2d(x, w, stride=2, padding=int(w.size(-1)-1)//2)\n", " have_convolution = True\n", " elif self.upscale is not None:\n", " x = self.upscale(x)\n", " \n", " if not have_convolution and self.intermediate is None:\n", " return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)\n", " elif not have_convolution:\n", " x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)\n", " \n", " if self.intermediate is not None:\n", " x = self.intermediate(x)\n", " if bias is not None:\n", " x = x + bias.view(1, -1, 1, 1)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Noise Layer\n", "\n", "The noise layer adds gaussian noise of learnable standard deviation (and zero mean). The noise itself is per-pixel (and image in the minibatch), but constant over the channels. Note that the learnable standard deviation is per channel.\n", "As you can see when you feed the same latent into the model several times, the effects are not all that large, but in the detail.\n", "\n", "There is a little trick in the code: if you set .noise for the noise layer, you can fix the noise. This is one of the tricks to use when checking against a reference implementation - this way you can get a 1-1 correspondence of outputs." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class NoiseLayer(nn.Module):\n", " \"\"\"adds noise. noise is per pixel (constant over channels) with per-channel weight\"\"\"\n", " def __init__(self, channels):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.zeros(channels))\n", " self.noise = None\n", " \n", " def forward(self, x, noise=None):\n", " if noise is None and self.noise is None:\n", " noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)\n", " elif noise is None:\n", " # here is a little trick: if you get all the noiselayers and set each\n", " # modules .noise attribute, you can have pre-defined noise.\n", " # Very useful for analysis\n", " noise = self.noise\n", " x = x + self.weight.view(1, -1, 1, 1) * noise\n", " return x " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Style Modification layer\n", "\n", "In the generator, a style modification layer is used after each (non-affine) instance norm layer. Recall that the instance norm normalizes the mean and standard deviation across pixels separately for each channel (and sample).\n", "So here we put back a mean and variance, but not just arbitrary learnable parameters, but as the output of a linear layer which takes the latent style vector as inputs. So in a way, this is the affine part of the instance norm, but with calculated parameters.\n", "The article call the affine Instance Norm *Adaptive Instance Norm (AdaIN)*." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class StyleMod(nn.Module):\n", " def __init__(self, latent_size, channels, use_wscale):\n", " super(StyleMod, self).__init__()\n", " self.lin = MyLinear(latent_size,\n", " channels * 2,\n", " gain=1.0, use_wscale=use_wscale)\n", " \n", " def forward(self, x, latent):\n", " style = self.lin(latent) # style => [batch_size, n_channels*2]\n", " shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]\n", " style = style.view(shape) # [batch_size, 2, n_channels, ...]\n", " x = x * (style[:, 0] + 1.) + style[:, 1]\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pixelnorm\n", "\n", "Pixelnorm normalizes per pixel across all channels.\n", "\n", "Note that the default configuration only use the pixel norm in the g_mapping. There it effectively forces the empirical standard deviation of the latent vector to be one." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class PixelNormLayer(nn.Module):\n", " def __init__(self, epsilon=1e-8):\n", " super().__init__()\n", " self.epsilon = epsilon\n", " def forward(self, x):\n", " return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Upscale and blur layers\n", "\n", "The StyleGAN has two types of upscaling. For the plain one, where you just set a block of 2x2 pixels to the value of the pixel to arrive an image that is scaled by 2. No fancy stuff like bilinear or bicubic interpolation. The alternative way - \"fused\" with convolution - uses a stride 2 transposed convolution instead. Note from above that they seem to not be quite equivalent (at least not with the same parametrisation).\n", "Both of these will have blocky results. To mitigate, the generator blurs the layer, by convolving with the simplest possible smoothing kernel.\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class BlurLayer(nn.Module):\n", " def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):\n", " super(BlurLayer, self).__init__()\n", " kernel=[1, 2, 1]\n", " kernel = torch.tensor(kernel, dtype=torch.float32)\n", " kernel = kernel[:, None] * kernel[None, :]\n", " kernel = kernel[None, None]\n", " if normalize:\n", " kernel = kernel / kernel.sum()\n", " if flip:\n", " kernel = kernel[:, :, ::-1, ::-1]\n", " self.register_buffer('kernel', kernel)\n", " self.stride = stride\n", " \n", " def forward(self, x):\n", " # expand kernel channels\n", " kernel = self.kernel.expand(x.size(1), -1, -1, -1)\n", " x = F.conv2d(\n", " x,\n", " kernel,\n", " stride=self.stride,\n", " padding=int((self.kernel.size(2)-1)/2),\n", " groups=x.size(1)\n", " )\n", " return x\n", "\n", "def upscale2d(x, factor=2, gain=1):\n", " assert x.dim() == 4\n", " if gain != 1:\n", " x = x * gain\n", " if factor != 1:\n", " shape = x.shape\n", " x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)\n", " x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])\n", " return x\n", "\n", "class Upscale2d(nn.Module):\n", " def __init__(self, factor=2, gain=1):\n", " super().__init__()\n", " assert isinstance(factor, int) and factor >= 1\n", " self.gain = gain\n", " self.factor = factor\n", " def forward(self, x):\n", " return upscale2d(x, factor=self.factor, gain=self.gain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generator Mapping Module\n", "\n", "With all these building blocks done, we can actually define the StyleGAN generator.\n", "The fist component is the mapping. It's a reasonably deep (8 layers) but otherwise very plain vanilla fully connected network aka Multi-Layer-Perceptron. The StyleGAN reference model uses Leaky ReLUs, so we do, too.\n", "Note that while we get an 18-channel (times 512 features) per image style matrix, all 18 channels will be the same. \n", "\n", "We also provide a truncation module pulling the upper layer's latent inputs towards the mean, but we don't activate it as the mean is not provided in the pre-trained network. We could run the mapping for a while and derive the truncation weights." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class G_mapping(nn.Sequential):\n", " def __init__(self, nonlinearity='lrelu', use_wscale=True):\n", " act, gain = {'relu': (torch.relu, np.sqrt(2)),\n", " 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]\n", " layers = [\n", " ('pixel_norm', PixelNormLayer()),\n", " ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense0_act', act),\n", " ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense1_act', act),\n", " ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense2_act', act),\n", " ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense3_act', act),\n", " ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense4_act', act),\n", " ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense5_act', act),\n", " ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense6_act', act),\n", " ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense7_act', act)\n", " ]\n", " super().__init__(OrderedDict(layers))\n", " \n", " def forward(self, x):\n", " x = super().forward(x)\n", " # Broadcast\n", " x = x.unsqueeze(1).expand(-1, 18, -1)\n", " return x\n", "\n", "class Truncation(nn.Module):\n", " def __init__(self, avg_latent, max_layer=8, threshold=0.7):\n", " super().__init__()\n", " self.max_layer = max_layer\n", " self.threshold = threshold\n", " self.register_buffer('avg_latent', avg_latent)\n", " def forward(self, x):\n", " assert x.dim() == 3\n", " interp = torch.lerp(self.avg_latent, x, self.threshold)\n", " do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)\n", " return torch.where(do_trunc, interp, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generator Synthesis Blocks\n", "\n", "Each Block consists of two halfs. Each of these halfs does the following:\n", "- Upscaling (if it's the first half) by a factor of two and blurring - fused with the convolution for the later layers\n", "- Convolution (if it's the first half, halving the channels for the later layers)\n", "- Noise\n", "- Activation (LeakyReLU in the reference model)\n", "- Optionally Pixel Norm (**not used** in the reference model)\n", "- Instance Norm (optional, but used in the reference model)\n", "- The style modulation (i.e. setting the mean/standard deviation of the outputs after instance norm, see above)\n", "\n", "Two of these sequences form a block that typically has `out_channels = in_channels//2` (in the earlier blocks, there are 512 input and 512 output channels) and `output_resolution = input_resolution * 2`.\n", "We combine all but the first two into a Module called the Layer Epilogue (the term taken from the original code).\n", "Note that the original implementation moves the bias of the convolution after the noise, but those two commute, so it doesn't matter.\n", "\n", "The first block (4x4 \"pixels\") doesn't have an input. The result of the first convolution is just replaced by a (trained) constant. We call it the `InputBlock`, the others `GSynthesisBlock`.\n", "(It might be nicer to do this the other way round, i.e. have the `LayerEpilogue` be the Layer and call the conv from that.)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class LayerEpilogue(nn.Module):\n", " \"\"\"Things to do at the end of each layer.\"\"\"\n", " def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " super().__init__()\n", " layers = []\n", " if use_noise:\n", " layers.append(('noise', NoiseLayer(channels)))\n", " layers.append(('activation', activation_layer))\n", " if use_pixel_norm:\n", " layers.append(('pixel_norm', PixelNorm()))\n", " if use_instance_norm:\n", " layers.append(('instance_norm', nn.InstanceNorm2d(channels)))\n", " self.top_epi = nn.Sequential(OrderedDict(layers))\n", " if use_styles:\n", " self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)\n", " else:\n", " self.style_mod = None\n", " def forward(self, x, dlatents_in_slice=None):\n", " x = self.top_epi(x)\n", " if self.style_mod is not None:\n", " x = self.style_mod(x, dlatents_in_slice)\n", " else:\n", " assert dlatents_in_slice is None\n", " return x\n", "\n", "\n", "class InputBlock(nn.Module):\n", " def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " super().__init__()\n", " self.const_input_layer = const_input_layer\n", " self.nf = nf\n", " if self.const_input_layer:\n", " # called 'const' in tf\n", " self.const = nn.Parameter(torch.ones(1, nf, 4, 4))\n", " self.bias = nn.Parameter(torch.ones(nf))\n", " else:\n", " self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN\n", " self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)\n", " self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " \n", " def forward(self, dlatents_in_range):\n", " batch_size = dlatents_in_range.size(0)\n", " if self.const_input_layer:\n", " x = self.const.expand(batch_size, -1, -1, -1)\n", " x = x + self.bias.view(1, -1, 1, 1)\n", " else:\n", " x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)\n", " x = self.epi1(x, dlatents_in_range[:, 0])\n", " x = self.conv(x)\n", " x = self.epi2(x, dlatents_in_range[:, 1])\n", " return x\n", "\n", "\n", "class GSynthesisBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " # 2**res x 2**res # res = 3..resolution_log2\n", " super().__init__()\n", " if blur_filter:\n", " blur = BlurLayer(blur_filter)\n", " else:\n", " blur = None\n", " self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,\n", " intermediate=blur, upscale=True)\n", " self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)\n", " self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " \n", " def forward(self, x, dlatents_in_range):\n", " x = self.conv0_up(x)\n", " x = self.epi1(x, dlatents_in_range[:, 0])\n", " x = self.conv1(x)\n", " x = self.epi2(x, dlatents_in_range[:, 1])\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generator - Synthesis part\n", "\n", "Finally, the synthesis part just stacks 9 blocks (input + 8 resolution doubling) and a pixelwise (1x1) convolution from 16 channels to RGB (3 channels).\n", "Note that lower resolution RGB convolutions don't seem to serve any purpose in the final model. (And we don't think that they produce images as they once did during the training - but the parameter file contains parameters, so we have the modules here. The alternative would be to filter out the weights).\n", "The reference implementation's somewhat convoluted setup in the \"recursive\" mode is mainly to provide a single static graph for all stages of the progressive training. It would be interesting to reimplement the full training in PyTorch, making use of the dynamic graphs.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class G_synthesis(nn.Module):\n", " def __init__(self,\n", " dlatent_size = 512, # Disentangled latent (W) dimensionality.\n", " num_channels = 3, # Number of output color channels.\n", " resolution = 1024, # Output resolution.\n", " fmap_base = 8192, # Overall multiplier for the number of feature maps.\n", " fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.\n", " fmap_max = 512, # Maximum number of feature maps in any layer.\n", " use_styles = True, # Enable style inputs?\n", " const_input_layer = True, # First layer is a learned constant?\n", " use_noise = True, # Enable noise inputs?\n", " randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.\n", " nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'\n", " use_wscale = True, # Enable equalized learning rate?\n", " use_pixel_norm = False, # Enable pixelwise feature vector normalization?\n", " use_instance_norm = True, # Enable instance normalization?\n", " dtype = torch.float32, # Data type to use for activations and outputs.\n", " blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.\n", " ):\n", " \n", " super().__init__()\n", " def nf(stage):\n", " return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n", " self.dlatent_size = dlatent_size\n", " resolution_log2 = int(np.log2(resolution))\n", " assert resolution == 2**resolution_log2 and resolution >= 4\n", "\n", " act, gain = {'relu': (torch.relu, np.sqrt(2)),\n", " 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]\n", " num_layers = resolution_log2 * 2 - 2\n", " num_styles = num_layers if use_styles else 1\n", " torgbs = []\n", " blocks = []\n", " for res in range(2, resolution_log2 + 1):\n", " channels = nf(res-1)\n", " name = '{s}x{s}'.format(s=2**res)\n", " if res == 2:\n", " blocks.append((name,\n", " InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,\n", " use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))\n", " \n", " else:\n", " blocks.append((name,\n", " GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))\n", " last_channels = channels\n", " self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)\n", " self.blocks = nn.ModuleDict(OrderedDict(blocks))\n", " \n", " def forward(self, dlatents_in):\n", " # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].\n", " # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)\n", " batch_size = dlatents_in.size(0) \n", " for i, m in enumerate(self.blocks.values()):\n", " if i == 0:\n", " x = m(dlatents_in[:, 2*i:2*i+2])\n", " else:\n", " x = m(x, dlatents_in[:, 2*i:2*i+2])\n", " rgb = self.torgb(x)\n", " return rgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## All done, let's define the model!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "g_all = nn.Sequential(OrderedDict([\n", " ('g_mapping', G_mapping()),\n", " #('truncation', Truncation(avg_latent)),\n", " ('g_synthesis', G_synthesis()) \n", "]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### But we need weights. Can we use the pretrained ones?\n", "\n", "Yes, we can! The following can be used to convert them from author's format. We have already done this for you, and you can get the weights from \n", "[here](https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/karras2019stylegan-ffhq-1024x1024.for_g_all.pt).\n", "\n", "Note that the weights are taken from [the reference implementation](https://github.com/NVlabs/stylegan) distributed by NVidia Corporation as Licensed under the CC-BY-NC 4.0 license. As such, the same applies here.\n", "\n", "For completeness, our conversion is below, but you can skip it if you download the PyTorch-ready weights." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/alberto/Documents/StyleGAN/dnnlib/tflib/tfutil.py:34: The name tf.Dimension is deprecated. Please use tf.compat.v1.Dimension instead.\n", "\n", "WARNING:tensorflow:From /home/alberto/Documents/StyleGAN/dnnlib/tflib/tfutil.py:74: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n", "\n", "WARNING:tensorflow:From /home/alberto/Documents/StyleGAN/dnnlib/tflib/tfutil.py:128: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n", "\n", "WARNING:tensorflow:From /home/alberto/Documents/StyleGAN/dnnlib/tflib/tfutil.py:97: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", "\n", "WARNING:tensorflow:From /home/alberto/Documents/StyleGAN/dnnlib/tflib/tfutil.py:109: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/home/alberto/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" ] }, { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: './karras2019stylegan-ffhq-1024x1024.pkl'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-12-c0de5626b61b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mdnnlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtflib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_tf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./karras2019stylegan-ffhq-1024x1024.pkl'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mweights_pt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcollections\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainables\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights_pt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'./karras2019stylegan-ffhq-1024x1024.pt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './karras2019stylegan-ffhq-1024x1024.pkl'" ] } ], "source": [ "# this can be run to get the weights, but you need the reference implementation and weights\n", "import os\n", "import pickle\n", "import numpy as np\n", "import PIL.Image\n", "import dnnlib, dnnlib.tflib, pickle, torch, collections\n", "\n", "dnnlib.tflib.init_tf()\n", "weights = pickle.load(open('./results/trained_agents/karras2019stylegan-ffhq-1024x1024.pkl','rb'))\n", "weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights]\n", "torch.save(weights_pt, './results/trained_agents/karras2019stylegan-ffhq-1024x1024.pt')\n", "\n", "# then on the PyTorch side run\n", "state_G, state_D, state_Gs = torch.load('./results/trained_agents/karras2019stylegan-ffhq-1024x1024.pt')\n", "def key_translate(k):\n", " k = k.lower().split('/')\n", " if k[0] == 'g_synthesis':\n", " if not k[1].startswith('torgb'):\n", " k.insert(1, 'blocks')\n", " k = '.'.join(k)\n", " k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin')\n", " .replace('const.noise.weight','epi1.top_epi.noise.weight')\n", " .replace('conv.noise.weight','epi2.top_epi.noise.weight')\n", " .replace('conv.stylemod','epi2.style_mod.lin')\n", " .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')\n", " .replace('conv0_up.stylemod','epi1.style_mod.lin')\n", " .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')\n", " .replace('conv1.stylemod','epi2.style_mod.lin')\n", " .replace('torgb_lod0','torgb'))\n", " else:\n", " k = '.'.join(k)\n", " return k\n", "\n", "def weight_translate(k, w):\n", " k = key_translate(k)\n", " if k.endswith('.weight'):\n", " if w.dim() == 2:\n", " w = w.t()\n", " elif w.dim() == 1:\n", " pass\n", " else:\n", " assert w.dim() == 4\n", " w = w.permute(3, 2, 0, 1)\n", " return w\n", "\n", "# we delete the useless torgb filters\n", "param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)}\n", "if 1:\n", " sd_shapes = {k : v.shape for k,v in g_all.state_dict().items()}\n", " param_shapes = {k : v.shape for k,v in param_dict.items() }\n", "\n", " for k in list(sd_shapes)+list(param_shapes):\n", " pds = param_shapes.get(k)\n", " sds = sd_shapes.get(k)\n", " if pds is None:\n", " print (\"sd only\", k, sds)\n", " elif sds is None:\n", " print (\"pd only\", k, pds)\n", " elif sds != pds:\n", " print (\"mismatch!\", k, pds, sds)\n", "\n", "g_all.load_state_dict(param_dict, strict=False) # needed for the blur kernels\n", "torch.save(g_all.state_dict(), './results/trained_agents/karras2019stylegan-ffhq-1024x1024_all.pt')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "CUDA error: out of memory", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-13-bc56004cb654>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg_all\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mcuda\u001b[0;34m(self, device)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \"\"\"\n\u001b[0;32m--> 311\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/StyleGAN/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \"\"\"\n\u001b[0;32m--> 311\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory" ] } ], "source": [ "model = g_all.cuda()\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "z = torch.rand(1,512).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "\n", "# Export the model (we use z as a dummy input)\n", "torch.onnx.export(model, z, \"test_stylegan.onnx\", verbose=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }