-
-
Save ProGamerGov/684c0953395e66db6ac5fe09d6723a5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
import torchvision.transforms as transforms | |
def chol(t, Ct, Cs): | |
chol_t = torch.cholesky(Ct) | |
chol_s = torch.cholesky(Cs) | |
ts = torch.mm(torch.mm(chol_s, torch.inverse(chol_t)), t) | |
return ts | |
def sym(t, Ct, Cs): | |
Qt = pca(t, Ct) | |
Qt_Cs_Qt = torch.mm(torch.mm(Qt, Cs), Qt) | |
eva_QtCsQt, eve_QtCsQt = torch.symeig(Qt_Cs_Qt, eigenvectors=True, upper=True) | |
Et_QtCsQt = torch.sqrt(torch.diagflat(eva_QtCsQt)) | |
Et_QtCsQt[Et_QtCsQt != Et_QtCsQt] = 0 # Convert nan to 0 | |
QtCsQt = torch.mm(torch.mm(eve_QtCsQt, Et_QtCsQt), eve_QtCsQt.T) | |
ts = torch.mm(torch.mm(torch.mm(torch.inverse(Qt), QtCsQt), torch.inverse(Qt)), t) | |
return ts | |
def pca(t, Ct): | |
eva_t, eve_t = torch.symeig(Ct, eigenvectors=True, upper=True) | |
Et = torch.sqrt(torch.diagflat(eva_t)) | |
Et[Et != Et] = 0 # Convert nan to 0 | |
Qt = torch.mm(torch.mm(eve_t, Et), eve_t.T) | |
return Qt | |
def getHistogram(tensor, eps): | |
mu_h = tensor.mean(0).mean(0) | |
h = tensor - mu_h | |
h = h.permute(2,0,1).reshape(tensor.size(2),-1) | |
Ch = torch.mm(h, h.T) / h.shape[1] + eps * torch.eye(h.shape[0]) | |
return mu_h, h, Ch | |
def matchHistogram(target_tensor, source_tensor, eps=1e-5): | |
if target_tensor.dim() == 4: | |
target_tensor = target_tensor.squeeze(0) | |
if source_tensor.dim() == 4: | |
source_tensor = source_tensor.squeeze(0) | |
target_tensor = target_tensor.permute(2, 1, 0) # Function expects w,h,c | |
source_tensor = source_tensor.permute(2, 1, 0) # Function expects w,h,c | |
_, t, Ct = getHistogram(target_tensor, eps) | |
mu_s, s, Cs = getHistogram(source_tensor, eps) | |
if mode == 'pca': | |
Qt = pca(t, Ct) | |
Qs = pca(s, Cs) | |
ts = torch.mm(torch.mm(Qs, torch.inverse(Qt)), t) | |
elif mode == 'sym': | |
ts = sym(t, Ct, Cs) | |
elif mode == 'chol': | |
ts = chol(t, Ct, Cs) | |
matched_tensor = ts.reshape(*target_tensor.permute(2,0,1).shape).permute(1,2,0) | |
matched_tensor += mu_s | |
matched_tensor = matched_tensor.permute(2, 1, 0) | |
return matched_tensor | |
# Preprocess an image before passing it to a model. | |
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, | |
# and subtract the mean pixel. | |
def preprocess(image_name, image_size): | |
image = Image.open(image_name).convert('RGB') | |
if type(image_size) is not tuple: | |
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) | |
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) | |
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) | |
return tensor | |
# Undo the above preprocessing. | |
def deprocess(output_tensor): | |
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) | |
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 | |
output_tensor.clamp_(0, 1) | |
Image2PIL = transforms.ToPILImage() | |
image = Image2PIL(output_tensor.cpu()) | |
return image | |
img1 = preprocess('out1.png', 512).squeeze(0) | |
img2 = preprocess('portrait_1.jpg', (img1.size(1),img1.size(2))).squeeze(0) | |
new_img = matchHistogram(img1, img2, 1e-5) | |
matched_img = deprocess(new_img) | |
matched_img.save('matched_img.png') | |
# np.transpose was replaced with torch.permute | |
# np.dot was replaced with torch.mm | |
# np.linalg.eigh was replaced with torch.symeig | |
# np.linalg.inv was replaced with torch.inverse | |
# np.diag is torch.diagflat | |
# torch.permute is like torch.transpose only it supports multiple dimensions. | |
# torch.mm replaces torch.dot because it supports multiple dimensions | |
# np.linalg.eigh is like torch.symeig #FIX#: https://stackoverflow.com/questions/58856160/why-do-tensorflow-and-pytorch-gradients-of-the-eigenvalue-decomposition-differ-f | |
# np.linalg.inv is the same as torch.inverse? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment