Created
July 16, 2014 21:08
-
-
Save mathias-brandewinder/0c2313e2f90ac5642a9a to your computer and use it in GitHub Desktop.
Logistic regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type Vec = float [] | |
let f (X:Vec) (W:Vec) = | |
(X,W) ||> Array.map2 (fun x w -> x * w) | |
let g z = 1. / (1. + exp (-z)) | |
let dg z = g(z) * (1. - g(z)) | |
let h (X:Vec) (W:Vec) = g (f X W |> Seq.sum) | |
let update (alpha:float) (W:Vec) (X:Vec) (y:float) = | |
(W,X) | |
||> Array.map2 (fun w x -> | |
w + alpha * (y - h X W) * x) | |
let data = | |
[ | |
1., [| 1.; 0.; 1. |] | |
0., [| 0.; 1.; 1. |] | |
1., [| 1.; 0.; 0. |] | |
0., [| 0.; 0.; 0. |] | |
1., [| 1.; 0.; 1. |] | |
0., [| 0.; 1.; 0. |] | |
] | |
let w0 = [| 0.; 0.; 0. |] | |
let logistic data w = | |
data | |
|> List.fold (fun w (y,x) -> | |
update 0.1 w x y) w | |
let w' = logistic data w0 |> logistic data |> logistic data | |
data |> List.map (fun (y,x) -> printfn "Real %.2f, Pred %.2f" y (h x w')) | |
// weighted version | |
let upd (alpha:float) (W:Vec) (X:Vec) (p,n) = | |
(W,X) | |
||> Array.map2 (fun w x -> | |
w + alpha * (p - (p+n) * h X W) * x) | |
let data' = | |
[ | |
(1.,0.), [| 1.; 0.; 1. |] | |
(0.,1.), [| 0.; 1.; 1. |] | |
(1.,0.), [| 1.; 0.; 0. |] | |
(0.,1.), [| 0.; 0.; 0. |] | |
(1.,0.), [| 1.; 0.; 1. |] | |
(0.,1.), [| 0.; 1.; 0. |] | |
] | |
let logreg data w = | |
data | |
|> List.fold (fun w (y,x) -> | |
upd 0.1 w x y) w | |
let l2 = logreg data' | |
let weights = | |
l2 [|0.;0.;0.;|] |> l2 |> l2 | |
data' |> List.iter (fun ((p,n),x) -> printfn "Real %.2f, Pred %.2f" (p/(p+n)) (h x weights)) | |
// check | |
let data'' = | |
[ | |
(20.,2.), [| 1.; 1.; 0.; 1. |] | |
(5.,10.), [| 1.; 0.; 1.; 1. |] | |
(50.,5.), [| 1.; 1.; 0.; 0. |] | |
(3.,10.), [| 1.; 0.; 0.; 0. |] | |
(10.,6.), [| 1.; 1.; 0.; 1. |] | |
(3.,90.), [| 1.; 0.; 1.; 0. |] | |
] | |
let l3 = logreg data'' | |
//let weights' = l3 [| 0.;0.;0.;0.;|] |> l3 |> l3 | |
let weights' = | |
[|0.;0.;0.;0.|] | |
|> Seq.unfold (fun w -> Some( w,l3 w)) | |
|> Seq.take 100 | |
|> Seq.nth 99 | |
data'' |> List.iter (fun ((p,n),x) -> printfn "Qty %i, Real %.2f, Pred %.2f" (int (p+n)) (p/(p+n)) (h x weights')) | |
(* | |
Batch update attempt | |
*) | |
type Obs = (float*float)*Vec | |
let sum (X1:Vec) (X2:Vec) = | |
(X1,X2) | |
||> Array.map2 (fun x1 x2 -> x1 + x2) | |
let vecAvg (X:Vec seq) = | |
let len = X |> Seq.length |> float | |
X | |
|> Seq.reduce (fun acc x -> sum acc x) | |
|> Array.map (fun tot -> tot / len) | |
let batch (alpha:float) (X:Obs seq) (W:Vec) = | |
X | |
|> Seq.map (fun ((p,n),x) -> upd alpha W x (p,n)) | |
|> vecAvg | |
let batchLog (alpha:float) (X:Obs seq) (iters:int) = | |
[|0.;0.;0.;0.;|] | |
|> Seq.unfold (fun w -> Some(w,batch alpha X w)) | |
|> Seq.take iters | |
|> Seq.nth (iters - 1) | |
let w''' = batchLog 0.1 data'' 10000 | |
data'' |> List.iter (fun ((p,n),x) -> printfn "Qty %i, Real %.2f, Pred %.2f" (int (p+n)) (p/(p+n)) (h x w''')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment