Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active December 14, 2019 15:47
Show Gist options
  • Save ProGamerGov/684c0953395e66db6ac5fe09d6723a5b to your computer and use it in GitHub Desktop.
Save ProGamerGov/684c0953395e66db6ac5fe09d6723a5b to your computer and use it in GitHub Desktop.
import torch
from PIL import Image
import torchvision.transforms as transforms
def chol(t, Ct, Cs):
chol_t = torch.cholesky(Ct)
chol_s = torch.cholesky(Cs)
ts = torch.mm(torch.mm(chol_s, torch.inverse(chol_t)), t)
return ts
def sym(t, Ct, Cs):
Qt = pca(t, Ct)
Qt_Cs_Qt = torch.mm(torch.mm(Qt, Cs), Qt)
eva_QtCsQt, eve_QtCsQt = torch.symeig(Qt_Cs_Qt, eigenvectors=True, upper=True)
Et_QtCsQt = torch.sqrt(torch.diagflat(eva_QtCsQt))
Et_QtCsQt[Et_QtCsQt != Et_QtCsQt] = 0 # Convert nan to 0
QtCsQt = torch.mm(torch.mm(eve_QtCsQt, Et_QtCsQt), eve_QtCsQt.T)
ts = torch.mm(torch.mm(torch.mm(torch.inverse(Qt), QtCsQt), torch.inverse(Qt)), t)
return ts
def pca(t, Ct):
eva_t, eve_t = torch.symeig(Ct, eigenvectors=True, upper=True)
Et = torch.sqrt(torch.diagflat(eva_t))
Et[Et != Et] = 0 # Convert nan to 0
Qt = torch.mm(torch.mm(eve_t, Et), eve_t.T)
return Qt
def getHistogram(tensor, eps):
mu_h = tensor.mean(0).mean(0)
h = tensor - mu_h
h = h.permute(2,0,1).reshape(tensor.size(2),-1)
Ch = torch.mm(h, h.T) / h.shape[1] + eps * torch.eye(h.shape[0])
return mu_h, h, Ch
def matchHistogram(target_tensor, source_tensor, eps=1e-5):
if target_tensor.dim() == 4:
target_tensor = target_tensor.squeeze(0)
if source_tensor.dim() == 4:
source_tensor = source_tensor.squeeze(0)
target_tensor = target_tensor.permute(2, 1, 0) # Function expects w,h,c
source_tensor = source_tensor.permute(2, 1, 0) # Function expects w,h,c
_, t, Ct = getHistogram(target_tensor, eps)
mu_s, s, Cs = getHistogram(source_tensor, eps)
if mode == 'pca':
Qt = pca(t, Ct)
Qs = pca(s, Cs)
ts = torch.mm(torch.mm(Qs, torch.inverse(Qt)), t)
elif mode == 'sym':
ts = sym(t, Ct, Cs)
elif mode == 'chol':
ts = chol(t, Ct, Cs)
matched_tensor = ts.reshape(*target_tensor.permute(2,0,1).shape).permute(1,2,0)
matched_tensor += mu_s
matched_tensor = matched_tensor.permute(2, 1, 0)
return matched_tensor
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
img1 = preprocess('out1.png', 512).squeeze(0)
img2 = preprocess('portrait_1.jpg', (img1.size(1),img1.size(2))).squeeze(0)
new_img = matchHistogram(img1, img2, 1e-5)
matched_img = deprocess(new_img)
matched_img.save('matched_img.png')
# np.transpose was replaced with torch.permute
# np.dot was replaced with torch.mm
# np.linalg.eigh was replaced with torch.symeig
# np.linalg.inv was replaced with torch.inverse
# np.diag is torch.diagflat
# torch.permute is like torch.transpose only it supports multiple dimensions.
# torch.mm replaces torch.dot because it supports multiple dimensions
# np.linalg.eigh is like torch.symeig #FIX#: https://stackoverflow.com/questions/58856160/why-do-tensorflow-and-pytorch-gradients-of-the-eigenvalue-decomposition-differ-f
# np.linalg.inv is the same as torch.inverse?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment