Skip to content

Instantly share code, notes, and snippets.

@neodelphis
Created July 15, 2019 10:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save neodelphis/b36fe40c381415033772f4c007669aed to your computer and use it in GitHub Desktop.
Save neodelphis/b36fe40c381415033772f4c007669aed to your computer and use it in GitHub Desktop.
conv backward naive with stride
def conv_backward_naive(dout, cache):
"""
A naive implementation of the backward pass for a convolutional layer.
Inputs:
- dout: Upstream derivatives.
- cache: A tuple of (x, w, b, conv_param) as in conv_forward_naive
Returns a tuple of:
- dx: Gradient with respect to x
- dw: Gradient with respect to w
- db: Gradient with respect to b
"""
dx, dw, db = None, None, None
# Récupération des variables
x, w, b, conv_param = cache
pad = conv_param['pad']
s = conv_param['stride']
# Initialisations
dx = np.zeros_like(x)
dw = np.zeros_like(w)
db = np.zeros_like(b)
# Dimensions
N, C, H, W = x.shape
F, _, HH, WW = w.shape
_, _, H_, W_ = dout.shape
# db - dout (N, F, H', W')
# On somme sur tous les éléments sauf les indices des filtres
db = np.sum(dout, axis=(0, 2, 3))
# Version ou l'on regarde la contribution de chaque volume d'entrée au résultat y
# et on rétropropage en regardant des produits de convolution entre volumes de tailles identiques
# dw
# 0-padding juste sur les deux dernières dimensions de x
xp = np.pad(x, ((0,), (0,), (pad,), (pad, )), 'constant')
for n in range(N): # Images
for f in range(F): # Filtres
for i in range(H_):
for j in range(W_):
for k in range(HH):
for l in range(WW):
for c in range(C): # Profondeur
dw[f,c,k,l] += xp[n,c,s*i+k,s*j+l] * dout[n,f,i,j]
# dx
# 0-padding juste sur les deux dernières dimensions de dx
dxp = np.pad(dx, ((0,), (0,), (pad,), (pad, )), 'constant')
for n in range(N): # Images
for i in range(H_):
for j in range(W_):
for k in range(HH):
for l in range(WW):
for f in range(F): # Filtres
for c in range(C): # Profondeur
dxp[n,c,s*i+k,s*j+l] += dout[n,f,i,j] * w[f,c,k,l]
#Remove padding for dx
if pad:
dx = dxp[:,:,pad:-pad,pad:-pad]
else:
dx = dxp
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
# END OF YOUR CODE #
###########################################################################
return dx, dw, db
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment