Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created May 11, 2021 01:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/64f84ae826c5afa8f8d95e282c1ea09f to your computer and use it in GitHub Desktop.
Save gngdb/64f84ae826c5afa8f8d95e282c1ea09f to your computer and use it in GitHub Desktop.
Technically an implementation of MLP-Mixer without nn.Linear
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MLP-Mixer",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lObQ4yDh5l9x",
"outputId": "f1f86d8d-547f-428a-fdb8-5a93d189beff"
},
"source": [
"!pip install einops"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: einops in /usr/local/lib/python3.7/dist-packages (0.3.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AAxH02j2EMgk"
},
"source": [
"This is [lucidrain's implementation](https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py):"
]
},
{
"cell_type": "code",
"metadata": {
"id": "exJtcDPw4up-"
},
"source": [
"from torch import nn\n",
"from functools import partial\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"\n",
"class PreNormResidual(nn.Module):\n",
" def __init__(self, dim, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" self.norm = nn.LayerNorm(dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fn(self.norm(x)) + x\n",
"\n",
"def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):\n",
" return nn.Sequential(\n",
" dense(dim, dim * expansion_factor),\n",
" nn.GELU(),\n",
" nn.Dropout(dropout),\n",
" dense(dim * expansion_factor, dim),\n",
" nn.Dropout(dropout)\n",
" )\n",
"\n",
"def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):\n",
" assert (image_size % patch_size) == 0, 'image must be divisible by patch size'\n",
" num_patches = (image_size // patch_size) ** 2\n",
" chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear\n",
"\n",
" return nn.Sequential(\n",
" Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),\n",
" nn.Linear((patch_size ** 2) * 3, dim),\n",
" *[nn.Sequential(\n",
" PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),\n",
" PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))\n",
" ) for _ in range(depth)],\n",
" nn.LayerNorm(dim),\n",
" Reduce('b n c -> b c', 'mean'),\n",
" nn.Linear(dim, num_classes)\n",
" )"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OS5_6jEl492Y"
},
"source": [
"import torch\n",
"\n",
"model = MLPMixer(\n",
" image_size = 256,\n",
" patch_size = 16,\n",
" dim = 512,\n",
" depth = 12,\n",
" num_classes = 1000\n",
")\n",
"\n",
"img = torch.randn(1, 3, 256, 256)\n",
"pred = model(img) # (1, 1000)\n",
"params = model.state_dict()"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wiLEmz0VFFi8",
"outputId": "f890e9dd-fb8b-484b-b77c-2cd376bf728c"
},
"source": [
"import time\n",
"batch = torch.randn(16, 3, 256, 256).cuda()\n",
"model = model.cuda()\n",
"before = time.time()\n",
"for _ in range(10):\n",
" _ = model(batch)\n",
"print(f\"Execution time {time.time() - before}\")"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"Execution time 0.5921013355255127\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nS7Jv7t6EU8K"
},
"source": [
"Here it is without any Linear modules:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "InJDOp_J5obZ"
},
"source": [
"class PreNormResidual(nn.Module):\n",
" def __init__(self, dim, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" self.norm = nn.LayerNorm(dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fn(self.norm(x)) + x\n",
"\n",
"def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = None):\n",
" return nn.Sequential(\n",
" dense(dim, dim * expansion_factor),\n",
" nn.GELU(),\n",
" nn.Dropout(dropout),\n",
" dense(dim * expansion_factor, dim),\n",
" nn.Dropout(dropout)\n",
" )\n",
"\n",
"def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):\n",
" assert (image_size % patch_size) == 0, 'image must be divisible by patch size'\n",
" num_patches = (image_size // patch_size) ** 2\n",
" def dense(dim_in, dim_out):\n",
" return nn.Sequential(Rearrange('b c n -> b n c'), nn.Conv1d(dim_in, dim_out, 1), Rearrange('b n c -> b c n'))\n",
" chan_first = partial(nn.Conv1d, kernel_size = 1)\n",
" chan_last = dense\n",
"\n",
" return nn.Sequential(\n",
" Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),\n",
" dense((patch_size ** 2) * 3, dim),\n",
" *[nn.Sequential(\n",
" PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),\n",
" PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))\n",
" ) for _ in range(depth)],\n",
" nn.LayerNorm(dim),\n",
" Reduce('b n c -> b () c', 'mean'),\n",
" dense(dim, num_classes),\n",
" Rearrange('b () c -> b c')\n",
" )"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5UlUQ7cVEqFN"
},
"source": [
"This cell verifies that it produces the same output:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ic6QGbRZ7wFm"
},
"source": [
"_model = MLPMixer(\n",
" image_size = 256,\n",
" patch_size = 16,\n",
" dim = 512,\n",
" depth = 12,\n",
" num_classes = 1000\n",
")\n",
"\n",
"_params = {}\n",
"keys = [k for k in params]\n",
"for k in _model.state_dict():\n",
" p = params[keys.pop(0)]\n",
" _params[k] = p.unsqueeze(-1) if 'weight' in k and p.ndim == 2 else p\n",
"_model.load_state_dict(_params)\n",
"\n",
"_pred = _model(img) # (1, 1000)\n",
"assert torch.abs(pred - _pred).max() < 1e-3"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Khn85urcFzzk",
"outputId": "87148952-264e-4ffb-9011-9b3f503a51ba"
},
"source": [
"_model = _model.cuda()\n",
"before = time.time()\n",
"for _ in range(10):\n",
" _ = _model(batch)\n",
"print(f\"Execution time {time.time() - before}\")"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Execution time 0.8272452354431152\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment