Created
July 15, 2019 10:08
-
-
Save neodelphis/b36fe40c381415033772f4c007669aed to your computer and use it in GitHub Desktop.
conv backward naive with stride
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv_backward_naive(dout, cache): | |
""" | |
A naive implementation of the backward pass for a convolutional layer. | |
Inputs: | |
- dout: Upstream derivatives. | |
- cache: A tuple of (x, w, b, conv_param) as in conv_forward_naive | |
Returns a tuple of: | |
- dx: Gradient with respect to x | |
- dw: Gradient with respect to w | |
- db: Gradient with respect to b | |
""" | |
dx, dw, db = None, None, None | |
# Récupération des variables | |
x, w, b, conv_param = cache | |
pad = conv_param['pad'] | |
s = conv_param['stride'] | |
# Initialisations | |
dx = np.zeros_like(x) | |
dw = np.zeros_like(w) | |
db = np.zeros_like(b) | |
# Dimensions | |
N, C, H, W = x.shape | |
F, _, HH, WW = w.shape | |
_, _, H_, W_ = dout.shape | |
# db - dout (N, F, H', W') | |
# On somme sur tous les éléments sauf les indices des filtres | |
db = np.sum(dout, axis=(0, 2, 3)) | |
# Version ou l'on regarde la contribution de chaque volume d'entrée au résultat y | |
# et on rétropropage en regardant des produits de convolution entre volumes de tailles identiques | |
# dw | |
# 0-padding juste sur les deux dernières dimensions de x | |
xp = np.pad(x, ((0,), (0,), (pad,), (pad, )), 'constant') | |
for n in range(N): # Images | |
for f in range(F): # Filtres | |
for i in range(H_): | |
for j in range(W_): | |
for k in range(HH): | |
for l in range(WW): | |
for c in range(C): # Profondeur | |
dw[f,c,k,l] += xp[n,c,s*i+k,s*j+l] * dout[n,f,i,j] | |
# dx | |
# 0-padding juste sur les deux dernières dimensions de dx | |
dxp = np.pad(dx, ((0,), (0,), (pad,), (pad, )), 'constant') | |
for n in range(N): # Images | |
for i in range(H_): | |
for j in range(W_): | |
for k in range(HH): | |
for l in range(WW): | |
for f in range(F): # Filtres | |
for c in range(C): # Profondeur | |
dxp[n,c,s*i+k,s*j+l] += dout[n,f,i,j] * w[f,c,k,l] | |
#Remove padding for dx | |
if pad: | |
dx = dxp[:,:,pad:-pad,pad:-pad] | |
else: | |
dx = dxp | |
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)***** | |
########################################################################### | |
# END OF YOUR CODE # | |
########################################################################### | |
return dx, dw, db |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment