Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created August 23, 2017 01:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/2e3a4358c1c9565466198b7a410ab82f to your computer and use it in GitHub Desktop.
Save unixpickle/2e3a4358c1c9565466198b7a410ab82f to your computer and use it in GitHub Desktop.
Cosine tracker for treeagent
// cosineTracker is a splitTracker for CosineAlgorithm.
type cosineTracker struct {
mseTracker
}
func (c *cosineTracker) Quality() float64 {
sums := []smallVec{c.sumTracker.leftSum, c.sumTracker.rightSum}
sqSums := []float64{c.leftSquares, c.rightSquares}
counts := []float64{float64(c.leftCount), float64(c.rightCount)}
// This is a closed form solution for cos(theta), where
// theta is the angle between the actual gradient and
// the approximated gradient.
//
// See https://gist.github.com/unixpickle/e002210247344aaa025a0601bf355bd8.
leftDot := sums[0].Dot(sums[0])
rightDot := sums[1].Dot(sums[1])
numerator := leftDot/counts[0] + rightDot/counts[1]
denomSq := sqSums[0] + sqSums[1]
denomSq *= leftDot/counts[0] + rightDot/counts[1]
denom := math.Sqrt(denomSq)
if denom == 0 {
return math.Inf(-1)
}
return numerator / denom
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment