Skip to content

Instantly share code, notes, and snippets.

@petered
Last active July 3, 2022 01:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save petered/8c59ebe02208c9bf470a68cb610f7264 to your computer and use it in GitHub Desktop.
Save petered/8c59ebe02208c9bf470a68cb610f7264 to your computer and use it in GitHub Desktop.
A rough draft of "image warping from a heat map"
import cv2
import numpy as np
def warp_image_with_heatmap(src_image: 'BGRImageArray', heatmap: 'HeatMapArray') -> 'BGRImageArray':
""" Rough draft of warping an image with a heatmap.
We compute a set of "resampled pixel locations" and then use cv2.remap to resample the image from these points.
Think of heatmap as representing a grid of point masses.
Each heatmap point "pulls" each resampled-pixel-point away from its original location.
The pull is proportional to the "heat" of the heatmap_pixel and the inverse distance between heatmap pixel and resampled pixel
This function is bad because:
- If force_constant not set properly (depends on image size and heatmap), pixels can overshoot sink
- Extremely slow - iterates over entire image for each pixel, so it's O(HWHW)
"""
force_constant = 100 # The "gravitational constant" - higher means more warping.
xs, ys = np.meshgrid(np.arange(src_image.shape[1]), np.arange(src_image.shape[0]))
xy_grid = np.concatenate([xs[:, :, None], ys[:, :, None]], axis=2)
print(f'Computing force field for image of shape {src_image.shape}...')
force_field = np.zeros(src_image.shape[:2] + (2,))
for i, (xy, h) in enumerate(zip(xy_grid.reshape(-1, 2), heatmap.ravel())):
vector = xy - xy_grid
distance_sq = np.sum(vector ** 2, axis=2) + 1e-9
force_field += force_constant * (vector * h) / distance_sq[:, :, None]
if i % 100 == 0:
print(f'.. {(i + 1) / (src_image.shape[0] * src_image.shape[1]):.2%}')
print('Done')
new_xy = (xy_grid + force_field).astype(np.float32)
distorted = cv2.remap(src_image, map1=new_xy[:, :, 0], map2=new_xy[:, :, 1], interpolation=cv2.INTER_LINEAR)
return distorted
def demo_standalone_image_warp():
image = cv2.imread(cv2.samples.findFile('lena.jpg'))
image = cv2.resize(image, dsize=None, fx=0.5, fy=0.5)
# Create heatmap from two superimposed gaussians
h, w = image.shape[:2]
xs, ys = np.meshgrid(np.arange(image.shape[1]), np.arange(image.shape[0]))
xy_grid = np.concatenate([xs[:, :, None], ys[:, :, None]], axis=2)
mu1 = 0.55*w, 0.53*h
sig1 = 0.07*w
mu2 = 0.2*w, 0.3*h
sig2 = 0.1*w
heatmap = np.exp(-((xy_grid-mu1)**2).sum(axis=2)/(2*sig1**2))/sig1**2 + np.exp(-((xy_grid-mu2)**2).sum(axis=2)/(2*sig2**2))/sig2**2
# Compute the warped image
distorted = warp_image_with_heatmap(image, heatmap)
# Display
heatmap_image = np.repeat((heatmap/heatmap.max()).astype(np.uint8)[:, :, None], repeats=3, axis=2)
display_image = np.hstack((image, heatmap_image, distorted))
cv2.imshow('Warping', display_image)
cv2.waitKey(10000)
if __name__ == "__main__":
demo_standalone_image_warp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment