Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Last active January 3, 2021 08:59
Show Gist options
  • Save FrancescoSaverioZuppichini/30eabc661c2b1d6445c9a7e0e923d324 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/30eabc661c2b1d6445c9a7e0e923d324 to your computer and use it in GitHub Desktop.
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment