Skip to content

Instantly share code, notes, and snippets.

@rish-16
Last active December 22, 2021 06:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rish-16/4ec691c0f41340f4dc44eb8f51c91bbc to your computer and use it in GitHub Desktop.
Save rish-16/4ec691c0f41340f4dc44eb8f51c91bbc to your computer and use it in GitHub Desktop.
Visualise selected patches from an image for comparison / sanity checks
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify # pip install patchify
def viz(image, selector, pw=10, ph=10):
'''
Params
- image: a PyTorch Tensor of shape (H, W, C)
- selector: the patch selector of choice (should return indices)
Returns
None | just plots the image and processed image side-by-side
'''
def pad(img):
# implement your padding function if image is not 4x4
raise NotImplementedError
# image = pad(img) # comment out if image is perfectly breakable into 4x4
patches = patchify(image.numpy(), [ph, pw, 3], step=10)
patches_new = torch.from_numpy(patches.reshape(21*16, 3, ph, pw))
N, C, h, w = patches_new.shape
# should return a tensor of K indices
patch_ids = selector(patches_new) # tweak if you have a custom way of getting indices
for i in range(len(patch_ids)):
cid = patch_ids[i].item()
patches_new[cid, :, :, :] = torch.ones(3, ph, pw)
reformed = patches_new.view(*patches.shape)
reformed_np = reformed.numpy()
reformed_img = unpatchify(reformed_np, image.shape) # reform image from patches
fig = plt.figure()
fig.add_subplot(121)
plt.imshow(image.reshape(210, 160, 3))
plt.title("Original")
fig.add_subplot(122)
plt.imshow(reformed_img)
plt.title("Patches")
plt.show()
def dummy_clip(patches):
N, dim = patches.shape # (16, 512)
out = torch.Tensor(N, 512)
layer = nn.Linear(dim, 512)
for i in range(N):
out[i] = layer(patches[i, :])
return out
def dummy_selector(patches):
idx = torch.randperm(16*21)[:10]
return idx
# viz(img_new, dummy_selector, dummy_clip)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment