Skip to content

Instantly share code, notes, and snippets.

@yangchenyun
Last active May 20, 2023 05:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yangchenyun/3777ea6ebae8cd2489e6acfcd61ec7ee to your computer and use it in GitHub Desktop.
Save yangchenyun/3777ea6ebae8cd2489e6acfcd61ec7ee to your computer and use it in GitHub Desktop.
convolution_tensor
class Conv(TensorOp):
def __init__(self, stride: Optional[int] = 1, padding: Optional[int] = 0):
self.stride = stride or 1
self.padding = padding or 0
def compute(self, A, B):
N,H,W,C_in = A.shape
K,_,_,C_out = B.shape
P = self.padding
S = self.stride
A_pad = A.pad(axes=((0, 0), (P, P), (P, P), (0, 0)))
Ns, Hs, Ws, Cs = A_pad.strides
conv_strides = (Ns, Hs*S, Ws*S, Hs, Ws, Cs)
conv_shape = tuple(np.array([N, (H+2*P-K)/S + 1, (W+2*P-K)/S + 1], dtype=np.int64))
inner_dim = K * K * C_in
out = A_pad.as_strided(conv_shape + (K, K, C_in), conv_strides).compact()
# Flatten the inner dimensions
out = out.reshape((out.size//inner_dim, inner_dim)) @ B.compact().reshape((inner_dim, C_out))
out = out.reshape(conv_shape + (C_out,))
return out
def gradient(self, out_grad, node):
Z, W = node.inputs
N,Hz,Wz,C_in = Z.shape
_,Ho,Wo,_ = out_grad.shape
K,_,_,C_out = W.shape
revP = K-1-self.padding
if self.stride > 1:
out_grad = dilate(out_grad, (1,2), self.stride - 1)
# Reverse calcuate the expected dimensions
H_g = (Hz - 1) + K - 2 * revP
W_g = (Wz - 1) + K - 2 * revP
assert H_g == out_grad.shape[1]
assert W_g == out_grad.shape[2]
# TODO: slice operator missing
# It is needed if the input is odd number (which we would avoid)
# Perform a full convolution
# flip kernel dimensions
# swap C_in and C_out
# dW: K,K,C_in,C_out -> K,K,C_out,C_in
fW = flip(W, (0, 1))
dZ = conv(out_grad, transpose(fW, (2, 3)), padding=revP)
assert dZ.shape == Z.shape
# Perform a cross-validate convolution
# Z: N,H,W,C_in -> C_in,H,W,N, treating N as input channels
# out_grad: N,H,W,C_out -> W,H,N,C_out -> H,W,N,C_out, treating N as input channels, H,W as kernel window
# dW: C_in,K,K,C_out -> K,K,C_in,C_out (keep the order of two kernel dimensions)
tZ = transpose(Z, (0, 3))
tOut_grad = transpose(transpose(out_grad, (0, 2)), (0, 1))
tW = conv(tZ, tOut_grad, padding=self.padding) # apply the same padding as in forward pass
dW = transpose(transpose(tW, (0, 2)), (0, 1))
assert dW.shape == W.shape
return dZ, dW
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment