Skip to content

Instantly share code, notes, and snippets.

@alper111

alper111/SoftTree.py

Last active Sep 20, 2019
Embed
What would you like to do?
PyTorch implementation of soft decision tree. All gating calculations are done at one step in order to utilize from GPU. The recursive definition might be faster for non-GPU machines.
import torch
class SoftTree(torch.nn.Module):
def __init__(self, in_features, out_features, depth, projection='constant', dropout=0.0):
super(SoftTree, self).__init__()
self.proj = projection
self.depth = depth
self.in_features = in_features
self.out_features = out_features
self.leaf_count = int(2**depth)
self.gate_count = int(self.leaf_count - 1)
self.gw = torch.nn.Parameter(
torch.nn.init.kaiming_normal_(
torch.empty(self.gate_count, in_features), nonlinearity='sigmoid').t())
self.gb = torch.nn.Parameter(torch.zeros(self.gate_count))
# dropout rate for gating weights.
self.drop = torch.nn.Dropout(p=dropout)
if self.proj == 'linear':
self.pw = torch.nn.init.kaiming_normal_(torch.empty(out_features*self.leaf_count, in_features), nonlinearity='linear')
self.pw = torch.nn.Parameter(self.pw.view(out_features, self.leaf_count, in_features).permute(0, 2, 1))
self.pb = torch.nn.Parameter(torch.zeros(out_features, self.leaf_count))
elif self.proj == 'constant':
# find a better init for this.
self.z = torch.nn.Parameter(torch.randn(out_features, self.leaf_count))
def forward(self, x):
node_densities = self.node_densities(x)
leaf_probs = node_densities[:, -self.leaf_count:].t()
if self.proj == 'linear':
gated_projection = torch.matmul(self.pw,leaf_probs).permute(2,0,1)
gated_bias = torch.matmul(self.pb,leaf_probs).permute(1,0)
result = torch.matmul(gated_projection,x.view(-1,self.in_features,1))[:,:,0] + gated_bias
elif self.proj == 'constant':
result = torch.matmul(self.z,leaf_probs).permute(1,0)
return result
def extra_repr(self):
return "in_features=%d, out_features=%d, depth=%d, projection=%s" % (
self.in_features,
self.out_features,
self.depth,
self.proj)
def node_densities(self, x):
gw_ = self.drop(self.gw)
gatings = torch.sigmoid(torch.add(torch.matmul(x,gw_),self.gb))
node_densities = torch.ones(x.shape[0], 2**(self.depth+1)-1, device=x.device)
it = 1
for d in range(1, self.depth+1):
for i in range(2**d):
parent_index = (it+1) // 2 - 1
child_way = (it+1) % 2
if child_way == 0:
parent_gating = gatings[:, parent_index]
else:
parent_gating = 1 - gatings[:, parent_index]
parent_density = node_densities[:, parent_index].clone()
node_densities[:, it] = (parent_density * parent_gating)
it += 1
return node_densities
def gatings(self, x):
with torch.no_grad():
gatings = torch.sigmoid(torch.add(torch.matmul(x,self.gw),self.gb))
return gatings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.