Skip to content

Instantly share code, notes, and snippets.

@Alescontrela
Last active June 15, 2018 05:00
Show Gist options
  • Save Alescontrela/c57664c9cd0a9689789363e9820e5d01 to your computer and use it in GitHub Desktop.
Save Alescontrela/c57664c9cd0a9689789363e9820e5d01 to your computer and use it in GitHub Desktop.
def convolutionBackward(dconv_prev, conv_in, filt, s):
'''
Backpropagation through a convolutional layer.
'''
(n_f, n_c, f, _) = filt.shape
(_, orig_dim, _) = conv_in.shape
## initialize derivatives
dout = np.zeros(conv_in.shape)
dfilt = np.zeros(filt.shape)
dbias = np.zeros((n_f,1))
for curr_f in range(n_f):
# loop through all filters
curr_y = out_y = 0
while curr_y + f <= orig_dim:
curr_x = out_x = 0
while curr_x + f <= orig_dim:
# loss gradient of filter (used to update the filter)
dfilt[curr_f] += dconv_prev[curr_f, out_y, out_x] * conv_in[:, curr_y:curr_y+f, curr_x:curr_x+f]
# loss gradient of the input to the convolution operation (conv1 in the case of this network)
dout[:, curr_y:curr_y+f, curr_x:curr_x+f] += dconv_prev[curr_f, out_y, out_x] * filt[curr_f]
curr_x += s
out_x += 1
curr_y += s
out_y += 1
# loss gradient of the bias
dbias[curr_f] = np.sum(dconv_prev[curr_f])
return dout, dfilt, dbias
def nanargmax(arr):
'''
return index of the largest non-nan value in the array. Output is an ordered pair tuple
'''
idx = np.nanargmax(arr)
idxs = np.unravel_index(idx, arr.shape)
return idxs
def maxpoolBackward(dpool, orig, f, s):
'''
Backpropagation through a maxpooling layer. The gradients are passed through the indices of greatest value in the original maxpooling during the forward step.
'''
(n_c, orig_dim, _) = orig.shape
dout = np.zeros(orig.shape)
for curr_c in range(n_c):
curr_y = out_y = 0
while curr_y + f <= orig_dim:
curr_x = out_x = 0
while curr_x + f <= orig_dim:
# obtain index of largest value in input for current window
(a, b) = nanargmax(orig[curr_c, curr_y:curr_y+f, curr_x:curr_x+f])
dout[curr_c, curr_y+a, curr_x+b] = dpool[curr_c, out_y, out_x]
curr_x += s
out_x += 1
curr_y += s
out_y += 1
return dout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment