Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active September 20, 2019 11:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alper111/ef29c006cb35aec03ecbae9866b30e23 to your computer and use it in GitHub Desktop.
Save alper111/ef29c006cb35aec03ecbae9866b30e23 to your computer and use it in GitHub Desktop.
PyTorch implementation of soft decision tree. All gating calculations are done at one step in order to utilize from GPU. The recursive definition might be faster for non-GPU machines.
import torch
class SoftTree(torch.nn.Module):
def __init__(self, in_features, out_features, depth, projection='constant', dropout=0.0):
super(SoftTree, self).__init__()
self.proj = projection
self.depth = depth
self.in_features = in_features
self.out_features = out_features
self.leaf_count = int(2**depth)
self.gate_count = int(self.leaf_count - 1)
self.gw = torch.nn.Parameter(
torch.nn.init.kaiming_normal_(
torch.empty(self.gate_count, in_features), nonlinearity='sigmoid').t())
self.gb = torch.nn.Parameter(torch.zeros(self.gate_count))
# dropout rate for gating weights.
self.drop = torch.nn.Dropout(p=dropout)
if self.proj == 'linear':
self.pw = torch.nn.init.kaiming_normal_(torch.empty(out_features*self.leaf_count, in_features), nonlinearity='linear')
self.pw = torch.nn.Parameter(self.pw.view(out_features, self.leaf_count, in_features).permute(0, 2, 1))
self.pb = torch.nn.Parameter(torch.zeros(out_features, self.leaf_count))
elif self.proj == 'constant':
# find a better init for this.
self.z = torch.nn.Parameter(torch.randn(out_features, self.leaf_count))
def forward(self, x):
node_densities = self.node_densities(x)
leaf_probs = node_densities[:, -self.leaf_count:].t()
if self.proj == 'linear':
gated_projection = torch.matmul(self.pw,leaf_probs).permute(2,0,1)
gated_bias = torch.matmul(self.pb,leaf_probs).permute(1,0)
result = torch.matmul(gated_projection,x.view(-1,self.in_features,1))[:,:,0] + gated_bias
elif self.proj == 'constant':
result = torch.matmul(self.z,leaf_probs).permute(1,0)
return result
def extra_repr(self):
return "in_features=%d, out_features=%d, depth=%d, projection=%s" % (
self.in_features,
self.out_features,
self.depth,
self.proj)
def node_densities(self, x):
gw_ = self.drop(self.gw)
gatings = torch.sigmoid(torch.add(torch.matmul(x,gw_),self.gb))
node_densities = torch.ones(x.shape[0], 2**(self.depth+1)-1, device=x.device)
it = 1
for d in range(1, self.depth+1):
for i in range(2**d):
parent_index = (it+1) // 2 - 1
child_way = (it+1) % 2
if child_way == 0:
parent_gating = gatings[:, parent_index]
else:
parent_gating = 1 - gatings[:, parent_index]
parent_density = node_densities[:, parent_index].clone()
node_densities[:, it] = (parent_density * parent_gating)
it += 1
return node_densities
def gatings(self, x):
with torch.no_grad():
gatings = torch.sigmoid(torch.add(torch.matmul(x,self.gw),self.gb))
return gatings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment