-
-
Save lotz84/4b8072baf5ecc8491877db210980a82d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
build-depends: | |
base ^>=4.16.4.0, | |
lens == 5.2.3, | |
linear == 1.22, | |
random == 1.2.1.1, | |
vector == 0.13.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE TypeApplications #-} | |
module Main where | |
import Data.Complex (Complex(..)) | |
import Data.Foldable (forM_, foldlM) | |
import Data.List (cycle, foldl') | |
import Data.Maybe (fromJust) | |
import GHC.TypeNats (KnownNat, Nat) | |
import Control.Lens (view, _1) | |
import qualified Data.Vector as V | |
import Linear.Conjugate (conjugate) | |
import Linear.Matrix ((!*)) | |
import Linear.Quaternion (Quaternion(..), _e) | |
import Linear.V (V(..), fromVector) | |
import Linear.V3 (V3(..)) | |
import System.Random (randomRIO, initStdGen) | |
-- | 実数 | |
type R = Double | |
-- | 複素数 | |
type C = Complex R | |
-- | 三重複素数 | |
data T = T C C C deriving Show | |
complex, gradZ, gradZC :: T -> C | |
complex (T a b c) = a | |
gradZ (T a b c) = b | |
gradZC (T a b c) = c | |
instance Num T where | |
(T z z' zc') + (T w w' wc') = T (z + w) (z' + w') (zc' + wc') | |
(T z z' zc') * (T w w' wc') = T (z * w) (z' * w + z * w') (zc' * w + z * wc') | |
negate (T z z' zc') = T (negate z) (negate z') (negate zc') | |
abs (T z z' zc') = T (abs z) w' wc' where | |
w' = conjugate z / (2 * abs z) * z' + z / (2 * abs z) * conjugate zc' | |
wc' = conjugate z / (2 * abs z) * zc' + z / (2 * abs z) * conjugate z' | |
signum (T z _ _) = T (signum z) 0 0 | |
fromInteger n = T (fromInteger n) 0 0 | |
instance Fractional T where | |
recip (T z z' zc') = T (recip z) (-1 * recip (z * z) * z') (-1 * recip (z * z) * zc') | |
fromRational r = T (fromRational r) 0 0 | |
instance Floating T where | |
pi = T pi 0 0 | |
exp (T z z' zc') = T (exp z) (exp z * z') (exp z * zc') | |
log (T z z' zc') = T (log z) (z' / z) (zc' / z) | |
sin (T z z' zc') = T (sin z) (cos z * z') (cos z * zc') | |
cos (T z z' zc') = T (cos z) (-sin z * z') (-sin z * zc') | |
asin (T z z' zc') = T (asin z) (z' / sqrt (1 - z ** 2)) (zc' / sqrt (1 - z ** 2)) | |
acos (T z z' zc') = T (acos z) (-z' / sqrt (1 - z ** 2)) (-zc' / sqrt (1 - z ** 2)) | |
atan (T z z' zc') = T (atan z) (z' / (1 + z ** 2)) (zc' / (1 + z ** 2)) | |
sinh (T z z' zc') = T (sinh z) (cosh z * z') (cosh z * zc') | |
cosh (T z z' zc') = T (cosh z) (sinh z * z') (sinh z * zc') | |
asinh (T z z' zc') = T (asinh z) (z' / sqrt (1 + z ** 2)) (zc' / sqrt (1 + z ** 2)) | |
acosh (T z z' zc') = T (acosh z) (z' / sqrt (z ** 2 - 1)) (zc' / sqrt (z ** 2 - 1)) | |
atanh (T z z' zc') = T (atanh z) (z' / (1 - z ** 2)) (zc' / (1 - z ** 2)) | |
type M (m :: Nat) (n :: Nat) a = V m (V n a) | |
linear :: (KnownNat m, KnownNat n, Num a) => M m n a -> V m a -> V n a -> V m a | |
linear w b z = w !* z + b | |
class Activatable a where | |
sigmoid :: a -> a | |
instance Activatable R where | |
sigmoid x = 1 / (1 + exp (-x)) | |
instance Activatable C where | |
sigmoid z = fmap sigmoid z | |
instance Activatable T where | |
sigmoid (T z@(x :+ y) z' zc') = T (sigmoid z) w' wc' where | |
sigmoid' x = sigmoid x * (1 - sigmoid x) | |
w' = (sigmoid' x :+ (-sigmoid' y)) * z' + (sigmoid' x :+ sigmoid' y) * conjugate zc' | |
wc' = (sigmoid' x :+ (-sigmoid' y)) * zc' + (sigmoid' x :+ sigmoid' y) * conjugate z' | |
-- | リストからベクトルに変換する | |
toV :: KnownNat n => [a] -> V n a | |
toV = fromJust . fromVector . V.fromList | |
-- | リストから行列に変換する | |
toM :: (KnownNat m, KnownNat n) => [[a]] -> M m n a | |
toM = toV . map toV | |
experiment1 :: IO () | |
experiment1 = do | |
-- パラメータを乱数で初期化 | |
w <- (uncurry (:+)) <$> randomRIO ((-1, -1), (1, 1)) :: IO C | |
b <- (uncurry (:+)) <$> randomRIO ((-1, -1), (1, 1)) :: IO C | |
let -- 教師データ | |
trainingData = take 1000 $ cycle [ | |
((-1):+(-1), 1:+0), | |
((-1):+ 1 , 0:+0), | |
( 1 :+(-1), 1:+1), | |
( 1 :+ 1 , 0:+1) | |
] | |
-- 学習モデル | |
model w b x = fmap sigmoid $ linear w b x | |
-- 学習係数 | |
learningRate = 0.1 :+ 0 | |
-- パラメータ更新の1ステップ | |
updateParam (w, b) (x, y) = | |
let -- パラメータを整形 | |
mw dw = toM @1 @1 [[T w dw 0]] | |
vb db = toV @1 [T b db 0] | |
vx = toV @1 [T x 0 0] | |
ty = T y 0 0 | |
-- 誤差関数の微分係数を計算 | |
df (dw, db) = gradZ $ (abs (ty - view _1 (model (mw dw) (vb db) vx))) ^ 2 | |
dfdw = df (1, 0) | |
dfdb = df (0, 1) | |
-- パラメータの更新 | |
in (w - learningRate * conjugate dfdw, b - learningRate * conjugate dfdb) | |
-- 学習データによるパラメータ更新 | |
(w', b') = foldl' updateParam (w, b) trainingData | |
mw = toM @1 @1 [[w']] | |
vb = toV @1 [b'] | |
-- 学習結果を表示 | |
putStrLn $ concat ["(-1) :+ (-1) | ", show $ model mw vb (toV @1 [(-1):+(-1)])] | |
putStrLn $ concat ["(-1) :+ 1 | ", show $ model mw vb (toV @1 [(-1):+ 1 ])] | |
putStrLn $ concat [" 1 :+ (-1) | ", show $ model mw vb (toV @1 [ 1 :+(-1)])] | |
putStrLn $ concat [" 1 :+ 1 | ", show $ model mw vb (toV @1 [ 1 :+ 1 ])] | |
-- | 四元数 | |
type H = Quaternion Double | |
-- | 1と虚数単位 | |
e, i, j, k :: H | |
e = Quaternion 1 (V3 0 0 0) | |
i = Quaternion 0 (V3 1 0 0) | |
j = Quaternion 0 (V3 0 1 0) | |
k = Quaternion 0 (V3 0 0 1) | |
-- | 四元数のスカラー倍 | |
(*|) :: Double -> H -> H | |
a *| q = fmap (*a) q | |
-- | 五重四元数 | |
data Q = Q H H H H H deriving Show | |
quaternion, gradQ, gradQi, gradQj, gradQk :: Q -> H | |
quaternion (Q a b c d e) = a | |
gradQ (Q a b c d e) = b | |
gradQi (Q a b c d e) = c | |
gradQj (Q a b c d e) = d | |
gradQk (Q a b c d e) = e | |
-- | ∂/∂q^{,i,j,k} を使った ∂/∂q^μ の計算 | |
ghr :: H -> H -> H -> H -> H -> H | |
ghr mu@(Quaternion mu_a (V3 mu_b mu_c mu_d)) q' qi' qj' qk' = | |
(q' * a + qi' * b + qj' * c + qk' * d) * recip mu | |
where | |
a = mu_a *| e | |
b = mu_a *| e + mu_b *| i | |
c = mu_a *| e + mu_c *| j | |
d = mu_a *| e + mu_d *| k | |
-- | GHR微積分における合成関数の微分法 | |
ghrChainRule :: (H, H, H, H) -> (H, H, H, H) -> (H, H, H, H) | |
ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk') = | |
let p' = dfdq * q' - dfdqi * i * qi' * i - dfdqj * j * qj' * j - dfdqk * k * qk' * k | |
pi' = dfdq * qi' - dfdqi * i * q' * i - dfdqj * j * qk' * j - dfdqk * k * qj' * k | |
pj' = dfdq * qj' - dfdqi * i * qk' * i - dfdqj * j * q' * j - dfdqk * k * qi' * k | |
pk' = dfdq * qk' - dfdqi * i * qj' * i - dfdqj * j * qi' * j - dfdqk * k * q' * k | |
in (p',pi', pj', pk') | |
instance Num Q where | |
(Q q q' qi' qj' qk') + (Q p p' pi' pj' pk') = Q (q + p) (q' + p') (qi' + pi') (qj' + pj') (qk' + pk') | |
(Q q q' qi' qj' qk') * (Q p p' pi' pj' pk') = Q (q * p) r' ri' rj' rk' where | |
r' = q * p' + ghr p q' qi' qj' qk' * p | |
ri' = q * pi' + ghr (p * i) q' qi' qj' qk' * p | |
rj' = q * pj' + ghr (p * j) q' qi' qj' qk' * p | |
rk' = q * pk' + ghr (p * k) q' qi' qj' qk' * p | |
negate (Q q q' qi' qj' qk') = Q (negate q) (negate q') (negate qi') (negate qj') (negate qk') | |
abs (Q q q' qi' qj' qk') = Q (abs q) p' pi' pj' pk' where | |
dfdq = conjugate q / (4 * abs q) | |
dfdqi = -i * dfdq * i | |
dfdqj = -j * dfdq * j | |
dfdqk = -k * dfdq * k | |
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk') | |
signum (Q q q' qi' qj' qk') = Q (signum q) 0 0 0 0 | |
fromInteger n = Q (fromInteger n) 0 0 0 0 | |
instance Fractional Q where | |
recip (Q q q' qi' qj' qk') = Q (recip q) p' pi' pj' pk' where | |
dfdq_mu mu = -(recip q) * (view _e (recip q * mu) *| recip mu) | |
dfdq = dfdq_mu e | |
dfdqi = dfdq_mu i | |
dfdqj = dfdq_mu j | |
dfdqk = dfdq_mu k | |
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk') | |
fromRational r = Q (fromRational r) 0 0 0 0 | |
instance Activatable H where | |
sigmoid q = fmap sigmoid q | |
instance Activatable Q where | |
sigmoid (Q q@(Quaternion x (V3 y z w)) q' qi' qj' qk') = Q (sigmoid q) p' pi' pj' pk' where | |
sigmoid' x = sigmoid x * (1 - sigmoid x) | |
dfdq = (sigmoid' x *| e - sigmoid' y *| i - sigmoid' z *| j - sigmoid' w *| k) / 4 | |
dfdqi = -i * dfdq * i | |
dfdqj = -j * dfdq * j | |
dfdqk = -k * dfdq * k | |
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk') | |
experiment2 :: IO () | |
experiment2 = do | |
-- パラメータを乱数で初期化 | |
let randomQuaterion = (\(x, y, z, w) -> Quaternion x (V3 y z w)) <$> randomRIO ((-1, -1, -1, -1), (1, 1, 1, 1)) :: IO H | |
w1 <- randomQuaterion | |
w2 <- randomQuaterion | |
w3 <- randomQuaterion | |
w4 <- randomQuaterion | |
b1 <- randomQuaterion | |
b2 <- randomQuaterion | |
b3 <- randomQuaterion | |
let -- 教師データ | |
combinations = do | |
x <- [-1, 1] | |
y <- [-1, 1] | |
z <- [-1, 1] | |
w <- [-1, 1] | |
pure $ Quaternion x (V3 y z w) | |
trainingData = take 10000 $ cycle $ map (\q@(Quaternion x (V3 y z w)) -> (q, Quaternion (-y) (V3 x (-w) z))) combinations | |
-- 学習モデル | |
model w1 w2 b1 b2 x = linear w2 b2 $ fmap sigmoid $ linear w1 b1 x | |
-- 学習係数 | |
learningRate = 0.1 *| e | |
-- パラメータ更新の1ステップ | |
updateParam (w1, w2, w3, w4, b1, b2, b3) (x, y) = | |
let -- パラメータを整形 | |
mw1 dw1 dw2 = toM @2 @1 [[Q w1 dw1 0 0 0], [Q w2 dw2 0 0 0]] | |
vb1 db1 db2 = toV @2 [Q b1 db1 0 0 0, Q b2 db2 0 0 0] | |
mw2 dw3 dw4 = toM @1 @2 [[Q w3 dw3 0 0 0, Q w4 dw4 0 0 0]] | |
vb2 db3 = toV @1 [Q b3 db3 0 0 0] | |
vx = toV @1 [Q x 0 0 0 0] | |
ty = Q y 0 0 0 0 | |
-- 誤差関数の微分係数を計算 | |
loss (dw1, dw2, dw3, dw4, db1, db2, db3) = (abs (ty - view _1 (model (mw1 dw1 dw2) (mw2 dw3 dw4) (vb1 db1 db2) (vb2 db3) vx))) ^ 2 | |
dfdw1 = gradQ $ loss (1, 0, 0, 0, 0, 0, 0) | |
dfdw2 = gradQ $ loss (0, 1, 0, 0, 0, 0, 0) | |
dfdw3 = gradQ $ loss (0, 0, 1, 0, 0, 0, 0) | |
dfdw4 = gradQ $ loss (0, 0, 0, 1, 0, 0, 0) | |
dfdb1 = gradQ $ loss (0, 0, 0, 0, 1, 0, 0) | |
dfdb2 = gradQ $ loss (0, 0, 0, 0, 0, 1, 0) | |
dfdb3 = gradQ $ loss (0, 0, 0, 0, 0, 0, 1) | |
-- パラメータの更新 | |
in ( w1 - learningRate * conjugate dfdw1 | |
, w2 - learningRate * conjugate dfdw2 | |
, w3 - learningRate * conjugate dfdw3 | |
, w4 - learningRate * conjugate dfdw4 | |
, b1 - learningRate * conjugate dfdb1 | |
, b2 - learningRate * conjugate dfdb2 | |
, b3 - learningRate * conjugate dfdb3 | |
) | |
-- 学習データによるパラメータ更新 | |
(w1', w2', w3', w4', b1', b2', b3') = foldl' updateParam (w1, w2, w3, w4, b1, b2, b3) trainingData | |
mw1 = toM @2 @1 [[w1'], [w2']] | |
mw2 = toM @1 @2 [[w3', w4']] | |
vb1 = toV @2 [b1', b2'] | |
vb2 = toV @1 [b3'] | |
-- 学習結果を表示 | |
forM_ combinations $ \q@(Quaternion x (V3 y z w)) -> | |
putStrLn $ concat [ | |
show x, " ", | |
show y, " ", | |
show z, " ", | |
show w, " : ", | |
show . view _1 $ model mw1 mw2 vb1 vb2 (toV @1 [q]) | |
] | |
main :: IO () | |
main = do | |
experiment1 | |
experiment2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment