Created
March 31, 2015 05:28
-
-
Save mathias-brandewinder/d3daebd687f2095de1b1 to your computer and use it in GitHub Desktop.
Conversion to F# of "Gradient Descent Training Using C#"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is a conversion to F# of the C# code presented in | |
// MSDN Magazine, March 2015, by James McCaffrey: | |
// https://msdn.microsoft.com/en-us/magazine/dn913188.aspx | |
open System | |
let sumprod (v1:float[]) (v2:float[]) = | |
Seq.zip v1 v2 |> Seq.sumBy (fun (x,y) -> x * y) | |
let sigmoid z = 1.0 / (1.0 + exp (- z)) | |
let logistic (weights:float[]) (features:float[]) = | |
sumprod weights features |> sigmoid | |
let makeAllData (numFeatures, numRows, seed) = | |
let rnd = Random(seed) | |
let generate (low,high) = low + (high-low) * rnd.NextDouble() | |
let weights = Array.init (numFeatures + 1) (fun _ -> generate(-10.0,10.0)) | |
let dataset = | |
[| for row in 1 .. numRows -> | |
let features = | |
[| | |
yield 1.0 | |
for feat in 1 .. numFeatures -> generate(-10.0,10.0) | |
|] | |
let value = | |
if logistic weights features > 0.55 | |
then 1.0 | |
else 0.0 | |
(features, value) | |
|] | |
weights, dataset | |
let shuffle (rng:Random) (data:_[]) = | |
let copy = Array.copy data | |
for i in 0 .. (copy.Length - 1) do | |
let r = rng.Next(i, copy.Length) | |
let tmp = copy.[r] | |
copy.[r] <- copy.[i] | |
copy.[i] <- tmp | |
copy | |
let makeTrainTest (allData:_[], seed) = | |
let rnd = Random(seed) | |
let totRows = allData.Length | |
let numTrainRows = int (float totRows * 0.80) // 80% hard-coded | |
let copy = shuffle rnd allData | |
copy.[.. numTrainRows-1], copy.[numTrainRows ..] | |
type Example = float [] * float | |
let Error (trainData:Example[], weights:float[]) = | |
// mean squared error using supplied weights | |
trainData | |
|> Array.map (fun (features,value) -> | |
let computed = logistic weights features | |
let desired = value | |
(computed - desired) * (computed - desired)) | |
|> Array.average | |
let Train (trainData:Example[], numFeatures, maxEpochs, alpha, seed) = | |
let rng = Random(seed) | |
let epoch = 0 | |
let update (example:Example) (weights:float[]) = | |
let features,target = example | |
let computed = logistic weights features | |
weights | |
|> Array.mapi (fun i w -> | |
w + alpha * (target - computed) * features.[i]) | |
let rec updateWeights (data:Example[]) epoch weights = | |
if epoch % 100 = 0 | |
then printfn "Epoch: %i, Error: %.2f" epoch (Error (data,weights)) | |
if epoch = maxEpochs then weights | |
else | |
let data = shuffle rng data | |
let weights = | |
data | |
|> Array.fold (fun w example -> update example w) weights | |
updateWeights data (epoch + 1) weights | |
// initialize the weights and start the recursive update | |
let initialWeights = [| for _ in 1 .. numFeatures + 1 -> 0. |] | |
let finalWeights = updateWeights trainData 0 initialWeights | |
let classifier (features:float[]) = | |
if logistic finalWeights features > 0.5 then 1. else 0. | |
classifier | |
printfn "Begin Logistic Regression (binary) Classification demo" | |
printfn "Goal is to demonstrate training using gradient descent" | |
let numFeatures = 8 // synthetic data | |
let numRows = 10000 | |
let seed = 1 | |
printfn "Generating %i artificial data items with %i features" numRows numFeatures | |
let trueWeights, allData = makeAllData(numFeatures, numRows, seed) | |
printfn "Data generation weights:" | |
trueWeights |> Array.iter (printf "%.2f ") | |
printfn "" | |
printfn "Creating train (80%%) and test (20%%) matrices" | |
let trainData, testData = makeTrainTest(allData, 0) | |
printfn "Done" | |
let maxEpochs = 1000 | |
let alpha = 0.01 | |
let classifier = Train (trainData,numFeatures,maxEpochs,alpha,0) | |
let accuracy (examples:Example[]) = | |
examples | |
|> Array.averageBy (fun (feat,value) -> | |
if classifier feat = value then 1. else 0.) | |
accuracy trainData |> printfn "Prediction accuracy on train data: %.4f" | |
accuracy testData |> printfn "Prediction accuracy on test data: %.4f" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment