Skip to content

Instantly share code, notes, and snippets.

@johndpope
Last active August 28, 2021 00:41
Show Gist options
  • Save johndpope/d1e84bf42e43d266b9a640a2fc958e2d to your computer and use it in GitHub Desktop.
Save johndpope/d1e84bf42e43d266b9a640a2fc958e2d to your computer and use it in GitHub Desktop.
import copy
import os
import simplejson as json
import click
import imageio
import numpy as np
import PIL.Image
import torch
import torchvision
import torch.nn.functional as F
import dnnlib
import legacy
import clip
import hashlib
def approach(
G,
*,
num_steps = 100,
w_avg_samples = 10000,
initial_learning_rate = 0.02,
initial_noise_factor = 0.02,
noise_floor = 0.02,
psi = 0.8,
noise_ramp_length = 1.0, # was 0.75
regularize_noise_weight = 10000, # was 1e5
seed = 69097,
noise_opt = True,
ws = None,
text = 'a computer generated image',
device: torch.device
):
'''
local_args = dict(locals())
params = []
for x in local_args:
if x != 'G' and x != 'device':
print(x,':',local_args[x])
params.append({x:local_args[x]})
print(json.dumps(params))
'''
G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
lr = initial_learning_rate
'''
# Compute w stats.
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
#w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None, truncation_psi=0.8) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
'''
# derive W from seed
if ws is None:
print('Generating w for seed %i' % seed )
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
w_samples = G.mapping(z, None, truncation_psi=psi)
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
w_avg = np.mean(w_samples, axis=0, keepdims=True)
else:
w_samples = torch.tensor(ws, device=device)
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
w_avg = np.mean(w_samples, axis=0, keepdims=True)
#w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
w_std = 2 # ~9.9 for portraits network. should compute if using median median
# Setup noise inputs.
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
if noise_opt:
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
print('optimizer: w + noise')
else:
optimizer = torch.optim.Adam([w_opt] , betas=(0.9, 0.999), lr=initial_learning_rate)
print('optimizer: w')
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
# Load the perceptor
print('Loading perceptor for text:', text)
perceptor, preprocess = clip.load('ViT-B/32', jit=True)
perceptor = perceptor.eval()
tx = clip.tokenize(text)
whispers = perceptor.encode_text(tx.cuda()).detach().clone()
# Descend
for step in range(num_steps):
# noise schedule
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
# floor
if w_noise_scale < noise_floor:
w_noise_scale = noise_floor
# lr schedule is disabled
'''
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
'''
''' for param_group in optimizer.param_groups:
param_group['lr'] = lr
'''
# do G.synthesis
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
synth_images = G.synthesis(ws, noise_mode='const')
#save1
'''
synth_images_save = (synth_images + 1) * (255/2)
synth_images_save = synth_images_save.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_images_save, 'RGB').save('project/test1.png')
'''
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
into = synth_images
into = nom(into) # normalize copied from CLIP preprocess. doesn't seem to affect tho
# scale to CLIP input size
into = torch.nn.functional.interpolate(synth_images, (224,224), mode='bilinear', align_corners=True)
# CLIP expects [1, 3, 224, 224], so we should be fine
glimmers = perceptor.encode_image(into)
away = -30 * torch.cosine_similarity(whispers, glimmers, dim = -1).mean() # Dunno why 30 works lol
# noise reg, from og projector
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
if noise_opt:
loss = away + reg_loss * regularize_noise_weight
else:
loss = away
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print(f'step {step+1:>4d}/{num_steps}: loss {float(loss):<5.2f} ','lr', lr, f'noise scale: {float(w_noise_scale):<5.6f}',f'away: {float(away / (-30)):<5.6f}')
w_out[step] = w_opt.detach()[0]
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
return w_out.repeat([1, G.mapping.num_ws, 1])
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
@click.option('--seed', help='Initial image seed', type=int, default=232322, show_default=True)
@click.option('--w', help='Do not use seed but load w from a file', type=str, metavar='FILE')
@click.option('--lr', help='Adam learning rate', type=float, required=False, default=0.02)
@click.option('--psi', help='Truncation psi for initial image', type=float, required=False, default=0.81)
@click.option('--inf', help='Initial noise factor', type=float, required=False, default=0.02)
@click.option('--nf', help='Noise floor', type=float, required=False, default=0.02)
@click.option('--noise-opt', help='Optimize noise vars as well as w', type=bool, required=False, default=True)
@click.option('--text', help='Text prompt', required=False, default='A computer-generated image')
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
@click.option('--save-ws', help='Save intermediate ws', type=bool, default=False, show_default=True)
def run_approach(
network_pkl: str,
outdir: str,
save_video: bool,
save_ws: bool,
seed: int,
num_steps: int,
text: str,
lr: float,
inf: float,
nf: float,
w: str,
psi: float,
noise_opt: bool
):
"""Descend on StyleGAN2 w vector value using CLIP, tuning an image with given text prompt.
Example:
\b
python3 approach.py --network network-snapshot-ffhq.pkl --outdir project --num-steps 100 \\
--text 'an image of a girl with a face resembling Paul Krugman' --psi 0.8 --seed 12345
"""
#seed = 1
np.random.seed(1)
torch.manual_seed(1)
local_args = dict(locals())
params = []
for x in local_args:
#if x != 'G' and x != 'device':
#print(x,':',local_args[x])
params.append({x:local_args[x]})
#print(json.dumps(params))
hashname = str(hashlib.sha1((json.dumps(params)).encode('utf-16be')).hexdigest() )
print('run hash', hashname)
ws = None
if w is not None:
print ('loading w from file', w, 'ignoring seed and psi')
ws = np.load(w)['w']
# take off
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
# approach
projected_w_steps = approach(
G,
num_steps=num_steps,
device=device,
initial_learning_rate = lr,
psi = psi,
seed = seed,
initial_noise_factor = inf,
noise_floor = nf,
text = text,
ws = ws,
noise_opt = noise_opt
)
# save video
os.makedirs(outdir, exist_ok=True)
if save_video:
video = imageio.get_writer(f'{outdir}/out-{hashname}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
print (f'Saving optimization progress video "{outdir}/out-{hashname}.mp4"')
for projected_w in projected_w_steps:
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
video.append_data(np.concatenate([synth_image], axis=1))
video.close()
# save ws
if save_ws:
print ('Saving optimization progress ws')
step = 0
for projected_w in projected_w_steps:
np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
step+=1
# save the result and the final w
print ('Saving finals')
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/out-{hashname}.png')
np.savez(f'{outdir}/w-{hashname}-final.npz', w=projected_w.unsqueeze(0).cpu().numpy())
# save params
with open(f'{outdir}/params-{hashname}.txt', 'w') as outfile:
json.dump(params, outfile)
if __name__ == "__main__":
run_approach()
#!/usr/bin/env python
# coding: utf-8
import os
os.chdir('/home/jp/Documents/gitWorkspace')
CODE_DIR = 'PTI'
os.chdir(f'./{CODE_DIR}')
import sys
import copy
import pickle
import numpy as np
from PIL import Image
import torch
from configs import paths_config, hyperparameters, global_config
from utils.align_data import pre_process_images
from scripts.run_pti import run_PTI
# from IPython.display import display
from imgcat import imgcat
import matplotlib.pyplot as plt
from scripts.latent_editor_wrapper import LatentEditorWrapper
current_directory = os.getcwd()
save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
os.makedirs(save_path, exist_ok=True)
image_dir_name = 'images'
## If set to true download desired image from given url. If set to False, assumes you have uploaded personal image to
## 'image_original' dir
use_image_online = True
image_name = '1' # put .jpg file in images_original NOT image_original folder
use_multi_id_training = False
global_config.device = 'cuda'
paths_config.e4e = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/e4e_ffhq_encode.pt'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed'
# paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/AlfredENeuman24_ADA-torch.pkl'
paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/ffhq.pkl'
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI/'
paths_config.style_clip_pretrained_mappers = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models'
hyperparameters.use_locality_regularization = False
os.makedirs(f'./{image_dir_name}_original', exist_ok=True)
os.makedirs(f'./{image_dir_name}_processed', exist_ok=True)
os.chdir(f'./{image_dir_name}_original')
original_image = Image.open(f'{image_name}.jpg')
os.chdir('/home/jp/Documents/gitWorkspace/PTI')
pre_process_images(f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_original')
aligned_image = Image.open(f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed/{image_name}.jpeg')
aligned_image.resize((512,512))
# ## Step 5 - Invert images using PTI
# In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'.
#
# In case use_multi_id_training is set to True and many images are inverted simultaneously
# activating the regularization to keep the *W* Space intact is recommended.
#
# If indeed the regularization is activated then please increase the number of pti steps from 350 to 450 at least
# using hyperparameters.max_pti_steps
os.chdir('/home/jp/Documents/gitWorkspace/PTI')
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
# ## Visualize results
def display_alongside_source_image(images):
res = np.concatenate([np.array(image) for image in images], axis=1)
return Image.fromarray(res)
def load_generators(model_id, image_name):
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda() ## tensor
old_D = d['D'].eval().requires_grad_(False).cpu()
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda()
return old_G, new_G
def export_updated_pickle(new_G,model_id):
print("Exporting large updated pickle based off new generator and ffhq.pkl")
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda() ## tensor
old_D = d['D'].eval().requires_grad_(False).cpu()
tmp = {}
tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu()
tmp['D'] = old_D
tmp['training_set_kwargs'] = None
tmp['augment_pipe'] = None
with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f:
pickle.dump(tmp, f)
generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name
old_G, new_G = load_generators(model_id, generator_type)
def plot_syn_images(syn_images):
for img in syn_images:
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
plt.axis('off')
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
imgcat(resized_image)
del img
del resized_image
torch.cuda.empty_cache()
# If multi_id_training was used for several images.
# You can alter the w_pivot index which is currently configured to 0, and then running
# the visualization code again. Using the same generator on different latent codes.
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
w_pivot = torch.load(f'{embedding_dir}/0.pt')
old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
print("new_G:",new_G)
print('Upper image is the inversion before Pivotal Tuning and the lower image is the product of pivotal tuning')
plot_syn_images([old_image, new_image])
# ## InterfaceGAN edits
latent_editor = LatentEditorWrapper()
latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [-2, 2])
# In order to get different edits. Such as younger face or make the face smile more.
# Please change the factors passed to "get_single_interface_gan_edits".
# Currently the factors are [-2,2]. You can pass for example: range(-3,3)
for direction, factor_and_edit in latents_after_edit.items():
print(f'Showing {direction} change')
for latent in factor_and_edit.values():
old_image = old_G.synthesis(latent, noise_mode='const', force_fp32 = True)
new_image = new_G.synthesis(latent, noise_mode='const', force_fp32 = True)
plot_syn_images([old_image, new_image])
# ## StyleCLIP editing
# ### Download pretrained models
# mappers_base_dir = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models'
# More pretrained mappers can be found at: "https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py"
# Download Afro mapper
# downloader.download_file("1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ", os.path.join(mappers_base_dir, 'afro.pt'))
# Download Mohawk mapper
# downloader.download_file("1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09", os.path.join(mappers_base_dir, 'mohawk.pt'))
# Download e4e encoder, used for the first inversion step instead on the W inversion.
# downloader.download_file("1cUv_reLE6k3604or78EranS7XzuVMWeO", os.path.join(mappers_base_dir, 'e4e_ffhq_encode.pt'))
# ### Use PTI with e4e backbone for StyleCLIP
# Changing first_inv_type to W+ makes the PTI use e4e encoder instead of W inversion in the first step
hyperparameters.first_inv_type = 'w+'
os.chdir('/home/jp/Documents/gitWorkspace/PTI')
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training)
# ### Apply edit
from scripts.pti_styleclip import styleclip_edit
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI'
os.chdir('/home/jp/Documents/gitWorkspace/PTI')
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['afro'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bobcut'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bowlcut'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['mohawk'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['depp'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['purple_hair'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['surprised'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['talor_swift'], use_wandb=False)
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['trump'], use_wandb=False)
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_afro.jpg'
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_afro.jpg'
original_styleCLIP = Image.open(original_styleCLIP_path).resize((256,256))
new_styleCLIP = Image.open(new_styleCLIP_path).resize((256,256))
display_alongside_source_image([original_styleCLIP, new_styleCLIP])
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_mohawk.jpg'
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_mohawk.jpg'
original_styleCLIP = Image.open(original_styleCLIP_path).resize((256,256))
new_styleCLIP = Image.open(new_styleCLIP_path).resize((256,256))
display_alongside_source_image([original_styleCLIP, new_styleCLIP])
# ## Other methods comparison
# ### Invert image using other methods
from scripts.latent_creators import e4e_latent_creator
from scripts.latent_creators import sg2_latent_creator
from scripts.latent_creators import sg2_plus_latent_creator
# e4e_latent_creator = e4e_latent_creator.E4ELatentCreator()
# e4e_latent_creator.create_latents()
print("INFO:sg2_latent_creator")
sg2_latent_creator = sg2_latent_creator.SG2LatentCreator(projection_steps = 600)
sg2_latent_creator.create_latents()
print("INFO:sg2_plus_latent_creator")
sg2_plus_latent_creator = sg2_plus_latent_creator.SG2PlusLatentCreator(projection_steps = 1200)
sg2_plus_latent_creator.create_latents()
inversions = {}
sg2_embedding_dir = f'{w_path_dir}/{paths_config.sg2_results_keyword}/{image_name}'
inversions[paths_config.sg2_results_keyword] = torch.load(f'{sg2_embedding_dir}/0.pt')
e4e_embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}'
inversions[paths_config.e4e_results_keyword] = torch.load(f'{e4e_embedding_dir}/0.pt')
sg2_plus_embedding_dir = f'{w_path_dir}/{paths_config.sg2_plus_results_keyword}/{image_name}'
inversions[paths_config.sg2_plus_results_keyword] = torch.load(f'{sg2_plus_embedding_dir}/0.pt')
def get_image_from_w(w, G):
if len(w.size()) <= 2:
w = w.unsqueeze(0)
img = G.synthesis(w, noise_mode='const', force_fp32=True)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
return img[0]
def plot_image_from_w(w, G):
img = get_image_from_w(w, G)
plt.axis('off')
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
imgcat(resized_image)
for inv_type, latent in inversions.items():
print(f'Displaying {inv_type} inversion')
plot_image_from_w(latent, old_G)
print(f'Displaying PTI inversion')
plot_image_from_w(w_pivot, new_G)
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy())
export_updated_pickle(new_G,model_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment