Skip to content

Instantly share code, notes, and snippets.

Last active January 5, 2018 16:02
Show Gist options
  • Save Tosainu/02784b8e1233158436e623633a2b50b5 to your computer and use it in GitHub Desktop.
Save Tosainu/02784b8e1233158436e623633a2b50b5 to your computer and use it in GitHub Desktop.
{-# LANGUAGE FlexibleContexts #-}
module Main where
import Control.Monad (replicateM)
import Graphics.Rendering.Chart.Backend.Cairo
import Graphics.Rendering.Chart.Easy hiding ((<.>))
import Graphics.Rendering.Chart.Gtk
import Numeric.GSL.Minimization
import Numeric.LinearAlgebra
import System.Random
instance (Random x, Random y) => Random (x, y) where
randomR ((x1, y1), (x2, y2)) gen1 =
let (x, gen2) = randomR (x1, x2) gen1
(y, gen3) = randomR (y1, y2) gen2
in ((x, y), gen3)
random = undefined
sigmoid :: Floating a => a -> a
sigmoid x = 1.0 / (1.0 + exp (-x))
sigmoid' :: (Container c e, Floating e) => c e -> c e
sigmoid' = cmap sigmoid
main :: IO ()
main = do
c1 <- replicateM 50 $
getStdRandom $ randomR ((-3.75, 2.25), (-2.25, 3.75)) :: IO [(Double, Double)]
-- c2 <- replicateM 50 $
c21 <- replicateM 25 $
getStdRandom $ randomR ((-1.75, 1.25), (-0.25, 2.75)) :: IO [(Double, Double)]
c22 <- replicateM 25 $
getStdRandom $ randomR ((2.25, -2.25), (3.75, -3.75)) :: IO [(Double, Double)]
c2 = c21 ++ c22
xs = c1 ++ c2
n = length xs
mx = (n><3) (concatMap (\(x1, x2) -> [1.0, x1, x2]) xs)
mt = vector $ replicate (length c1) 1.0
++ replicate (length c2) 0.0
let f x w = sigmoid' (x #> w)
ll x y w = -y <.> log (f x w) - (1 - y) <.> log (1 - f x w)
dll x y w = tr x #> (f x w - y)
(mw, p) = minimizeVD SteepestDescent 10e-3 3000 10e-4 10e-4
(ll mx mt)
(dll mx mt)
(vector [1, 1, 1])
x1hat = [-5.0, 5.0]
x2hat = map (\xi -> (- (mw ! 0) + (mw ! 1) * xi) / (mw ! 2)) x1hat
yhat = zip x1hat x2hat
print p
putStrLn $ "w~ = " ++ show mw
let fo = def { _fo_size = (480, 480)
, _fo_format = SVG
toFile fo "logistic.svg" $ do
layout_x_axis . laxis_generate .= scaledAxis def (-4, 4)
layout_x_axis . laxis_title .= "x1"
layout_y_axis . laxis_generate .= scaledAxis def (-4, 4)
layout_y_axis . laxis_title .= "x2"
setShapes [PointShapeCircle, PointShapeCircle]
plot (points "C1" c1)
plot (points "C2" c2)
plot (line "y" [yhat])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment