Skip to content

Instantly share code, notes, and snippets.

@TaylorHawkes
Last active March 25, 2024 21:51
Show Gist options
  • Save TaylorHawkes/b725432ffcf8c9e734e667949e60d217 to your computer and use it in GitHub Desktop.
Save TaylorHawkes/b725432ffcf8c9e734e667949e60d217 to your computer and use it in GitHub Desktop.
#example training w/ trainable dropout layer
for epoch in range(epochs):
for (train_batch, test_batch) in zip(train_dataloader, test_dataloader):
features, target = train_batch
test_features, test_target= test_batch
unfreeze_all_but_dropout(model);
optimizer.zero_grad();
pred = model(features);
loss = criterion_loss(pred,target);
loss.backward();
optimizer.step();
freeze_all_but_dropout(model);
test_pred = model(test_features);
loss_test = criterion_loss(test_pred,test_target);
loss_test.backward();
optimizer.step();
# then your model would have a trainable dropout layer as its inital layer that looks like so:
class TrainableDropoutLayer(nn.Module):
def __init__(self, num_features):
super(TrainableDropoutLayer, self).__init__()
self.num_features = num_features
self.scale = nn.Parameter(torch.ones(num_features))
def forward(self, x):
scale = torch.sigmoid(self.scale)
return x * scale
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment