Last active
April 16, 2020 06:26
-
-
Save mowolf/fa1b0c327ddd3add758bac5d990a4638 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kornia | |
import numpy as np | |
import torch.nn.functional as F | |
def get_M(): | |
''' | |
Provide mock trafo Matrix and original image sizes | |
:return: w_original, h_original, M | |
''' | |
h, w = fake.shape[2:] | |
desiredLeftEye = [0.35, 0.35] | |
rotation_point = [500, 500] | |
angle = -torch.ones(1) | |
# get original positions | |
m1 = round(w * 0.5) | |
m2 = round(desiredLeftEye[0] * w) | |
# set random scale and transformation target width | |
scale = torch.tensor([1.1]) | |
width = 220 | |
long_edge_size = width / abs(np.cos(np.deg2rad(angle))) | |
# sizes of tensor before transformation | |
w_original = int(scale * long_edge_size) | |
h_original = int(scale * long_edge_size) | |
# get offset | |
tX = w_original * 0.5 | |
tY = h_original * desiredLeftEye[1] | |
# get rotation center | |
center = torch.ones(1, 2) | |
center[..., 0] = m1 | |
center[..., 1] = m2 | |
# compute the transformation matrix | |
M = kornia.get_rotation_matrix2d(center, angle, scale) | |
# update transformation matrix with offset | |
M[0, 0, 2] += (tX - m1) | |
M[0, 1, 2] += (tY - m2) | |
return w_original, h_original, M | |
if __name__ == '__main__': | |
tensor = torch.ones(1, 3, 1000, 1000, requires_grad=False) | |
# the input to the network is a part of the tensor which was transformed | |
# e.g. input = kornia.warp_affine(tensor[:, :, 500:756, 500:756], M, dsize=(h_original, w_original)) | |
input = torch.ones(1, 3, 256, 256, requires_grad=True) | |
# get mock weight | |
weight = torch.nn.Parameter(torch.FloatTensor([1.234]), requires_grad=True) | |
# optimizer update step | |
update = 0 | |
# this simulates 10 iterations if the netwrok | |
for i in range(10): | |
# set weight | |
w1 = weight + update | |
# simulated forward pass | |
# fake is in my case a not perfect reconstruction of the input | |
fake = w1 * input | |
# Trafo Params, with M is the inverse of M used to generate it | |
w_original, h_original, M = get_M() | |
# apply warp | |
fake_warped = kornia.warp_affine(fake, M, dsize=(h_original, w_original)) | |
# here get the original part of the image that corresponds to the position of fake_warped | |
# range is just an example | |
fake_unaligned = F.interpolate(fake_warped, size=256) | |
real_unaligned = F.interpolate(tensor[:, :, 500:756, 500:756], size=256) | |
# Loss | |
loss = torch.nn.MSELoss() | |
loss_val = loss(fake_unaligned, real_unaligned) | |
print("Loss: " + str(loss_val)) | |
fake_unaligned.register_hook(lambda grad: print("fake_unaligned " + str(grad.nonzero()))) | |
fake_warped.register_hook(lambda grad: print("fake_warped " + str(grad.nonzero()))) | |
# backward pass | |
loss_val.backward() | |
# basically an optimizer for this one weight | |
gradient, *_ = weight.grad | |
print(f"Gradient of weight w.r.t to L: {gradient}") | |
update = -gradient |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment