Skip to content

Instantly share code, notes, and snippets.

@kvanhoey
Created January 15, 2020 16:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kvanhoey/76d446e9833de5d5738c282444b78819 to your computer and use it in GitHub Desktop.
Save kvanhoey/76d446e9833de5d5738c282444b78819 to your computer and use it in GitHub Desktop.
Function to compute the amount of stride to add before and after the data arrays to obtain a 'same' convolution using a 'valid' convolution.
import math
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
def compute_amount_padding(L, F, S):
'''Compute the amount of padding to add prior to a 'valid' convolution to obtain
the same output size as for a 'same' convolution.
L: length of data
F: filter size
S: stride
Condition: S <= F
Return: tuple of two elements with the amount of padding to add before and after the data, respectively'''
L_target = math.ceil(L / S)
P = ((L_target - 1)*S - L + F)
# If stride is larger then filter size: makes no sense
if S > F:
print("WARNING in compute_amount_padding: stride is larger then filter size, which makes no sense")
return 0
return [math.floor(P/2), math.ceil(P/2)]
# Verify validity for a bunch of combinations:
data_lengths = range(1,11)
filter_sizes = range(1,12)
for length in data_lengths:
for filter in filter_sizes:
strides = range(1,filter)
for stride in strides:
# input tensor
input = keras.Input(shape=(length, length, 3,), name="img_content")
# "same" convolution
output_same = keras.layers.Conv2D(1,kernel_size=filter,strides=stride,padding="same")(input)
# "padding + valid convolution"
P = compute_amount_padding(length,filter,stride)
input_padded = keras.layers.Lambda( lambda xi: tf.pad(xi, [[0,0], P, P ,[0,0]], "SYMMETRIC"))(input)
output_pad = keras.layers.Conv2D(1,kernel_size=filter,strides=stride,padding="valid")(input_padded)
# check if same output size
condition = output_same.shape[1:3] == output_pad.shape[1:3]
print(output_same.shape[1:3],output_pad.shape[1:3],condition)
assert(condition)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment