Skip to content

Instantly share code, notes, and snippets.

@aruberts
Last active November 16, 2021 11:20
Show Gist options
  • Save aruberts/5e97edb8e8d1820db70b427c7ee74995 to your computer and use it in GitHub Desktop.
Save aruberts/5e97edb8e8d1820db70b427c7ee74995 to your computer and use it in GitHub Desktop.
TabNet
class TabNet(tf.keras.Model):
def __init__(
self,
num_features,
feature_dim,
output_dim,
n_step = 2,
n_total = 4,
n_shared = 2,
relaxation_factor = 1.5,
bn_epsilon = 1e-5,
bn_momentum = 0.7,
sparsity_coefficient = 1e-5
):
super(TabNet, self).__init__()
self.output_dim, self.num_features = output_dim, num_features
self.n_step, self.relaxation_factor = n_step, relaxation_factor
self.sparsity_coefficient = sparsity_coefficient
self.bn = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, epsilon=bn_epsilon
)
kargs = {
"feature_dim": feature_dim + output_dim,
"n_total": n_total,
"n_shared": n_shared,
"bn_momentum": bn_momentum
}
# first feature transformer block is built first to get the shared blocks
self.feature_transforms = [FeatureTransformer(**kargs)]
self.attentive_transforms = []
# each step consists out of FT and AT
for i in range(n_step):
self.feature_transforms.append(
FeatureTransformer(**kargs, fcs=self.feature_transforms[0].shared_fcs)
)
self.attentive_transforms.append(
AttentiveTransformer(num_features)
)
# Final output layer
self.head = tf.keras.layers.Dense(2, activation="softmax", use_bias=False)
def call(self, features, training = None):
bs = tf.shape(features)[0] # get batch shape
out_agg = tf.zeros((bs, self.output_dim)) # empty array with outputs to fill
prior_scales = tf.ones((bs, self.num_features)) # prior scales initialised as 1s
importance = tf.zeros([bs, self.num_features]) # importances
masks = []
features = self.bn(features, training=training) # Batch Normalisation
masked_features = features
total_entropy = 0.0
for step_i in range(self.n_step + 1):
# (masked) features go through the FT
x = self.feature_transforms[step_i](
masked_features, training=training
)
# first FT is not used to generate output
if step_i > 0:
# first half of the FT output goes towards the decision
out = tf.keras.activations.relu(x[:, : self.output_dim])
out_agg += out
scale_agg = tf.reduce_sum(out, axis=1, keepdims=True) / (self.n_step - 1)
importance += mask_values * scale_agg
# no need to build the features mask for the last step
if step_i < self.n_step:
# second half of the FT output goes as input to the AT
x_for_mask = x[:, self.output_dim :]
# apply AT with prior scales
mask_values = self.attentive_transforms[step_i](
x_for_mask, prior_scales, training=training
)
# recalculate the prior scales
prior_scales *= self.relaxation_factor - mask_values
# multiply the second half of the FT output by the attention mask to enforce sparsity
masked_features = tf.multiply(mask_values, features)
# entropy is used to penalize the amount of sparsity in feature selection
total_entropy += tf.reduce_mean(
tf.reduce_sum(
tf.multiply(-mask_values, tf.math.log(mask_values + 1e-15)),
axis=1,
)
)
# append mask values for later explainability
masks.append(tf.expand_dims(tf.expand_dims(mask_values, 0), 3))
#Per step selection masks
self.selection_masks = masks
# Final output
final_output = self.head(out)
# Add sparsity loss
if training:
loss = total_entropy / (self.n_step-1)
self.add_loss(self.sparsity_coefficient * loss)
return final_output, importance
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment