Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created August 19, 2019 09:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/024e2c247a94c050165942e96a886d27 to your computer and use it in GitHub Desktop.
Save koshian2/024e2c247a94c050165942e96a886d27 to your computer and use it in GitHub Desktop.
ACGAN(5) AnimeFace, 10, resnet
import torch
from torch import nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv1 = self.conv_bn_relu(ch)
self.conv2 = self.conv_bn_relu(ch)
def conv_bn_relu(self, ch):
return nn.Sequential(
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
nn.BatchNorm2d(ch),
nn.ReLU(True)
)
def forward(self, inputs):
x = self.conv2(self.conv1(inputs))
return inputs + x
class Generator(nn.Module):
def __init__(self, upsampling_type):
assert upsampling_type in ["nearest_neighbor", "transpose_conv", "pixel_shuffler"]
self.upsampling_type = upsampling_type
super().__init__()
self.inital = nn.Sequential(
nn.Conv2d(110, 768, 1),
nn.BatchNorm2d(768),
nn.ReLU(True)
)
self.conv1 = self.generator_block(768, 512, 4, 2)
self.conv2 = self.generator_block(512, 256, 2, 2)
self.conv3 = self.generator_block(256, 128, 2, 2)
self.conv4 = self.generator_block(128, 64, 2, 2)
self.conv5 = self.generator_block(64, 32, 2, 2)
self.conv6 = self.generator_block(32, 16, 2, 1)
self.out = nn.Sequential(
nn.Conv2d(16, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def generator_block(self, in_ch, out_ch, upsampling_factor, n_residual_block):
layers = []
if self.upsampling_type == "transpose_conv":
layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=upsampling_factor, stride=upsampling_factor))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(True))
elif self.upsampling_type == "nearest_neighbor":
layers.append(nn.UpsamplingNearest2d(scale_factor=upsampling_factor))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(True))
elif self.upsampling_type == "pixel_shuffler":
layers.append(nn.Conv2d(in_ch, out_ch * upsampling_factor ** 2, kernel_size=1))
layers.append(nn.BatchNorm2d(out_ch * upsampling_factor ** 2))
layers.append(nn.ReLU(True))
layers.append(nn.PixelShuffle(upscale_factor=upsampling_factor))
for i in range(n_residual_block):
layers.append(ResidualBlock(out_ch))
return nn.Sequential(*layers)
def forward(self, inputs):
x = self.conv6(self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(self.inital(inputs)))))))
return self.out(x)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(3, 32, 2, 1)
self.conv2 = self.conv_bn_relu(32, 64, 2, 2)
self.conv3 = self.conv_bn_relu(64, 128, 2, 2)
self.conv4 = self.conv_bn_relu(128, 256, 2, 2)
self.conv5 = self.conv_bn_relu(256, 512, 2, 2)
self.prob = nn.Linear(512, 1)
self.classes = nn.Linear(512, 10)
def conv_bn_relu(self, in_ch, out_ch, reps, pooling_size):
layers = []
if pooling_size > 1:
layers.append(nn.AvgPool2d(pooling_size))
for i in range(reps):
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, out_ch, 3, padding=1))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, True))
# layers.append(nn.Dropout(0.5))
return nn.Sequential(*layers)
def forward(self, inputs):
x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(inputs)))))
x = F.avg_pool2d(x, kernel_size=8).view(x.size(0), -1)
prob = self.prob(x)
classes = self.classes(x)
return prob, classes
if __name__ == "__main__":
model = Discriminator()
summary(model, (3,128,128), device="cpu")
import torch
from torch import nn
import torchvision
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from models import Generator, Discriminator
import os
import shutil
import pickle
import statistics
import glob
def load_dataset(batch_size):
# 前処理
for dir in sorted(glob.glob("thumb/*")):
imgs = glob.glob(dir + "/*.png")
if len(imgs) == 0:
shutil.rmtree(dir)
trans = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
dataset = torchvision.datasets.ImageFolder(root="./thumb10", transform=trans) # thumb10で10種類
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=6)
return dataloader
def weight_init(layer):
if type(layer) in [nn.Conv2d, nn.ConvTranspose2d]:
nn.init.normal_(layer.weight, 0.0, 0.02)
class HingeLoss(nn.Module):
def __init__(self, batch_size, device):
super().__init__()
self.ones = torch.ones(batch_size).to(device)
self.zeros = torch.zeros(batch_size).to(device)
def __call__(self, logits, condition):
assert condition in ["gen", "dis_real", "dis_fake"]
batch_len = len(logits)
if condition == "gen":
# Generatorでは、本物になるようにHinge lossを返す
return -torch.mean(logits)
elif condition == "dis_real":
minval = torch.min(logits - 1, self.zeros[:batch_len])
return -torch.mean(minval)
else:
minval = torch.min(-logits - 1, self.zeros[:batch_len])
return - torch.mean(minval)
def train(upsampling_type):
assert upsampling_type in ["nearest_neighbor", "transpose_conv", "pixel_shuffler"]
output_dir = "anime_acgan_" + upsampling_type
device = "cuda"
batch_size = 128
dataloader = load_dataset(batch_size)
model_G = Generator(upsampling_type)
model_D = Discriminator()
model_G.apply(weight_init)
model_D.apply(weight_init)
model_G, model_D = model_G.to(device), model_D.to(device)
if device == "cuda":
model_G, model_D = torch.nn.DataParallel(model_G), torch.nn.DataParallel(model_D)
param_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
param_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
hinge_loss = HingeLoss(batch_size, device)
softmax_loss = torch.nn.CrossEntropyLoss()
result = {"d_loss":[], "g_loss":[]}
for epoch in range(1):
log_loss_D, log_loss_G = [], []
for real_img, real_label in tqdm(dataloader):
batch_len = len(real_img)
real_img, real_label = real_img.to(device), real_label.to(device)
# train G
rand_X = torch.randn(batch_len, 100, 1, 1)
label_onehot = torch.eye(10)[real_label] # 176
label_onehot = label_onehot.view(batch_len, 10, 1, 1)
rand_X = torch.cat([rand_X, label_onehot], dim=1)
rand_X = rand_X.to(device)
fake_img = model_G(rand_X)
fake_img_tensor = fake_img.detach()
g_out = model_D(fake_img)
loss = hinge_loss(g_out[0], "gen")
loss += softmax_loss(g_out[1], real_label)
log_loss_G.append(loss.item())
# backprop
param_D.zero_grad()
param_G.zero_grad()
loss.backward()
param_G.step()
# train D
# train real
d_out_real = model_D(real_img)
loss_real = hinge_loss(d_out_real[0], "dis_real")
loss_real += softmax_loss(d_out_real[1], real_label)
# train fake
d_out_fake = model_D(fake_img_tensor)
loss_fake = hinge_loss(d_out_fake[0], "dis_fake")
loss_fake += softmax_loss(d_out_fake[1], real_label)
loss = (loss_real + loss_fake) / 2.0
log_loss_D.append(loss.item())
# backprop
param_D.zero_grad()
param_G.zero_grad()
loss.backward()
param_D.step()
# ログ
result["d_loss"].append(statistics.mean(log_loss_D))
result["g_loss"].append(statistics.mean(log_loss_G))
print(f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
torchvision.utils.save_image(fake_img_tensor[:36], f"{output_dir}/epoch_{epoch:03}.png", nrow=6, padding=3, normalize=True, range=(-1.0, 1.0))
# 係数保存
if not os.path.exists(output_dir + "/models"):
os.mkdir(output_dir+"/models")
if epoch % 10 == 0:
torch.save(model_G.state_dict(), f"{output_dir}/models/gen_epoch_{epoch:03}.pytorch")
torch.save(model_D.state_dict(), f"{output_dir}/models/dis_epoch_{epoch:03}.pytorch")
# ログ
with open(output_dir + "/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
def copy_top10():
pic_size = []
dirs = sorted(glob.glob("thumb/*"))
for dir in dirs:
pic_size.append(len(glob.glob(dir + "/*.png")))
pic_size = np.array(pic_size)
idx = np.argsort(pic_size)[::-1]
top10dirs = np.array(dirs)[idx][:10]
if not os.path.exists("thumb10"):
os.mkdir("thumb10")
for d in top10dirs:
shutil.copytree(d, d.replace("thumb", "thumb10"))
if __name__ == "__main__":
for upsampling in ["pixel_shuffler"]:
train(upsampling)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment