Skip to content

Instantly share code, notes, and snippets.

@tixxit
Created December 21, 2015 21:59
Show Gist options
  • Save tixxit/987a9a3130b2b3fa914a to your computer and use it in GitHub Desktop.
Save tixxit/987a9a3130b2b3fa914a to your computer and use it in GitHub Desktop.
sealed trait Step[K, V, T, I, O] {
def andThen[O2](that: Step[K, V, T, O2]): Step[K, V, T, O2] =
ComposedStep(this, that)
}
object Step {
type LeafSummerFunction[K, V, T, I, O] = (Map[Int, Tree[K, V, T]], Int, Int, I) => TraversableOnce[O]
type TrainingStep[K, V, T] = Step[K, V, T, Instance[K, V, T], Tree[K, V, T]]
trait LeafSummer[K, V, T, I, O] extends Step[K, V, T, I, O] {
def semigroup: Semigroup[O]
def apply(trees: Map[Int, Tree[K, V, T]], treeIndex: Int, leafIndex: Int, value: I): TraversableOnce[O]
}
case class ComposedStep[K, V, T, A, B, C](step1: Step[K, V, T, A, B], step2: Step[K, V, T, B, C]) extends Step[K, V, T, A, C]
def sumByLeaf[K, V, T, I, O](f: LeafSummerFunction[K, V, T, I, O])(implicit ev: Semigroup[O]): LeafSummer[K, V, T, I, O] =
new LeafSummer[K, V, T, I, O] {
def semigroup: Semigroup[O] = ev
def apply(trees: Map[Int, Tree[K, V, T]], treeIndex: Int, leafIndex: Int, value: I): TraversableOnce[O] =
f(trees, treeIndex, leafIndex, value)
}
def fromInstance[K, V, T, O](f: LeafSummerFunction[K, V, T, Instance[K, V, T], O]): LeafSummer[K, V, T, Instance[K, V, T], O] =
sumByLeaf[K, V, T, Instance[K, V, T], O](f)
def fromTree[K, V, T, I, O](f: (Tree[K, V, T], Int, I) => TraversableOnce[O]): LeafSummer[K, V, T, I, O] =
sumByLeaf[K, V, T, I, O] { (tress, treeIndex, leafIndex, input) =>
trees.get(treeIndex).iterator.flatMap(f(_, leafIndex, input))
}
def updateTargets[K, V, T: Semigroup](sampler: Sampler[K]): TrainingStep[K, V, T] =
fromInstance { (_, treeIndex, _, instance) =>
val count = sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)
List.fill(count)(instance.target)
}.andThen(fromTree { (tree, leafIndex, t) =>
// Imaginary, fast updateLeaf method.
tree.updateLeaf(leafIndex)(_.copy(target = t))
})
def expand[K, V, T](sampler: Sampler[K], stopper: Stopper[T], splitter: Splitter[V, T], evaluator: Evaluator[V, T]): TrainingStep[K, V, T] =
fromInstance { (trees, treeIndex, _, instance) =>
for {
tree <- trees.get(treeIndex)
count = sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)
if count > 0
leaf <- tree.leafFor(instance.features)
if stopper.shouldSplit(leaf.target)
} yield {
instance.features.collect { case (k, v) if sampler.includeFeature(k, treeIndex, leaf.index) =>
k -> splitter.semigroup.intTimes(count, splitter.create(v, instance.target))
}
}
}
.andThen(sumByLeaf { (trees, treeIndex, leafIndex, featureStats) =>
for {
tree <- trees.get(treeIndex).toList
leaf <- tree.get(leafIndex).toList
split <- featureStats.map { case (feature, stats) =>
splitter
.split(leaf.target, stats)
.map { rawSplit =>
val (split, goodness) = evaluator.evaluate(rawSplit)
Max((feature, split, goodness))
}
}
} yield split
} (Max.semigroup(Order.by(_._3))))
.andThen(fromTree { (tree, leafIndex, maxSplit) =>
// update tree with new target.
val (feature, split, _) = maxSplit.get
tree.updateLeaf(leafIndex) { leaf =>
SplitNode(for {
(pred, target) <- split.predicates
} yield (feature, pred, LeafNode(0, target, ())))
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment