Created
March 3, 2020 04:04
-
-
Save AranKomat/95378ecea8c2feb6e6781976e77aa802 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Axial(nn.Module): | |
def __init__(self, config): | |
super(Axial, self).__init__() | |
self.config = config | |
self._d_embs = config.axial_d_embs # e.g. (128, 384) sum of these two numbers should be equal to d_model | |
self._shape = config.axial_pos_shape # e.g. (64,128) product of these two numbers should be equal to seqlen | |
self.weights = [] | |
for ax, d_emb in enumerate(self._d_embs): | |
ax_shape = [1] * len(self._shape) | |
ax_shape[ax] = self._shape[ax] | |
ax_shape = (1,) + tuple(ax_shape) + (d_emb,) | |
ax_emb = nn.Parameter(torch.zeros(ax_shape, dtype=torch.half, device=self.config.device).normal_(0, 1)) | |
self.weights.append(ax_emb) | |
def forward(self, inputs): | |
embs = [] | |
for ax_emb in self.weights: | |
embs.append(ax_emb.expand((inputs.shape[0],) + self._shape + (ax_emb.shape[-1],))) | |
return inputs + torch.cat( | |
[emb.reshape(inputs.shape[:-1] + (emb.shape[-1],)) for emb in embs], -1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment