Skip to content

Instantly share code, notes, and snippets.

@thomelane
Created September 20, 2018 23:11
Show Gist options
  • Save thomelane/a8507f13b298402c4f0ae8b3b48e1396 to your computer and use it in GitHub Desktop.
Save thomelane/a8507f13b298402c4f0ae8b3b48e1396 to your computer and use it in GitHub Desktop.
[Convolutions on Medium] Used in Medium blog post series #python #convolutions
def apply_conv(data, kernel, conv):
"""
Args:
data (NDArray): input data.
kernel (NDArray): convolution's kernel parameters.
conv (Block): convolutional layer.
Returns:
NDArray: output data (after applying convolution).
"""
# add dimensions for batch and channels if necessary
while data.ndim < len(conv.weight.shape):
data = data.expand_dims(0)
# add dimensions for channels and in_channels if necessary
while kernel.ndim < len(conv.weight.shape):
kernel = kernel.expand_dims(0)
# check if transpose convolution
if type(conv).__name__.endswith("Transpose"):
in_channel_idx = 0
else:
in_channel_idx = 1
# initialize and set weight
conv._in_channels = kernel.shape[in_channel_idx]
conv.initialize()
conv.weight.set_data(kernel)
return conv(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment