Skip to content

Instantly share code, notes, and snippets.

/RBM.hs Secret

Created May 6, 2016 06:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/9c7ec31460a51c2b36864920567cca47 to your computer and use it in GitHub Desktop.
Save anonymous/9c7ec31460a51c2b36864920567cca47 to your computer and use it in GitHub Desktop.
-- RBMの実装
-- 入力データフォーマットは、libsvmと同じ
module RBM2 where
import System.IO ( IOMode (ReadMode, WriteMode)
, Handle
, hIsEOF
, openFile
, hClose
, hPutStrLn)
import System.Random ( randomRIO )
import Control.Exception (throwIO)
import Control.Monad.Trans.Class (lift)
import Data.Array.IO (IOArray)
import Data.Array.MArray ( getBounds
, readArray
, writeArray
, newArray
, newArray_)
import Data.Foldable (forM_)
import Data.Traversable (forM)
import Data.Text ( Text
, splitOn
, pack
, unpack)
import Data.Text.IO ( hGetLine
, readFile)
import qualified Data.ByteString.Lazy as LB
import Data.Aeson ( ToJSON
, FromJSON
, eitherDecode
, parseJSON
, (.:))
import Data.Aeson.Types (Value (Object)
, typeMismatch)
import Iter (fold)
import Math
import Debug.Trace (trace)
-----------------------------------------------------------
-- 設定の定義など
-- | コア数の設定
data MProc = MProc { procNum :: Int -- number of processors
, procReal :: Int
, procModel :: Int
} deriving (Show)
-- | MProcのJSONパース設定
instance FromJSON MProc where
parseJSON (Object v) = MProc
<$> v .: (pack "numProc")
<*> v .: (pack "numReal")
<*> v .: (pack "numModel")
-- A non-Object value is of the wrong type, so fail.
parseJSON v = typeMismatch "MProc" v
-- | 設定データ
data Config = Config { confNV :: Int -- number of visible node
, confNH :: Int -- number of hidden node
, confIVal :: Double -- initial weight
, confEff :: Double -- learning eff
, confNCD :: Int -- number of constructive divergences
, confNTrain :: Int -- number of set of training
, confProc :: Maybe MProc
} deriving (Show)
-- | ConfigのJSONのパース設定
instance FromJSON Config where
parseJSON (Object v) = Config
<$> v .: (pack "numVisible")
<*> v .: (pack "numHidden")
<*> v .: (pack "ival")
<*> v .: (pack "larningEff")
<*> v .: (pack "numCD")
<*> v .: (pack "numNTrain")
<*> v .: (pack "proc")
-- A non-Object value is of the wrong type, so fail.
parseJSON v = typeMismatch "MProc" v
-----------------------------------------------------------
-- 学習するためのコード
-- | 入力データの種類
type Signal = Int
-- | 層の各のノードの値の集合はベクトルとして表現
type Layer = Vector Double
-- | 1つの入力データは、SignalとLayerで表現
type Event = (Signal, Layer)
-- | すべての入力データ
type EventSet = Vector Event
-- | バイアスはベクトルで表現
type Bias = Vector Double
-- | 可視層と隠れ層間の重みは行列で表現
type Weight = Matrix Double
-- | 結合定数
-- | * 可視層のバイアス
-- | * 隠れ層バイアス
-- | * 可視層と隠れ層間の重み
type Coupling = (Bias, Bias, Weight)
-- | 読み込んだファイルの1行からイベントを生成する。
-- | フォーマットは、libsvmのフォーマットと一緒。
-- | 最初のカラムはシグナルとして認識する。
-- | 2番目のカラム以降は各ユニットのビット値(0,1)として認識する。
lineToEvent :: Text -> Event
lineToEvent line = let xs = splitOn (pack " ") line
signal = read (unpack $ head xs) :: Int
in (signal, mkLayer (tail xs))
where
parseValue :: Text -> Double
parseValue text = read $ unpack $ splitOn (pack ":") text !! 1
mkLayer :: [Text] -> Vector Double
mkLayer [] = []
mkLayer (x:xs) = parseValue x : mkLayer xs
-- | ファイルの全てのイベントを読み込み、EventSetを生成する。
-- | ファイルのフォーマットは、libsvmと一緒。
-- | 第1引数は、入力データ。
createEventSet :: String -> IO EventSet
createEventSet input = do
h <- openFile input ReadMode
evs <- getEvents h
hClose h
putStrLn $ "event set has been created"
return evs
where
-- 全てのイベントを読み込む
getEvents :: Handle -> IO EventSet
getEvents h = do
eof <- hIsEOF h
if eof
then return []
else do
line <- hGetLine h
let ev = lineToEvent line
evs <- getEvents h
return $ ev : evs
-- | 初期化された結合定数を生成する。
initCoupling :: Config -> Coupling
initCoupling conf = let nv = confNV conf
nh = confNH conf
iv = confIVal conf
vm = newVector nv 0.0
hm = newVector nh 0.0
im = newMatrix nv nh iv
in (vm, hm, im)
-- | 可視ユニットが1となる条件付き確率を計算する
calcVisibleProp :: Coupling -- 結合定数
-> Vector Double -- hidden nodes
-> Vector Double
calcVisibleProp (bv, bh, w) h = logistic `applyFV` (bv `addVV` (w `dotMV` h))
-- | 隠れユニットが1となる条件付き確率を計算する
calcHideProp :: Coupling -- index of hidden node
-> Vector Double -- visible nodes
-> Vector Double
calcHideProp (bv, bh, w) v = logistic `applyFV` (bh `addVV` (w `dotMtV` v))
-- | Constructive Divergence
consDiv :: Coupling
-> Int -- times to calc
-> Vector Double -- initial value of visible node
-> IO (Vector Double, Vector Double) -- (visible, hidden)
consDiv coupling times v = do
let (bv, bh, w) = coupling
h <- updateHidden coupling v
fold (v, h) [0 .. (times-1)] $ \(v, h) -> \_ -> do
v' <- updateVisible coupling h
h' <- updateHidden coupling v'
return (v', h')
where
updateVisible :: Coupling -> Vector Double -> IO (Vector Double)
updateVisible coupling h = do
forM (calcVisibleProp coupling h) $ \pv -> do
r <- randomRIO (0.0, 1.0)
return $ if r < pv then 1 else 0
updateHidden :: Coupling -> Vector Double -> IO (Vector Double)
updateHidden coupling v = do
forM (calcHideProp coupling v) $ \ph -> do
r <- randomRIO (0.0, 1.0)
return $ if r < ph then 1 else 0
-- | 勾配を計算し、バイアスと重みを更新する。
-- | 1イベント毎に更新する。
-- | 与えられた全てのイベントを使用する。
calcGrad :: Config
-> Coupling
-> EventSet
-> IO (Coupling, Double)
calcGrad conf coupling evs = do
let ncd = confNCD conf
eff = confEff conf
fold (coupling, 0.0) evs $ \((bv, bh, w), _) -> \(s, v) -> do
(v', h') <- consDiv coupling ncd v
let ph = calcHideProp coupling v
ph' = calcHideProp coupling v'
gv = v `subVV` v'
gh = ph `subVV` ph'
gw = (v `mulVV` ph) `subMM` (v' `mulVV` ph')
dv = eff `mulCV` gv
dh = eff `mulCV` gh
dw = eff `mulCM` gw
bv' = bv `addVV` dv
bh' = bh `addVV` dh
w' = w `addMM` dw
div = divV gv + divV gh + divM gw
return ((bv', bh', w'), div)
-- | 設定ファイルを読み込み、Configを生成する。
loadConfig :: String -> IO Config
loadConfig confile = do
mconf <- eitherDecode <$> LB.readFile confile :: IO (Either String Config)
case mconf of
Left msg -> throwIO $ userError $ "ERROR: failed to read config file: " ++ msg
Right conf -> return conf
-- | 結合定数をファイルに保存する。
-- | 第1引数は、出力ファイル名。
-- | 第2引数は、結合定数。
saveCoupling :: String -> Coupling -> IO ()
saveCoupling output (bv, bh, w) = do
h <- openFile output WriteMode
hPutStrLn h $ joinV bv
hPutStrLn h $ joinV bh
forM_ w $ \v -> do
hPutStrLn h $ joinV v
hClose h
where
joinV :: Vector Double -> String
joinV [x] = show x
joinV (x:v) = (show x) ++ " " ++ (joinV v)
-- | ファイルから結合定数を読み込み、Couplingを生成する。
loadCoupling :: String -> IO Coupling
loadCoupling input = do
h <- openFile input ReadMode
line <- hGetLine h
let bv = (read . unpack) `map` splitOn (pack " ") line
line <- hGetLine h
let bh = (read . unpack) `map` splitOn (pack " ") line
w <- readWeight h
hClose h
return (bv, bh, w)
where
readWeight :: Handle -> IO Weight
readWeight h = do
eof <- hIsEOF h
if eof
then return []
else do
line <- hGetLine h
let row = (read . unpack) `map` splitOn (pack " ") line
(row :) <$> readWeight h
-- | 学習を実行する。
learnMain :: [String] -> IO ()
learnMain args = do
let (input : confile : args') = args
conf <- loadConfig confile
putStrLn $ "input file: " ++ input
coupling <- case args' of
[wfile] -> do
putStrLn $ "loading coupling file: " ++ (show wfile)
loadCoupling wfile
_ -> do
putStrLn $ "initializing coupling"
return $ initCoupling conf
evs <- createEventSet input
let ntrain = confNTrain conf
coup <- fold coupling [0..(ntrain-1)] $ \coup -> \i -> do
(coup', div) <- calcGrad conf coup evs
putStrLn $ (show i) ++ ": " ++ (show div)
return coup'
saveCoupling (input ++ ".weight") coup
-----------------------------------------------------------
-- 予測するためのコード
-- | 信号用のエネルギー計算
-- | 隠れ変数について周辺化したボルツマン分布
-- | 「深層学習 人工知能学会」近代科学社 p61 式(2.44)
calcEnergy :: Layer
-> Coupling
-> Double
calcEnergy v (bv, bh, w)
= let ev = dotVV bv v
eh = foldl1 (+) (log `applyFV` (1 `addCV` (exp `applyFV` (bh `addVV` (w `dotMtV` v)))))
in ev + eh
-- | 与えられた1件のデータがどのグループに属するか判定する。
predictLayer :: Layer -> Coupling -> Coupling -> Int
predictLayer v c1 c2
= let e1 = calcEnergy v c1
e2 = calcEnergy v c2
in if e1 > e2 then 1 else 2
-- | 予測する。
predict :: String -> String -> String -> IO ()
predict input wfile1 wfile2 = do
c1 <- loadCoupling wfile1
c2 <- loadCoupling wfile2
evs <- createEventSet input
let (w, t) = foldl (check c1 c2) (0.0, 0.0) evs
putStrLn $ (show w) ++ "/" ++ (show t) ++ "=" ++ (show $ w/t * 100) ++ "%"
where
check :: Coupling
-> Coupling
-> (Double, Double)
-> Event
-> (Double, Double)
check c1 c2 (w, t) (s, l)
= let p = predictLayer l c1 c2
in if p == s then (w+1, t+1) else (w, t+1)
predictMain :: [String] -> IO ()
predictMain args = do
let [input, wfile1, wfile2] = args
predict input wfile1 wfile2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment