Skip to content

Instantly share code, notes, and snippets.

@Luolc
Last active November 11, 2020 02:40
Show Gist options
  • Save Luolc/d776dee3c26db204d2f7a9958bdc0cee to your computer and use it in GitHub Desktop.
Save Luolc/d776dee3c26db204d2f7a9958bdc0cee to your computer and use it in GitHub Desktop.
import argparse
import os
import numpy as np
import math
import sys
import pdb
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.datasets.mnist import MNIST
from lenet import LeNet5Half
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import resnet
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=500, help='size of the batches')
parser.add_argument('--teacher_dir', type=str, default='cache/models')
parser.add_argument('--generator_dir', type=str, default='cache/models/generators')
parser.add_argument('--latent_dim', type=int, default=1000, help='dimensionality of the latent space')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
opt = parser.parse_args()
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = opt.img_size // 4
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2))
self.conv_blocks0 = nn.Sequential(
nn.BatchNorm2d(128),
)
self.conv_blocks1 = nn.Sequential(
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv_blocks2 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
nn.BatchNorm2d(opt.channels, affine=False)
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks0(out)
img = nn.functional.interpolate(img,scale_factor=2)
img = self.conv_blocks1(img)
img = nn.functional.interpolate(img,scale_factor=2)
img = self.conv_blocks2(img)
return img
generator = torch.load(opt.generator_dir)
generator.eval()
z = Variable(torch.randn(opt.batch_size, opt.latent_dim))
teacher = torch.load(opt.teacher_dir + 'teacher')
teacher.eval()
gen_imgs = generator(z)
outputs_T, features_T = teacher(gen_imgs, out_feature=True)
pred = outputs_T.data.max(1)[1]
# inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],std=[1/0.229, 1/0.224, 1/0.255])
inv_normalize = transforms.Compose([transforms.Normalize(mean = [ 0., 0., 0. ],
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
std = [ 1., 1., 1. ]),
])
gen_imgs = inv_normalize(gen_imgs)
def save_images(images, targets):
# method to store generated images locally
for id in range(images.shape[0]):
class_id = targets[id].item()
#save into separate folders
place_to_store = 'images/s{:03d}_img_id{:03d}.jpg'.format(class_id, id)
image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
pil_image.save(place_to_store)
save_images(gen_imgs, pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment