Created
May 30, 2018 20:50
-
-
Save stites/5a1158a7435a71332b3ecdae41edfb2a to your computer and use it in GitHub Desktop.
hasktorch lenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE AllowAmbiguousTypes #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} | |
module LeNet where | |
import Data.Function ((&)) | |
import GHC.Natural | |
import Numeric.Backprop | |
import Prelude as P | |
import Torch.Double as Torch | |
import qualified ReLU | |
import qualified Torch.Long as Ix | |
import qualified Torch.Double.NN.Conv2d as NN | |
import qualified Torch.Double.NN.Layers as NN | |
import qualified Torch.Double.NN.Activation as NN | |
data LeNet s = LeNet | |
{ conv1 :: BVar s (Conv2d 1 6 5 5) | |
, conv2 :: BVar s (Conv2d 6 16 5 5) | |
, fc1 :: BVar s (Linear (16*5*5) 120) | |
, fc2 :: BVar s (Linear 120 84) | |
, fc3 :: BVar s (Linear 84 10) | |
} | |
lenet | |
:: forall s | |
. Reifies s W | |
=> Double | |
-> LeNet s -- ^ lenet architecture | |
-> BVar s (Tensor '[1,32,32]) -- ^ input | |
-> BVar s (Tensor '[10]) -- ^ output | |
lenet lr arch inp | |
= lenetLayer lr (conv1 arch) inp | |
& lenetLayer lr (conv2 arch) | |
& flattenBP | |
-- start fully connected network | |
& relu . linear (fc1 arch) | |
& relu . linear (fc2 arch) | |
& linear (fc3 arch) | |
lenetLayer lr conv | |
= maxPooling2d | |
(Kernel2d :: Kernel2d 2 2) | |
(Step2d :: Step2d 2 2) | |
(Padding2d :: Padding2d 0 0) | |
(sing :: SBool 'True) | |
. relu | |
. NN.conv2dMM | |
(Step2d :: Step2d 1 1) | |
(Padding2d :: Padding2d 0 0) | |
lr conv | |
------------------------------------------------------------------------------- | |
-- layer initialization: | |
newLinear :: forall o i . KnownNatDim2 i o => IO (Linear i o) | |
newLinear = Linear <$> newLayerWithBias (natVal (Proxy @i)) | |
newConv2d :: forall o i kH kW . KnownNatDim4 i o kH kW => IO (Conv2d i o kH kW) | |
newConv2d = Conv2d <$> newLayerWithBias (natVal (Proxy @i) * natVal (Proxy @kH) * natVal (Proxy @kW)) | |
newLayerWithBias :: Dimensions2 d d' => Natural -> IO (Tensor d, Tensor d') | |
newLayerWithBias n = do | |
g <- newRNG | |
(,) <$> uniform g (-stdv) stdv | |
<*> uniform g (-stdv) stdv | |
where | |
stdv :: Double | |
stdv = 1 / P.sqrt (fromIntegral n) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment