Skip to content

Instantly share code, notes, and snippets.

@jalola
Last active August 31, 2022 07:54
Show Gist options
  • Star 21 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save jalola/f41278bb27447bed9cd3fb48ec142aec to your computer and use it in GitHub Desktop.
Save jalola/f41278bb27447bed9cd3fb48ec142aec to your computer and use it in GitHub Desktop.
d2s pytorch
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super(DepthToSpace, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class SpaceToDepth(nn.Module):
def __init__(self, block_size):
super(SpaceToDepth, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
@ndrplz
Copy link

ndrplz commented Nov 27, 2018

Thanks for the snippet!

I would just add:

    def __call__(self, *args, **kwargs):
        return super(DepthToSpace, self).__call__(*args, **kwargs)

and

    def __call__(self, *args, **kwargs):
        return super(SpaceToDepth, self).__call__(*args, **kwargs)

to DepthToSpace and SpaceToDepth respectively, so that PyCharm does not complain about these classes being non-callable.

Best,
A.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment