Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Created May 2, 2022 10:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/0c9830f134b79383bc595a35b0faa754 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/0c9830f134b79383bc595a35b0faa754 to your computer and use it in GitHub Desktop.
Test
r = 4
channels = 8
x = torch.randn((1, channels, 64, 64))
_, _, h, w = x.shape
# we want a vector of shape 1, 8, 32, 32
x = rearrange(x, "b c h w -> b (h w) c") # shape = [1, 4096, 8]
x = rearrange(x, "b (hw r) c -> b hw (c r)", r=r) # shape = [1, 1024, 32]
reducer = nn.Linear(channels*r, channels)
x = reducer(x) # shape = [1, 1024, 8]
half_r = r // 2
x = rearrange(x, "b (h w) c -> b c h w", h=h//half_r) # shape = [1, 8, 32, 32]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment