Skip to content

Instantly share code, notes, and snippets.

@PadLex
PadLex / torch_conv_layer_to_fully_connected.py
Last active December 7, 2024 14:58 — forked from vvolhejn/torch_conv_layer_to_fully_connected.py
Convert PyTorch convolutional layer to fully connected layer
"""
The function `torch_conv_layer_to_affine` takes a `torch.nn.Conv2d` layer `conv`
and produces an equivalent `torch.nn.Linear` layer `fc`.
Specifically, this means that the following holds for `x` of a valid shape:
torch.flatten(conv(x)) == fc(torch.flatten(x))
Or equivalently:
conv(x) == fc(torch.flatten(x)).reshape(conv(x).shape)