Last active
October 16, 2016 05:36
-
-
Save lotz84/6c9076e110ff335939ea3e3085128d80 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes #-} | |
module Main where | |
import Control.Monad | |
import Control.Monad.IO.Class (MonadIO, liftIO) | |
import Control.Monad.Morph (hoist, generalize) | |
import Control.Monad.State (State, StateT) | |
import qualified Control.Monad.State as State | |
import Data.List (genericLength) | |
import Data.List.Extra (chunksOf) | |
import Data.Time | |
import Numeric.AD.Mode.Kahn | |
import System.Random.MWC | |
import System.Random.Shuffle | |
type Input = [Double] | |
type Params = [Double] | |
type Output = Double | |
type Var s = AD s (Kahn Double) | |
toDouble :: Real a => a -> Double | |
toDouble = fromRational . toRational | |
generateTrainingData :: Int -> IO [(Input, Output)] | |
generateTrainingData n = do | |
sequence . replicate n $ do | |
withSystemRandom . asGenIO $ \gen -> do | |
x1 <- uniformR (-1, 1) gen :: IO Double | |
x2 <- uniformR (-1, 1) gen :: IO Double | |
let y = if h x1 x2 > 0 then 1 else 0 | |
pure ([x1, x2], y) | |
where | |
h x y = x + y - 1 | |
createInitialParams :: IO Params | |
createInitialParams = do | |
withSystemRandom . asGenIO $ \gen -> do | |
sequence . replicate 3 $ (uniformR (-1, 1) gen :: IO Double) | |
predictor :: [Var s] -> Input -> Var s | |
predictor [w1, w2, w3] [x1, x2] = 1 / (1 + exp (-phi)) | |
where phi = w1 * auto x1 + w2 * auto x2 + w3 | |
loss :: [Var s] -> (Input, Output) -> Var s | |
loss w (x, t) = | |
let y = predictor w x | |
in (-(auto t)) * log y + (auto t - 1) * log (1 - y) | |
totalLoss :: [Double] -> [(Input, Output)] -> Double | |
totalLoss w td = toDouble . sum . map (loss $ map auto w) $ td | |
miniBatch :: MonadIO m | |
=> Int -- バッチサイズ | |
-> [(Input, Output)] -- 教師データ | |
-> ((forall s. [Var s] -> Var s) -> (Params -> m Params)) | |
-> Params-> m Params -- パラメータを更新する関数 | |
miniBatch size trainingData updateWith w = do | |
tds <- liftIO $ shuffleM trainingData | |
let chunks = chunksOf size tds | |
chunkLoss chunk w = (sum $ loss w <$> chunk) | |
updates = map (\chunk -> updateWith $ chunkLoss chunk) chunks | |
foldr (>=>) pure updates $ w | |
(|+|), (|-|) :: Num a => [a] -> [a] -> [a] | |
(|+|) = zipWith (+) | |
(|-|) = zipWith (-) | |
(|*|) = zipWith (*) | |
(|/|) = zipWith (/) | |
infixl 6 |+|, |-| | |
infixl 7 |*|, |/| | |
(*|) :: Num a => a -> [a] -> [a] | |
(*|) a xs = map (a*) xs | |
infixl 7 *| | |
gradientDescent :: Double | |
-> (forall s. [Var s] -> Var s) -> Params -> Params | |
gradientDescent alpha f ws = ws |-| (alpha *| grad f ws) | |
gdWithMomentum :: Double -> Double -- 学習率, 慣性係数 | |
-> (forall s. [Var s] -> Var s)-> Params -> State [Double] Params | |
gdWithMomentum rate gamma f ws = do | |
State.modify $ \vs -> gamma *| vs |+| rate *| grad f ws | |
State.gets $ \vs -> ws |-| vs | |
nag :: Double -> Double -- 学習率, 慣性係数 | |
-> (forall s. [Var s] -> Var s) -> Params -> State [Double] Params | |
nag rate gamma f ws = do | |
State.modify $ \vs -> gamma *| vs |+| rate *| grad f (ws |-| gamma *| vs) | |
State.gets $ \vs -> ws |-| vs | |
adagrad :: Double -- 学習率 | |
-> (forall s. [Var s] -> Var s) -> Params -> State [Double] Params | |
adagrad rate f ws = do | |
let vs = grad f ws | |
State.modify $ \gs -> gs |+| map (^2) vs | |
State.gets $ \gs -> ws |-| rate *| vs |/| (sqrt <$> (gs |+| eps)) | |
where | |
eps = repeat 1.0e-8 | |
adadelta :: Double | |
-> (forall s. [Var s] -> Var s) -> Params -> State [(Double, Double)] Params | |
adadelta rho f ws = do | |
(ds, gs) <- State.gets unzip | |
let vs = grad f ws | |
gs' = rho *| gs |+| (1-rho) *| map (^2) vs | |
deltas = vs |*| (sqrt <$> ((ds |+| eps) |/| (gs' |+| eps))) | |
ds' = rho *| ds |+| (1-rho) *| deltas |*| deltas | |
State.put $ zip ds' gs' | |
pure $ ws |-| deltas | |
where | |
eps = repeat 1.0e-8 | |
rmsprop :: Double | |
-> (forall s. [Var s] -> Var s) -> Params -> State [Double] Params | |
rmsprop rate f ws = do | |
let vs = grad f ws | |
State.modify $ \gs -> 0.9 *| gs |+| 0.1 *| map (^2) vs | |
State.gets $ \gs -> ws |-| rate *| vs |/| (sqrt <$> (gs |+| eps)) | |
where | |
eps = repeat 1.0e-8 | |
adam :: Double -> Double -> Double | |
-> (forall s. [Var s] -> Var s) -> Params -> State (Int, [(Double, Double)]) Params | |
adam rate beta1 beta2 f ws = do | |
(count, (ms, ds)) <- State.gets (fmap unzip) | |
let vs = grad f ws | |
ms' = beta1 *| ms |+| (1-beta1) *| vs | |
ds' = beta2 *| ds |+| (1-beta2) *| map (^2) vs | |
count' = count + 1 | |
deltas = rate *| (ms' |/| repeat (1-beta1^count')) |/| (sqrt <$> (ds' |/| repeat (1 - beta2^count')) |+| eps) | |
State.put $ (count', zip ms' ds') | |
pure $ ws |-| deltas | |
where | |
eps = repeat 1.0e-8 | |
main :: IO () | |
main = do | |
let epoch = 100 -- エポック数 | |
trainingData <- generateTrainingData 1000 | |
initial <- createInitialParams | |
putStrLn $ "Initial parameters: " ++ show initial | |
putStrLn $ "Initial total loss: " ++ show (totalLoss initial trainingData) | |
putStrLn "" | |
experiment "バッチ勾配降下法" trainingData $ do | |
let update = miniBatch (length trainingData) trainingData (\loss w -> pure $ gradientDescent 0.01 loss w) | |
(foldr (>=>) pure $ replicate epoch update) initial | |
experiment "確率的勾配降下法" trainingData $ do | |
let update = miniBatch 1 trainingData (\loss w -> pure $ gradientDescent 0.01 loss w) | |
(foldr (>=>) pure $ replicate epoch update) $ initial | |
experiment "ミニバッチ勾配降下法" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> pure $ gradientDescent 0.01 loss w) | |
(foldr (>=>) pure $ replicate epoch update) $ initial | |
experiment "Momentum(慣性)" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ gdWithMomentum 0.01 0.9 loss w) | |
flip State.evalStateT [0, 0, 0] $ (foldr (>=>) pure $ replicate epoch update) initial | |
experiment "Nesterovの加速勾配降下法" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ nag 0.01 0.9 loss w) | |
flip State.evalStateT [0, 0, 0] $ (foldr (>=>) pure $ replicate epoch update) initial | |
experiment "Adagrad" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ adagrad 0.01 loss w) | |
flip State.evalStateT [0, 0, 0] $ (foldr (>=>) pure $ replicate epoch update) initial | |
experiment "Adadelta" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ adadelta 0.95 loss w) | |
flip State.evalStateT [(0, 0), (0, 0), (0, 0)] $ (foldr (>=>) pure $ replicate epoch update) initial | |
experiment "RMSprop" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ rmsprop 0.001 loss w) | |
flip State.evalStateT [0, 0, 0] $ (foldr (>=>) pure $ replicate epoch update) initial | |
experiment "Adam" trainingData $ do | |
let update = miniBatch 50 trainingData (\loss w -> hoist generalize $ adam 0.001 0.9 0.999 loss w) | |
flip State.evalStateT (0, [(0, 0), (0, 0), (0, 0)]) $ (foldr (>=>) pure $ replicate epoch update) initial | |
where | |
experiment :: String -> [(Input, Output)] -> IO Params -> IO () | |
experiment title trainingData training = do | |
putStrLn title | |
putStrLn $ replicate 40 '-' | |
start <- getCurrentTime | |
learned <- training | |
putStrLn $ "Learned parameters: " ++ show learned | |
end <- getCurrentTime | |
putStrLn $ "Final total loss: " ++ show (totalLoss learned trainingData) | |
putStrLn $ "Time: " ++ show (diffUTCTime end start) | |
putStrLn "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: x-sgd | |
version: 0.1.0.0 | |
build-type: Simple | |
cabal-version: >=1.10 | |
executable app | |
hs-source-dirs: app | |
main-is: Main.hs | |
ghc-options: -threaded -rtsopts -with-rtsopts=-N | |
build-depends: base | |
, time | |
, extra | |
, mtl | |
, mmorph | |
, mwc-random | |
, random-shuffle | |
, ad | |
default-language: Haskell2010 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Initial parameters: [-0.2109002859956821,-0.7472215390933672,0.1279927949794184] | |
Initial total loss: 851.9460485831221 | |
バッチ勾配降下法 | |
---------------------------------------- | |
Learned parameters: [7.341714362217305,7.004903775629043,-7.493096245550326] | |
Final total loss: 53.383585155878194 | |
Time: 6.31818s | |
確率的勾配降下法 | |
---------------------------------------- | |
Learned parameters: [7.258200613808301,6.92452611425834,-7.407920594577997] | |
Final total loss: 53.99831971478062 | |
Time: 3.785089s | |
ミニバッチ勾配降下法 | |
---------------------------------------- | |
Learned parameters: [7.26634189705081,6.9321004662588805,-7.399471377592351] | |
Final total loss: 53.98624370279099 | |
Time: 3.798387s | |
Momentum(慣性) | |
---------------------------------------- | |
Learned parameters: [15.927673309491524,15.500967470448142,-16.032998990400852] | |
Final total loss: 24.99592706328587 | |
Time: 3.849795s | |
Nesterovの加速勾配降下法 | |
---------------------------------------- | |
Learned parameters: [15.922508536026339,15.488071799588825,-15.908699325553314] | |
Final total loss: 25.08924809475623 | |
Time: 3.878567s | |
Adagrad | |
---------------------------------------- | |
Learned parameters: [0.4746285755128946,1.433774226124509e-2,-0.6531383867679561] | |
Final total loss: 470.1318101936902 | |
Time: 3.927045s | |
Adadelta | |
---------------------------------------- | |
Learned parameters: [0.5686946094452934,0.25433904327629,-1.0225396232571826] | |
Final total loss: 379.91595005544406 | |
Time: 3.92515s | |
RMSprop | |
---------------------------------------- | |
Learned parameters: [1.3814916523926128,1.102867413655655,-1.788353626452354] | |
Final total loss: 236.21422191680716 | |
Time: 3.956204s | |
Adam | |
---------------------------------------- | |
Learned parameters: [1.0630007467446652,0.7124753136414661,-1.34321667536241] | |
Final total loss: 296.40278151073596 | |
Time: 3.830224s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment