Skip to content

Instantly share code, notes, and snippets.

@kahartma
Created July 10, 2017 14:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kahartma/37b97c096a3dd141d19257ad6e2d73b8 to your computer and use it in GitHub Desktop.
Save kahartma/37b97c096a3dd141d19257ad6e2d73b8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"from torch import optim\n",
"\n",
"import torch.autograd as autograd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class WGAN_I_Discriminator(nn.Module):\n",
"\tdef __init__(self):\n",
"\t\tsuper(WGAN_I_Discriminator, self).__init__()\n",
"\n",
"\t\tself.did_init_train = False\n",
"\n",
"\tdef train_init(self,alpha=0.001,betas=(0,0.9)):\n",
"\t\tself.loss = torch.nn.L1Loss()\n",
"\t\tself.optimizer = optim.Adam(self.parameters(),lr=alpha,betas=betas)\n",
"\t\tself.did_init_train = True\n",
"\n",
"\tdef train_batch(self, batch_real, batch_fake, lambd=10):\n",
"\t\tif not self.did_init_train:\n",
"\t\t\tself.train_init()\n",
"\n",
"\t\t# Reset gradients\n",
"\t\tself.optimizer.zero_grad()\n",
"\n",
"\t\t# Compute output and loss\n",
"\t\tfx_real = self.forward(batch_real)\n",
"\t\tloss_real = fx_real.mean()\n",
"\n",
"\t\tfx_fake = self.forward(batch_fake)\n",
"\t\tloss_fake = fx_fake.mean()\n",
"\n",
"\n",
"\t\t#dreist geklaut von\n",
"\t\t# https://github.com/caogang/wgan-gp/blob/master/gan_toy.py\n",
"\t\t# gradients = autograd.grad(outputs=fx_comb, inputs=interpolates,\n",
"\t\t# \t\t\t\t\t grad_outputs=grad_ones,\n",
"\t\t# \t\t\t\t\t create_graph=True, retain_graph=True, only_inputs=True)[0]\n",
"\t\t# gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambd\n",
"\t\tloss_penalty = self.calc_gradient_penalty(batch_real, batch_fake,lambd)\n",
"\n",
"\t\tloss = loss_fake - loss_real + loss_penalty\n",
"\n",
"\t\t# Backprop gradient\n",
"\t\tloss.backward()\n",
"\n",
"\t\t# Update parameters\n",
"\t\tself.optimizer.step()\n",
"\n",
"\t\tself.optimizer.zero_grad()\n",
"\n",
"\t\treturn loss.data[0] # return loss\n",
"\n",
"\n",
"\tdef calc_gradient_penalty(self, real_data, fake_data,lambd):\n",
"\t\talpha = torch.rand(real_data.size(0), 1,1,1)\n",
"\t\talpha = alpha.expand(real_data.size())\n",
"\t\talpha = alpha\n",
"\n",
"\t\tinterpolates = alpha * real_data.data + ((1 - alpha) * fake_data.data)\n",
"\n",
"\t\tinterpolates = interpolates\n",
"\t\tinterpolates = Variable(interpolates, requires_grad=True)\n",
"\n",
"\t\tdisc_interpolates = self(interpolates)\n",
"\n",
"\t\tgradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,\n",
"\t\t\t\t\t\t\t\t grad_outputs=torch.ones(disc_interpolates.size()),\n",
"\t\t\t\t\t\t\t\t create_graph=True, retain_graph=True, only_inputs=True)[0]\n",
"\n",
"\t\tgradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambd \n",
"\t\t#for grad in gradients[1:]:\n",
"\t\t#\tgradient_penalty += 0 * grad.view(-1)[0]\n",
"\n",
"\t\treturn gradient_penalty\n",
"\n",
"\n",
"class WGAN_I_Generator(nn.Module):\n",
"\tdef __init__(self):\n",
"\t\tsuper(WGAN_I_Generator, self).__init__()\n",
"\n",
"\t\tself.did_init_train = False\n",
"\n",
"\tdef train_init(self,alpha=0.001,betas=(0,0.9)):\n",
"\t\tself.loss = None\n",
"\t\tself.optimizer = optim.Adam(self.parameters(),lr=alpha,betas=betas)\n",
"\t\tself.did_init_train = True\n",
"\n",
"\tdef train_batch(self, batch_noise, discriminator):\n",
"\t\tif not self.did_init_train:\n",
"\t\t\tself.train_init()\n",
"\n",
"\t\tnoise = batch_noise\n",
"\n",
"\t\t# Reset gradients\n",
"\t\tself.optimizer.zero_grad()\n",
"\n",
"\t\t# Generate and discriminate\n",
"\t\tgen = self.forward(noise)\n",
"\t\tprint 'generated'\n",
"\t\tdisc = discriminator(gen)\n",
"\t\tloss = -disc.mean()\n",
"\n",
"\t\t# Backprop gradient\n",
"\t\tloss.backward()\n",
"\n",
"\t\t# Update parameters\n",
"\t\tself.optimizer.step()\n",
"\n",
"\t\tself.optimizer.zero_grad()\n",
"\n",
"\t\treturn loss.data[0] # return loss"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class G(WGAN_I_Generator):\n",
"\tdef __init__(self):\n",
"\t\tsuper(G, self).__init__()\n",
"\t\tself.project = nn.Sequential(\n",
"\t\t\tnn.Linear(1,28*28),\n",
"\t\t\tnn.Tanh()\n",
"\t\t\t)\n",
"\n",
"\t\tself.did_init_train = False\n",
"\n",
"\tdef forward(self, inp):\n",
"\t\tx = self.project(inp)\n",
"\t\tx = x.view(x.size(0),1,28,28)\n",
"\t\treturn x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class D1(WGAN_I_Discriminator):\n",
"\tdef __init__(self):\n",
"\t\tsuper(D1, self).__init__()\n",
"\t\tself.features = nn.Sequential(\n",
"\t\t\tnn.Conv2d(1,3,3),\n",
"\t\t\tnn.Conv2d(3,3,2,stride=2),\n",
"\t\t\tnn.LeakyReLU(negative_slope=0.2),\n",
"\t\t\tnn.Conv2d(3,3,1),\n",
"\t\t\tnn.LeakyReLU(negative_slope=0.2)\n",
"\t\t\t)\n",
"\t\tself.classification = nn.Sequential(\n",
"\t\t\tnn.Linear(3*13*13,1)\n",
"\t\t\t)\n",
"\n",
"\t\tself.did_init_train = False\n",
"\n",
"\tdef forward(self, inp):\n",
"\t\tprint inp\n",
"\t\tinp = inp.view(inp.size(0),1,28,28)\n",
"\t\tx = self.features(inp)\n",
"\t\tprint x.size()\n",
"\t\tx = x.view(x.size(0),3*13*13)\n",
"\t\tprint x\n",
"\t\tx = self.classification(x)\n",
"\t\t#x = (x*x)/x\n",
"\t\tprint x\n",
"\t\treturn x\n",
" \n",
"class D2(WGAN_I_Discriminator):\n",
"\tdef __init__(self):\n",
"\t\tsuper(D2, self).__init__()\n",
"\t\tself.features = nn.Sequential(\n",
"\t\t\tnn.Conv2d(1,3,3),\n",
"\t\t\tnn.Conv2d(3,3,2),\n",
"\t\t\tnn.LeakyReLU(negative_slope=0.2),\n",
"\t\t\tnn.Conv2d(3,3,1),\n",
"\t\t\tnn.LeakyReLU(negative_slope=0.2)\n",
"\t\t\t)\n",
"\t\tself.classification = nn.Sequential(\n",
"\t\t\tnn.Linear(3*25*25,1)\n",
"\t\t\t)\n",
"\n",
"\t\tself.did_init_train = False\n",
"\n",
"\tdef forward(self, inp):\n",
"\t\tprint inp\n",
"\t\tinp = inp.view(inp.size(0),1,28,28)\n",
"\t\tx = self.features(inp)\n",
"\t\tprint x.size()\n",
"\t\tx = x.view(x.size(0),3*25*25)\n",
"\t\tprint x\n",
"\t\tx = self.classification(x)\n",
"\t\t#x = (x*x)/x\n",
"\t\tprint x\n",
"\t\treturn x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def train(gen,disc):\n",
" batch_size = 5\n",
"\n",
" ind = np.random.randint(28*28,size=(batch_size,1))\n",
" batch = np.zeros((batch_size,28*28))\n",
" batch[:,ind] = 1\n",
" batch = batch.reshape((batch_size,1,28,28))\n",
" X_batch = Variable(torch.from_numpy(batch).float(),requires_grad=False)\n",
" noise = np.random.uniform(-1,1,(batch_size,1))\n",
" noise = Variable(torch.from_numpy(noise).float(),volatile=True)\n",
"\n",
" fake_batch = Variable(gen.forward(noise).data,requires_grad=False)\n",
"\n",
" loss_d = disc.train_batch(X_batch,fake_batch)\n",
"\n",
" noise = np.random.uniform(-1,1,(batch_size,1))\n",
" noise = Variable(torch.from_numpy(noise).float(),requires_grad=False)\n",
"\n",
" loss_g = gen.train_batch(noise,disc)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 13, 13])\n",
"Variable containing:\n",
" 0.4414 0.4414 0.4414 ... 0.3586 0.3586 0.3586\n",
" 0.4414 0.4414 0.4414 ... 0.3586 0.3586 0.3586\n",
" 0.4414 0.4414 0.4414 ... 0.3586 0.3586 0.3586\n",
" 0.4414 0.4414 0.4414 ... 0.3586 0.3586 0.3586\n",
" 0.4414 0.4414 0.4414 ... 0.3586 0.3586 0.3586\n",
"[torch.FloatTensor of size 5x507]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" 8.2887\n",
" 8.2887\n",
" 8.2887\n",
" 8.2887\n",
" 8.2887\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" 0.0543 -0.7153 -0.5160 ... 0.4735 -0.7164 0.8703\n",
" 0.5354 0.5323 -0.1328 ... 0.1906 -0.3890 -0.4585\n",
" 0.3904 -0.2583 -0.3826 ... 0.3853 0.4926 -0.5895\n",
" ... ⋱ ... \n",
" 0.4613 0.4360 0.1483 ... 0.3889 -0.0371 0.6373\n",
" 0.2166 0.3023 -0.6837 ... -0.7028 -0.3547 -0.8479\n",
" 0.4830 -0.0631 0.5583 ... -0.0454 0.4298 -0.0724\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" -0.2529 -0.1734 -0.0816 ... -0.0701 -0.2074 0.6213\n",
" 0.4080 0.6716 -0.2989 ... 0.3909 -0.7981 -0.7031\n",
" 0.6493 -0.3965 -0.7459 ... 0.3729 0.7814 -0.6155\n",
" ... ⋱ ... \n",
" 0.6018 -0.1482 -0.1588 ... -0.0197 -0.2622 0.6513\n",
" 0.5241 -0.1004 -0.7786 ... -0.7421 -0.6843 -0.6513\n",
" -0.0552 -0.3759 0.5408 ... -0.4487 0.5794 0.3645\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0.0725 -0.7353 -0.5366 ... 0.4995 -0.7355 0.8786\n",
" 0.5422 0.5230 -0.1228 ... 0.1782 -0.3547 -0.4409\n",
" 0.3723 -0.2498 -0.3544 ... 0.3860 0.4697 -0.5879\n",
" ... ⋱ ... \n",
" 0.4522 0.4647 0.1659 ... 0.4099 -0.0236 0.6364\n",
" 0.1964 0.3240 -0.6772 ... -0.7004 -0.3307 -0.8554\n",
" 0.5087 -0.0438 0.5593 ... -0.0199 0.4202 -0.0987\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" 0.2171 -0.8570 -0.6808 ... 0.6779 -0.8529 0.9298\n",
" 0.5948 0.4433 -0.0408 ... 0.0760 -0.0480 -0.2864\n",
" 0.2165 -0.1798 -0.1051 ... 0.3919 0.2626 -0.5752\n",
" ... ⋱ ... \n",
" 0.3751 0.6611 0.3039 ... 0.5642 0.0856 0.6296\n",
" 0.0278 0.4863 -0.6210 ... -0.6800 -0.1225 -0.9049\n",
" 0.6838 0.1127 0.5674 ... 0.1849 0.3386 -0.3040\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" -0.4044 0.2149 0.1825 ... -0.3702 0.1638 0.3769\n",
" 0.3306 0.7323 -0.3828 ... 0.4874 -0.8987 -0.7929\n",
" 0.7492 -0.4653 -0.8535 ... 0.3661 0.8683 -0.6291\n",
" ... ⋱ ... \n",
" 0.6658 -0.4503 -0.3174 ... -0.2486 -0.3752 0.6588\n",
" 0.6522 -0.3144 -0.8191 ... -0.7615 -0.7973 -0.4785\n",
" -0.3560 -0.5199 0.5311 ... -0.6178 0.6477 0.5577\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 13, 13])\n",
"Variable containing:\n",
" 0.4369 0.4373 0.4381 ... 0.2969 0.3493 0.3412\n",
" 0.4377 0.4275 0.4342 ... 0.2881 0.3508 0.3765\n",
" 0.4370 0.4378 0.4384 ... 0.2974 0.3494 0.3362\n",
" 0.4377 0.4516 0.4498 ... 0.2999 0.3517 0.2957\n",
" 0.4391 0.4240 0.4323 ... 0.2860 0.3522 0.3744\n",
"[torch.FloatTensor of size 5x507]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" 8.2505\n",
" 9.1394\n",
" 8.1786\n",
" 8.1919\n",
" 9.1422\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" 0.0325 -0.4281 -0.3088 ... 0.2833 -0.4287 0.5208\n",
" 0.3204 0.3186 -0.0795 ... 0.1141 -0.2328 -0.2744\n",
" 0.2336 -0.1546 -0.2290 ... 0.2306 0.2948 -0.3528\n",
" ... ⋱ ... \n",
" 0.2761 0.2609 0.0887 ... 0.2327 -0.0222 0.3814\n",
" 0.1296 0.1809 -0.4091 ... -0.4206 -0.2123 -0.5074\n",
" 0.2891 -0.0378 0.3341 ... -0.0272 0.2572 -0.0433\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" -0.1359 -0.0932 -0.0438 ... -0.0377 -0.1114 0.3337\n",
" 0.2191 0.3608 -0.1605 ... 0.2100 -0.4287 -0.3777\n",
" 0.3488 -0.2130 -0.4007 ... 0.2003 0.4197 -0.3306\n",
" ... ⋱ ... \n",
" 0.3233 -0.0796 -0.0853 ... -0.0106 -0.1409 0.3499\n",
" 0.2815 -0.0539 -0.4183 ... -0.3986 -0.3676 -0.3498\n",
" -0.0297 -0.2019 0.2905 ... -0.2410 0.3112 0.1958\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0.0454 -0.4602 -0.3359 ... 0.3127 -0.4604 0.5500\n",
" 0.3394 0.3274 -0.0768 ... 0.1115 -0.2220 -0.2760\n",
" 0.2330 -0.1564 -0.2218 ... 0.2416 0.2940 -0.3680\n",
" ... ⋱ ... \n",
" 0.2831 0.2909 0.1038 ... 0.2566 -0.0148 0.3984\n",
" 0.1229 0.2028 -0.4239 ... -0.4384 -0.2070 -0.5354\n",
" 0.3184 -0.0274 0.3501 ... -0.0125 0.2630 -0.0618\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" 0.2141 -0.8453 -0.6716 ... 0.6687 -0.8414 0.9172\n",
" 0.5868 0.4373 -0.0402 ... 0.0750 -0.0473 -0.2825\n",
" 0.2136 -0.1774 -0.1036 ... 0.3865 0.2590 -0.5674\n",
" ... ⋱ ... \n",
" 0.3700 0.6522 0.2998 ... 0.5565 0.0844 0.6211\n",
" 0.0274 0.4797 -0.6126 ... -0.6708 -0.1208 -0.8926\n",
" 0.6745 0.1112 0.5597 ... 0.1824 0.3340 -0.2999\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" -0.2112 0.1122 0.0953 ... -0.1933 0.0855 0.1968\n",
" 0.1727 0.3824 -0.1999 ... 0.2545 -0.4693 -0.4140\n",
" 0.3912 -0.2429 -0.4457 ... 0.1911 0.4534 -0.3285\n",
" ... ⋱ ... \n",
" 0.3477 -0.2351 -0.1657 ... -0.1298 -0.1959 0.3440\n",
" 0.3405 -0.1642 -0.4277 ... -0.3977 -0.4163 -0.2498\n",
" -0.1859 -0.2715 0.2773 ... -0.3226 0.3382 0.2912\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 13, 13])\n",
"Variable containing:\n",
" 0.4384 0.4386 0.4391 ... 0.3217 0.3531 0.3482\n",
" 0.4390 0.4345 0.4371 ... 0.3207 0.3545 0.3779\n",
" 0.4383 0.4389 0.4392 ... 0.3203 0.3529 0.3446\n",
" 0.4377 0.4514 0.4497 ... 0.3007 0.3517 0.2965\n",
" 0.4398 0.4332 0.4365 ... 0.3207 0.3553 0.3768\n",
"[torch.FloatTensor of size 5x507]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" 8.4410\n",
" 9.0147\n",
" 8.4003\n",
" 8.1756\n",
" 9.2071\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"generated\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" -2.0899e-02 -6.1942e-01 -4.2454e-01 ... 3.5738e-01 -6.2571e-01 8.3005e-01\n",
" 5.0659e-01 5.6921e-01 -1.7381e-01 ... 2.4101e-01 -5.1894e-01 -5.2725e-01\n",
" 4.6155e-01 -2.9277e-01 -4.9143e-01 ... 3.8233e-01 5.7969e-01 -5.9586e-01\n",
" ... ⋱ ... \n",
" 4.9779e-01 3.0861e-01 7.4796e-02 ... 2.9769e-01 -9.2509e-02 6.4068e-01\n",
" 2.9780e-01 2.0959e-01 -7.0917e-01 ... -7.1267e-01 -4.4862e-01 -8.1279e-01\n",
" 3.6868e-01 -1.4212e-01 5.5416e-01 ... -1.4957e-01 4.6857e-01 3.6823e-02\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" -5.0101e-01 4.6172e-01 3.5818e-01 ... -5.4845e-01 4.0837e-01 1.5882e-01\n",
" 2.7247e-01 7.6965e-01 -4.3928e-01 ... 5.4982e-01 -9.3915e-01 -8.4149e-01\n",
" 8.0475e-01 -5.1122e-01 -9.0273e-01 ... 3.6115e-01 9.0948e-01 -6.3856e-01\n",
" ... ⋱ ... \n",
" 7.0631e-01 -6.1996e-01 -4.2123e-01 ... -3.9802e-01 -4.4981e-01 6.6406e-01\n",
" 7.2583e-01 -4.5103e-01 -8.4377e-01 ... -7.7465e-01 -8.5447e-01 -3.2547e-01\n",
" -5.3624e-01 -6.0775e-01 5.2404e-01 ... -7.1221e-01 6.9100e-01 6.6756e-01\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 5.4838e-02 -7.1587e-01 -5.1660e-01 ... 4.7424e-01 -7.1700e-01 8.7056e-01\n",
" 5.3562e-01 5.3203e-01 -1.3249e-01 ... 1.9024e-01 -3.8805e-01 -4.5799e-01\n",
" 3.8984e-01 -2.5803e-01 -3.8182e-01 ... 3.8533e-01 4.9192e-01 -5.8946e-01\n",
" ... ⋱ ... \n",
" 4.6106e-01 4.3689e-01 1.4881e-01 ... 3.8948e-01 -3.6716e-02 6.3723e-01\n",
" 2.1598e-01 3.0291e-01 -6.8348e-01 ... -7.0272e-01 -3.5400e-01 -8.4808e-01\n",
" 4.8380e-01 -6.2564e-02 5.5836e-01 ... -4.4652e-02 4.2956e-01 -7.3126e-02\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" -3.0316e-01 -4.9362e-02 3.3377e-03 ... -1.7034e-01 -9.0072e-02 5.5210e-01\n",
" 3.8380e-01 6.9216e-01 -3.2629e-01 ... 4.2282e-01 -8.3735e-01 -7.3484e-01\n",
" 6.8427e-01 -4.1901e-01 -7.8610e-01 ... 3.7070e-01 8.1360e-01 -6.1985e-01\n",
" ... ⋱ ... \n",
" 6.2319e-01 -2.5113e-01 -2.1085e-01 ... -9.4302e-02 -2.9930e-01 6.5372e-01\n",
" 5.6827e-01 -1.7084e-01 -7.9236e-01 ... -7.4843e-01 -7.2509e-01 -6.0150e-01\n",
" -1.5538e-01 -4.2446e-01 5.3772e-01 ... -5.0742e-01 6.0221e-01 4.3101e-01\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" 1.0318e-01 -7.6638e-01 -5.7015e-01 ... 5.4163e-01 -7.6523e-01 8.9163e-01\n",
" 5.5362e-01 5.0702e-01 -1.0572e-01 ... 1.5708e-01 -2.9437e-01 -4.1032e-01\n",
" 3.4111e-01 -2.3536e-01 -3.0502e-01 ... 3.8725e-01 4.2958e-01 -5.8531e-01\n",
" ... ⋱ ... \n",
" 4.3661e-01 5.1113e-01 1.9551e-01 ... 4.4470e-01 -7.5892e-04 6.3501e-01\n",
" 1.6173e-01 3.6000e-01 -6.6605e-01 ... -6.9619e-01 -2.8902e-01 -8.6741e-01\n",
" 5.5006e-01 -1.1006e-02 5.6104e-01 ... 2.3364e-02 4.0363e-01 -1.4294e-01\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 13, 13])\n",
"Variable containing:\n",
" 0.4367 0.4349 0.4368 ... 0.2974 0.3507 0.3566\n",
" 0.4400 0.4225 0.4316 ... 0.2894 0.3536 0.3754\n",
" 0.4368 0.4372 0.4379 ... 0.2994 0.3510 0.3423\n",
" 0.4379 0.4259 0.4339 ... 0.2892 0.3527 0.3780\n",
" 0.4370 0.4386 0.4394 ... 0.3005 0.3515 0.3289\n",
"[torch.FloatTensor of size 5x507]\n",
"\n",
"Variable containing:\n",
" 0.0934\n",
" 0.0992\n",
" 0.0900\n",
" 0.1002\n",
" 0.0878\n",
"[torch.FloatTensor of size 5x1]\n",
"\n"
]
}
],
"source": [
"train(G(),disc=D1())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" ... ⋱ ... \n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
" 0 0 0 ... 0 0 0\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 25, 25])\n",
"Variable containing:\n",
" 0.0831 0.0831 0.0831 ... 0.1096 0.1096 0.1096\n",
" 0.0831 0.0831 0.0831 ... 0.1096 0.1096 0.1096\n",
" 0.0831 0.0831 0.0831 ... 0.1096 0.1096 0.1096\n",
" 0.0831 0.0831 0.0831 ... 0.1096 0.1096 0.1096\n",
" 0.0831 0.0831 0.0831 ... 0.1096 0.1096 0.1096\n",
"[torch.FloatTensor of size 5x1875]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" -8.1986\n",
" -8.1986\n",
" -8.1986\n",
" -8.1986\n",
" -8.1986\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" -0.3062 0.6279 -0.0862 ... 0.8130 -0.2229 0.6384\n",
" 0.3776 0.5711 -0.8317 ... -0.6858 -0.6892 0.0281\n",
" -0.8529 0.3308 0.8555 ... 0.7805 -0.5374 0.7383\n",
" ... ⋱ ... \n",
" -0.8116 -0.8118 0.8022 ... 0.2282 -0.6979 0.2036\n",
" 0.5507 0.5026 0.4619 ... 0.2841 0.4306 -0.5585\n",
" 0.1682 -0.7257 0.2229 ... 0.4527 -0.7851 0.6203\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" -0.0771 0.6437 0.0904 ... 0.7471 -0.1233 0.5834\n",
" 0.2574 0.5758 -0.7474 ... -0.7041 -0.5262 -0.0979\n",
" -0.7506 0.5585 0.7684 ... 0.6497 -0.4813 0.6666\n",
" ... ⋱ ... \n",
" -0.7444 -0.7538 0.7003 ... 0.4310 -0.5093 0.3517\n",
" 0.3219 0.3901 0.6051 ... 0.3583 0.1989 -0.3522\n",
" -0.1288 -0.7456 0.0583 ... 0.2993 -0.7108 0.5442\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0.2901 0.6674 0.3532 ... 0.6043 0.0378 0.4850\n",
" 0.0526 0.5830 -0.5448 ... -0.7310 -0.1720 -0.2884\n",
" -0.4736 0.7938 0.5423 ... 0.3340 -0.3844 0.5235\n",
" ... ⋱ ... \n",
" -0.5986 -0.6321 0.4584 ... 0.6755 -0.0878 0.5513\n",
" -0.1147 0.1883 0.7690 ... 0.4660 -0.2028 0.0451\n",
" -0.5370 -0.7744 -0.2034 ... 0.0270 -0.5524 0.4043\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" -0.0960 0.6424 0.0764 ... 0.7530 -0.1313 0.5880\n",
" 0.2673 0.5754 -0.7553 ... -0.7027 -0.5411 -0.0879\n",
" -0.7606 0.5425 0.7767 ... 0.6621 -0.4859 0.6728\n",
" ... ⋱ ... \n",
" -0.7505 -0.7590 0.7098 ... 0.4160 -0.5269 0.3404\n",
" 0.3421 0.3996 0.5948 ... 0.3525 0.2186 -0.3704\n",
" -0.1052 -0.7441 0.0717 ... 0.3123 -0.7174 0.5507\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" -0.6530 0.5956 -0.4056 ... 0.8986 -0.4024 0.7288\n",
" 0.5766 0.5620 -0.9266 ... -0.6478 -0.8750 0.2667\n",
" -0.9502 -0.2102 0.9449 ... 0.9178 -0.6339 0.8404\n",
" ... ⋱ ... \n",
" -0.8985 -0.8906 0.9162 ... -0.2089 -0.8954 -0.1054\n",
" 0.8255 0.6785 0.1081 ... 0.1308 0.7458 -0.8147\n",
" 0.6356 -0.6834 0.5031 ... 0.6837 -0.8827 0.7396\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 25, 25])\n",
"Variable containing:\n",
" 1.5147e-01 1.0613e-01 1.2389e-01 ... 1.1655e-01 1.2373e-01 4.5109e-02\n",
" 1.7359e-01 9.4683e-02 1.2308e-01 ... 1.2920e-01 1.2182e-01 5.7374e-02\n",
" 1.9538e-01 9.4103e-02 1.1978e-01 ... 1.3478e-01 1.1749e-01 8.3565e-02\n",
" 1.7204e-01 9.5412e-02 1.2318e-01 ... 1.2892e-01 1.2201e-01 5.6234e-02\n",
" 9.9869e-02 1.1396e-01 1.2426e-01 ... 8.6980e-02 1.2402e-01 3.4973e-02\n",
"[torch.FloatTensor of size 5x1875]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" -8.1560\n",
" -8.2090\n",
" -7.7423\n",
" -8.2413\n",
" -7.6836\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" -0.0561 0.1150 -0.0158 ... 0.1489 -0.0408 0.1169\n",
" 0.0691 0.1046 -0.1523 ... -0.1256 -0.1262 0.0051\n",
" -0.1562 0.0606 0.1567 ... 0.1429 -0.0984 0.1352\n",
" ... ⋱ ... \n",
" -0.1486 -0.1487 0.1469 ... 0.0418 -0.1278 0.0373\n",
" 0.1008 0.0920 0.0846 ... 0.0520 0.0788 -0.1023\n",
" 0.0308 -0.1329 0.0408 ... 0.0829 -0.1438 0.1136\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" -0.0349 0.2912 0.0409 ... 0.3380 -0.0558 0.2639\n",
" 0.1165 0.2605 -0.3382 ... -0.3186 -0.2381 -0.0443\n",
" -0.3396 0.2527 0.3476 ... 0.2939 -0.2177 0.3016\n",
" ... ⋱ ... \n",
" -0.3368 -0.3411 0.3168 ... 0.1950 -0.2304 0.1591\n",
" 0.1456 0.1765 0.2738 ... 0.1621 0.0900 -0.1593\n",
" -0.0583 -0.3373 0.0264 ... 0.1354 -0.3216 0.2462\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0.2494 0.5738 0.3037 ... 0.5196 0.0325 0.4169\n",
" 0.0453 0.5013 -0.4684 ... -0.6285 -0.1479 -0.2480\n",
" -0.4072 0.6825 0.4662 ... 0.2871 -0.3305 0.4501\n",
" ... ⋱ ... \n",
" -0.5147 -0.5435 0.3941 ... 0.5808 -0.0755 0.4740\n",
" -0.0986 0.1619 0.6611 ... 0.4007 -0.1744 0.0388\n",
" -0.4617 -0.6658 -0.1748 ... 0.0232 -0.4750 0.3476\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" -0.0142 0.0947 0.0113 ... 0.1110 -0.0194 0.0867\n",
" 0.0394 0.0848 -0.1113 ... -0.1036 -0.0798 -0.0130\n",
" -0.1121 0.0800 0.1145 ... 0.0976 -0.0716 0.0992\n",
" ... ⋱ ... \n",
" -0.1106 -0.1119 0.1046 ... 0.0613 -0.0777 0.0502\n",
" 0.0504 0.0589 0.0877 ... 0.0520 0.0322 -0.0546\n",
" -0.0155 -0.1097 0.0106 ... 0.0460 -0.1057 0.0812\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" -0.3383 0.3086 -0.2101 ... 0.4655 -0.2084 0.3775\n",
" 0.2987 0.2911 -0.4800 ... -0.3356 -0.4533 0.1382\n",
" -0.4922 -0.1089 0.4895 ... 0.4754 -0.3284 0.4353\n",
" ... ⋱ ... \n",
" -0.4654 -0.4614 0.4746 ... -0.1082 -0.4638 -0.0546\n",
" 0.4276 0.3515 0.0560 ... 0.0677 0.3864 -0.4220\n",
" 0.3293 -0.3540 0.2606 ... 0.3542 -0.4573 0.3831\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 25, 25])\n",
"Variable containing:\n",
" 0.0952 0.0873 0.0993 ... 0.1108 0.1173 0.0978\n",
" 0.1114 0.0883 0.1083 ... 0.1197 0.1186 0.0860\n",
" 0.1747 0.0882 0.1165 ... 0.1321 0.1173 0.0872\n",
" 0.0957 0.0849 0.0960 ... 0.1127 0.1162 0.1017\n",
" 0.0918 0.1053 0.1100 ... 0.0979 0.1201 0.0709\n",
"[torch.FloatTensor of size 5x1875]\n",
"\n",
"Variable containing:\n",
"1.00000e-02 *\n",
" -7.7284\n",
" -7.4459\n",
" -7.5042\n",
" -7.7754\n",
" -7.2036\n",
"[torch.FloatTensor of size 5x1]\n",
"\n",
"generated\n",
"Variable containing:\n",
"(0 ,0 ,.,.) = \n",
" -0.2203 0.6340 -0.0180 ... 0.7896 -0.1849 0.6179\n",
" 0.3324 0.5729 -0.8027 ... -0.6930 -0.6324 -0.0207\n",
" -0.8190 0.4256 0.8262 ... 0.7360 -0.5163 0.7123\n",
" ... ⋱ ... \n",
" -0.7878 -0.7910 0.7671 ... 0.3101 -0.6332 0.2624\n",
" 0.4692 0.4609 0.5208 ... 0.3132 0.3457 -0.4848\n",
" 0.0541 -0.7335 0.1603 ... 0.3959 -0.7586 0.5922\n",
" ⋮ \n",
"\n",
"(1 ,0 ,.,.) = \n",
" 0.4087 0.6757 0.4377 ... 0.5399 0.0958 0.4460\n",
" -0.0232 0.5856 -0.4484 ... -0.7402 -0.0257 -0.3526\n",
" -0.3357 0.8466 0.4302 ... 0.1910 -0.3472 0.4627\n",
" ... ⋱ ... \n",
" -0.5328 -0.5782 0.3462 ... 0.7400 0.0824 0.6115\n",
" -0.2701 0.1103 0.8116 ... 0.5019 -0.3385 0.1915\n",
" -0.6466 -0.7840 -0.2927 ... -0.0744 -0.4822 0.3482\n",
" ⋮ \n",
"\n",
"(2 ,0 ,.,.) = \n",
" 0.5663 0.6880 0.5536 ... 0.4272 0.1834 0.3827\n",
" -0.1388 0.5896 -0.2781 ... -0.7538 0.1992 -0.4449\n",
" -0.0946 0.9040 0.2296 ... -0.0432 -0.2878 0.3602\n",
" ... ⋱ ... \n",
" -0.4177 -0.4843 0.1537 ... 0.8176 0.3317 0.6917\n",
" -0.4819 -0.0118 0.8635 ... 0.5538 -0.5210 0.3991\n",
" -0.7738 -0.7981 -0.4203 ... -0.2265 -0.3611 0.2571\n",
" ⋮ \n",
"\n",
"(3 ,0 ,.,.) = \n",
" -0.6398 0.5973 -0.3914 ... 0.8954 -0.3941 0.7248\n",
" 0.5681 0.5625 -0.9234 ... -0.6497 -0.8690 0.2555\n",
" -0.9475 -0.1839 0.9421 ... 0.9136 -0.6296 0.8364\n",
" ... ⋱ ... \n",
" -0.8953 -0.8876 0.9125 ... -0.1880 -0.8895 -0.0902\n",
" 0.8166 0.6712 0.1271 ... 0.1385 0.7347 -0.8061\n",
" 0.6183 -0.6856 0.4911 ... 0.6745 -0.8791 0.7346\n",
" ⋮ \n",
"\n",
"(4 ,0 ,.,.) = \n",
" -0.4089 0.6199 -0.1721 ... 0.8395 -0.2706 0.6634\n",
" 0.4328 0.5688 -0.8631 ... -0.6765 -0.7510 0.0901\n",
" -0.8877 0.1995 0.8864 ... 0.8277 -0.5635 0.7685\n",
" ... ⋱ ... \n",
" -0.8385 -0.8357 0.8402 ... 0.1188 -0.7665 0.1265\n",
" 0.6412 0.5527 0.3802 ... 0.2462 0.5287 -0.6412\n",
" 0.3071 -0.7154 0.3002 ... 0.5201 -0.8151 0.6541\n",
"[torch.FloatTensor of size 5x1x28x28]\n",
"\n",
"torch.Size([5, 3, 25, 25])\n",
"Variable containing:\n",
" 1.6241e-01 9.8943e-02 1.2214e-01 ... 1.2023e-01 1.2195e-01 4.7683e-02\n",
" 1.9938e-01 9.8590e-02 1.1673e-01 ... 1.3575e-01 1.1490e-01 9.1535e-02\n",
" 2.0071e-01 1.0678e-01 1.1405e-01 ... 1.3898e-01 1.1325e-01 1.0587e-01\n",
" 1.0286e-01 1.1230e-01 1.2154e-01 ... 8.5690e-02 1.2290e-01 3.3641e-02\n",
" 1.4082e-01 1.0897e-01 1.2234e-01 ... 1.0656e-01 1.2304e-01 3.9299e-02\n",
"[torch.FloatTensor of size 5x1875]\n",
"\n",
"Variable containing:\n",
"-0.1610\n",
"-0.1399\n",
"-0.1315\n",
"-0.1537\n",
"-0.1601\n",
"[torch.FloatTensor of size 5x1]\n",
"\n"
]
}
],
"source": [
"train(G(),disc=D2())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment