GAN 训练框架版本v2。
对简单的 Conditional GAN 和 Unconditional GAN 可以直接简单套用。
自定义的训练方式通过继承TrainerFrame类,并重写某些函数的形式,如例子的DCGANTrainer
。
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# apply 函数针对net.children(), apply的参数为一个函数,该函数的输入是 Module | |
def weights_init(m): | |
class_name = m.__class__.__name__ | |
if class_name.find("conv") != -1: | |
m.weight.data.normal_(0,0.02) | |
elif class_name.find("norm") != -1: |
import torch | |
import torch.nn.functional as F | |
from torch.nn.functional import conv2d | |
import numpy as np | |
mse = lambda x,y: np.mean((x-y).flatten() ** 2) | |
def psnr(image1, image2, L = 255, eps = 1e-6): | |
m = mse(image1, image2) | |
return np.log10(L**2 / (m + eps)) * 10 |
import os, sys, json | |
import re | |
class BibtexParser: | |
""" Simple Parser of bibtex | |
""" | |
def __init__(self, path, encoding = "utf-8"): | |
# self.path = os.path.abspath(path) | |
self.path = path | |
self.encoding = encoding |
""" ImageViewer which supports dragging and scaling and | |
fits altered size of windows. | |
""" | |
import os | |
import cv2 | |
import numpy as np | |
from PIL import ImageTk | |
import PIL.Image as Image | |
from copy import deepcopy |