Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created July 14, 2018 10:43
Show Gist options
  • Save ssnl/966d302259d11e129f75b2178f961131 to your computer and use it in GitHub Desktop.
Save ssnl/966d302259d11e129f75b2178f961131 to your computer and use it in GitHub Desktop.
view_im2col.py
# inp:
# N, C, H ,W
#
# out:
# N, IH, IW, C, KH, KW
#
# kernel_size:
# KW, KW
def im2col(inp, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
assert padding == (0, 0)
assert inp.dim() == 4
n, c, h, w = inp.size()
tsn, tsc, tsh, tsw = inp.stride()
kh, kw = kernel_size
sh, sw = stride
dh, dw = dilation
shape = [
n,
int((h - (kh - 1) * dh) / sh),
int((w - (kw - 1) * dw) / sw),
c,
kh,
kw,
]
strides = [
tsn,
tsh * sh,
tsw * sw,
tsc,
tsh * dh,
tsw * dw,
]
return inp.as_strided(shape, strides)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment