Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created July 16, 2023 15:00
Show Gist options
  • Save madebyollin/ea69bffa92a7092a720438386a11d098 to your computer and use it in GitHub Desktop.
Save madebyollin/ea69bffa92a7092a720438386a11d098 to your computer and use it in GitHub Desktop.
bfloat16 nearest neighbor upsample code, for when F.interpolate / nn.Upsample don't work
# PyTorch <=2.0 doesn't support bfloat16 F.interpolate natively.
# so, we have to do things the old fashioned way.
import torch
import torch.nn as nn
# functional implementation
def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int):
"""Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation."""
s = scale_factor
return x.reshape(*x.shape, 1, 1).expand(*x.shape, s, s).transpose(-2, -3).reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
# nonfunctional implementation
class NearestNeighborUpsample(nn.Module):
def __init__(self, scale_factor):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x):
return nearest_neighbor_upsample(x, self.scale_factor)
@madebyollin
Copy link
Author

A typical error message when trying to use F.interpolate for bfloat16 in older PyTorch versions is:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3931, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'BFloat16'

The version in this gist avoids that error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment