Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Created August 3, 2022 13:43
Show Gist options
  • Save torridgristle/ed572d416c9acc9d1495d8bb25fb715d to your computer and use it in GitHub Desktop.
Save torridgristle/ed572d416c9acc9d1495d8bb25fb715d to your computer and use it in GitHub Desktop.
Max Pool 2d Unpooling
# Perform max pool 2d with indicies on a tensor
max_size = 8
max_output, max_indices = F.max_pool2d_with_indices(input_tensor,max_size)
# Unpool it to get a tensor of the original size with zeros in all non-max areas
max_unpool = F.max_unpool2d(max_output,max_indices,max_size,max_size)
# Unpool it using a tensor of ones with the same indices to get ones where the tensor was sampled
max_mask = F.max_unpool2d(torch.ones_like(max_output),max_indices,max_size,max_size)
# Makes a kernel that's round and the distance from the center
def DistanceKernel(size=9):
blur_kernel = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,size+2)[1:-1]])).reshape(1,2,size,size)
blur_kernel = 1-blur_kernel.norm(2,1,True)
blur_kernel = blur_kernel.relu().reshape(1,1,size,size)
blur_kernel = blur_kernel / blur_kernel.sum()
return blur_kernel
def CustomUnpooling(x, mask, width, pow=1):
pad = (width-1)/2
pad = [math.floor(pad),math.ceil(pad),math.floor(pad),math.ceil(pad)]
kernel = DistanceKernel(width) ** pow
x = F.pad(x,pad,'constant',0.0)
x_weighted = F.conv2d(x,kernel.expand(x.shape[1],-1,-1,-1),None,1,groups=x.shape[1])
mask = F.pad(mask,pad,'constant',0.0)
mask_weighted = F.conv2d(mask,kernel.expand(mask.shape[1],-1,-1,-1),None,1,groups=mask.shape[1])
output = x_weighted / mask_weighted
return output
# The pow argument makes it sharper as it goes up, 8 seems to be a reasonable upper limit.
# Otherwise it can get blurry as width increases and masked areas overlap.
# Depending on the sparsity of some areas you might need a width that's almost double the max pool's kernel size
smooth_unpool = CustomUnpooling(max_unpool,max_mask,16,8)
# Example outputs https://imgur.com/a/qGFpSUO it attempts to spread out known values until it hits another known value,
# appearing almost like voronoi cells. In the last example the samples are equally spaced to show that it will
# create square shapes if the input is equally spaced since it's dependent on
# the locations of sampled values, not the values thesmelves.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment