Skip to content

Instantly share code, notes, and snippets.

@peteflorence
Last active January 16, 2024 14:18
Show Gist options
  • Save peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e to your computer and use it in GitHub Desktop.
Save peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e to your computer and use it in GitHub Desktop.
Bilinear interpolation in PyTorch, and benchmarking vs. numpy

Here's a simple implementation of bilinear interpolation on tensors using PyTorch.

I wrote this up since I ended up learning a lot about options for interpolation in both the numpy and PyTorch ecosystems. More generally than just interpolation, too, it's also a nice case study in how PyTorch magically can put very numpy-like code on the GPU (and by the way, do autodiff for you too).

For interpolation in PyTorch, this open issue calls for more interpolation features. There is now a nn.functional.grid_sample() feature but at least at first this didn't look like what I needed (but we'll come back to this later).

In particular I wanted to take an image, W x H x C, and sample it many times at different random locations. Note also that this is different than upsampling which exhaustively samples and also doesn't give us flexibility with the precision of sampling.

The implementations: numpy and PyTorch

First let's look at a comparable implementation in numpy which is slightly modified from here.

import numpy as np

def bilinear_interpolate_numpy(im, x, y):
    x0 = np.floor(x).astype(int)
    x1 = x0 + 1
    y0 = np.floor(y).astype(int)
    y1 = y0 + 1

    x0 = np.clip(x0, 0, im.shape[1]-1)
    x1 = np.clip(x1, 0, im.shape[1]-1)
    y0 = np.clip(y0, 0, im.shape[0]-1)
    y1 = np.clip(y1, 0, im.shape[0]-1)

    Ia = im[ y0, x0 ]
    Ib = im[ y1, x0 ]
    Ic = im[ y0, x1 ]
    Id = im[ y1, x1 ]

    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    return (Ia.T*wa).T + (Ib.T*wb).T + (Ic.T*wc).T + (Id.T*wd).T

And now here I've converted this implementation to PyTorch:

import torch
dtype = torch.cuda.FloatTensor
dtype_long = torch.cuda.LongTensor

def bilinear_interpolate_torch(im, x, y):
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1

    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    
    Ia = im[ y0, x0 ][0]
    Ib = im[ y1, x0 ][0]
    Ic = im[ y0, x1 ][0]
    Id = im[ y1, x1 ][0]
    
    wa = (x1.type(dtype)-x) * (y1.type(dtype)-y)
    wb = (x1.type(dtype)-x) * (y-y0.type(dtype))
    wc = (x-x0.type(dtype)) * (y1.type(dtype)-y)
    wd = (x-x0.type(dtype)) * (y-y0.type(dtype))

    return torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd)

Testing for correctness

Bilinear interpolation is very simple but there are a few things that can be easily messed up.

I did a quick comparison for correctness with SciPy's interp2d.

  • Side note: there are actually a ton of interpolation options in SciPy but none I tested met my critera of (a) doing bilinear interpolation for high-dimensional spaces and (b) efficiently use gridded data. The ones I tested that were built for many dimensions were requiring me to specify sample points for all of those dimensions (and doing trilinear, or other) interpolation. I could get LinearNDInterpolator to do bilinear interpolation for high dimensional vectors but this does not meet criteria (b). There's probably a better option but, at any rate, I gave up and went back to my numpy and PyTorch options.
# Also use scipy to check for correctness
import scipy.interpolate
def bilinear_interpolate_scipy(image, x, y):
    x_indices = np.arange(image.shape[0])
    y_indices = np.arange(image.shape[1])
    interp_func = scipy.interpolate.interp2d(x_indices, y_indices, image, kind='linear')
    return interp_func(x,y)

# Make small sample data that's easy to interpret
image = np.ones((5,5))
image[3,3] = 4
image[3,4] = 3

sample_x, sample_y = np.asarray([3.2]), np.asarray([3.4])

print "numpy result:", bilinear_interpolate_numpy(image, sample_x, sample_y)
print "scipy result:", bilinear_interpolate_scipy(image, sample_x, sample_y)

image = torch.unsqueeze(torch.FloatTensor(image).type(dtype),2)
sample_x = torch.FloatTensor([sample_x]).type(dtype)
sample_y = torch.FloatTensor([sample_y]).type(dtype)

print "torch result:", bilinear_interpolate_torch(image, sample_x, sample_y)

The above gives:

numpy result: [2.68]
scipy result: [2.68]
torch result: 
 2.6800
[torch.cuda.FloatTensor of size 1x1 (GPU 0)]

High dimensional bilinear interpolation

For the correctness test comparing with scipy, we couldn't do W x H x C interpolation for anything but C=1. Now though, we can do bilinear interpolation in either numpy or torch for arbitrary C:

# Do high dimensional bilinear interpolation in numpy and PyTorch
W, H, C = 25, 25, 7
image = np.random.randn(W, H, C)

num_samples = 4
samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1)

print bilinear_interpolate_numpy(image, samples_x, samples_y)

image = torch.from_numpy(image).type(dtype)
samples_x = torch.FloatTensor([samples_x]).type(dtype)
samples_y = torch.FloatTensor([samples_y]).type(dtype)

print bilinear_interpolate_torch(image, samples_x, samples_y)

You'll find that the above numpy and torch versions give the same result.

Bechmarking: numpy (CPU) vs. pytorch (CPU) vs. pytorch (GPU)

Now we do some simple benchmarking:

# Timing comparison for WxHxC (where C is large for a high dimensional descriptor)
W, H, C = 640, 480, 32
image = np.random.randn(W, H, C)

num_samples = 10000
samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1)

import time

start = time.time()
bilinear_interpolate_numpy(image, samples_x, samples_y)
print "numpy took       ", time.time() - start

dtype = torch.FloatTensor
dtype_long = torch.LongTensor
image = torch.FloatTensor(image).type(dtype)
samples_x = torch.FloatTensor([samples_x]).type(dtype)
samples_y = torch.FloatTensor([samples_y]).type(dtype)

start = time.time()
bilinear_interpolate_torch(image, samples_x, samples_y)
print "torch on CPU took", time.time() - start 

dtype = torch.cuda.FloatTensor
dtype_long = torch.cuda.LongTensor
image = image.type(dtype)
samples_x = samples_x.type(dtype)
samples_y = samples_y.type(dtype)

start = time.time()
bilinear_interpolate_torch(image, samples_x, samples_y)
print "torch on GPU took", time.time() - start

On my machine (CPU: 10-core i7-6950X, GPU: GTX 1080) I get the following times (in seconds):

numpy took        0.00756597518921
torch on CPU took 0.12672996521
torch on GPU took 0.000642061233521

Interestingly we have torch on the GPU beating numpy (CPU-only) by about 10x. I'm not sure why torch on the CPU is that slow for this test case. Note that the ratios between these change quite drastically for different W, H, C, num_samples.

Using the available nn.functional.grid_sample()

I ended up figuring out how to use nn.functional.grid_sample() although it was a little odd of a fit for my needs. (Data needs to be in N x C x W x H tensor input, and samples need to be as normalized between [-1,1], and AFAIK the WxH ordering of the samples do not have any meaning other -- they are completely separate samples.)

It was good practice in using permute, multiple unsqueezes, cat.

import torch.nn.functional
dtype = torch.cuda.FloatTensor
dtype_long = torch.cuda.LongTensor

def bilinear_interpolate_torch_gridsample(image, samples_x, samples_y):
                                                # input image is: W x H x C
    image = image.permute(2,0,1)                # change to:      C x W x H
    image = image.unsqueeze(0)                  # change to:  1 x C x W x H
    samples_x = samples_x.unsqueeze(2)
    samples_x = samples_x.unsqueeze(3)
    samples_y = samples_y.unsqueeze(2)
    samples_y = samples_y.unsqueeze(3)
    samples = torch.cat([samples_x, samples_y],3)
    samples[:,:,:,0] = (samples[:,:,:,0]/(W-1)) # normalize to between  0 and 1
    samples[:,:,:,1] = (samples[:,:,:,1]/(H-1)) # normalize to between  0 and 1
    samples = samples*2-1                       # normalize to between -1 and 1
    return torch.nn.functional.grid_sample(image, samples)

# Correctness test
W, H, C = 5, 5, 1
test_image = torch.ones(W,H,C).type(dtype)
test_image[3,3,:] = 4
test_image[3,4,:] = 3

test_samples_x = torch.FloatTensor([[3.2]]).type(dtype)
test_samples_y = torch.FloatTensor([[3.4]]).type(dtype)

print bilinear_interpolate_torch_gridsample(test_image, test_samples_x, test_samples_y)

# Benchmark
start = time.time()
bilinear_interpolate_torch_gridsample(image, samples_x, samples_y)
print "torch gridsample took ", time.time() - start

My wrapping of grid_sample produces the same bilinear interpolation results and at speeds comparable to our bilinear_interpolate_torch() function:

Variable containing:
(0 ,0 ,.,.) = 
  2.6800
[torch.cuda.FloatTensor of size 1x1x1x1 (GPU 0)]

torch gridsample took  0.000624895095825

Another note about the nn.functional.grid_sample() interface is that it forces the sampled interpolations into a Variable() wrapper even though this seems best left to the programmer to decide.

Conclusions

  • It's surprisingly easy to convert powerful vectorized numpy code into more-powerful vectorized PyTorch code
  • PyTorch is very fast on the GPU
  • Some of the higher-level feature (like nn.function.grid_sample) are nice but so too is writing your own tensor manipulations (and can be comparably fast)
@KruskalLin
Copy link

Thanks for sharing!

@v-prgmr
Copy link

v-prgmr commented Dec 9, 2019

@peteflorence,

Ia = im[ y0, x0 ][0]
Ib = im[ y1, x0 ][0]
Ic = im[ y0, x1 ][0]
Id = im[ y1, x1 ][0]

Aren't the above lines in the torch implementation, only taking the intensity values at the (0,0) , (0,1), (1,0) , (1,1) for the first coordinate of samples_x and samples_y ?

@sbarratt
Copy link

Some in this thread might be interested in the torch_interpolations package: https://github.com/sbarratt/torch_interpolations

@chenerg
Copy link

chenerg commented Nov 20, 2020

dude, this is amazing

@thanhmvu
Copy link

thanhmvu commented Dec 3, 2020

Kudos to this! Very nice write-up indeed.

One minor detail worth pointing out is inbilinear_interpolate_numpy, points at or beyond the boundaries will map to interp(im, w-1, h-1) = 0.0, even though I believe the desired result is interp(im, w-1, h-1) = im[w-1, h-1]. For example, I tried sample_x, sample_y = np.asarray([4]), np.asarray([3]) and got numpy result: [0.] scipy result: [3.]

@OrkhanHI
Copy link

OrkhanHI commented Jan 4, 2021

Thanks for sharing. I have added python code which includes batch size with zero padding and can pass align_corners argument.

https://github.com/OrkhanHI/pytorch_grid_sample_python/blob/main/pytorch_grid_sample_python.md

@thomasaarholt
Copy link

Thanks for sharing!

@ppwwyyxx
Copy link

ppwwyyxx commented Mar 17, 2022

Anyone looking at this now and future should note the flaws in this gist. Listing them here:

  1. The behavior of grid_sample has changed in pytorch a few years ago. The old behavior is imprecise pytorch/pytorch#20785
  2. Why this gist claims to get same results between its bilinear sample implementation and (old) grid_sample? Because the implementation in this gist is also imprecise in the same way.
  3. The claim in this gist that "speeds (of grid_sample) comparable to our bilinear_interpolate_torch() function:" is also unfounded: pytorch's cuda operations are executed async and it cannot be benchmarked by "time.time()". Not to mention that benchmark cannot be done by running the operation for only once.

@cbender98
Copy link

ad the boundary problem and this comment @thanhmvu:

Kudos to this! Very nice write-up indeed.

One minor detail worth pointing out is inbilinear_interpolate_numpy, points at or beyond the boundaries will map to interp(im, w-1, h-1) = 0.0, even though I believe the desired result is interp(im, w-1, h-1) = im[w-1, h-1]. For example, I tried sample_x, sample_y = np.asarray([4]), np.asarray([3]) and got numpy result: [0.] scipy result: [3.]

The problem of bilinear_interpolate_torch(im, x, y) and bilinear_interpolate_numpy(im, x, y) at the boundaries is due to the fact, that at the boundary x0 == x1 or y0 == y1, leading to weighting factors w* to be 0. This can easily be fixed by the following changes:

    x0 = torch.clamp(x0, 0, image.shape[0]-2)
    x1 = torch.clamp(x1, 1, image.shape[0]-1)
    y0 = torch.clamp(y0, 0, image.shape[1]-2)
    y1 = torch.clamp(y1, 1, image.shape[1]-1)

This ensures that at least one weighting factor is non-zero for boundary interpolation points. Please correct me, if my correction is wrong :) .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment