Created
September 18, 2022 15:43
-
-
Save jostmey/f5018f6803df4c6faecd3bca64e5829c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
########################################################################################## | |
# Author: Jared L. Ostmeyer | |
# Date Started: 2020-08-13 | |
# Purpose: Implementation neural decision tree (NDT) and variants for Keras in TensorFlow | |
########################################################################################## | |
from tensorflow.keras.layers import * | |
from tensorflow.keras.initializers import Initializer | |
import tensorflow as tf | |
class NDT(Layer): | |
def __init__(self, depth, num_trees=1, **kwargs): | |
self.depth = depth | |
self.num_trees = num_trees | |
super(__class__, self).__init__(**kwargs) | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update( | |
{ | |
'depth': self.depth, | |
'num_trees': self.num_trees | |
} | |
) | |
return config | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
num_forks = 2**self.depth-1 | |
logits = tf.reshape(inputs, [ -1, self.num_trees, num_forks ]) | |
sigmoids = tf.nn.sigmoid(logits) | |
trees_flat = tf.ones_like(logits[:,:,0:1]) | |
j = 0 | |
for i in range(self.depth): # Grow the trees | |
decisions1 = sigmoids[:,:,j:j+2**i] | |
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3) | |
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, decision, 2 ] | |
width = int(trees_flat.shape[2])*2 | |
trees_flat = tf.reshape(trees, [ -1, self.num_trees, width ]) | |
j += 2**i | |
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ] | |
probabilities1 = probabilities[:,:,1] # [ batch, tree ] | |
return probabilities1 | |
@staticmethod | |
def num_inputs(depth, num_trees=1): | |
return num_trees*(2**depth-1) | |
class NST(Layer): | |
def __init__(self, depth, num_trees=1, **kwargs): | |
self.depth = depth | |
self.num_trees = num_trees | |
super(__class__, self).__init__(**kwargs) | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update( | |
{ | |
'depth': self.depth, | |
'num_trees': self.num_trees | |
} | |
) | |
return config | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
num_forks = 2**self.depth-1 | |
logits = tf.reshape(inputs, [ -1, self.num_trees, num_forks ]) | |
trees_flat = tf.ones_like(logits[:,:,0:1]) | |
j = 0 | |
for i in range(self.depth): # Grow the trees | |
scale = 1.0/(2**(self.depth-i-1)) | |
decisions1 = tf.nn.sigmoid(scale*logits[:,:,j:j+2**i]) | |
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3) | |
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, decision, 2 ] | |
width = int(trees_flat.shape[2])*2 | |
trees_flat = tf.reshape(trees, [ -1, self.num_trees, width ]) | |
j += 2**i | |
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ] | |
probabilities1 = probabilities[:,:,1] # [ batch, tree ] | |
return probabilities1 | |
@staticmethod | |
def num_inputs(depth, num_trees=1): | |
return num_trees*(2**depth-1) | |
class NCT(Layer): | |
def __init__(self, depth, num_trees=1, **kwargs): | |
self.depth = depth | |
self.num_trees = num_trees | |
super(__class__, self).__init__(**kwargs) | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update( | |
{ | |
'depth': self.depth, | |
'num_trees': self.num_trees | |
} | |
) | |
return config | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
width = 2**(self.depth-1) | |
logits = tf.reshape(inputs, [ -1, self.num_trees, self.depth, width ]) # [ batch, tree, depth, width ] | |
sigmoids = tf.nn.sigmoid(logits) | |
trees_flat = tf.ones_like(logits[:,:,0:1,0]) | |
for i in range(self.depth): | |
sigmoids_pool = tf.reshape( | |
sigmoids[:,:,i,:], | |
[ -1, self.num_trees, 2**i, int(width/2**i) ] | |
) # [ batch, tree, fork, pool ] | |
decisions1 = tf.reduce_mean(sigmoids_pool, axis=3) # [ batch, tree, fork ] | |
decisions = tf.stack([ 1.0-decisions1, decisions1 ], axis=3) # [ batch, tree, fork, 2 ] | |
trees = tf.expand_dims(trees_flat, axis=3)*decisions # [ batch, tree, fork, 2 ] | |
trees_flat = tf.reshape( | |
trees, | |
[ -1, self.num_trees, 2*int(trees_flat.shape[2]) ] | |
) | |
probabilities = tf.reduce_sum(trees, axis=2) # [ batch, tree, 2 ] | |
probabilities1 = probabilities[:,:,1] # [ batch, tree ] | |
return probabilities1 | |
@staticmethod | |
def num_inputs(depth, num_trees=1): | |
return num_trees*depth*2**(depth-1) | |
class ParityFlip(Layer): | |
def __init__(self, offset=0, **kwargs): | |
self.offset = offset | |
super(__class__, self).__init__(**kwargs) | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update( | |
{ | |
'offset': self.offset | |
} | |
) | |
return config | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
signs = [] | |
for i in range(int(inputs.shape[1])): | |
parity = (i+self.offset)%2 | |
signs.append(2*parity-1) | |
signs = tf.convert_to_tensor(signs, dtype=inputs.dtype) | |
signs_expand = tf.expand_dims(signs, axis=0) | |
outputs = signs_expand*inputs | |
return outputs | |
class RandomFlip(Layer): | |
def build(self, input_shape): | |
class _ParityInitializer(Initializer): | |
def __call__(self, shape, dtype=None): | |
parities = tf.round(tf.random.uniform(shape, dtype=dtype)) | |
signs = 2.0*parities-1.0 | |
return signs | |
self.bias = self.add_weight( | |
name='bias', | |
shape=input_shape[1:], | |
initializer=_ParityInitializer(), | |
trainable=False | |
) | |
super(__class__, self).build(input_shape) | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
signs_expand = tf.expand_dims(self.bias, axis=0) | |
outputs = signs_expand*inputs | |
return outputs | |
class Shuffle(Layer): | |
def build(self, input_shape): | |
class _ShuffleInitializer(Initializer): | |
def __call__(self, shape, dtype=None): | |
eyes = tf.eye(shape[1], dtype=dtype) | |
shuffles = tf.random.shuffle(eyes) | |
return shuffles | |
self.kernel = self.add_weight( | |
name='kernel', | |
shape=[ int(input_shape[1]), int(input_shape[1]) ], | |
initializer=_ShuffleInitializer(), | |
trainable=False | |
) | |
super(__class__, self).build(input_shape) | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, mask=None): | |
outputs = tf.matmul(inputs, self.kernel) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Neural decision trees attempt to combine the natural interpretability of decision tree models with the performance of deep learning models. However, neural decision tree models suffer an issue similar to but not identical to the vanishing gradient problem. This gist implements strategies to fix this gradient issue.