This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Expand represnets the operation of exploring the | |
-- current active set of nodes. | |
type Expand a = | |
Seq (Node a) -> Node a -> Node a -> Seq (Node a) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package decisiontrees | |
import ( | |
"code.google.com/p/goprotobuf/proto" | |
pb "github.com/ajtulloch/decisiontrees/protobufs" | |
"github.com/golang/glog" | |
"time" | |
) | |
type boostingTreeGenerator struct { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func (b *boostingTreeGenerator) doBoostingRound(e Examples, round int) { | |
startTime := time.Now() | |
defer func() { | |
glog.Infof("Round %v, duration %v", round, time.Now().Sub(startTime)) | |
}() | |
if b.forestConfig.GetStochasticityConfig() != nil { | |
e = e.subsampleExamples(b.forestConfig.GetStochasticityConfig().GetPerRoundSamplingRate()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
Example = namedtuple('Example', ['features', 'label']) | |
def loss(pairs): | |
""" | |
L^2 loss - sum of squared divergece of label from average label | |
""" | |
if not pairs: | |
return 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module MachineLearning.DecisionTrees | |
(LossFunction(..), | |
Examples, | |
trainBoosting, | |
predictForest) where | |
import Data.Function (on) | |
import Data.List (and, sortBy) | |
import Data.Maybe (fromJust, | |
isJust) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
buildTreeAtLevel :: (Examples -> Double) -> PB.SplittingConstraints -> Int -> Examples -> DecisionTree | |
buildTreeAtLevel leafWeight splittingConstraints level examples = | |
if shouldSplit splittingConstraints level examples bestSplit | |
then Branch { | |
_feature=_splitFeature bestSplit | |
, _value=_splitValue bestSplit | |
, _left=recur $ V.takeWhile takePredicate orderedExamples | |
, _right=recur $ V.dropWhile takePredicate orderedExamples | |
} | |
else Leaf (leafWeight examples) where |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func (r *randomForestGenerator) ConstructForest(e Examples) *pb.Forest { | |
result := &pb.Forest{ | |
Trees: make([]*pb.TreeNode, int(r.forestConfig.GetNumWeakLearners())), | |
Rescaling: pb.Rescaling_AVERAGING.Enum(), | |
} | |
wg := sync.WaitGroup{} | |
for i := 0; i < int(r.forestConfig.GetNumWeakLearners()); i++ { | |
wg.Add(1) | |
go func(i int) { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func (b *boostingTreeGenerator) doInfluenceTrimming(e Examples) Examples { | |
lossFunction := b.getLossFunction() | |
by(func(e1, e2 *pb.Example) bool { | |
return lossFunction.GetSampleImportance(e1) < lossFunction.GetSampleImportance(e2) | |
}).Sort(e) | |
// Find cutoff point | |
weightSum := 0.0 | |
for _, ex := range e { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func (b *boostingTreeGenerator) doBoostingRound(e Examples, round int) { | |
startTime := time.Now() | |
defer func() { | |
glog.Infof("Round %v, duration %v", round, time.Now().Sub(startTime)) | |
}() | |
if b.forestConfig.GetStochasticityConfig() != nil { | |
e = e.subsampleExamples(b.forestConfig.GetStochasticityConfig().GetPerRoundSamplingRate()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type lossState struct { | |
averageLabel float64 | |
sumSquaredDivergence float64 | |
numExamples int | |
} | |
func constructLoss(e Examples) *lossState { | |
l := &lossState{} | |
for _, ex := range e { | |
l.addExample(ex) |