Skip to content

Instantly share code, notes, and snippets.

@alper111
Created September 20, 2019 11:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alper111/50c461b84a1f46079658d9f1e7811d00 to your computer and use it in GitHub Desktop.
Save alper111/50c461b84a1f46079658d9f1e7811d00 to your computer and use it in GitHub Desktop.
PyTorch implementation of soft decision tree node. This recursive version is more readable however the tensorized version is faster.
import torch
class SoftNode(torch.nn.Module):
def __init__(self, in_features, out_features, depth, projection="constant"):
super(SoftNode, self).__init__()
self.projection = projection
if depth > 0:
self.left = SoftNode(in_features, out_features, depth-1, projection=projection)
self.right = SoftNode(in_features, out_features, depth-1, projection=projection)
self.gating = torch.nn.Linear(in_features, 1)
self.expert = None
else:
self.left = None
self.right = None
self.gating = None
if projection == "constant":
self.expert = torch.nn.Parameter(torch.randn(out_features))
elif projection == "linear":
self.expert = torch.nn.Linear(in_features, out_features)
def forward(self, x):
if self.expert is None:
g = torch.sigmoid(self.gating(x))
out = g * self.left(x) + (1-g) * self.right(x)
else:
if self.projection == "constant":
out = self.expert
else:
out = self.expert(x)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment