Skip to content

Instantly share code, notes, and snippets.

@jostmey
Created September 18, 2022 15:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jostmey/f5018f6803df4c6faecd3bca64e5829c to your computer and use it in GitHub Desktop.
Save jostmey/f5018f6803df4c6faecd3bca64e5829c to your computer and use it in GitHub Desktop.
##########################################################################################
# Author: Jared L. Ostmeyer
# Date Started: 2020-08-13
# Purpose: Implementation neural decision tree (NDT) and variants for Keras in TensorFlow
##########################################################################################
from tensorflow.keras.layers import *
from tensorflow.keras.initializers import Initializer
import tensorflow as tf
class NDT(Layer):
def __init__(self, depth, num_trees=1, **kwargs):
self.depth = depth
self.num_trees = num_trees
super(__class__, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'depth': self.depth,
'num_trees': self.num_trees
}
)
return config
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
num_forks = 2**self.depth-1
logits = tf.reshape(inputs, [ -1, self.num_trees, num_forks ])
sigmoids = tf.nn.sigmoid(logits)
trees_flat = tf.ones_like(logits[:,:,0:1])
j = 0
for i in range(self.depth): # Grow the trees
decisions1 = sigmoids[:,:,j:j+2**i]
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3)
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, decision, 2 ]
width = int(trees_flat.shape[2])*2
trees_flat = tf.reshape(trees, [ -1, self.num_trees, width ])
j += 2**i
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ]
probabilities1 = probabilities[:,:,1] # [ batch, tree ]
return probabilities1
@staticmethod
def num_inputs(depth, num_trees=1):
return num_trees*(2**depth-1)
class NST(Layer):
def __init__(self, depth, num_trees=1, **kwargs):
self.depth = depth
self.num_trees = num_trees
super(__class__, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'depth': self.depth,
'num_trees': self.num_trees
}
)
return config
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
num_forks = 2**self.depth-1
logits = tf.reshape(inputs, [ -1, self.num_trees, num_forks ])
trees_flat = tf.ones_like(logits[:,:,0:1])
j = 0
for i in range(self.depth): # Grow the trees
scale = 1.0/(2**(self.depth-i-1))
decisions1 = tf.nn.sigmoid(scale*logits[:,:,j:j+2**i])
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3)
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, decision, 2 ]
width = int(trees_flat.shape[2])*2
trees_flat = tf.reshape(trees, [ -1, self.num_trees, width ])
j += 2**i
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ]
probabilities1 = probabilities[:,:,1] # [ batch, tree ]
return probabilities1
@staticmethod
def num_inputs(depth, num_trees=1):
return num_trees*(2**depth-1)
class NCT(Layer):
def __init__(self, depth, num_trees=1, **kwargs):
self.depth = depth
self.num_trees = num_trees
super(__class__, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'depth': self.depth,
'num_trees': self.num_trees
}
)
return config
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
width = 2**(self.depth-1)
logits = tf.reshape(inputs, [ -1, self.num_trees, self.depth, width ]) # [ batch, tree, depth, width ]
sigmoids = tf.nn.sigmoid(logits)
trees_flat = tf.ones_like(logits[:,:,0:1,0])
for i in range(self.depth):
sigmoids_pool = tf.reshape(
sigmoids[:,:,i,:],
[ -1, self.num_trees, 2**i, int(width/2**i) ]
) # [ batch, tree, fork, pool ]
decisions1 = tf.reduce_mean(sigmoids_pool, axis=3) # [ batch, tree, fork ]
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3) # [ batch, tree, fork, 2 ]
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, fork, 2 ]
trees_flat = tf.reshape(
trees,
[ -1, self.num_trees, 2*int(trees_flat.shape[2]) ]
)
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ]
probabilities1 = probabilities[:,:,1] # [ batch, tree ]
return probabilities1
@staticmethod
def num_inputs(depth, num_trees=1):
return num_trees*depth*2**(depth-1)
class ParityFlip(Layer):
def __init__(self, offset=0, **kwargs):
self.offset = offset
super(__class__, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'offset': self.offset
}
)
return config
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
signs = []
for i in range(int(inputs.shape[1])):
parity = (i+self.offset)%2
signs.append(2*parity-1)
signs = tf.convert_to_tensor(signs, dtype=inputs.dtype)
signs_expand = tf.expand_dims(signs, axis=0)
outputs = signs_expand*inputs
return outputs
class RandomFlip(Layer):
def build(self, input_shape):
class _ParityInitializer(Initializer):
def __call__(self, shape, dtype=None):
parities = tf.round(tf.random.uniform(shape, dtype=dtype))
signs = 2.0*parities-1.0
return signs
self.bias = self.add_weight(
name='bias',
shape=input_shape[1:],
initializer=_ParityInitializer(),
trainable=False
)
super(__class__, self).build(input_shape)
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
signs_expand = tf.expand_dims(self.bias, axis=0)
outputs = signs_expand*inputs
return outputs
class Shuffle(Layer):
def build(self, input_shape):
class _ShuffleInitializer(Initializer):
def __call__(self, shape, dtype=None):
eyes = tf.eye(shape[1], dtype=dtype)
shuffles = tf.random.shuffle(eyes)
return shuffles
self.kernel = self.add_weight(
name='kernel',
shape=[ int(input_shape[1]), int(input_shape[1]) ],
initializer=_ShuffleInitializer(),
trainable=False
)
super(__class__, self).build(input_shape)
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
outputs = tf.matmul(inputs, self.kernel)
return outputs
@jostmey
Copy link
Author

jostmey commented Sep 18, 2022

Neural decision trees attempt to combine the natural interpretability of decision tree models with the performance of deep learning models. However, neural decision tree models suffer an issue similar to but not identical to the vanishing gradient problem. This gist implements strategies to fix this gradient issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment