Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created March 3, 2020 04:04
Show Gist options
  • Save AranKomat/95378ecea8c2feb6e6781976e77aa802 to your computer and use it in GitHub Desktop.
Save AranKomat/95378ecea8c2feb6e6781976e77aa802 to your computer and use it in GitHub Desktop.
class Axial(nn.Module):
def __init__(self, config):
super(Axial, self).__init__()
self.config = config
self._d_embs = config.axial_d_embs # e.g. (128, 384) sum of these two numbers should be equal to d_model
self._shape = config.axial_pos_shape # e.g. (64,128) product of these two numbers should be equal to seqlen
self.weights = []
for ax, d_emb in enumerate(self._d_embs):
ax_shape = [1] * len(self._shape)
ax_shape[ax] = self._shape[ax]
ax_shape = (1,) + tuple(ax_shape) + (d_emb,)
ax_emb = nn.Parameter(torch.zeros(ax_shape, dtype=torch.half, device=self.config.device).normal_(0, 1))
self.weights.append(ax_emb)
def forward(self, inputs):
embs = []
for ax_emb in self.weights:
embs.append(ax_emb.expand((inputs.shape[0],) + self._shape + (ax_emb.shape[-1],)))
return inputs + torch.cat(
[emb.reshape(inputs.shape[:-1] + (emb.shape[-1],)) for emb in embs], -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment