-
-
Save johndpope/d1e84bf42e43d266b9a640a2fc958e2d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import os | |
import simplejson as json | |
import click | |
import imageio | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torchvision | |
import torch.nn.functional as F | |
import dnnlib | |
import legacy | |
import clip | |
import hashlib | |
def approach( | |
G, | |
*, | |
num_steps = 100, | |
w_avg_samples = 10000, | |
initial_learning_rate = 0.02, | |
initial_noise_factor = 0.02, | |
noise_floor = 0.02, | |
psi = 0.8, | |
noise_ramp_length = 1.0, # was 0.75 | |
regularize_noise_weight = 10000, # was 1e5 | |
seed = 69097, | |
noise_opt = True, | |
ws = None, | |
text = 'a computer generated image', | |
device: torch.device | |
): | |
''' | |
local_args = dict(locals()) | |
params = [] | |
for x in local_args: | |
if x != 'G' and x != 'device': | |
print(x,':',local_args[x]) | |
params.append({x:local_args[x]}) | |
print(json.dumps(params)) | |
''' | |
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) | |
lr = initial_learning_rate | |
''' | |
# Compute w stats. | |
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') | |
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) | |
#w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] | |
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None, truncation_psi=0.8) # [N, L, C] | |
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] | |
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] | |
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 | |
''' | |
# derive W from seed | |
if ws is None: | |
print('Generating w for seed %i' % seed ) | |
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) | |
w_samples = G.mapping(z, None, truncation_psi=psi) | |
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) | |
w_avg = np.mean(w_samples, axis=0, keepdims=True) | |
else: | |
w_samples = torch.tensor(ws, device=device) | |
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) | |
w_avg = np.mean(w_samples, axis=0, keepdims=True) | |
#w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 | |
w_std = 2 # ~9.9 for portraits network. should compute if using median median | |
# Setup noise inputs. | |
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } | |
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable | |
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) | |
if noise_opt: | |
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) | |
print('optimizer: w + noise') | |
else: | |
optimizer = torch.optim.Adam([w_opt] , betas=(0.9, 0.999), lr=initial_learning_rate) | |
print('optimizer: w') | |
# Init noise. | |
for buf in noise_bufs.values(): | |
buf[:] = torch.randn_like(buf) | |
buf.requires_grad = True | |
# Load the perceptor | |
print('Loading perceptor for text:', text) | |
perceptor, preprocess = clip.load('ViT-B/32', jit=True) | |
perceptor = perceptor.eval() | |
tx = clip.tokenize(text) | |
whispers = perceptor.encode_text(tx.cuda()).detach().clone() | |
# Descend | |
for step in range(num_steps): | |
# noise schedule | |
t = step / num_steps | |
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 | |
# floor | |
if w_noise_scale < noise_floor: | |
w_noise_scale = noise_floor | |
# lr schedule is disabled | |
''' | |
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) | |
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) | |
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) | |
lr = initial_learning_rate * lr_ramp | |
''' | |
''' for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
''' | |
# do G.synthesis | |
w_noise = torch.randn_like(w_opt) * w_noise_scale | |
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) | |
synth_images = G.synthesis(ws, noise_mode='const') | |
#save1 | |
''' | |
synth_images_save = (synth_images + 1) * (255/2) | |
synth_images_save = synth_images_save.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
PIL.Image.fromarray(synth_images_save, 'RGB').save('project/test1.png') | |
''' | |
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
into = synth_images | |
into = nom(into) # normalize copied from CLIP preprocess. doesn't seem to affect tho | |
# scale to CLIP input size | |
into = torch.nn.functional.interpolate(synth_images, (224,224), mode='bilinear', align_corners=True) | |
# CLIP expects [1, 3, 224, 224], so we should be fine | |
glimmers = perceptor.encode_image(into) | |
away = -30 * torch.cosine_similarity(whispers, glimmers, dim = -1).mean() # Dunno why 30 works lol | |
# noise reg, from og projector | |
reg_loss = 0.0 | |
for v in noise_bufs.values(): | |
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() | |
while True: | |
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 | |
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 | |
if noise.shape[2] <= 8: | |
break | |
noise = F.avg_pool2d(noise, kernel_size=2) | |
if noise_opt: | |
loss = away + reg_loss * regularize_noise_weight | |
else: | |
loss = away | |
# Step | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
print(f'step {step+1:>4d}/{num_steps}: loss {float(loss):<5.2f} ','lr', lr, f'noise scale: {float(w_noise_scale):<5.6f}',f'away: {float(away / (-30)):<5.6f}') | |
w_out[step] = w_opt.detach()[0] | |
# Normalize noise. | |
with torch.no_grad(): | |
for buf in noise_bufs.values(): | |
buf -= buf.mean() | |
buf *= buf.square().mean().rsqrt() | |
return w_out.repeat([1, G.mapping.num_ws, 1]) | |
#---------------------------------------------------------------------------- | |
@click.command() | |
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) | |
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') | |
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) | |
@click.option('--seed', help='Initial image seed', type=int, default=232322, show_default=True) | |
@click.option('--w', help='Do not use seed but load w from a file', type=str, metavar='FILE') | |
@click.option('--lr', help='Adam learning rate', type=float, required=False, default=0.02) | |
@click.option('--psi', help='Truncation psi for initial image', type=float, required=False, default=0.81) | |
@click.option('--inf', help='Initial noise factor', type=float, required=False, default=0.02) | |
@click.option('--nf', help='Noise floor', type=float, required=False, default=0.02) | |
@click.option('--noise-opt', help='Optimize noise vars as well as w', type=bool, required=False, default=True) | |
@click.option('--text', help='Text prompt', required=False, default='A computer-generated image') | |
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True) | |
@click.option('--save-ws', help='Save intermediate ws', type=bool, default=False, show_default=True) | |
def run_approach( | |
network_pkl: str, | |
outdir: str, | |
save_video: bool, | |
save_ws: bool, | |
seed: int, | |
num_steps: int, | |
text: str, | |
lr: float, | |
inf: float, | |
nf: float, | |
w: str, | |
psi: float, | |
noise_opt: bool | |
): | |
"""Descend on StyleGAN2 w vector value using CLIP, tuning an image with given text prompt. | |
Example: | |
\b | |
python3 approach.py --network network-snapshot-ffhq.pkl --outdir project --num-steps 100 \\ | |
--text 'an image of a girl with a face resembling Paul Krugman' --psi 0.8 --seed 12345 | |
""" | |
#seed = 1 | |
np.random.seed(1) | |
torch.manual_seed(1) | |
local_args = dict(locals()) | |
params = [] | |
for x in local_args: | |
#if x != 'G' and x != 'device': | |
#print(x,':',local_args[x]) | |
params.append({x:local_args[x]}) | |
#print(json.dumps(params)) | |
hashname = str(hashlib.sha1((json.dumps(params)).encode('utf-16be')).hexdigest() ) | |
print('run hash', hashname) | |
ws = None | |
if w is not None: | |
print ('loading w from file', w, 'ignoring seed and psi') | |
ws = np.load(w)['w'] | |
# take off | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda') | |
with dnnlib.util.open_url(network_pkl) as fp: | |
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore | |
# approach | |
projected_w_steps = approach( | |
G, | |
num_steps=num_steps, | |
device=device, | |
initial_learning_rate = lr, | |
psi = psi, | |
seed = seed, | |
initial_noise_factor = inf, | |
noise_floor = nf, | |
text = text, | |
ws = ws, | |
noise_opt = noise_opt | |
) | |
# save video | |
os.makedirs(outdir, exist_ok=True) | |
if save_video: | |
video = imageio.get_writer(f'{outdir}/out-{hashname}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') | |
print (f'Saving optimization progress video "{outdir}/out-{hashname}.mp4"') | |
for projected_w in projected_w_steps: | |
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
synth_image = (synth_image + 1) * (255/2) | |
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
video.append_data(np.concatenate([synth_image], axis=1)) | |
video.close() | |
# save ws | |
if save_ws: | |
print ('Saving optimization progress ws') | |
step = 0 | |
for projected_w in projected_w_steps: | |
np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy()) | |
step+=1 | |
# save the result and the final w | |
print ('Saving finals') | |
projected_w = projected_w_steps[-1] | |
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
synth_image = (synth_image + 1) * (255/2) | |
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/out-{hashname}.png') | |
np.savez(f'{outdir}/w-{hashname}-final.npz', w=projected_w.unsqueeze(0).cpu().numpy()) | |
# save params | |
with open(f'{outdir}/params-{hashname}.txt', 'w') as outfile: | |
json.dump(params, outfile) | |
if __name__ == "__main__": | |
run_approach() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import os | |
os.chdir('/home/jp/Documents/gitWorkspace') | |
CODE_DIR = 'PTI' | |
os.chdir(f'./{CODE_DIR}') | |
import sys | |
import copy | |
import pickle | |
import numpy as np | |
from PIL import Image | |
import torch | |
from configs import paths_config, hyperparameters, global_config | |
from utils.align_data import pre_process_images | |
from scripts.run_pti import run_PTI | |
# from IPython.display import display | |
from imgcat import imgcat | |
import matplotlib.pyplot as plt | |
from scripts.latent_editor_wrapper import LatentEditorWrapper | |
current_directory = os.getcwd() | |
save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models") | |
os.makedirs(save_path, exist_ok=True) | |
image_dir_name = 'images' | |
## If set to true download desired image from given url. If set to False, assumes you have uploaded personal image to | |
## 'image_original' dir | |
use_image_online = True | |
image_name = '1' # put .jpg file in images_original NOT image_original folder | |
use_multi_id_training = False | |
global_config.device = 'cuda' | |
paths_config.e4e = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/e4e_ffhq_encode.pt' | |
paths_config.input_data_id = image_dir_name | |
paths_config.input_data_path = f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed' | |
# paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/AlfredENeuman24_ADA-torch.pkl' | |
paths_config.stylegan2_ada_ffhq = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models/ffhq.pkl' | |
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI/' | |
paths_config.style_clip_pretrained_mappers = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models' | |
hyperparameters.use_locality_regularization = False | |
os.makedirs(f'./{image_dir_name}_original', exist_ok=True) | |
os.makedirs(f'./{image_dir_name}_processed', exist_ok=True) | |
os.chdir(f'./{image_dir_name}_original') | |
original_image = Image.open(f'{image_name}.jpg') | |
os.chdir('/home/jp/Documents/gitWorkspace/PTI') | |
pre_process_images(f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_original') | |
aligned_image = Image.open(f'/home/jp/Documents/gitWorkspace/PTI/{image_dir_name}_processed/{image_name}.jpeg') | |
aligned_image.resize((512,512)) | |
# ## Step 5 - Invert images using PTI | |
# In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'. | |
# | |
# In case use_multi_id_training is set to True and many images are inverted simultaneously | |
# activating the regularization to keep the *W* Space intact is recommended. | |
# | |
# If indeed the regularization is activated then please increase the number of pti steps from 350 to 450 at least | |
# using hyperparameters.max_pti_steps | |
os.chdir('/home/jp/Documents/gitWorkspace/PTI') | |
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training) | |
# ## Visualize results | |
def display_alongside_source_image(images): | |
res = np.concatenate([np.array(image) for image in images], axis=1) | |
return Image.fromarray(res) | |
def load_generators(model_id, image_name): | |
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: | |
d = pickle.load(f) | |
old_G = d['G_ema'].cuda() ## tensor | |
old_D = d['D'].eval().requires_grad_(False).cpu() | |
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new: | |
new_G = torch.load(f_new).cuda() | |
return old_G, new_G | |
def export_updated_pickle(new_G,model_id): | |
print("Exporting large updated pickle based off new generator and ffhq.pkl") | |
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: | |
d = pickle.load(f) | |
old_G = d['G_ema'].cuda() ## tensor | |
old_D = d['D'].eval().requires_grad_(False).cpu() | |
tmp = {} | |
tmp['G_ema'] = old_G.eval().requires_grad_(False).cpu()# copy.deepcopy(new_G).eval().requires_grad_(False).cpu() | |
tmp['G'] = new_G.eval().requires_grad_(False).cpu() # copy.deepcopy(new_G).eval().requires_grad_(False).cpu() | |
tmp['D'] = old_D | |
tmp['training_set_kwargs'] = None | |
tmp['augment_pipe'] = None | |
with open(f'{paths_config.checkpoints_dir}/model_{model_id}.pkl', 'wb') as f: | |
pickle.dump(tmp, f) | |
generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name | |
old_G, new_G = load_generators(model_id, generator_type) | |
def plot_syn_images(syn_images): | |
for img in syn_images: | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0] | |
plt.axis('off') | |
resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) | |
imgcat(resized_image) | |
del img | |
del resized_image | |
torch.cuda.empty_cache() | |
# If multi_id_training was used for several images. | |
# You can alter the w_pivot index which is currently configured to 0, and then running | |
# the visualization code again. Using the same generator on different latent codes. | |
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' | |
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' | |
w_pivot = torch.load(f'{embedding_dir}/0.pt') | |
old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) | |
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) | |
print("new_G:",new_G) | |
print('Upper image is the inversion before Pivotal Tuning and the lower image is the product of pivotal tuning') | |
plot_syn_images([old_image, new_image]) | |
# ## InterfaceGAN edits | |
latent_editor = LatentEditorWrapper() | |
latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [-2, 2]) | |
# In order to get different edits. Such as younger face or make the face smile more. | |
# Please change the factors passed to "get_single_interface_gan_edits". | |
# Currently the factors are [-2,2]. You can pass for example: range(-3,3) | |
for direction, factor_and_edit in latents_after_edit.items(): | |
print(f'Showing {direction} change') | |
for latent in factor_and_edit.values(): | |
old_image = old_G.synthesis(latent, noise_mode='const', force_fp32 = True) | |
new_image = new_G.synthesis(latent, noise_mode='const', force_fp32 = True) | |
plot_syn_images([old_image, new_image]) | |
# ## StyleCLIP editing | |
# ### Download pretrained models | |
# mappers_base_dir = '/home/jp/Documents/gitWorkspace/PTI/pretrained_models' | |
# More pretrained mappers can be found at: "https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py" | |
# Download Afro mapper | |
# downloader.download_file("1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ", os.path.join(mappers_base_dir, 'afro.pt')) | |
# Download Mohawk mapper | |
# downloader.download_file("1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09", os.path.join(mappers_base_dir, 'mohawk.pt')) | |
# Download e4e encoder, used for the first inversion step instead on the W inversion. | |
# downloader.download_file("1cUv_reLE6k3604or78EranS7XzuVMWeO", os.path.join(mappers_base_dir, 'e4e_ffhq_encode.pt')) | |
# ### Use PTI with e4e backbone for StyleCLIP | |
# Changing first_inv_type to W+ makes the PTI use e4e encoder instead of W inversion in the first step | |
hyperparameters.first_inv_type = 'w+' | |
os.chdir('/home/jp/Documents/gitWorkspace/PTI') | |
model_id = run_PTI(use_wandb=False, use_multi_id_training=use_multi_id_training) | |
# ### Apply edit | |
from scripts.pti_styleclip import styleclip_edit | |
paths_config.checkpoints_dir = '/home/jp/Documents/gitWorkspace/PTI' | |
os.chdir('/home/jp/Documents/gitWorkspace/PTI') | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['afro'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bobcut'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['bowlcut'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['mohawk'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['angry'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['depp'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['purple_hair'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['surprised'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['talor_swift'], use_wandb=False) | |
styleclip_edit(use_multi_id_G=use_multi_id_training, run_id=model_id, edit_types = ['trump'], use_wandb=False) | |
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_afro.jpg' | |
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_afro.jpg' | |
original_styleCLIP = Image.open(original_styleCLIP_path).resize((256,256)) | |
new_styleCLIP = Image.open(new_styleCLIP_path).resize((256,256)) | |
display_alongside_source_image([original_styleCLIP, new_styleCLIP]) | |
original_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/e4e/{image_name}_mohawk.jpg' | |
new_styleCLIP_path = f'/home/jp/Documents/gitWorkspace/PTI/StyleCLIP_results/{image_dir_name}/{image_name}/PTI/{image_name}_mohawk.jpg' | |
original_styleCLIP = Image.open(original_styleCLIP_path).resize((256,256)) | |
new_styleCLIP = Image.open(new_styleCLIP_path).resize((256,256)) | |
display_alongside_source_image([original_styleCLIP, new_styleCLIP]) | |
# ## Other methods comparison | |
# ### Invert image using other methods | |
from scripts.latent_creators import e4e_latent_creator | |
from scripts.latent_creators import sg2_latent_creator | |
from scripts.latent_creators import sg2_plus_latent_creator | |
# e4e_latent_creator = e4e_latent_creator.E4ELatentCreator() | |
# e4e_latent_creator.create_latents() | |
print("INFO:sg2_latent_creator") | |
sg2_latent_creator = sg2_latent_creator.SG2LatentCreator(projection_steps = 600) | |
sg2_latent_creator.create_latents() | |
print("INFO:sg2_plus_latent_creator") | |
sg2_plus_latent_creator = sg2_plus_latent_creator.SG2PlusLatentCreator(projection_steps = 1200) | |
sg2_plus_latent_creator.create_latents() | |
inversions = {} | |
sg2_embedding_dir = f'{w_path_dir}/{paths_config.sg2_results_keyword}/{image_name}' | |
inversions[paths_config.sg2_results_keyword] = torch.load(f'{sg2_embedding_dir}/0.pt') | |
e4e_embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}' | |
inversions[paths_config.e4e_results_keyword] = torch.load(f'{e4e_embedding_dir}/0.pt') | |
sg2_plus_embedding_dir = f'{w_path_dir}/{paths_config.sg2_plus_results_keyword}/{image_name}' | |
inversions[paths_config.sg2_plus_results_keyword] = torch.load(f'{sg2_plus_embedding_dir}/0.pt') | |
def get_image_from_w(w, G): | |
if len(w.size()) <= 2: | |
w = w.unsqueeze(0) | |
img = G.synthesis(w, noise_mode='const', force_fp32=True) | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() | |
return img[0] | |
def plot_image_from_w(w, G): | |
img = get_image_from_w(w, G) | |
plt.axis('off') | |
resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) | |
imgcat(resized_image) | |
for inv_type, latent in inversions.items(): | |
print(f'Displaying {inv_type} inversion') | |
plot_image_from_w(latent, old_G) | |
print(f'Displaying PTI inversion') | |
plot_image_from_w(w_pivot, new_G) | |
np.savez(f'projected_w.npz', w=w_pivot.cpu().detach().numpy()) | |
export_updated_pickle(new_G,model_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment