Skip to content

Instantly share code, notes, and snippets.

@lotz84
Created December 3, 2023 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lotz84/4b8072baf5ecc8491877db210980a82d to your computer and use it in GitHub Desktop.
Save lotz84/4b8072baf5ecc8491877db210980a82d to your computer and use it in GitHub Desktop.
build-depends:
base ^>=4.16.4.0,
lens == 5.2.3,
linear == 1.22,
random == 1.2.1.1,
vector == 0.13.1.0
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeApplications #-}
module Main where
import Data.Complex (Complex(..))
import Data.Foldable (forM_, foldlM)
import Data.List (cycle, foldl')
import Data.Maybe (fromJust)
import GHC.TypeNats (KnownNat, Nat)
import Control.Lens (view, _1)
import qualified Data.Vector as V
import Linear.Conjugate (conjugate)
import Linear.Matrix ((!*))
import Linear.Quaternion (Quaternion(..), _e)
import Linear.V (V(..), fromVector)
import Linear.V3 (V3(..))
import System.Random (randomRIO, initStdGen)
-- | 実数
type R = Double
-- | 複素数
type C = Complex R
-- | 三重複素数
data T = T C C C deriving Show
complex, gradZ, gradZC :: T -> C
complex (T a b c) = a
gradZ (T a b c) = b
gradZC (T a b c) = c
instance Num T where
(T z z' zc') + (T w w' wc') = T (z + w) (z' + w') (zc' + wc')
(T z z' zc') * (T w w' wc') = T (z * w) (z' * w + z * w') (zc' * w + z * wc')
negate (T z z' zc') = T (negate z) (negate z') (negate zc')
abs (T z z' zc') = T (abs z) w' wc' where
w' = conjugate z / (2 * abs z) * z' + z / (2 * abs z) * conjugate zc'
wc' = conjugate z / (2 * abs z) * zc' + z / (2 * abs z) * conjugate z'
signum (T z _ _) = T (signum z) 0 0
fromInteger n = T (fromInteger n) 0 0
instance Fractional T where
recip (T z z' zc') = T (recip z) (-1 * recip (z * z) * z') (-1 * recip (z * z) * zc')
fromRational r = T (fromRational r) 0 0
instance Floating T where
pi = T pi 0 0
exp (T z z' zc') = T (exp z) (exp z * z') (exp z * zc')
log (T z z' zc') = T (log z) (z' / z) (zc' / z)
sin (T z z' zc') = T (sin z) (cos z * z') (cos z * zc')
cos (T z z' zc') = T (cos z) (-sin z * z') (-sin z * zc')
asin (T z z' zc') = T (asin z) (z' / sqrt (1 - z ** 2)) (zc' / sqrt (1 - z ** 2))
acos (T z z' zc') = T (acos z) (-z' / sqrt (1 - z ** 2)) (-zc' / sqrt (1 - z ** 2))
atan (T z z' zc') = T (atan z) (z' / (1 + z ** 2)) (zc' / (1 + z ** 2))
sinh (T z z' zc') = T (sinh z) (cosh z * z') (cosh z * zc')
cosh (T z z' zc') = T (cosh z) (sinh z * z') (sinh z * zc')
asinh (T z z' zc') = T (asinh z) (z' / sqrt (1 + z ** 2)) (zc' / sqrt (1 + z ** 2))
acosh (T z z' zc') = T (acosh z) (z' / sqrt (z ** 2 - 1)) (zc' / sqrt (z ** 2 - 1))
atanh (T z z' zc') = T (atanh z) (z' / (1 - z ** 2)) (zc' / (1 - z ** 2))
type M (m :: Nat) (n :: Nat) a = V m (V n a)
linear :: (KnownNat m, KnownNat n, Num a) => M m n a -> V m a -> V n a -> V m a
linear w b z = w !* z + b
class Activatable a where
sigmoid :: a -> a
instance Activatable R where
sigmoid x = 1 / (1 + exp (-x))
instance Activatable C where
sigmoid z = fmap sigmoid z
instance Activatable T where
sigmoid (T z@(x :+ y) z' zc') = T (sigmoid z) w' wc' where
sigmoid' x = sigmoid x * (1 - sigmoid x)
w' = (sigmoid' x :+ (-sigmoid' y)) * z' + (sigmoid' x :+ sigmoid' y) * conjugate zc'
wc' = (sigmoid' x :+ (-sigmoid' y)) * zc' + (sigmoid' x :+ sigmoid' y) * conjugate z'
-- | リストからベクトルに変換する
toV :: KnownNat n => [a] -> V n a
toV = fromJust . fromVector . V.fromList
-- | リストから行列に変換する
toM :: (KnownNat m, KnownNat n) => [[a]] -> M m n a
toM = toV . map toV
experiment1 :: IO ()
experiment1 = do
-- パラメータを乱数で初期化
w <- (uncurry (:+)) <$> randomRIO ((-1, -1), (1, 1)) :: IO C
b <- (uncurry (:+)) <$> randomRIO ((-1, -1), (1, 1)) :: IO C
let -- 教師データ
trainingData = take 1000 $ cycle [
((-1):+(-1), 1:+0),
((-1):+ 1 , 0:+0),
( 1 :+(-1), 1:+1),
( 1 :+ 1 , 0:+1)
]
-- 学習モデル
model w b x = fmap sigmoid $ linear w b x
-- 学習係数
learningRate = 0.1 :+ 0
-- パラメータ更新の1ステップ
updateParam (w, b) (x, y) =
let -- パラメータを整形
mw dw = toM @1 @1 [[T w dw 0]]
vb db = toV @1 [T b db 0]
vx = toV @1 [T x 0 0]
ty = T y 0 0
-- 誤差関数の微分係数を計算
df (dw, db) = gradZ $ (abs (ty - view _1 (model (mw dw) (vb db) vx))) ^ 2
dfdw = df (1, 0)
dfdb = df (0, 1)
-- パラメータの更新
in (w - learningRate * conjugate dfdw, b - learningRate * conjugate dfdb)
-- 学習データによるパラメータ更新
(w', b') = foldl' updateParam (w, b) trainingData
mw = toM @1 @1 [[w']]
vb = toV @1 [b']
-- 学習結果を表示
putStrLn $ concat ["(-1) :+ (-1) | ", show $ model mw vb (toV @1 [(-1):+(-1)])]
putStrLn $ concat ["(-1) :+ 1 | ", show $ model mw vb (toV @1 [(-1):+ 1 ])]
putStrLn $ concat [" 1 :+ (-1) | ", show $ model mw vb (toV @1 [ 1 :+(-1)])]
putStrLn $ concat [" 1 :+ 1 | ", show $ model mw vb (toV @1 [ 1 :+ 1 ])]
-- | 四元数
type H = Quaternion Double
-- | 1と虚数単位
e, i, j, k :: H
e = Quaternion 1 (V3 0 0 0)
i = Quaternion 0 (V3 1 0 0)
j = Quaternion 0 (V3 0 1 0)
k = Quaternion 0 (V3 0 0 1)
-- | 四元数のスカラー倍
(*|) :: Double -> H -> H
a *| q = fmap (*a) q
-- | 五重四元数
data Q = Q H H H H H deriving Show
quaternion, gradQ, gradQi, gradQj, gradQk :: Q -> H
quaternion (Q a b c d e) = a
gradQ (Q a b c d e) = b
gradQi (Q a b c d e) = c
gradQj (Q a b c d e) = d
gradQk (Q a b c d e) = e
-- | ∂/∂q^{,i,j,k} を使った ∂/∂q^μ の計算
ghr :: H -> H -> H -> H -> H -> H
ghr mu@(Quaternion mu_a (V3 mu_b mu_c mu_d)) q' qi' qj' qk' =
(q' * a + qi' * b + qj' * c + qk' * d) * recip mu
where
a = mu_a *| e
b = mu_a *| e + mu_b *| i
c = mu_a *| e + mu_c *| j
d = mu_a *| e + mu_d *| k
-- | GHR微積分における合成関数の微分法
ghrChainRule :: (H, H, H, H) -> (H, H, H, H) -> (H, H, H, H)
ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk') =
let p' = dfdq * q' - dfdqi * i * qi' * i - dfdqj * j * qj' * j - dfdqk * k * qk' * k
pi' = dfdq * qi' - dfdqi * i * q' * i - dfdqj * j * qk' * j - dfdqk * k * qj' * k
pj' = dfdq * qj' - dfdqi * i * qk' * i - dfdqj * j * q' * j - dfdqk * k * qi' * k
pk' = dfdq * qk' - dfdqi * i * qj' * i - dfdqj * j * qi' * j - dfdqk * k * q' * k
in (p',pi', pj', pk')
instance Num Q where
(Q q q' qi' qj' qk') + (Q p p' pi' pj' pk') = Q (q + p) (q' + p') (qi' + pi') (qj' + pj') (qk' + pk')
(Q q q' qi' qj' qk') * (Q p p' pi' pj' pk') = Q (q * p) r' ri' rj' rk' where
r' = q * p' + ghr p q' qi' qj' qk' * p
ri' = q * pi' + ghr (p * i) q' qi' qj' qk' * p
rj' = q * pj' + ghr (p * j) q' qi' qj' qk' * p
rk' = q * pk' + ghr (p * k) q' qi' qj' qk' * p
negate (Q q q' qi' qj' qk') = Q (negate q) (negate q') (negate qi') (negate qj') (negate qk')
abs (Q q q' qi' qj' qk') = Q (abs q) p' pi' pj' pk' where
dfdq = conjugate q / (4 * abs q)
dfdqi = -i * dfdq * i
dfdqj = -j * dfdq * j
dfdqk = -k * dfdq * k
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk')
signum (Q q q' qi' qj' qk') = Q (signum q) 0 0 0 0
fromInteger n = Q (fromInteger n) 0 0 0 0
instance Fractional Q where
recip (Q q q' qi' qj' qk') = Q (recip q) p' pi' pj' pk' where
dfdq_mu mu = -(recip q) * (view _e (recip q * mu) *| recip mu)
dfdq = dfdq_mu e
dfdqi = dfdq_mu i
dfdqj = dfdq_mu j
dfdqk = dfdq_mu k
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk')
fromRational r = Q (fromRational r) 0 0 0 0
instance Activatable H where
sigmoid q = fmap sigmoid q
instance Activatable Q where
sigmoid (Q q@(Quaternion x (V3 y z w)) q' qi' qj' qk') = Q (sigmoid q) p' pi' pj' pk' where
sigmoid' x = sigmoid x * (1 - sigmoid x)
dfdq = (sigmoid' x *| e - sigmoid' y *| i - sigmoid' z *| j - sigmoid' w *| k) / 4
dfdqi = -i * dfdq * i
dfdqj = -j * dfdq * j
dfdqk = -k * dfdq * k
(p',pi', pj', pk') = ghrChainRule (dfdq, dfdqi, dfdqj, dfdqk) (q', qi', qj', qk')
experiment2 :: IO ()
experiment2 = do
-- パラメータを乱数で初期化
let randomQuaterion = (\(x, y, z, w) -> Quaternion x (V3 y z w)) <$> randomRIO ((-1, -1, -1, -1), (1, 1, 1, 1)) :: IO H
w1 <- randomQuaterion
w2 <- randomQuaterion
w3 <- randomQuaterion
w4 <- randomQuaterion
b1 <- randomQuaterion
b2 <- randomQuaterion
b3 <- randomQuaterion
let -- 教師データ
combinations = do
x <- [-1, 1]
y <- [-1, 1]
z <- [-1, 1]
w <- [-1, 1]
pure $ Quaternion x (V3 y z w)
trainingData = take 10000 $ cycle $ map (\q@(Quaternion x (V3 y z w)) -> (q, Quaternion (-y) (V3 x (-w) z))) combinations
-- 学習モデル
model w1 w2 b1 b2 x = linear w2 b2 $ fmap sigmoid $ linear w1 b1 x
-- 学習係数
learningRate = 0.1 *| e
-- パラメータ更新の1ステップ
updateParam (w1, w2, w3, w4, b1, b2, b3) (x, y) =
let -- パラメータを整形
mw1 dw1 dw2 = toM @2 @1 [[Q w1 dw1 0 0 0], [Q w2 dw2 0 0 0]]
vb1 db1 db2 = toV @2 [Q b1 db1 0 0 0, Q b2 db2 0 0 0]
mw2 dw3 dw4 = toM @1 @2 [[Q w3 dw3 0 0 0, Q w4 dw4 0 0 0]]
vb2 db3 = toV @1 [Q b3 db3 0 0 0]
vx = toV @1 [Q x 0 0 0 0]
ty = Q y 0 0 0 0
-- 誤差関数の微分係数を計算
loss (dw1, dw2, dw3, dw4, db1, db2, db3) = (abs (ty - view _1 (model (mw1 dw1 dw2) (mw2 dw3 dw4) (vb1 db1 db2) (vb2 db3) vx))) ^ 2
dfdw1 = gradQ $ loss (1, 0, 0, 0, 0, 0, 0)
dfdw2 = gradQ $ loss (0, 1, 0, 0, 0, 0, 0)
dfdw3 = gradQ $ loss (0, 0, 1, 0, 0, 0, 0)
dfdw4 = gradQ $ loss (0, 0, 0, 1, 0, 0, 0)
dfdb1 = gradQ $ loss (0, 0, 0, 0, 1, 0, 0)
dfdb2 = gradQ $ loss (0, 0, 0, 0, 0, 1, 0)
dfdb3 = gradQ $ loss (0, 0, 0, 0, 0, 0, 1)
-- パラメータの更新
in ( w1 - learningRate * conjugate dfdw1
, w2 - learningRate * conjugate dfdw2
, w3 - learningRate * conjugate dfdw3
, w4 - learningRate * conjugate dfdw4
, b1 - learningRate * conjugate dfdb1
, b2 - learningRate * conjugate dfdb2
, b3 - learningRate * conjugate dfdb3
)
-- 学習データによるパラメータ更新
(w1', w2', w3', w4', b1', b2', b3') = foldl' updateParam (w1, w2, w3, w4, b1, b2, b3) trainingData
mw1 = toM @2 @1 [[w1'], [w2']]
mw2 = toM @1 @2 [[w3', w4']]
vb1 = toV @2 [b1', b2']
vb2 = toV @1 [b3']
-- 学習結果を表示
forM_ combinations $ \q@(Quaternion x (V3 y z w)) ->
putStrLn $ concat [
show x, " ",
show y, " ",
show z, " ",
show w, " : ",
show . view _1 $ model mw1 mw2 vb1 vb2 (toV @1 [q])
]
main :: IO ()
main = do
experiment1
experiment2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment