Skip to content

Instantly share code, notes, and snippets.

@stites
Created May 30, 2018 20:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stites/5a1158a7435a71332b3ecdae41edfb2a to your computer and use it in GitHub Desktop.
Save stites/5a1158a7435a71332b3ecdae41edfb2a to your computer and use it in GitHub Desktop.
hasktorch lenet
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module LeNet where
import Data.Function ((&))
import GHC.Natural
import Numeric.Backprop
import Prelude as P
import Torch.Double as Torch
import qualified ReLU
import qualified Torch.Long as Ix
import qualified Torch.Double.NN.Conv2d as NN
import qualified Torch.Double.NN.Layers as NN
import qualified Torch.Double.NN.Activation as NN
data LeNet s = LeNet
{ conv1 :: BVar s (Conv2d 1 6 5 5)
, conv2 :: BVar s (Conv2d 6 16 5 5)
, fc1 :: BVar s (Linear (16*5*5) 120)
, fc2 :: BVar s (Linear 120 84)
, fc3 :: BVar s (Linear 84 10)
}
lenet
:: forall s
. Reifies s W
=> Double
-> LeNet s -- ^ lenet architecture
-> BVar s (Tensor '[1,32,32]) -- ^ input
-> BVar s (Tensor '[10]) -- ^ output
lenet lr arch inp
= lenetLayer lr (conv1 arch) inp
& lenetLayer lr (conv2 arch)
& flattenBP
-- start fully connected network
& relu . linear (fc1 arch)
& relu . linear (fc2 arch)
& linear (fc3 arch)
lenetLayer lr conv
= maxPooling2d
(Kernel2d :: Kernel2d 2 2)
(Step2d :: Step2d 2 2)
(Padding2d :: Padding2d 0 0)
(sing :: SBool 'True)
. relu
. NN.conv2dMM
(Step2d :: Step2d 1 1)
(Padding2d :: Padding2d 0 0)
lr conv
-------------------------------------------------------------------------------
-- layer initialization:
newLinear :: forall o i . KnownNatDim2 i o => IO (Linear i o)
newLinear = Linear <$> newLayerWithBias (natVal (Proxy @i))
newConv2d :: forall o i kH kW . KnownNatDim4 i o kH kW => IO (Conv2d i o kH kW)
newConv2d = Conv2d <$> newLayerWithBias (natVal (Proxy @i) * natVal (Proxy @kH) * natVal (Proxy @kW))
newLayerWithBias :: Dimensions2 d d' => Natural -> IO (Tensor d, Tensor d')
newLayerWithBias n = do
g <- newRNG
(,) <$> uniform g (-stdv) stdv
<*> uniform g (-stdv) stdv
where
stdv :: Double
stdv = 1 / P.sqrt (fromIntegral n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment