-
-
Save ak9250/5349f5b8a59b5a1d228efe35470e9a14 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/convert.py b/convert.py | |
new file mode 100644 | |
index 0000000..266347b | |
--- /dev/null | |
+++ b/convert.py | |
@@ -0,0 +1,71 @@ | |
+import os | |
+import torch | |
+from models import UgatitSadalinHourglass | |
+ | |
+class Namespace: | |
+ pass | |
+ | |
+n = Namespace() | |
+ | |
+n.phase = 'train' | |
+n.light = True | |
+ | |
+n.dataset = 'photo2cartoon' | |
+ | |
+n.iteration = 1000000 | |
+n.batch_size = 1 | |
+n.print_freq = 1000 | |
+n.save_freq = 1000 | |
+n.decay_flag = True | |
+ | |
+n.lr = 0.0001 | |
+n.adv_weight = 1 | |
+n.cycle_weight = 50 | |
+n.identity_weight = 10 | |
+n.cam_weight = 1000 | |
+n.faceid_weight = 1 | |
+ | |
+n.ch = 32 | |
+n.n_dis = 6 | |
+ | |
+n.img_size = 256 | |
+n.img_ch = 3 | |
+ | |
+n.gpu_ids = [0] | |
+n.benchmark_flag = False | |
+n.resume = False | |
+n.rho_clipper = 1.0 | |
+n.w_clipper = 1.0 | |
+n.pretrained_weights = 'models/photo2cartoon_weights.pt' | |
+ | |
+ | |
+n.result_dir = './experiment/{}-size{}-ch{}-{}-lr{}-adv{}-cyc{}-id{}-identity{}-cam{}'.format( | |
+ os.path.basename(__file__)[:-3], | |
+ n.img_size, | |
+ n.ch, | |
+ n.light, | |
+ n.lr, | |
+ n.adv_weight, | |
+ n.cycle_weight, | |
+ n.faceid_weight, | |
+ n.identity_weight, | |
+ n.cam_weight) | |
+ | |
+m = UgatitSadalinHourglass(n) | |
+m.build_model() | |
+ | |
+trace_input = torch.rand([1, 3, 256, 256]) | |
+dynamic_axes = {'image': {0: 'batch', 2: 'height', 3: 'width'}} | |
+ | |
+torch.onnx.export( | |
+ model=m.genB2A, | |
+ args=(trace_input,), | |
+ f='photo2cartoon_b2a.onnx', | |
+ input_names=('image',), | |
+ output_names=('output',), | |
+ dynamic_axes=dynamic_axes, | |
+ opset_version=11, | |
+) | |
+ | |
+import code | |
+code.interact(local={**globals(), **locals()}) | |
diff --git a/models/UGATIT_sadalin_hourglass.py b/models/UGATIT_sadalin_hourglass.py | |
index 8e0fe16..beb1aef 100644 | |
--- a/models/UGATIT_sadalin_hourglass.py | |
+++ b/models/UGATIT_sadalin_hourglass.py | |
@@ -1,10 +1,10 @@ | |
import time | |
+import os | |
import itertools | |
from dataset import ImageFolder | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
from .networks import * | |
-from utils import * | |
from glob import glob | |
from .face_features import FaceFeatures | |
@@ -44,7 +44,6 @@ class UgatitSadalinHourglass(object): | |
self.img_size = args.img_size | |
self.img_ch = args.img_ch | |
- self.device = f'cuda:{args.gpu_ids[0]}' | |
self.gpu_ids = args.gpu_ids | |
self.benchmark_flag = args.benchmark_flag | |
self.resume = args.resume | |
@@ -94,30 +93,21 @@ class UgatitSadalinHourglass(object): | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
- self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform) | |
- self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform) | |
- self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform) | |
- self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform) | |
- | |
- self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True) | |
- self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True) | |
- self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False) | |
- self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False) | |
""" Define Generator, Discriminator """ | |
- self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) | |
- self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) | |
- self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) | |
- self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) | |
- self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) | |
- self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) | |
+ self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light) | |
+ self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light) | |
+ self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7) | |
+ self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7) | |
+ self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5) | |
+ self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5) | |
- self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device) | |
+ self.facenet = FaceFeatures('models/model_mobilefacenet.pth') | |
""" Define Loss """ | |
- self.L1_loss = nn.L1Loss().to(self.device) | |
- self.MSE_loss = nn.MSELoss().to(self.device) | |
- self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) | |
+ self.L1_loss = nn.L1Loss() | |
+ self.MSE_loss = nn.MSELoss() | |
+ self.BCE_loss = nn.BCEWithLogitsLoss() | |
""" Trainer """ | |
self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001) | |
@@ -130,6 +120,16 @@ class UgatitSadalinHourglass(object): | |
self.Rho_clipper = RhoClipper(0, self.rho_clipper) | |
self.W_Clipper = WClipper(0, self.w_clipper) | |
+ if self.pretrained_weights: | |
+ params = torch.load(self.pretrained_weights, map_location='cpu') | |
+ self.genA2B.load_state_dict(params['genA2B']) | |
+ self.genB2A.load_state_dict(params['genB2A']) | |
+ self.disGA.load_state_dict(params['disGA']) | |
+ self.disGB.load_state_dict(params['disGB']) | |
+ self.disLA.load_state_dict(params['disLA']) | |
+ self.disLB.load_state_dict(params['disLB']) | |
+ print(" [*] Load {} Success".format(self.pretrained_weights)) | |
+ | |
def train(self): | |
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() | |
@@ -145,15 +145,6 @@ class UgatitSadalinHourglass(object): | |
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) | |
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) | |
- if self.pretrained_weights: | |
- params = torch.load(self.pretrained_weights, map_location=self.device) | |
- self.genA2B.load_state_dict(params['genA2B']) | |
- self.genB2A.load_state_dict(params['genB2A']) | |
- self.disGA.load_state_dict(params['disGA']) | |
- self.disGB.load_state_dict(params['disGB']) | |
- self.disLA.load_state_dict(params['disLA']) | |
- self.disLB.load_state_dict(params['disLB']) | |
- print(" [*] Load {} Success".format(self.pretrained_weights)) | |
if len(self.gpu_ids) > 1: | |
self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids) | |
diff --git a/models/face_features.py b/models/face_features.py | |
index 8ea3fbe..53de802 100644 | |
--- a/models/face_features.py | |
+++ b/models/face_features.py | |
@@ -4,10 +4,9 @@ from .mobilefacenet import MobileFaceNet | |
class FaceFeatures(object): | |
- def __init__(self, weights_path, device): | |
- self.device = device | |
- self.model = MobileFaceNet(512).to(device) | |
- self.model.load_state_dict(torch.load(weights_path)) | |
+ def __init__(self, weights_path): | |
+ self.model = MobileFaceNet(512) | |
+ self.model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) | |
self.model.eval() | |
def infer(self, batch_tensor): | |
diff --git a/models/networks.py b/models/networks.py | |
index 165aa2f..3bac424 100644 | |
--- a/models/networks.py | |
+++ b/models/networks.py | |
@@ -363,9 +363,11 @@ class adaLIN(nn.Module): | |
self.rho.data.fill_(0.9) | |
def forward(self, input, gamma, beta): | |
- in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) | |
+ in_mean = torch.mean(input, dim=[2, 3], keepdim=True) | |
+ in_var = torch.mean((input - in_mean) ** 2, dim=[2, 3], keepdim=True) | |
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) | |
- ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) | |
+ ln_mean = torch.mean(input, dim=[1, 2, 3], keepdim=True) | |
+ ln_var = torch.mean((input - ln_mean) ** 2, dim=[1, 2, 3], keepdim=True) | |
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) | |
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln | |
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) | |
@@ -385,9 +387,11 @@ class LIN(nn.Module): | |
self.beta.data.fill_(0.0) | |
def forward(self, input): | |
- in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) | |
+ in_mean = torch.mean(input, dim=[2, 3], keepdim=True) | |
+ in_var = torch.mean((input - in_mean) ** 2, dim=[2, 3], keepdim=True) | |
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) | |
- ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) | |
+ ln_mean = torch.mean(input, dim=[1, 2, 3], keepdim=True) | |
+ ln_var = torch.mean((input - ln_mean) ** 2, dim=[1, 2, 3], keepdim=True) | |
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) | |
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln | |
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment