Skip to content

Instantly share code, notes, and snippets.

@mowolf
Created April 13, 2020 16:13
Show Gist options
  • Save mowolf/858e01906402c1c8a64b33d488ae975f to your computer and use it in GitHub Desktop.
Save mowolf/858e01906402c1c8a64b33d488ae975f to your computer and use it in GitHub Desktop.
# this kills the gradient of fake_B_unaligned
# fake_B_unaligned should have the gradient of this alignement operation
# real_B_unaligned should have no gradient
def align_fake(self, margin=70):
# get params
desiredLeftEye = [float(self.alignment_params["desiredLeftEye"][0]),
float(self.alignment_params["desiredLeftEye"][1])]
rotation_point = self.alignment_params["eyesCenter"]
angle = -self.alignment_params["angle"]
h, w = self.fake_B.shape[2:]
# get original positions
m1 = round(w * 0.5)
m2 = round(desiredLeftEye[0] * w)
# define the scale factor
scale = 1 / self.alignment_params["scale"]
width = int(self.alignment_params["shape"][0])
long_edge_size = width / abs(np.cos(np.deg2rad(self.alignment_params["angle"])))
w_original = int(scale * long_edge_size)
h_original = int(scale * long_edge_size)
# get offset
tX = w_original * 0.5
tY = h_original * desiredLeftEye[1]
# get rotation center
center = torch.ones(1, 2)
center[..., 0] = m1
center[..., 1] = m2
# compute the transformation matrix
M = tgm.get_rotation_matrix2d(center, angle, scale).to(self.device)
M[0, 0, 2] += (tX - m1)
M[0, 1, 2] += (tY - m2)
# get insertion point
x_start = int(rotation_point[0] - (0.5 * w_original))
y_start = int(rotation_point[1] - (desiredLeftEye[0] * h_original))
_, _, h_tensor, w_tensor = self.real_B_unaligned_full.shape
# Now apply the transformation to original image
# clone fake
fake_B_clone = self.fake_B.clone().requires_grad_()
# apply warp
fake_B_warped = tgm.warp_affine(fake_B_clone, M, dsize=(h_original, w_original))
# clone warped
self.fake_B_unaligned = fake_B_warped.clone().requires_grad_()
# make sure warping does not exceed real_B_unaligned_full dimensions
if y_start < 0:
fake_B_warped = fake_B_warped[:, :, abs(y_start):h_original, :]
h_original += y_start
y_start = 0
if x_start < 0:
fake_B_warped = fake_B_warped[:, :, :, abs(x_start):w_original]
w_original += x_start
x_start = 0
if y_start + h_original > h_tensor:
h_original -= (y_start + h_original - h_tensor)
fake_B_warped = fake_B_warped[:, :, 0:h_original, :]
if x_start + w_original > w_tensor:
w_original -= (x_start + w_original - w_tensor)
fake_B_warped = fake_B_warped[:, :, :, 0:w_original]
# create mask that is true where fake_B_warped is 0
This is the background that is not filled with image after the transformation
mask = ((fake_B_warped[0][0] == 0) & (fake_B_warped[0][1] == 0) & (fake_B_warped[0][2] == 0))
# fill fake_B_filled where mask = False with self.real_B_unaligned_full
fake_B_filled = torch.where(mask,
self.real_B_unaligned_full[:, :, y_start:y_start + h_original,
x_start:x_start + w_original],
fake_B_warped
# reinsert into tensor
self.fake_B_unaligned = self.real_B_unaligned_full.clone().requires_grad_()
mask = torch.zeros_like(self.fake_B_unaligned, dtype=torch.bool)
mask[0, :, y_start:y_start + h_original, x_start:x_start + w_original] = True
self.fake_B_unaligned = self.fake_B_unaligned.masked_scatter(mask, fake_B_filled
# cutout tensor
h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:]
margin = max(
min(
y_start - max(0, y_start - margin),
x_start - max(0, x_start - margin),
min(y_start + h_original + margin, h_size_tensor) - y_start - h_original,
min(x_start + w_original + margin, w_size_tensor) - x_start - w_original,
),
0
)
self.fake_B_unaligned = self.fake_B_unaligned[:, :, y_start - margin:y_start + h_original + margin,
x_start - margin:x_start + w_original + margin]
self.real_B_unaligned = self.real_B_unaligned_full[:, :, y_start - margin:y_start + h_original + margin,
x_start - margin:x_start + w_original + margin]
self.fake_B_unaligned = F.interpolate(self.fake_B_unaligned, size=256)
self.real_B_unaligned = F.interpolate(self.real_B_unaligned, size=256)
self.fake_B_unaligned = F.interpolate(fake_B_warped, size=256)
self.real_B_unaligned = F.interpolate(self.real_B_unaligned_full[:, :, y_start:y_start + h_original,
x_start:x_start + w_original], size=256)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment