Created
December 6, 2023 04:47
-
-
Save madebyollin/69440ecb9805ebd60aeafaf533008a9e to your computer and use it in GitHub Desktop.
Mamba Diffusion (IADB)
- Needed to process the sequence in both directions (e.g. flip it every few blocks)
- Larger model worked better
Still needs more training I guess
Code with the fixes
class Denoiser(nn.Module):
def __init__(self, n_io=Config.channels, n_f=128, n_b=8):
super().__init__()
assert n_b % 8 == 0, f"Silly flipping logic breaks if n_b is not divisible by 8"
self.enc = nn.Sequential(nn.Conv2d(n_io + 1, n_f, 1), nn.ReLU(), nn.Conv2d(n_f, n_f, 1, bias=False), nn.PixelUnshuffle(2))
self.mid = nn.ModuleList(Block(n_f * 4, Mamba) for _ in range(n_b))
self.dec = nn.Sequential(nn.Conv2d(n_f * 12, n_f * 4, 1), nn.ReLU(), nn.PixelShuffle(2), nn.Conv2d(n_f, n_io, 1))
def transpose_xy(self, *args):
# swap x/y axes of an N[XY]C tensor
return [a.view(a.shape[0], int(a.shape[1]**0.5), int(a.shape[1]**0.5), a.shape[2]).transpose(1, 2).reshape(a.shape) for a in args]
def flip_s(self, *args):
# reverse sequence axis of an NSE tensor
return [a.flip(1) for a in args]
def forward(self, x_noisy, noise_level):
x = self.enc(th.cat([x_noisy, noise_level.expand(x_noisy[:, :1].shape)], 1))
y = x.flatten(2).transpose(-2, -1)
z = None
for i, mid in enumerate(self.mid):
y, z = mid(y, z)
# make mamba's 1d conv alternate axes (possible alternative: make mamba use a 2d conv...somehow...)
y, z = self.transpose_xy(y, z)
if (i + 1) % 4 == 0:
# let the network also process the sequence of both directions, by reversing every 4 layers
y, z = self.flip_s(y, z)
y, z = y.transpose(-2, -1).view(x.shape), z.transpose(-2, -1).view(x.shape)
out = self.dec(th.cat([x, y, z], 1))
return Prediction(IADB.target_to_denoised(out, x_noisy, noise_level).detach(), out)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Quick initial test of training an IADB diffusion model using the Mamba building block a la this tweet.