Skip to content

Instantly share code, notes, and snippets.

@Muhammad4hmed
Created August 11, 2021 19:00
Show Gist options
  • Save Muhammad4hmed/302b18c55e46703be114345497b8fbc7 to your computer and use it in GitHub Desktop.
Save Muhammad4hmed/302b18c55e46703be114345497b8fbc7 to your computer and use it in GitHub Desktop.
class Squeeze(nn.Module):
def __init__(self, dims=-1):
super().__init__()
self.dims = dims
def forward(self, x):
return x.squeeze(self.dims)
class AttentionHead(nn.Module):
def __init__(self, in_features, hidden_dim, num_targets):
super().__init__()
self.in_features = in_features
self.middle_features = hidden_dim
self.W = nn.Linear(in_features, hidden_dim)
self.V = nn.Linear(hidden_dim, 1)
self.out_features = hidden_dim
def forward(self, features):
att = torch.tanh(self.W(features))
score = self.V(att)
attention_weights = torch.softmax(score, dim=1)
context_vector = attention_weights * features
context_vector = torch.sum(context_vector, dim=1)
return context_vector
class CNNHead(nn.Module):
def __init__(self, in_features, hidden_dim, kernel_size=10, num_targets=1):
super().__init__()
self.head = nn.Sequential(nn.Conv1d(in_features, hidden_dim, kernel_size=kernel_size),
nn.AdaptiveMaxPool1d(1),
Squeeze()
)
self.out_features = hidden_dim
def forward(self, x):
return self.head(x.permute(0,2,1))
class LSTMHead(nn.Module):
def __init__(self, in_features, hidden_dim, n_layers, num_targets=1):
super().__init__()
self.lstm = nn.LSTM(in_features,
hidden_dim,
n_layers,
batch_first=True,
bidirectional=False,
dropout=0.2)
self.out_features = hidden_dim
def forward(self, x):
self.lstm.flatten_parameters()
_, (hidden, _) = self.lstm(x)
out = hidden[-1]
return out
class TransformerHead(nn.Module):
def __init__(self, in_features, max_length, num_layers=1, nhead=8, num_targets=1):
super().__init__()
self.transformer = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=in_features,
nhead=nhead),
num_layers=num_layers)
self.row_fc = nn.Linear(in_features, 1)
self.out_features = max_length
def forward(self, x):
out = self.transformer(x)
out = self.row_fc(out).squeeze(-1)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment