Last active August 28, 2021 00:41
import copy
import os
import simplejson as json
import click
import imageio
import numpy as np
import PIL.Image
import torch
import torchvision
import torch.nn.functional as F
import dnnlib
import legacy
import clip
import hashlib
def approach(
num_steps = 100,
w_avg_samples = 10000,
initial_learning_rate = 0.02,
initial_noise_factor = 0.02,
noise_floor = 0.02,
psi = 0.8,
noise_ramp_length = 1.0, # was 0.75
regularize_noise_weight = 10000, # was 1e5
seed = 69097,
noise_opt = True,
ws = None,
text = 'a computer generated image',
device: torch.device
local_args = dict(locals())
params = []
for x in local_args:
if x != 'G' and x != 'device':
G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
lr = initial_learning_rate
# Compute w stats.
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
#w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None, truncation_psi=0.8) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
# derive W from seed
if ws is None:
print('Generating w for seed %i' % seed )
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
w_samples = G.mapping(z, None, truncation_psi=psi)
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
w_avg = np.mean(w_samples, axis=0, keepdims=True)
w_samples = torch.tensor(ws, device=device)
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
w_avg = np.mean(w_samples, axis=0, keepdims=True)
#w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
w_std = 2 # ~9.9 for portraits network. should compute if using median median
# Setup noise inputs.
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
if noise_opt:
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
print('optimizer: w + noise')
optimizer = torch.optim.Adam([w_opt] , betas=(0.9, 0.999), lr=initial_learning_rate)
print('optimizer: w')
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
# Load the perceptor
print('Loading perceptor for text:', text)
perceptor, preprocess = clip.load('ViT-B/32', jit=True)
perceptor = perceptor.eval()
tx = clip.tokenize(text)
whispers = perceptor.encode_text(tx.cuda()).detach().clone()
# Descend
for step in range(num_steps):
# noise schedule
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
# floor
if w_noise_scale < noise_floor:
w_noise_scale = noise_floor
# lr schedule is disabled
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
''' for param_group in optimizer.param_groups:
param_group['lr'] = lr
# do G.synthesis
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
synth_images = G.synthesis(ws, noise_mode='const')
synth_images_save = (synth_images + 1) * (255/2)
synth_images_save = synth_images_save.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_images_save, 'RGB').save('project/test1.png')
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
into = synth_images
into = nom(into) # normalize copied from CLIP preprocess. doesn't seem to affect tho
# scale to CLIP input size
into = torch.nn.functional.interpolate(synth_images, (224,224), mode='bilinear', align_corners=True)
# CLIP expects [1, 3, 224, 224], so we should be fine
glimmers = perceptor.encode_image(into)
away = -30 * torch.cosine_similarity(whispers, glimmers, dim = -1).mean() # Dunno why 30 works lol
# noise reg, from og projector
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
if noise.shape[2] <= 8:
noise = F.avg_pool2d(noise, kernel_size=2)
if noise_opt:
loss = away + reg_loss * regularize_noise_weight
loss = away
# Step
print(f'step {step+1:>4d}/{num_steps}: loss {float(loss):<5.2f} ','lr', lr, f'noise scale: {float(w_noise_scale):<5.6f}',f'away: {float(away / (-30)):<5.6f}')
w_out[step] = w_opt.detach()[0]
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
return w_out.repeat([1, G.mapping.num_ws, 1])
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
@click.option('--seed', help='Initial image seed', type=int, default=232322, show_default=True)
@click.option('--w', help='Do not use seed but load w from a file', type=str, metavar='FILE')
@click.option('--lr', help='Adam learning rate', type=float, required=False, default=0.02)
@click.option('--psi', help='Truncation psi for initial image', type=float, required=False, default=0.81)
@click.option('--inf', help='Initial noise factor', type=float, required=False, default=0.02)
@click.option('--nf', help='Noise floor', type=float, required=False, default=0.02)
@click.option('--noise-opt', help='Optimize noise vars as well as w', type=bool, required=False, default=True)
@click.option('--text', help='Text prompt', required=False, default='A computer-generated image')
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
@click.option('--save-ws', help='Save intermediate ws', type=bool, default=False, show_default=True)
def run_approach(
network_pkl: str,
outdir: str,
save_video: bool,
save_ws: bool,
seed: int,
num_steps: int,
text: str,
lr: float,
inf: float,
nf: float,
w: str,
psi: float,
noise_opt: bool
"""Descend on StyleGAN2 w vector value using CLIP, tuning an image with given text prompt.
python3 --network network-snapshot-ffhq.pkl --outdir project --num-steps 100 \\
--text 'an image of a girl with a face resembling Paul Krugman' --psi 0.8 --seed 12345
#seed = 1
local_args = dict(locals())
params = []
for x in local_args:
#if x != 'G' and x != 'device':
hashname = str(hashlib.sha1((json.dumps(params)).encode('utf-16be')).hexdigest() )
print('run hash', hashname)
ws = None
if w is not None:
print ('loading w from file', w, 'ignoring seed and psi')
ws = np.load(w)['w']
# take off
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
# approach
projected_w_steps = approach(
initial_learning_rate = lr,
psi = psi,
seed = seed,
initial_noise_factor = inf,
noise_floor = nf,
text = text,
ws = ws,
noise_opt = noise_opt
# save video
os.makedirs(outdir, exist_ok=True)
if save_video:
video = imageio.get_writer(f'{outdir}/out-{hashname}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
print (f'Saving optimization progress video "{outdir}/out-{hashname}.mp4"')
for projected_w in projected_w_steps:
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
video.append_data(np.concatenate([synth_image], axis=1))
# save ws
if save_ws:
print ('Saving optimization progress ws')
step = 0
for projected_w in projected_w_steps:
np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
# save the result and the final w
print ('Saving finals')
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/out-{hashname}.png')
np.savez(f'{outdir}/w-{hashname}-final.npz', w=projected_w.unsqueeze(0).cpu().numpy())
# save params
with open(f'{outdir}/params-{hashname}.txt', 'w') as outfile:
json.dump(params, outfile)
if __name__ == "__main__":
#!/usr/bin/env python
# coding: utf-8
import os
import sys
import copy
import pickle
import numpy as np
from PIL import Image
import torch
from configs import paths_config, hyperparameters, global_config
from utils.align_data import pre_process_images
from scripts.run_pti import run_PTI
# from IPython.display import display
from imgcat import imgcat
import matplotlib.pyplot as plt
from scripts.latent_editor_wrapper import LatentEditorWrapper
current_directory = os.getcwd()
save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
os.makedirs(save_path, exist_ok=True)
image_dir_name = 'images'
## If set to true download desired image from given url. If set to False, assumes you have uploaded personal image to
## 'image_original' dir
use_image_online = True
image_name = '1' # put .jpg file in images_original NOT image_original folder
use_multi_id_training = False
global_config.device = 'cuda'
paths_config.e4e = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed'
# paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/AlfredENeuman24_ADA-torch.pkl'
paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/ffhq.pkl'
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI/'
paths_config.style_clip_pretrained_mappers = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models'
hyperparameters.use_locality_regularization = False
os.makedirs(f'./{image_dir_name}_original', exist_ok=True)
os.makedirs(f'./{image_dir_name}_processed', exist_ok=True)
original_image ='{image_name}.jpg')
aligned_image ='/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed/{image_name}.jpeg')
# ## Step 5 - Invert images using PTI
# In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'.
# In case use_multi_id_training is set to True and many images are inverted simultaneously
# activating the regularization to keep the *W* Space intact is recommended.
# If indeed the regularization is activated then please increase the number of pti steps from 350 to 450 at least
# using hyperparameters.max_pti_steps
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
# ## Visualize results
def display_alongside_source_image(images):
res = np.concatenate([np.array(image) for image in images], axis=1)
return Image.fromarray(res)
def load_generators(model_id, image_name):
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda() ## tensor
old_D = d['D'].eval().requires_grad_(False).cpu()
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda()
return old_G, new_G
def export_updated_pickle(new_G,model_id):
print("Exporting large updated pickle based off new generator and ffhq.pkl")
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda() ## tensor
old_D = d['D'].eval().requires_grad_(False).cpu()
tmp = {}
tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['D'] = old_D
tmp['training_set_kwargs'] = None
tmp['augment_pipe'] = None
with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f:
pickle.dump(tmp, f)
generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name
old_G, new_G = load_generators(model_id, generator_type)
def plot_syn_images(syn_images):
for img in syn_images:
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
del img
del resized_image
# If multi_id_training was used for several images.
# You can alter the w_pivot index which is currently configured to 0, and then running
# the visualization code again. Using the same generator on different latent codes.
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
w_pivot = torch.load(f'{embedding_dir}/')
old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
print('Upper image is the inversion before Pivotal Tuning and the lower image is the product of pivotal tuning')
plot_syn_images([old_image, new_image])
# ## InterfaceGAN edits
latent_editor = LatentEditorWrapper()
latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [-2, 2])
# In order to get different edits. Such as younger face or make the face smile more.
# Please change the factors passed to "get_single_interface_gan_edits".
# Currently the factors are [-2,2]. You can pass for example: range(-3,3)
for direction, factor_and_edit in latents_after_edit.items():
print(f'Showing {direction} change')
for latent in factor_and_edit.values():
old_image = old_G.synthesis(latent, noise_mode='const', force_fp32 = True)
new_image = new_G.synthesis(latent, noise_mode='const', force_fp32 = True)
plot_syn_images([old_image, new_image])
# ## StyleCLIP editing
# ### Download pretrained models
# mappers_base_dir = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models'
# More pretrained mappers can be found at: ""
# Download Afro mapper
# downloader.download_file("1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ", os.path.join(mappers_base_dir, ''))
# Download Mohawk mapper
# downloader.download_file("1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09", os.path.join(mappers_base_dir, ''))
# Download e4e encoder, used for the first inversion step instead on the W inversion.
# downloader.download_file("1cUv_reLE6k3604or78EranS7XzuVMWeO", os.path.join(mappers_base_dir, ''))
# ### Use PTI with e4e backbone for StyleCLIP
# Changing first_inv_type to W+ makes the PTI use e4e encoder instead of W inversion in the first step
hyperparameters.first_inv_type = 'w+'
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
# ### Apply edit
from scripts.pti_styleclip import styleclip_edit
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI'
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['afro'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bobcut'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bowlcut'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['mohawk'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['depp'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['purple_hair'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['surprised'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['talor_swift'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['trump'], use_wandb=False)
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_afro.jpg'
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_afro.jpg'
original_styleCLIP =,256))
new_styleCLIP =,256))
display_alongside_source_image([original_styleCLIP, new_styleCLIP])
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_mohawk.jpg'
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_mohawk.jpg'
original_styleCLIP =,256))
new_styleCLIP =,256))
display_alongside_source_image([original_styleCLIP, new_styleCLIP])
# ## Other methods comparison
# ### Invert image using other methods
from scripts.latent_creators import e4e_latent_creator
from scripts.latent_creators import sg2_latent_creator
from scripts.latent_creators import sg2_plus_latent_creator
# e4e_latent_creator = e4e_latent_creator.E4ELatentCreator()
# e4e_latent_creator.create_latents()
sg2_latent_creator = sg2_latent_creator.SG2LatentCreator(projection_steps = 600)
sg2_plus_latent_creator = sg2_plus_latent_creator.SG2PlusLatentCreator(projection_steps = 1200)
inversions = {}
sg2_embedding_dir = f'{w_path_dir}/{paths_config.sg2_results_keyword}/{image_name}'
inversions[paths_config.sg2_results_keyword] = torch.load(f'{sg2_embedding_dir}/')
e4e_embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}'
inversions[paths_config.e4e_results_keyword] = torch.load(f'{e4e_embedding_dir}/')
sg2_plus_embedding_dir = f'{w_path_dir}/{paths_config.sg2_plus_results_keyword}/{image_name}'
inversions[paths_config.sg2_plus_results_keyword] = torch.load(f'{sg2_plus_embedding_dir}/')
def get_image_from_w(w, G):
if len(w.size()) <= 2:
w = w.unsqueeze(0)
img = G.synthesis(w, noise_mode='const', force_fp32=True)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
return img[0]
def plot_image_from_w(w, G):
img = get_image_from_w(w, G)
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
for inv_type, latent in inversions.items():
print(f'Displaying {inv_type} inversion')
plot_image_from_w(latent, old_G)
print(f'Displaying PTI inversion')
plot_image_from_w(w_pivot, new_G)
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy())
