Skip to content

Instantly share code, notes, and snippets.

@tcapelle
Last active September 7, 2021 12:32
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tcapelle/aedddabb2e7138a9c105b9048756d5f3 to your computer and use it in GitHub Desktop.
Save tcapelle/aedddabb2e7138a9c105b9048756d5f3 to your computer and use it in GitHub Desktop.
A segmentation model using an MLP mixer. Code from @lucidrains
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout)
)
def MLPSegmentationMixer(image_size, channels, patch_size, dim, depth, out_channels=3, expansion_factor = 4, dropout = 0.):
assert (image_size[0] % patch_size) == 0, 'image must be divisible by patch size'
h, w = (image_size[0] // patch_size) , (image_size[1] // patch_size)
num_patches = h*w
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * channels, dim),
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
nn.Linear(dim, (patch_size ** 2) * out_channels),
Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', c=out_channels, h=h, p1 = patch_size, p2 = patch_size),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment