Last active
November 11, 2020 02:40
-
-
Save Luolc/d776dee3c26db204d2f7a9958bdc0cee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import numpy as np | |
import math | |
import sys | |
import pdb | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from torchvision import datasets | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
from torchvision.datasets.mnist import MNIST | |
from lenet import LeNet5Half | |
from torchvision.datasets import CIFAR10 | |
from torchvision.datasets import CIFAR100 | |
import resnet | |
from PIL import Image | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--batch_size', type=int, default=500, help='size of the batches') | |
parser.add_argument('--teacher_dir', type=str, default='cache/models') | |
parser.add_argument('--generator_dir', type=str, default='cache/models/generators') | |
parser.add_argument('--latent_dim', type=int, default=1000, help='dimensionality of the latent space') | |
parser.add_argument('--channels', type=int, default=3, help='number of image channels') | |
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') | |
opt = parser.parse_args() | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.init_size = opt.img_size // 4 | |
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2)) | |
self.conv_blocks0 = nn.Sequential( | |
nn.BatchNorm2d(128), | |
) | |
self.conv_blocks1 = nn.Sequential( | |
nn.Conv2d(128, 128, 3, stride=1, padding=1), | |
nn.BatchNorm2d(128, 0.8), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
self.conv_blocks2 = nn.Sequential( | |
nn.Conv2d(128, 64, 3, stride=1, padding=1), | |
nn.BatchNorm2d(64, 0.8), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), | |
nn.Tanh(), | |
nn.BatchNorm2d(opt.channels, affine=False) | |
) | |
def forward(self, z): | |
out = self.l1(z) | |
out = out.view(out.shape[0], 128, self.init_size, self.init_size) | |
img = self.conv_blocks0(out) | |
img = nn.functional.interpolate(img,scale_factor=2) | |
img = self.conv_blocks1(img) | |
img = nn.functional.interpolate(img,scale_factor=2) | |
img = self.conv_blocks2(img) | |
return img | |
generator = torch.load(opt.generator_dir) | |
generator.eval() | |
z = Variable(torch.randn(opt.batch_size, opt.latent_dim)) | |
teacher = torch.load(opt.teacher_dir + 'teacher') | |
teacher.eval() | |
gen_imgs = generator(z) | |
outputs_T, features_T = teacher(gen_imgs, out_feature=True) | |
pred = outputs_T.data.max(1)[1] | |
# inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],std=[1/0.229, 1/0.224, 1/0.255]) | |
inv_normalize = transforms.Compose([transforms.Normalize(mean = [ 0., 0., 0. ], | |
std = [ 1/0.229, 1/0.224, 1/0.225 ]), | |
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], | |
std = [ 1., 1., 1. ]), | |
]) | |
gen_imgs = inv_normalize(gen_imgs) | |
def save_images(images, targets): | |
# method to store generated images locally | |
for id in range(images.shape[0]): | |
class_id = targets[id].item() | |
#save into separate folders | |
place_to_store = 'images/s{:03d}_img_id{:03d}.jpg'.format(class_id, id) | |
image_np = images[id].data.cpu().numpy().transpose((1, 2, 0)) | |
pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) | |
pil_image.save(place_to_store) | |
save_images(gen_imgs, pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment