Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Created November 14, 2023 18:42
Show Gist options
  • Save KeAWang/db737b1fe43a864fb15fb4b4c9005ef5 to your computer and use it in GitHub Desktop.
Save KeAWang/db737b1fe43a864fb15fb4b4c9005ef5 to your computer and use it in GitHub Desktop.
PyTorch NaN embedder
import torch
class NanWrapper(torch.nn.Module):
"""Wrapper module around a torch Module that handles incoming nans"""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
""" Masks the entire last dimension (usually the feature/channel dimension) if any element is NaN. """
mask = ~torch.any(torch.isnan(x), dim=-1, keepdim=True) # 0 if nan, 1 otherwise
masked_x = torch.where(mask, x, torch.zeros_like(x))
fx = self.module(masked_x)
masked_fx = fx * mask
return masked_fx
class ConvLinear(torch.nn.Module):
"""Linear projection implemented as a 1x1 convolution; useful for sequence data"""
def __init__(self, in_features, out_features, bias=True, channel_last=True):
super().__init__()
self.conv = torch.nn.Conv1d(in_features, out_features, kernel_size=1, bias=bias)
self.channel_last = channel_last
def forward(self, x):
assert x.ndim == 3, "Expected input to be (batch_size, seq_len, input_size) or (batch_size, input_size, seq_len)"
if self.channel_last:
x = x.transpose(1, 2)
x = self.conv(x)
x = x.transpose(1, 2)
else:
x = self.conv(x)
return x
# example usage
if __name__ == "__main__":
nan_embedder1 = ConvLinear(3, 4)
nan_embedder2 = torch.nn.Linear(3, 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment