Created
December 13, 2019 21:02
-
-
Save Mustufain/62441b12a504dfc71b8bfa0dbf5e2915 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def backward(self, dZ): | |
""" | |
Backward proporgation for convolution. | |
Parameters: | |
dZ -- gradient of the cost with respect to the output of the conv layer (Z), numpy array of shape (m, n_H, n_W, n_C) | |
Returns: | |
dA_prev -- gradient of the cost with respect to the input of the conv | |
layer (A_prev), numpy array of shape (m, n_H_prev, n_W_prev, n_C_prev) | |
dW -- gradient of the cost with respect to the weights of the conv layer (W) | |
numpy array of shape (f, f, n_C_prev, n_C) | |
db -- gradient of the cost with respect to the biases of the conv layer (b) | |
numpy array of shape (1, 1, 1, n_C) | |
""" | |
np.random.seed(self.seed) | |
m, n_H_prev, n_W_prev, n_C_prev = self.A_prev.shape | |
f, f, n_C_prev, n_C = self.params[0].shape | |
m, n_H, n_W, n_C = dZ.shape | |
dA_prev = np.zeros(self.A_prev.shape) | |
dW = np.zeros(self.params[0].shape) | |
db = np.zeros(self.params[1].shape) | |
# Pad A_prev and dA_prev | |
A_prev_pad = self.zero_pad(self.A_prev, self.pad) | |
dA_prev_pad = self.zero_pad(dA_prev, self.pad) | |
for i in range(m): | |
a_prev_pad = A_prev_pad[i, :, :, :] | |
da_prev_pad = dA_prev_pad[i, :, :, :] | |
for h in range(n_H): | |
for w in range(n_W): | |
for c in range(n_C): | |
vert_start, vert_end, horiz_start, horiz_end = self.get_corners( | |
h, w, self.filter_size, self.stride) | |
a_slice_prev = a_prev_pad[ | |
vert_start:vert_end, horiz_start:horiz_end, :] | |
da_prev_pad[ | |
vert_start:vert_end, horiz_start:horiz_end, :] += self.params[0][:, :, :, c] * dZ[i, h, w, c] | |
dW[:, :, :, c] += a_slice_prev * dZ[i, h, w, c] | |
db[:, :, :, c] += dZ[i, h, w, c] | |
dA_prev[i, :, :, :] = da_prev_pad[self.pad:-self.pad, self.pad:-self.pad, :] | |
assert(dA_prev.shape == (m, n_H_prev, n_W_prev, n_C_prev)) | |
return dA_prev, [dW, db] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment