Skip to content

Instantly share code, notes, and snippets.

@ptrelford
Last active December 19, 2015 11:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ptrelford/5946697 to your computer and use it in GitHub Desktop.
Save ptrelford/5946697 to your computer and use it in GitHub Desktop.
Decision Trees - port of Machine Learning in Action from Python to F#
open System.Collections.Generic
module internal Tuple =
open Microsoft.FSharp.Reflection
let toArray = FSharpValue.GetTupleFields
module internal Array =
let removeAt i (xs:'a[]) = [|yield! xs.[..i-1];yield! xs.[i+1..]|]
let splitDataSet(dataSet:obj[][], axis, value) = [|
for featVec in dataSet do
if featVec.[axis] = value then
yield featVec |> Array.removeAt axis
|]
let calcShannonEnt(dataSet:obj[][]) =
let numEntries = dataSet.Length
dataSet
|> Seq.countBy (fun featVec -> featVec.[featVec.Length-1])
|> Seq.sumBy (fun (key,count) ->
let prob = float count / float numEntries
-prob * log(prob)/log(2.0)
)
let chooseBestFeatureToSplit(dataSet:obj[][]) =
let numFeatures = dataSet.[0].Length - 1
let baseEntropy = calcShannonEnt(dataSet)
[0..numFeatures-1] |> List.map (fun i ->
let featList = [for example in dataSet -> example.[i]]
let newEntropy =
let uniqueValues = Seq.distinct featList
uniqueValues |> Seq.sumBy (fun value ->
let subDataSet = splitDataSet(dataSet, i, value)
let prob = float subDataSet.Length / float dataSet.Length
prob * calcShannonEnt(subDataSet)
)
let infoGain = baseEntropy - newEntropy
i, infoGain
)
|> List.maxBy snd |> fst
let majorityCnt(classList:obj[]) =
let classCount = Dictionary()
for vote in classList do
if not <| classCount.ContainsKey(vote) then
classCount.Add(vote,0)
classCount.[vote] <- classCount.[vote] + 1
[for kvp in classCount -> kvp.Key, kvp.Value]
|> List.sortBy (snd >> (~-))
|> List.head
|> fst
type Label = string
type Value = obj
type Tree = Leaf of Value | Branch of Label * (Value * Tree)[]
let rec createTree(dataSet:obj[][], labels:string[]) =
let classList = [|for example in dataSet -> example.[example.Length-1]|]
if classList |> Seq.forall((=) classList.[0])
then Leaf(classList.[0])
elif dataSet.[0].Length = 1
then Leaf(majorityCnt(classList))
else
let bestFeat = chooseBestFeatureToSplit(dataSet)
let bestFeatLabel = labels.[bestFeat]
let labels = labels |> Array.removeAt bestFeat
let featValues = [|for example in dataSet -> example.[bestFeat]|]
let uniqueVals = featValues |> Seq.distinct |> Seq.toArray
let subTrees =
[|for value in uniqueVals ->
let subLabels = labels.[*]
let split = splitDataSet(dataSet, bestFeat, value)
value, createTree(split, subLabels)|]
Branch(bestFeatLabel, subTrees)
let rec classify(inputTree, featLabels:string[], testVec:obj[]) =
match inputTree with
| Leaf(x) -> x
| Branch(s,xs) ->
let featIndex = featLabels |> Array.findIndex ((=) s)
xs |> Array.pick (fun (value,tree) ->
if testVec.[featIndex] = value
then classify(tree, featLabels, testVec) |> Some
else None
)
let myDat =
[|(1, 1, "yes"); (1, 1, "yes"); (1, 0, "no"); (0, 1, "no"); (0, 1, "no")|]
|> Array.map Tuple.toArray
let Assert condition = if not condition then failwith "Failed"
let expected : obj[][] = [|[|1; "yes"|]; [|1; "yes"|]; [|0; "no"|]|]
Assert (splitDataSet(myDat,0,1) = expected)
Assert (round(calcShannonEnt(myDat)*100.) = round(0.9709505945*100.))
Assert (chooseBestFeatureToSplit(myDat) = 0)
let labels = [|"no surfacing";"flippers"|]
let myTree = createTree(myDat,labels)
Assert (classify(myTree,labels,[|1;0|]) = box "no")
Assert (classify(myTree,labels,[|1;1|]) = box "yes")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment