Skip to content

Instantly share code, notes, and snippets.

@rehno-lindeque
Last active February 1, 2024 19:50
Show Gist options
  • Save rehno-lindeque/22ea91274900ea195b6a8bf89c70cd68 to your computer and use it in GitHub Desktop.
Save rehno-lindeque/22ea91274900ea195b6a8bf89c70cd68 to your computer and use it in GitHub Desktop.
Personal PyTorch Cheatsheet: Low effort from scratch layers

PyTorch Cheatsheet: Low effort from scratch layers

  • b: Batch size
  • i: Input features (Linear)
  • o: Output features
  • c_in: Input channels
  • c_out: Output channels
  • n: Input/output length (1D sequence, often number of tokens)
  • n_in: Input length (1D sequence)
  • n_out: Output length (1D sequence)
  • h_in: Input height
  • w_in: Input width
  • h_out: Output height
  • w_out: Output width
  • k: Kernel size (1D)
  • kh: Kernel height
  • kw: Kernel width

Linear layer

X = rearrange(X, 'b c_in -> b i 1')
W = rearrange(W, 'c_out c_in -> 1 i o')
Y = (X * W).sum(dim=1) # + bias

Conv1d layer

unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride)
unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k')
W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k')
Y = (unfolded_X * W).sum(dim=(2, 4))

Including dilation

Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4))

Including groups

unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups)
W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 5))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')

Conv2d layer

unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride)
unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw')
W = rearrange(W, 'c_out c_in kh kw -> 1 c_out c_in 1 1 kh kw')
Y = (unfolded_X * W).sum(dim=(2, 5, 6))

Including dilation

Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6))

Including groups

unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups)
W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 6, 7))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')

Side-note: pad

X_padded = nn.functional.pad(X, (left, right, top, bottom))

Appendix: Matrix Multiplication common variants

a. matmul or @

Assuming that both arguments have the same number of dims and dims ≥ 2:

A = rearrange(A, '... m n -> ... m n 1')
B = rearrange(B, '... n p -> ... 1 n p')
C = (A * B).sum(dim=-2)  # ... m p

b. bmm

A = rearrange(A, 'b m n -> b m n 1')
B = rearrange(B, 'b n p -> b 1 n p')
C = (A * B).sum(dim=-2)  # b m p

d. mm

A = rearrange(A, 'm n -> m n 1')
B = rearrange(B, 'n p -> 1 n p')
C = (A * B).sum(dim=-2)  # m p

Appendix: Reshaping common variants

a. Flatten dimensions

# X = rearrange(X, 'b c1 c2 h w -> b (c1 c2) h w')
X = X.view(X.size(0), -1, *X.shape[-2:])

b. Unflatten dimensions

# X = rearrange(X, 'b (c1 c2) h w -> b c1 c2 h w', c1=c1, c2=c2)
X = X.view(X.size(0), c1, c2, *X.shape[-2])

c. Permute dimensions

# X = rearrange(X, 'b c h w -> b w h c')
X = X.permute(0, 3, 2, 1)

d. Squeeze dimension

# X = rearrange(X, 'b i 1 -> b i')
X = X.squeeze(-1)

e. Unsqueeze dimension

# X = rearrange(X, 'b i -> b i 1')
X = X.unsqueeze(-1)
X = X[..., None]
@rehno-lindeque
Copy link
Author

rehno-lindeque commented Oct 1, 2023

TODO (WIP / incomplete / wrong / unrefined)

Depthwise separable ConvTranspose2d with kernel_size=(2,2) and groups=4*...

Can be rewritten as Conv2d followed by PixelShuffle:

# Depthwise ConvTranspose2d
nn.Conv2d(
    in_channels=in_channels,
    out_channels=in_channels * 4, # Alternatively: in_channels (to save parameters)
    kernel_size=2,
    bias=True,
),
nn.PixelShuffle(2),
nn.ReLU(),

# Pointwise Conv2d
nn.Conv2d(
    in_channels=in_channels, # Alternatively: in_channels // 4 (to save parameters)
    out_channels=out_channels, out_channels
    kernel_size=1,
    bias=True,
),
nn.ReLU(),
nn.ReplicationPad2d(1), # Should technically pad the input, but I prefer this.

More efficient dilation

dilated_X = rearrange(X, 'b c_in (dilation n_in) -> b c_in dilation n_in', dilation=dilation)
unfolded_X = dilated_X.unfold(..)

...

rearrange(Y, 'b c_out dilation n_out -> b c_out (dilation n_out)')

Simple Attention

qkv = linear(tokens)
Q,K,V = rearrange(qkv, 'b ??? -> b qkv h n i', h=num_heads).unbind(dim=1)

attention_scores = (Q * K).sum(dim=1) / sqrt(d)
attention_weights = nn.functional.softmax(attention_scores, dim=-1)
Y = (attention_weights.unsqueeze(2) * V).sum(dim=-2)

Multihead Attention Mechanism

qkv = linear(tokens)
Q,K,V = rearrange(qkv, 'b ??? -> b qkv h n i', h=num_heads).unbind(dim=1)

attention_scores = (Q * K).sum(dim=-1) / sqrt(d)
attention_weights = nn.functional.softmax(attention_scores, dim=-2)
Y = (attention_weights.unsqueeze(-1) * V).sum(dim=-2)

@rehno-lindeque
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment