Skip to content

Instantly share code, notes, and snippets.

@ak9250
Forked from antonpaquin/photo2cartoon.patch
Created August 7, 2020 15:03
Show Gist options
  • Save ak9250/5349f5b8a59b5a1d228efe35470e9a14 to your computer and use it in GitHub Desktop.
Save ak9250/5349f5b8a59b5a1d228efe35470e9a14 to your computer and use it in GitHub Desktop.
diff --git a/convert.py b/convert.py
new file mode 100644
index 0000000..266347b
--- /dev/null
+++ b/convert.py
@@ -0,0 +1,71 @@
+import os
+import torch
+from models import UgatitSadalinHourglass
+
+class Namespace:
+ pass
+
+n = Namespace()
+
+n.phase = 'train'
+n.light = True
+
+n.dataset = 'photo2cartoon'
+
+n.iteration = 1000000
+n.batch_size = 1
+n.print_freq = 1000
+n.save_freq = 1000
+n.decay_flag = True
+
+n.lr = 0.0001
+n.adv_weight = 1
+n.cycle_weight = 50
+n.identity_weight = 10
+n.cam_weight = 1000
+n.faceid_weight = 1
+
+n.ch = 32
+n.n_dis = 6
+
+n.img_size = 256
+n.img_ch = 3
+
+n.gpu_ids = [0]
+n.benchmark_flag = False
+n.resume = False
+n.rho_clipper = 1.0
+n.w_clipper = 1.0
+n.pretrained_weights = 'models/photo2cartoon_weights.pt'
+
+
+n.result_dir = './experiment/{}-size{}-ch{}-{}-lr{}-adv{}-cyc{}-id{}-identity{}-cam{}'.format(
+ os.path.basename(__file__)[:-3],
+ n.img_size,
+ n.ch,
+ n.light,
+ n.lr,
+ n.adv_weight,
+ n.cycle_weight,
+ n.faceid_weight,
+ n.identity_weight,
+ n.cam_weight)
+
+m = UgatitSadalinHourglass(n)
+m.build_model()
+
+trace_input = torch.rand([1, 3, 256, 256])
+dynamic_axes = {'image': {0: 'batch', 2: 'height', 3: 'width'}}
+
+torch.onnx.export(
+ model=m.genB2A,
+ args=(trace_input,),
+ f='photo2cartoon_b2a.onnx',
+ input_names=('image',),
+ output_names=('output',),
+ dynamic_axes=dynamic_axes,
+ opset_version=11,
+)
+
+import code
+code.interact(local={**globals(), **locals()})
diff --git a/models/UGATIT_sadalin_hourglass.py b/models/UGATIT_sadalin_hourglass.py
index 8e0fe16..beb1aef 100644
--- a/models/UGATIT_sadalin_hourglass.py
+++ b/models/UGATIT_sadalin_hourglass.py
@@ -1,10 +1,10 @@
import time
+import os
import itertools
from dataset import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from .networks import *
-from utils import *
from glob import glob
from .face_features import FaceFeatures
@@ -44,7 +44,6 @@ class UgatitSadalinHourglass(object):
self.img_size = args.img_size
self.img_ch = args.img_ch
- self.device = f'cuda:{args.gpu_ids[0]}'
self.gpu_ids = args.gpu_ids
self.benchmark_flag = args.benchmark_flag
self.resume = args.resume
@@ -94,30 +93,21 @@ class UgatitSadalinHourglass(object):
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
- self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
- self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
- self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
- self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
-
- self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
- self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
- self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
- self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
""" Define Generator, Discriminator """
- self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
- self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
- self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
- self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
- self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
- self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
+ self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light)
+ self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light)
+ self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
+ self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
+ self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
+ self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
- self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device)
+ self.facenet = FaceFeatures('models/model_mobilefacenet.pth')
""" Define Loss """
- self.L1_loss = nn.L1Loss().to(self.device)
- self.MSE_loss = nn.MSELoss().to(self.device)
- self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)
+ self.L1_loss = nn.L1Loss()
+ self.MSE_loss = nn.MSELoss()
+ self.BCE_loss = nn.BCEWithLogitsLoss()
""" Trainer """
self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001)
@@ -130,6 +120,16 @@ class UgatitSadalinHourglass(object):
self.Rho_clipper = RhoClipper(0, self.rho_clipper)
self.W_Clipper = WClipper(0, self.w_clipper)
+ if self.pretrained_weights:
+ params = torch.load(self.pretrained_weights, map_location='cpu')
+ self.genA2B.load_state_dict(params['genA2B'])
+ self.genB2A.load_state_dict(params['genB2A'])
+ self.disGA.load_state_dict(params['disGA'])
+ self.disGB.load_state_dict(params['disGB'])
+ self.disLA.load_state_dict(params['disLA'])
+ self.disLB.load_state_dict(params['disLB'])
+ print(" [*] Load {} Success".format(self.pretrained_weights))
+
def train(self):
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
@@ -145,15 +145,6 @@ class UgatitSadalinHourglass(object):
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
- if self.pretrained_weights:
- params = torch.load(self.pretrained_weights, map_location=self.device)
- self.genA2B.load_state_dict(params['genA2B'])
- self.genB2A.load_state_dict(params['genB2A'])
- self.disGA.load_state_dict(params['disGA'])
- self.disGB.load_state_dict(params['disGB'])
- self.disLA.load_state_dict(params['disLA'])
- self.disLB.load_state_dict(params['disLB'])
- print(" [*] Load {} Success".format(self.pretrained_weights))
if len(self.gpu_ids) > 1:
self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids)
diff --git a/models/face_features.py b/models/face_features.py
index 8ea3fbe..53de802 100644
--- a/models/face_features.py
+++ b/models/face_features.py
@@ -4,10 +4,9 @@ from .mobilefacenet import MobileFaceNet
class FaceFeatures(object):
- def __init__(self, weights_path, device):
- self.device = device
- self.model = MobileFaceNet(512).to(device)
- self.model.load_state_dict(torch.load(weights_path))
+ def __init__(self, weights_path):
+ self.model = MobileFaceNet(512)
+ self.model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
self.model.eval()
def infer(self, batch_tensor):
diff --git a/models/networks.py b/models/networks.py
index 165aa2f..3bac424 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -363,9 +363,11 @@ class adaLIN(nn.Module):
self.rho.data.fill_(0.9)
def forward(self, input, gamma, beta):
- in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
+ in_mean = torch.mean(input, dim=[2, 3], keepdim=True)
+ in_var = torch.mean((input - in_mean) ** 2, dim=[2, 3], keepdim=True)
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
- ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
+ ln_mean = torch.mean(input, dim=[1, 2, 3], keepdim=True)
+ ln_var = torch.mean((input - ln_mean) ** 2, dim=[1, 2, 3], keepdim=True)
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
@@ -385,9 +387,11 @@ class LIN(nn.Module):
self.beta.data.fill_(0.0)
def forward(self, input):
- in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
+ in_mean = torch.mean(input, dim=[2, 3], keepdim=True)
+ in_var = torch.mean((input - in_mean) ** 2, dim=[2, 3], keepdim=True)
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
- ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
+ ln_mean = torch.mean(input, dim=[1, 2, 3], keepdim=True)
+ ln_var = torch.mean((input - ln_mean) ** 2, dim=[1, 2, 3], keepdim=True)
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment