Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created December 6, 2023 04:47
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save madebyollin/69440ecb9805ebd60aeafaf533008a9e to your computer and use it in GitHub Desktop.
Save madebyollin/69440ecb9805ebd60aeafaf533008a9e to your computer and use it in GitHub Desktop.
Mamba Diffusion (IADB)
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@madebyollin
Copy link
Author

Quick initial test of training an IADB diffusion model using the Mamba building block a la this tweet.

@madebyollin
Copy link
Author

  1. Needed to process the sequence in both directions (e.g. flip it every few blocks)
  2. Larger model worked better
    image

Still needs more training I guess

@madebyollin
Copy link
Author

Code with the fixes

class Denoiser(nn.Module):
    def __init__(self, n_io=Config.channels, n_f=128, n_b=8):
        super().__init__()
        assert n_b % 8 == 0, f"Silly flipping logic breaks if n_b is not divisible by 8"
        self.enc = nn.Sequential(nn.Conv2d(n_io + 1, n_f, 1), nn.ReLU(), nn.Conv2d(n_f, n_f, 1, bias=False), nn.PixelUnshuffle(2))
        self.mid = nn.ModuleList(Block(n_f * 4, Mamba) for _ in range(n_b))
        self.dec = nn.Sequential(nn.Conv2d(n_f * 12, n_f * 4, 1), nn.ReLU(), nn.PixelShuffle(2), nn.Conv2d(n_f, n_io, 1))
        
    def transpose_xy(self, *args):
        # swap x/y axes of an N[XY]C tensor
        return [a.view(a.shape[0], int(a.shape[1]**0.5), int(a.shape[1]**0.5), a.shape[2]).transpose(1, 2).reshape(a.shape) for a in args]
    
    def flip_s(self, *args):
        # reverse sequence axis of an NSE tensor
        return [a.flip(1) for a in args]

    def forward(self, x_noisy, noise_level):
        x = self.enc(th.cat([x_noisy, noise_level.expand(x_noisy[:, :1].shape)], 1))
        y = x.flatten(2).transpose(-2, -1)
        z = None
        for i, mid in enumerate(self.mid):
            y, z = mid(y, z)
            # make mamba's 1d conv alternate axes (possible alternative: make mamba use a 2d conv...somehow...)
            y, z = self.transpose_xy(y, z)
            if (i + 1) % 4 == 0:
                # let the network also process the sequence of both directions, by reversing every 4 layers
                y, z = self.flip_s(y, z)
        y, z = y.transpose(-2, -1).view(x.shape), z.transpose(-2, -1).view(x.shape)
        out = self.dec(th.cat([x, y, z], 1))
        return Prediction(IADB.target_to_denoised(out, x_noisy, noise_level).detach(), out)

@madebyollin
Copy link
Author

Unknown_3
Samples after more training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment