Created
April 13, 2020 16:13
-
-
Save mowolf/858e01906402c1c8a64b33d488ae975f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this kills the gradient of fake_B_unaligned | |
# fake_B_unaligned should have the gradient of this alignement operation | |
# real_B_unaligned should have no gradient | |
def align_fake(self, margin=70): | |
# get params | |
desiredLeftEye = [float(self.alignment_params["desiredLeftEye"][0]), | |
float(self.alignment_params["desiredLeftEye"][1])] | |
rotation_point = self.alignment_params["eyesCenter"] | |
angle = -self.alignment_params["angle"] | |
h, w = self.fake_B.shape[2:] | |
# get original positions | |
m1 = round(w * 0.5) | |
m2 = round(desiredLeftEye[0] * w) | |
# define the scale factor | |
scale = 1 / self.alignment_params["scale"] | |
width = int(self.alignment_params["shape"][0]) | |
long_edge_size = width / abs(np.cos(np.deg2rad(self.alignment_params["angle"]))) | |
w_original = int(scale * long_edge_size) | |
h_original = int(scale * long_edge_size) | |
# get offset | |
tX = w_original * 0.5 | |
tY = h_original * desiredLeftEye[1] | |
# get rotation center | |
center = torch.ones(1, 2) | |
center[..., 0] = m1 | |
center[..., 1] = m2 | |
# compute the transformation matrix | |
M = tgm.get_rotation_matrix2d(center, angle, scale).to(self.device) | |
M[0, 0, 2] += (tX - m1) | |
M[0, 1, 2] += (tY - m2) | |
# get insertion point | |
x_start = int(rotation_point[0] - (0.5 * w_original)) | |
y_start = int(rotation_point[1] - (desiredLeftEye[0] * h_original)) | |
_, _, h_tensor, w_tensor = self.real_B_unaligned_full.shape | |
# Now apply the transformation to original image | |
# clone fake | |
fake_B_clone = self.fake_B.clone().requires_grad_() | |
# apply warp | |
fake_B_warped = tgm.warp_affine(fake_B_clone, M, dsize=(h_original, w_original)) | |
# clone warped | |
self.fake_B_unaligned = fake_B_warped.clone().requires_grad_() | |
# make sure warping does not exceed real_B_unaligned_full dimensions | |
if y_start < 0: | |
fake_B_warped = fake_B_warped[:, :, abs(y_start):h_original, :] | |
h_original += y_start | |
y_start = 0 | |
if x_start < 0: | |
fake_B_warped = fake_B_warped[:, :, :, abs(x_start):w_original] | |
w_original += x_start | |
x_start = 0 | |
if y_start + h_original > h_tensor: | |
h_original -= (y_start + h_original - h_tensor) | |
fake_B_warped = fake_B_warped[:, :, 0:h_original, :] | |
if x_start + w_original > w_tensor: | |
w_original -= (x_start + w_original - w_tensor) | |
fake_B_warped = fake_B_warped[:, :, :, 0:w_original] | |
# create mask that is true where fake_B_warped is 0 | |
This is the background that is not filled with image after the transformation | |
mask = ((fake_B_warped[0][0] == 0) & (fake_B_warped[0][1] == 0) & (fake_B_warped[0][2] == 0)) | |
# fill fake_B_filled where mask = False with self.real_B_unaligned_full | |
fake_B_filled = torch.where(mask, | |
self.real_B_unaligned_full[:, :, y_start:y_start + h_original, | |
x_start:x_start + w_original], | |
fake_B_warped | |
# reinsert into tensor | |
self.fake_B_unaligned = self.real_B_unaligned_full.clone().requires_grad_() | |
mask = torch.zeros_like(self.fake_B_unaligned, dtype=torch.bool) | |
mask[0, :, y_start:y_start + h_original, x_start:x_start + w_original] = True | |
self.fake_B_unaligned = self.fake_B_unaligned.masked_scatter(mask, fake_B_filled | |
# cutout tensor | |
h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:] | |
margin = max( | |
min( | |
y_start - max(0, y_start - margin), | |
x_start - max(0, x_start - margin), | |
min(y_start + h_original + margin, h_size_tensor) - y_start - h_original, | |
min(x_start + w_original + margin, w_size_tensor) - x_start - w_original, | |
), | |
0 | |
) | |
self.fake_B_unaligned = self.fake_B_unaligned[:, :, y_start - margin:y_start + h_original + margin, | |
x_start - margin:x_start + w_original + margin] | |
self.real_B_unaligned = self.real_B_unaligned_full[:, :, y_start - margin:y_start + h_original + margin, | |
x_start - margin:x_start + w_original + margin] | |
self.fake_B_unaligned = F.interpolate(self.fake_B_unaligned, size=256) | |
self.real_B_unaligned = F.interpolate(self.real_B_unaligned, size=256) | |
self.fake_B_unaligned = F.interpolate(fake_B_warped, size=256) | |
self.real_B_unaligned = F.interpolate(self.real_B_unaligned_full[:, :, y_start:y_start + h_original, | |
x_start:x_start + w_original], size=256) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment