Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Created August 24, 2022 18:58
Show Gist options
  • Save torridgristle/24c8c672e285668d53b3b0efec7cf4db to your computer and use it in GitHub Desktop.
Save torridgristle/24c8c672e285668d53b3b0efec7cf4db to your computer and use it in GitHub Desktop.
Blur an image with a depth map in PyTorch. Splits the map into ranges of values, multiplies the image by those ranges, blurs them and the split map, sums all the blurred images and blurred maps together, divide blurred image sum by blurred map sum.
#1 is end and 0 is start in the map.
def map_blur(img,map,s_start=0.375,s_end=8,steps=8):
img_slices = img * 0
map_slices = map * 0
for s in range(steps):
sigma = (s/(steps-1)) * (s_end-s_start) + s_start
slice_start = (s+0)/steps
slice_end = (s+1)/steps
map_slice = torch.logical_and(
torch.greater_equal(map,slice_start),
torch.less(map,slice_end) if slice_end != 1.0 else torch.less_equal(map,slice_end),
).float()
img_slice = img * map_slice
map_slices += GaussianBlur_Sigma(map_slice,sigma,False,'zeros')
img_slices += GaussianBlur_Sigma(img_slice,sigma,False,'zeros')
out = img_slices / map_slices
return out
### Simple gradients to use as maps for testing.
# Linear gradient
# test_map = torch.linspace(0,1,256).reshape(1,1,1,256) * torch.ones([1,1,256,256])
# Radial gradient
test_map = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,256)])).reshape(1,2,256,256)
# random center
test_map += torch.rand([1,2,1,1]).mul(2).sub(1)
# turn into radial gradient, divide by 2**0.5 since that's the value at the corners if the center is still (0,0)
test_map = test_map.norm(2,1,True) / (2**0.5)
# Invert it so that the center is 1 and the edges are either 0 or lower
test_map = 1-test_map
# Softplus to smoothly fade to zero instead of going below zero, since the map is only going to be used 0-1
test_map = F.softplus(test_map,8)
### Simple checker pattern to use as image for testing.
test_img = torch.ones([1,1,16,16])
test_img = torch.cat([
test_img,test_img*-1+1
],-1)
test_img = torch.cat([
test_img*-1+1,test_img
],-2)
test_img = test_img.repeat(1,1,8,8)
### My gaussian blurring function that only takes sigma as an input and determines what kernel size is needed for the smallest value on the kernel to be 1/255 or lower, up to sigma 20.
# Also has various strange things for padding with zero without fading to zero at the edges by blurring a tensor of ones padded with zeros and dividing the blurred image by that.
def GaussianBlur_Sigma(x, sigma=0.375, allow_even=False, pad_mode='reflect', norm_edges=True):
# Prediction of what input will result in 1/255 when put through exp(-0.5*(x/sigma)**2)
# that's then divided by the sum for points -1024 through 1024, up to sigma 20.
width = ((sigma**1.2)*-1.5)+sigma*4.54
if width <= 1:
return x
if allow_even == True:
width = math.ceil(width*2+1)
else:
width = math.ceil(width)*2+1
kernel = T.functional_tensor._get_gaussian_kernel1d(width,sigma).reshape(1,1,1,width).to(device)
pad = (width-1)*0.5
pad = [math.floor(pad),math.ceil(pad),math.floor(pad),math.ceil(pad)]
if pad_mode == 'zeros':
if norm_edges == True:
mask = F.pad(torch.ones([1,1,x.shape[-2],x.shape[-1]],device=device),pad,'constant',value=0.0)
x = F.pad(x,pad,'constant',value=0.0)
else:
x_new = F.pad(x,pad,'constant',value=0.0)
with torch.no_grad():
x_new.data = F.pad(x,pad,pad_mode)
x = x_new
x = F.conv2d(x,kernel.expand(x.shape[1],1,-1,-1),stride=1,groups=x.shape[1])
x = F.conv2d(x,kernel.permute(0,1,3,2).expand(x.shape[1],1,-1,-1),stride=1,groups=x.shape[1])
if pad_mode == 'zeros':
if norm_edges == True:
mask = F.conv2d(mask,kernel,stride=1,groups=1)
mask = F.conv2d(mask,kernel.permute(0,1,3,2),stride=1,groups=1)
x = x / mask.add(1e-8)
return x
### Example usage with img, map, min sigma, max sigma, and number of steps.
blur_out = map_blur(test_img,test_map,0.375,8,8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment