Skip to content

Instantly share code, notes, and snippets.

@GallagherCommaJack
Last active August 10, 2021 08:37
Show Gist options
  • Save GallagherCommaJack/723a45be7f6ff0c69d9bae8d10a2baa8 to your computer and use it in GitHub Desktop.
Save GallagherCommaJack/723a45be7f6ff0c69d9bae8d10a2baa8 to your computer and use it in GitHub Desktop.
class CrossAttentionModConv2d(nn.Module):
def __init__(self, state, ch, d_context, ch_q=None, d_v=None, n_head=1):
super().__init__()
assert ch % n_head == 0
self.state = state
self.n_head = n_head
self.ch = ch
self.d_context = d_context
self.ch_q = ch_q or self.ch
self.d_v = d_v or self.d_context
self.q_proj = nn.Conv2d(ch, self.ch_q, 1)
self.k_proj = nn.Linear(self.d_context, self.ch_q)
self.v_proj = nn.Linear(self.d_context, self.d_v)
self.scale_proj = nn.Conv2d(self.d_v, self.ch, 1)
self.shift_proj = nn.Conv2d(self.d_v, self.ch, 1)
def forward(self, input, return_attns=False):
n, c, h, w = input.shape
n, s, d = context
q_per_head = self.ch_q // self.n_head
v_per_head = self.d_v // self.n_head
q = self.q_proj(input).view([n, self.n_head, q_per_head, h*w]).transpose(2,3)
k = self.k_proj(state['cross']).view([n, s, self.n_head, q_per_head]).transpose(1,2)
v = self.v_proj(state['cross']).view([n, s, self.n_head, v_per_head]).transpose(1,2)
assert q.shape == torch.Size([n, self.n_head, h*w, q_per_head])
assert k.shape == torch.Size([n, self.n_head, s, q_per_head])
assert v.shape == torch.Size([n, self.n_head, s, v_per_head])
attn = (q @ k.transpose(2,3) / q_per_head ** 0.5).softmax(3)
assert attn.shape == torch.Size([n, self.n_head, h*w, s])
y = rearrange(attn @ v, 'n n_h (h w) v -> n (n_h v) h w', n=n, n_h=self.n_head, h=h, w=w, v=v_per_head)
scales = self.scale_proj(y)
shifts = self.shift_proj(y)
out = input * scales + shifts
if return_attns:
return out, attn
else:
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment