Created
July 16, 2023 15:00
-
-
Save madebyollin/ea69bffa92a7092a720438386a11d098 to your computer and use it in GitHub Desktop.
bfloat16 nearest neighbor upsample code, for when F.interpolate / nn.Upsample don't work
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PyTorch <=2.0 doesn't support bfloat16 F.interpolate natively. | |
# so, we have to do things the old fashioned way. | |
import torch | |
import torch.nn as nn | |
# functional implementation | |
def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int): | |
"""Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation.""" | |
s = scale_factor | |
return x.reshape(*x.shape, 1, 1).expand(*x.shape, s, s).transpose(-2, -3).reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:])) | |
# nonfunctional implementation | |
class NearestNeighborUpsample(nn.Module): | |
def __init__(self, scale_factor): | |
super().__init__() | |
self.scale_factor = scale_factor | |
def forward(self, x): | |
return nearest_neighbor_upsample(x, self.scale_factor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A typical error message when trying to use
F.interpolate
forbfloat16
in older PyTorch versions is:The version in this gist avoids that error.