-
-
Save anonymous/9c7ec31460a51c2b36864920567cca47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- RBMの実装 | |
-- 入力データフォーマットは、libsvmと同じ | |
module RBM2 where | |
import System.IO ( IOMode (ReadMode, WriteMode) | |
, Handle | |
, hIsEOF | |
, openFile | |
, hClose | |
, hPutStrLn) | |
import System.Random ( randomRIO ) | |
import Control.Exception (throwIO) | |
import Control.Monad.Trans.Class (lift) | |
import Data.Array.IO (IOArray) | |
import Data.Array.MArray ( getBounds | |
, readArray | |
, writeArray | |
, newArray | |
, newArray_) | |
import Data.Foldable (forM_) | |
import Data.Traversable (forM) | |
import Data.Text ( Text | |
, splitOn | |
, pack | |
, unpack) | |
import Data.Text.IO ( hGetLine | |
, readFile) | |
import qualified Data.ByteString.Lazy as LB | |
import Data.Aeson ( ToJSON | |
, FromJSON | |
, eitherDecode | |
, parseJSON | |
, (.:)) | |
import Data.Aeson.Types (Value (Object) | |
, typeMismatch) | |
import Iter (fold) | |
import Math | |
import Debug.Trace (trace) | |
----------------------------------------------------------- | |
-- 設定の定義など | |
-- | コア数の設定 | |
data MProc = MProc { procNum :: Int -- number of processors | |
, procReal :: Int | |
, procModel :: Int | |
} deriving (Show) | |
-- | MProcのJSONパース設定 | |
instance FromJSON MProc where | |
parseJSON (Object v) = MProc | |
<$> v .: (pack "numProc") | |
<*> v .: (pack "numReal") | |
<*> v .: (pack "numModel") | |
-- A non-Object value is of the wrong type, so fail. | |
parseJSON v = typeMismatch "MProc" v | |
-- | 設定データ | |
data Config = Config { confNV :: Int -- number of visible node | |
, confNH :: Int -- number of hidden node | |
, confIVal :: Double -- initial weight | |
, confEff :: Double -- learning eff | |
, confNCD :: Int -- number of constructive divergences | |
, confNTrain :: Int -- number of set of training | |
, confProc :: Maybe MProc | |
} deriving (Show) | |
-- | ConfigのJSONのパース設定 | |
instance FromJSON Config where | |
parseJSON (Object v) = Config | |
<$> v .: (pack "numVisible") | |
<*> v .: (pack "numHidden") | |
<*> v .: (pack "ival") | |
<*> v .: (pack "larningEff") | |
<*> v .: (pack "numCD") | |
<*> v .: (pack "numNTrain") | |
<*> v .: (pack "proc") | |
-- A non-Object value is of the wrong type, so fail. | |
parseJSON v = typeMismatch "MProc" v | |
----------------------------------------------------------- | |
-- 学習するためのコード | |
-- | 入力データの種類 | |
type Signal = Int | |
-- | 層の各のノードの値の集合はベクトルとして表現 | |
type Layer = Vector Double | |
-- | 1つの入力データは、SignalとLayerで表現 | |
type Event = (Signal, Layer) | |
-- | すべての入力データ | |
type EventSet = Vector Event | |
-- | バイアスはベクトルで表現 | |
type Bias = Vector Double | |
-- | 可視層と隠れ層間の重みは行列で表現 | |
type Weight = Matrix Double | |
-- | 結合定数 | |
-- | * 可視層のバイアス | |
-- | * 隠れ層バイアス | |
-- | * 可視層と隠れ層間の重み | |
type Coupling = (Bias, Bias, Weight) | |
-- | 読み込んだファイルの1行からイベントを生成する。 | |
-- | フォーマットは、libsvmのフォーマットと一緒。 | |
-- | 最初のカラムはシグナルとして認識する。 | |
-- | 2番目のカラム以降は各ユニットのビット値(0,1)として認識する。 | |
lineToEvent :: Text -> Event | |
lineToEvent line = let xs = splitOn (pack " ") line | |
signal = read (unpack $ head xs) :: Int | |
in (signal, mkLayer (tail xs)) | |
where | |
parseValue :: Text -> Double | |
parseValue text = read $ unpack $ splitOn (pack ":") text !! 1 | |
mkLayer :: [Text] -> Vector Double | |
mkLayer [] = [] | |
mkLayer (x:xs) = parseValue x : mkLayer xs | |
-- | ファイルの全てのイベントを読み込み、EventSetを生成する。 | |
-- | ファイルのフォーマットは、libsvmと一緒。 | |
-- | 第1引数は、入力データ。 | |
createEventSet :: String -> IO EventSet | |
createEventSet input = do | |
h <- openFile input ReadMode | |
evs <- getEvents h | |
hClose h | |
putStrLn $ "event set has been created" | |
return evs | |
where | |
-- 全てのイベントを読み込む | |
getEvents :: Handle -> IO EventSet | |
getEvents h = do | |
eof <- hIsEOF h | |
if eof | |
then return [] | |
else do | |
line <- hGetLine h | |
let ev = lineToEvent line | |
evs <- getEvents h | |
return $ ev : evs | |
-- | 初期化された結合定数を生成する。 | |
initCoupling :: Config -> Coupling | |
initCoupling conf = let nv = confNV conf | |
nh = confNH conf | |
iv = confIVal conf | |
vm = newVector nv 0.0 | |
hm = newVector nh 0.0 | |
im = newMatrix nv nh iv | |
in (vm, hm, im) | |
-- | 可視ユニットが1となる条件付き確率を計算する | |
calcVisibleProp :: Coupling -- 結合定数 | |
-> Vector Double -- hidden nodes | |
-> Vector Double | |
calcVisibleProp (bv, bh, w) h = logistic `applyFV` (bv `addVV` (w `dotMV` h)) | |
-- | 隠れユニットが1となる条件付き確率を計算する | |
calcHideProp :: Coupling -- index of hidden node | |
-> Vector Double -- visible nodes | |
-> Vector Double | |
calcHideProp (bv, bh, w) v = logistic `applyFV` (bh `addVV` (w `dotMtV` v)) | |
-- | Constructive Divergence | |
consDiv :: Coupling | |
-> Int -- times to calc | |
-> Vector Double -- initial value of visible node | |
-> IO (Vector Double, Vector Double) -- (visible, hidden) | |
consDiv coupling times v = do | |
let (bv, bh, w) = coupling | |
h <- updateHidden coupling v | |
fold (v, h) [0 .. (times-1)] $ \(v, h) -> \_ -> do | |
v' <- updateVisible coupling h | |
h' <- updateHidden coupling v' | |
return (v', h') | |
where | |
updateVisible :: Coupling -> Vector Double -> IO (Vector Double) | |
updateVisible coupling h = do | |
forM (calcVisibleProp coupling h) $ \pv -> do | |
r <- randomRIO (0.0, 1.0) | |
return $ if r < pv then 1 else 0 | |
updateHidden :: Coupling -> Vector Double -> IO (Vector Double) | |
updateHidden coupling v = do | |
forM (calcHideProp coupling v) $ \ph -> do | |
r <- randomRIO (0.0, 1.0) | |
return $ if r < ph then 1 else 0 | |
-- | 勾配を計算し、バイアスと重みを更新する。 | |
-- | 1イベント毎に更新する。 | |
-- | 与えられた全てのイベントを使用する。 | |
calcGrad :: Config | |
-> Coupling | |
-> EventSet | |
-> IO (Coupling, Double) | |
calcGrad conf coupling evs = do | |
let ncd = confNCD conf | |
eff = confEff conf | |
fold (coupling, 0.0) evs $ \((bv, bh, w), _) -> \(s, v) -> do | |
(v', h') <- consDiv coupling ncd v | |
let ph = calcHideProp coupling v | |
ph' = calcHideProp coupling v' | |
gv = v `subVV` v' | |
gh = ph `subVV` ph' | |
gw = (v `mulVV` ph) `subMM` (v' `mulVV` ph') | |
dv = eff `mulCV` gv | |
dh = eff `mulCV` gh | |
dw = eff `mulCM` gw | |
bv' = bv `addVV` dv | |
bh' = bh `addVV` dh | |
w' = w `addMM` dw | |
div = divV gv + divV gh + divM gw | |
return ((bv', bh', w'), div) | |
-- | 設定ファイルを読み込み、Configを生成する。 | |
loadConfig :: String -> IO Config | |
loadConfig confile = do | |
mconf <- eitherDecode <$> LB.readFile confile :: IO (Either String Config) | |
case mconf of | |
Left msg -> throwIO $ userError $ "ERROR: failed to read config file: " ++ msg | |
Right conf -> return conf | |
-- | 結合定数をファイルに保存する。 | |
-- | 第1引数は、出力ファイル名。 | |
-- | 第2引数は、結合定数。 | |
saveCoupling :: String -> Coupling -> IO () | |
saveCoupling output (bv, bh, w) = do | |
h <- openFile output WriteMode | |
hPutStrLn h $ joinV bv | |
hPutStrLn h $ joinV bh | |
forM_ w $ \v -> do | |
hPutStrLn h $ joinV v | |
hClose h | |
where | |
joinV :: Vector Double -> String | |
joinV [x] = show x | |
joinV (x:v) = (show x) ++ " " ++ (joinV v) | |
-- | ファイルから結合定数を読み込み、Couplingを生成する。 | |
loadCoupling :: String -> IO Coupling | |
loadCoupling input = do | |
h <- openFile input ReadMode | |
line <- hGetLine h | |
let bv = (read . unpack) `map` splitOn (pack " ") line | |
line <- hGetLine h | |
let bh = (read . unpack) `map` splitOn (pack " ") line | |
w <- readWeight h | |
hClose h | |
return (bv, bh, w) | |
where | |
readWeight :: Handle -> IO Weight | |
readWeight h = do | |
eof <- hIsEOF h | |
if eof | |
then return [] | |
else do | |
line <- hGetLine h | |
let row = (read . unpack) `map` splitOn (pack " ") line | |
(row :) <$> readWeight h | |
-- | 学習を実行する。 | |
learnMain :: [String] -> IO () | |
learnMain args = do | |
let (input : confile : args') = args | |
conf <- loadConfig confile | |
putStrLn $ "input file: " ++ input | |
coupling <- case args' of | |
[wfile] -> do | |
putStrLn $ "loading coupling file: " ++ (show wfile) | |
loadCoupling wfile | |
_ -> do | |
putStrLn $ "initializing coupling" | |
return $ initCoupling conf | |
evs <- createEventSet input | |
let ntrain = confNTrain conf | |
coup <- fold coupling [0..(ntrain-1)] $ \coup -> \i -> do | |
(coup', div) <- calcGrad conf coup evs | |
putStrLn $ (show i) ++ ": " ++ (show div) | |
return coup' | |
saveCoupling (input ++ ".weight") coup | |
----------------------------------------------------------- | |
-- 予測するためのコード | |
-- | 信号用のエネルギー計算 | |
-- | 隠れ変数について周辺化したボルツマン分布 | |
-- | 「深層学習 人工知能学会」近代科学社 p61 式(2.44) | |
calcEnergy :: Layer | |
-> Coupling | |
-> Double | |
calcEnergy v (bv, bh, w) | |
= let ev = dotVV bv v | |
eh = foldl1 (+) (log `applyFV` (1 `addCV` (exp `applyFV` (bh `addVV` (w `dotMtV` v))))) | |
in ev + eh | |
-- | 与えられた1件のデータがどのグループに属するか判定する。 | |
predictLayer :: Layer -> Coupling -> Coupling -> Int | |
predictLayer v c1 c2 | |
= let e1 = calcEnergy v c1 | |
e2 = calcEnergy v c2 | |
in if e1 > e2 then 1 else 2 | |
-- | 予測する。 | |
predict :: String -> String -> String -> IO () | |
predict input wfile1 wfile2 = do | |
c1 <- loadCoupling wfile1 | |
c2 <- loadCoupling wfile2 | |
evs <- createEventSet input | |
let (w, t) = foldl (check c1 c2) (0.0, 0.0) evs | |
putStrLn $ (show w) ++ "/" ++ (show t) ++ "=" ++ (show $ w/t * 100) ++ "%" | |
where | |
check :: Coupling | |
-> Coupling | |
-> (Double, Double) | |
-> Event | |
-> (Double, Double) | |
check c1 c2 (w, t) (s, l) | |
= let p = predictLayer l c1 c2 | |
in if p == s then (w+1, t+1) else (w, t+1) | |
predictMain :: [String] -> IO () | |
predictMain args = do | |
let [input, wfile1, wfile2] = args | |
predict input wfile1 wfile2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment