Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Last active December 22, 2023 11:31
Show Gist options
  • Save lzqlzzq/3979d16976d8d4e4dd6ba5dbe4ba8f54 to your computer and use it in GitHub Desktop.
Save lzqlzzq/3979d16976d8d4e4dd6ba5dbe4ba8f54 to your computer and use it in GitHub Desktop.
A simple and clean Multilayer Perceptron implementation (pytorch)
from typing import List
from itertools import chain
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self,
hidden_sizes: List[int],
act_fn: nn.Module = nn.ReLU,
layer_norm: bool = True,
bias: bool = True,
dropout=.2):
super().__init__()
self.layers = nn.Sequential(*[
nn.Sequential(
nn.Linear(in_features=in_size, out_features=out_size, bias=bias),
act_fn(),
nn.LayerNorm(out_size) if layer_norm else nn.Identity(),
nn.Dropout(dropout)) \
for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])])
def forward(self, x):
return self.layers(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment