Skip to content

Instantly share code, notes, and snippets.

@rehno-lindeque
Last active February 1, 2024 19:50
Show Gist options
  • Save rehno-lindeque/22ea91274900ea195b6a8bf89c70cd68 to your computer and use it in GitHub Desktop.
Save rehno-lindeque/22ea91274900ea195b6a8bf89c70cd68 to your computer and use it in GitHub Desktop.
Personal PyTorch Cheatsheet: Low effort from scratch layers

PyTorch Cheatsheet: Low effort from scratch layers

  • b: Batch size
  • i: Input features (Linear)
  • o: Output features
  • c_in: Input channels
  • c_out: Output channels
  • n: Input/output length (1D sequence, often number of tokens)
  • n_in: Input length (1D sequence)
  • n_out: Output length (1D sequence)
  • h_in: Input height
  • w_in: Input width
  • h_out: Output height
  • w_out: Output width
  • k: Kernel size (1D)
  • kh: Kernel height
  • kw: Kernel width

Linear layer

X = rearrange(X, 'b c_in -> b i 1')
W = rearrange(W, 'c_out c_in -> 1 i o')
Y = (X * W).sum(dim=1) # + bias

Conv1d layer

unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride)
unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k')
W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k')
Y = (unfolded_X * W).sum(dim=(2, 4))

Including dilation

Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4))

Including groups

unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups)
W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 5))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')

Conv2d layer

unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride)
unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw')
W = rearrange(W, 'c_out c_in kh kw -> 1 c_out c_in 1 1 kh kw')
Y = (unfolded_X * W).sum(dim=(2, 5, 6))

Including dilation

Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6))

Including groups

unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups)
W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 6, 7))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')

Side-note: pad

X_padded = nn.functional.pad(X, (left, right, top, bottom))

Appendix: Matrix Multiplication common variants

a. matmul or @

Assuming that both arguments have the same number of dims and dims ≥ 2:

A = rearrange(A, '... m n -> ... m n 1')
B = rearrange(B, '... n p -> ... 1 n p')
C = (A * B).sum(dim=-2)  # ... m p

b. bmm

A = rearrange(A, 'b m n -> b m n 1')
B = rearrange(B, 'b n p -> b 1 n p')
C = (A * B).sum(dim=-2)  # b m p

d. mm

A = rearrange(A, 'm n -> m n 1')
B = rearrange(B, 'n p -> 1 n p')
C = (A * B).sum(dim=-2)  # m p

Appendix: Reshaping common variants

a. Flatten dimensions

# X = rearrange(X, 'b c1 c2 h w -> b (c1 c2) h w')
X = X.view(X.size(0), -1, *X.shape[-2:])

b. Unflatten dimensions

# X = rearrange(X, 'b (c1 c2) h w -> b c1 c2 h w', c1=c1, c2=c2)
X = X.view(X.size(0), c1, c2, *X.shape[-2])

c. Permute dimensions

# X = rearrange(X, 'b c h w -> b w h c')
X = X.permute(0, 3, 2, 1)

d. Squeeze dimension

# X = rearrange(X, 'b i 1 -> b i')
X = X.squeeze(-1)

e. Unsqueeze dimension

# X = rearrange(X, 'b i -> b i 1')
X = X.unsqueeze(-1)
X = X[..., None]
@rehno-lindeque
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment