Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import math
import numpy as np
batchsize = 2
r = 2
out_channels = 3
in_channels = r ** 2 * out_channels
in_height = 3
in_width = 3
out_height = in_height * r
out_width = in_width * r
in_map = np.zeros((batchsize, in_channels, in_height, in_width), dtype=np.int32)
for b in xrange(batchsize):
for k in xrange(in_channels):
for h in xrange(in_height):
for w in xrange(in_width):
in_map[b, k, h, w] = in_channels * in_height * in_width * b + in_height * in_width * k + in_height * h + w
print in_map
out_map = np.reshape(in_map, (batchsize, r, r, out_channels, in_height, in_width))
# print out_map
out_map = np.transpose(out_map, (0, 3, 4, 1, 5, 2))
# print out_map
out_map = np.reshape(out_map, (batchsize, out_channels, out_height, out_width))
print out_map
# test
for b in xrange(batchsize):
for k in xrange(out_channels):
for h in xrange(in_height * r):
for w in xrange(in_width * r):
_k = out_channels * r * (h % r) + out_channels * (w % r) + k
_h = int(math.floor(h / float(r)))
_w = int(math.floor(w / float(r)))
assert out_map[b, k, h, w] == in_map[b, _k, _h, _w]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment