Skip to content

Instantly share code, notes, and snippets.

@nousr
Last active December 8, 2023 17:37
Show Gist options
  • Save nousr/afb467d28d629e9809f7ab9c183a408c to your computer and use it in GitHub Desktop.
Save nousr/afb467d28d629e9809f7ab9c183a408c to your computer and use it in GitHub Desktop.
""""
Proof of concept "DiM" - nousr
general structure was "transpiled" from DiT by meta
bi-direction idea comes from DifuSSM (https://arxiv.org/abs/2311.18257)
"""
import torch
import math
from timm.models.vision_transformer import PatchEmbed
from einops import rearrange
from torch import nn
from mamba_ssm.modules.mamba_simple import Mamba, Block as MambaBlock
def exists(x):
return x is not None
def default(x, default):
return x if exists(x) else default
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class AdaLNModulation(nn.Module):
"""
implements AdaLN-Zero from DiT
"""
def __init__(self, hidden_dim, expansion_factor=2) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, expansion_factor * hidden_dim, bias=True)
)
self.init_weights()
def init_weights(self):
nn.init.constant_(self.layers[-1].weight, 0)
nn.init.constant_(self.layers[-1].bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_dim, patch_size, out_channels) -> None:
super().__init__()
self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_dim, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = AdaLNModulation(hidden_dim=hidden_dim)
self.init_weights()
def init_weights(self):
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MLP(nn.Module):
def __init__(self, dim, output_dim=None, mult=1):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.SiLU(),
nn.Linear(dim * mult, output_dim or dim),
)
def forward(self, x):
return self.layers(x)
class ModulatedMambaBlock(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.mixer = Mamba(hidden_dim)
self.norm = nn.LayerNorm(hidden_dim)
self.adaLN_mod = AdaLNModulation(hidden_dim=hidden_dim, expansion_factor=3)
def forward(self, hidden_states, residual, c):
# apply the residual unless we're in the first state
# cast to fp32 (was the default for example mamba models)
# TODO: try adding shift, scale, gate for the residual?
residual = (
(hidden_states + residual) if residual is not None else hidden_states
).to(torch.float32)
# find the shift & scale for the conditioning
shift, scale, gate = self.adaLN_mod(c).chunk(3, dim=1)
# norm the residual to create the new hidden state
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
# apply shift and scale to the normed hidden state
hidden_states = modulate(hidden_states, shift, scale)
# actually send through mamba layer after its been modulated
hidden_states = self.mixer(hidden_states, inference_params=None)
# apply the gate
hidden_states = hidden_states + gate.unsqueeze(1) * hidden_states
return hidden_states, residual
class BidirectionalModulatedMambaBlock(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.mamba_fwd = ModulatedMambaBlock(hidden_dim)
self.mamba_bwd = ModulatedMambaBlock(hidden_dim)
self.proj_x = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, x, residual, bwd_residual, c):
# send through both blocks
x_fwd, x_fwd_residual = self.mamba_fwd(x, residual, c)
x_bwd, x_bwd_residual = self.mamba_bwd(x.flip(dims=[1]), bwd_residual, c)
# flip the bwd
x_bwd = x_bwd.flip(dims=[1])
# combine along embedding dimension
x = torch.cat([x_fwd, x_bwd], dim=-1)
# project
x = self.proj_x(x)
return x, x_fwd_residual, x_bwd_residual
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class DiM(nn.Module):
def __init__(
self,
hidden_dim=512,
image_size=64,
patch_size=4,
in_channels=3,
out_channels=None,
depth=4,
):
super().__init__()
# NOTE: stuff for lucidrains wrapper
self.random_or_learned_sinusoidal_cond = True
self.self_condition = False
self.in_channels = in_channels
self.out_channels = default(out_channels, in_channels)
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
image_size, patch_size, in_channels, hidden_dim, bias=True
)
self.blocks = nn.ModuleList(
[BidirectionalModulatedMambaBlock(hidden_dim) for _ in range(depth)]
)
self.t_embedder = TimestepEmbedder(hidden_dim)
self.final_layer = FinalLayer(
hidden_dim=hidden_dim, patch_size=patch_size, out_channels=self.out_channels
)
def init_weights(self):
"""
init weights according to some reference implementations
"""
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# initialize the patch embed like nn.linear (instead of nn.conv2d)
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# initialize timestep embedding MLP
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def forward(self, x, t):
"""
forward pass of DiM
"""
x = self.x_embedder(x)
t = self.t_embedder(t)
residual = None
bwd_residual = None
for block in self.blocks:
x, residual, bwd_residual = block(x, residual, bwd_residual, t)
x = self.final_layer(x, t)
x = self.unpatchify(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment