Skip to content

Instantly share code, notes, and snippets.

@ahancock1
Created October 10, 2020 14:05
Show Gist options
  • Save ahancock1/31d3778b4fb826eb79e6c111336bb798 to your computer and use it in GitHub Desktop.
Save ahancock1/31d3778b4fb826eb79e6c111336bb798 to your computer and use it in GitHub Desktop.
Half Space Trees HST classifier .Net C# implemetation
public class ClassifierSettings
{
public int WindowSize { get; set; } = 250;
public int Estimators { get; set; } = 25;
public int MaxDepth { get; set; } = 15;
public int Features { get; set; } = 0;
public double MinLimit { get; set; } = 0d;
public double MaxLimit { get; set; } = 1d;
}
public class Classifier
{
private readonly Node[] _nodes;
private readonly Random _random = new Random(42);
private readonly ClassifierSettings _settings;
private int _count;
public Classifier(ClassifierSettings settings)
{
_settings = settings;
_nodes = new Node[_settings.Estimators];
}
private Node Build(Limit[] limits, int depth)
{
if (depth == _settings.MaxDepth)
{
return new Node
{
Type = NodeType.External,
Depth = depth
};
}
var feature = _random.NextChoice(
Enumerable.Range(0, _settings.Features)
.ToDictionary(i => i, i => limits[i].Max - limits[i].Min));
var limit = limits[feature];
var max = limit.Max;
var min = limit.Min;
const double padding = 0.15d;
var value = _random.NextDouble(
limit.Min + padding * (limit.Max - limit.Min),
limit.Max - padding * (limit.Max - limit.Min));
limit.Max = value;
limit.Min = min;
var left = Build(limits, depth + 1);
limit.Max = max;
limit.Min = value;
var right = Build(limits, depth + 1);
return new Node
{
Left = left,
Right = right,
Type = NodeType.Internal,
Depth = depth,
Feature = feature,
Value = value
};
}
private void Initialise()
{
var features = _settings.Features;
for (var i = 0; i < _settings.Estimators; i++)
{
var limits = Enumerable.Range(0, features)
.Select(_ =>
new Limit
{
Min = _settings.MinLimit,
Max = _settings.MaxLimit
})
.ToArray();
_nodes[i] = Build(limits, 0);
}
}
public void Fit(IData data)
{
if (_count == 0)
{
Initialise();
}
foreach (var root in _nodes)
{
foreach (var node in Path(root, n =>
data.Features[n.Feature] < n.Value))
{
if (_count < _settings.WindowSize)
{
node.RightMass++;
}
node.LeftMass++;
}
}
_count++;
if (_count % _settings.WindowSize == 0)
{
foreach (var root in _nodes)
{
Iterate(root, n =>
{
n.RightMass = n.LeftMass;
n.LeftMass = 0;
});
}
}
}
private IEnumerable<Node> Path(Node node, Func<Node, bool> evaluator)
{
while (node != null)
{
yield return node;
node = evaluator.Invoke(node) ? node.Left : node.Right;
}
}
private void Iterate(Node node, Action<Node> action)
{
action.Invoke(node);
if (node.Type == NodeType.External)
{
return;
}
Iterate(node.Left, action);
Iterate(node.Right, action);
}
public double Score(IData data)
{
var size = Math.Min(_settings.WindowSize, _count);
var max = _settings.Estimators * size *
(Math.Pow(2, _settings.MaxDepth + 1) - 1);
var limit = 0.1d * _settings.WindowSize;
var score = 0d;
foreach (var root in _nodes)
{
foreach (var node in Path(root, n =>
data.Features[n.Feature] < n.Value))
{
score += node.RightMass * Math.Pow(2, node.Depth);
if (node.RightMass < limit)
{
break;
}
}
}
return 1 - score / max;
}
private class Limit
{
public double Min { get; set; }
public double Max { get; set; }
}
private class Node
{
public NodeType Type { get; set; }
public Node Left { get; set; }
public Node Right { get; set; }
public int LeftMass { get; set; }
public int RightMass { get; set; }
public int Feature { get; set; }
public double Value { get; set; }
public int Depth { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment