Skip to content

Instantly share code, notes, and snippets.

@dienhoa
Created May 1, 2023 14:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dienhoa/3b702e2046ced2f87e03c87ff0a0ce62 to your computer and use it in GitHub Desktop.
Save dienhoa/3b702e2046ced2f87e03c87ff0a0ce62 to your computer and use it in GitHub Desktop.
Transformer for timeseries with attention layer
class OurTST(Module):
def __init__(self, c_in, c_out, d_model, seq_len, n_layers, drop_out, fc_dropout):
self.c_in, self.c_out, self.seq_len = c_in, c_out, seq_len
self.W_P = nn.Linear(c_in, d_model)
# Positional encoding
W_pos = torch.empty((seq_len, d_model), device=default_device())
nn.init.uniform_(W_pos, -0.02, 0.02)
self.W_pos = nn.Parameter(W_pos, requires_grad=True)
self.drop_out = nn.Dropout(drop_out)
self.encoders = nn.ModuleList(nn.MultiheadAttention(embed_dim=d_model, num_heads=1, batch_first=True) for i in range(n_layers))
self.norm_layers = nn.ModuleList(nn.BatchNorm1d(seq_len) for i in range(n_layers))
self.head = nn.Sequential(
nn.GELU(),
Flatten(),
nn.Dropout(fc_dropout),
nn.Linear(seq_len * d_model, c_out)
)
def forward(self, x):
o = x.swapaxes(1, 2) # [bs,c_in,seq_len] -> [bs,seq_len,c_in]
o = self.W_P(o) # [bs,seq_len,c_in] -> [bs,seq_len,d_model]
o = self.drop_out(o + self.W_pos)
for enc, norm in zip(self.encoders, self.norm_layers):
residual = o
o = enc(o, o, o)[0] # [bs, seq_len,d_model] -> [bs,seq_len,d_model]
o = norm(o) # Add residual connection and apply batch normalization
o = o.contiguous()
o = self.head(o) # [bs,seq_len x d_model] -> [bs,c_out]
return o
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment