Last active
November 16, 2021 11:20
-
-
Save aruberts/5e97edb8e8d1820db70b427c7ee74995 to your computer and use it in GitHub Desktop.
TabNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TabNet(tf.keras.Model): | |
def __init__( | |
self, | |
num_features, | |
feature_dim, | |
output_dim, | |
n_step = 2, | |
n_total = 4, | |
n_shared = 2, | |
relaxation_factor = 1.5, | |
bn_epsilon = 1e-5, | |
bn_momentum = 0.7, | |
sparsity_coefficient = 1e-5 | |
): | |
super(TabNet, self).__init__() | |
self.output_dim, self.num_features = output_dim, num_features | |
self.n_step, self.relaxation_factor = n_step, relaxation_factor | |
self.sparsity_coefficient = sparsity_coefficient | |
self.bn = tf.keras.layers.BatchNormalization( | |
momentum=bn_momentum, epsilon=bn_epsilon | |
) | |
kargs = { | |
"feature_dim": feature_dim + output_dim, | |
"n_total": n_total, | |
"n_shared": n_shared, | |
"bn_momentum": bn_momentum | |
} | |
# first feature transformer block is built first to get the shared blocks | |
self.feature_transforms = [FeatureTransformer(**kargs)] | |
self.attentive_transforms = [] | |
# each step consists out of FT and AT | |
for i in range(n_step): | |
self.feature_transforms.append( | |
FeatureTransformer(**kargs, fcs=self.feature_transforms[0].shared_fcs) | |
) | |
self.attentive_transforms.append( | |
AttentiveTransformer(num_features) | |
) | |
# Final output layer | |
self.head = tf.keras.layers.Dense(2, activation="softmax", use_bias=False) | |
def call(self, features, training = None): | |
bs = tf.shape(features)[0] # get batch shape | |
out_agg = tf.zeros((bs, self.output_dim)) # empty array with outputs to fill | |
prior_scales = tf.ones((bs, self.num_features)) # prior scales initialised as 1s | |
importance = tf.zeros([bs, self.num_features]) # importances | |
masks = [] | |
features = self.bn(features, training=training) # Batch Normalisation | |
masked_features = features | |
total_entropy = 0.0 | |
for step_i in range(self.n_step + 1): | |
# (masked) features go through the FT | |
x = self.feature_transforms[step_i]( | |
masked_features, training=training | |
) | |
# first FT is not used to generate output | |
if step_i > 0: | |
# first half of the FT output goes towards the decision | |
out = tf.keras.activations.relu(x[:, : self.output_dim]) | |
out_agg += out | |
scale_agg = tf.reduce_sum(out, axis=1, keepdims=True) / (self.n_step - 1) | |
importance += mask_values * scale_agg | |
# no need to build the features mask for the last step | |
if step_i < self.n_step: | |
# second half of the FT output goes as input to the AT | |
x_for_mask = x[:, self.output_dim :] | |
# apply AT with prior scales | |
mask_values = self.attentive_transforms[step_i]( | |
x_for_mask, prior_scales, training=training | |
) | |
# recalculate the prior scales | |
prior_scales *= self.relaxation_factor - mask_values | |
# multiply the second half of the FT output by the attention mask to enforce sparsity | |
masked_features = tf.multiply(mask_values, features) | |
# entropy is used to penalize the amount of sparsity in feature selection | |
total_entropy += tf.reduce_mean( | |
tf.reduce_sum( | |
tf.multiply(-mask_values, tf.math.log(mask_values + 1e-15)), | |
axis=1, | |
) | |
) | |
# append mask values for later explainability | |
masks.append(tf.expand_dims(tf.expand_dims(mask_values, 0), 3)) | |
#Per step selection masks | |
self.selection_masks = masks | |
# Final output | |
final_output = self.head(out) | |
# Add sparsity loss | |
if training: | |
loss = total_entropy / (self.n_step-1) | |
self.add_loss(self.sparsity_coefficient * loss) | |
return final_output, importance |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment