Skip to content

Instantly share code, notes, and snippets.

@BloodAxe
Created August 19, 2020 08:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BloodAxe/7da9b9f48325b963d22d0687a6e93114 to your computer and use it in GitHub Desktop.
Save BloodAxe/7da9b9f48325b963d22d0687a6e93114 to your computer and use it in GitHub Desktop.
# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/models/timm.py#L104
def forward(self, **kwargs):
x = kwargs[self.input_key]
x = self.rgb_bn(x)
x = self.encoder.forward_features(x)
embedding = self.pool(x)
result = {
OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)),
OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)),
}
if self.need_embedding:
result[OUTPUT_PRED_EMBEDDING] = embedding
if self.arc_margin is not None:
result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment