Skip to content

Instantly share code, notes, and snippets.

@notogawa
Last active November 19, 2016 05:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save notogawa/d32595f2eb79a3cf6bcdd4bd97add3d3 to your computer and use it in GitHub Desktop.
Save notogawa/d32595f2eb79a3cf6bcdd4bd97add3d3 to your computer and use it in GitHub Desktop.
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-}
import Data.Singletons
import Data.Singletons.Prelude.Enum
import Data.Singletons.Prelude.List
import Data.Singletons.Prelude.Num
import GHC.TypeLits
data NDArray (shape :: [Nat]) a = NDArray (Sing shape) -- 中身はshape意外省略
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html
reshape :: Product from ~ Product to => Sing to -> NDArray from a -> NDArray to a
reshape = const . NDArray
-- reshape可
reshapeable :: NDArray '[2,3,4] a -> NDArray '[3,8] a
reshapeable = reshape sing
-- 型検査でreshape不可 (次元の積が合わない)
-- unreshapeable :: NDArray '[2,3,4] a -> NDArray '[3,3,4] a
-- unreshapeable = reshape sing -- Couldn't match type ‘24’ with ‘36’
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html
dot :: (Last xs ~ Last (Init ys), Num a) =>
NDArray xs a -> NDArray ys a -> NDArray (Init xs :++ Init (Init ys) :++ '[Last ys]) a
NDArray xs `dot` NDArray ys = NDArray (sInit xs %:++ sInit (sInit ys) %:++ SCons (sLast ys) SNil)
-- dot可
dottable :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,2] a -> NDArray '[2,3,3,2] a
dottable = dot
-- 型検査でdot不可 (結果の型が合わない)
-- undottable1 :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,4] a -> NDArray '[2,3,3,2] a
-- undottable1 = dot -- Couldn't match type ‘4’ with ‘2’
-- 型検査でdot不可 (引数の型が合わない)
-- undottable2 :: Num a => NDArray '[2,3,4] a -> NDArray '[4,3,2] a -> NDArray '[2,3,4,2] a
-- undottable2 = dot -- Couldn't match type ‘4’ with ‘3’
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
transpose :: Sort axes ~ EnumFromTo 0 (Length shape - 1) =>
Sing axes -> NDArray shape a ->
NDArray (Map ((:!!$$) shape) axes) a
transpose axes (NDArray shape) =
NDArray (sMap (singFun1 (toProxy shape) (shape %:!!)) axes) where
toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape)
toProxy _ = Proxy
transposable :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a
transposable = transpose
-- transpose不可 (axesにshapeの長さ以上のものが含まれる)
-- untransposable1 :: Sing '[1,0,3] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a
-- untransposable1 = transpose -- Couldn't match type ‘3’ with ‘2’
-- transpose不可 (axesに同じ要素が2つ以上含まれる)
-- untransposable2 :: Sing '[1,0,0] -> NDArray '[2,3,4] a -> NDArray '[3,2,2] a
-- untransposable2 = transpose -- Couldn't match type ‘1’ with ‘2’
-- transpose不可 (axesの長さとshapeの長さが一致しない)
-- untransposable3 :: Sing '[1,0] -> NDArray '[2,3,4] a -> NDArray '[3,2] a
-- untransposable3 = transpose -- Couldn't match type ‘'[]’ with ‘'[2]’
-- transpose不可 (結果の型が合わない)
-- untransposable4 :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,5] a
-- untransposable4 = transpose -- Couldn't match type ‘4’ with ‘5’
-- これらreshape,dot,transposeがあればtensordotが定義できるはず,
-- しかし,そのまま素直に(こうできたらいいなと思うように)書くと,
-- 型検査に失敗し,1200行くらいのエラーを吐く.
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html
tensordot :: (Num a, ns ~ Nub ns, ms ~ Nub ms,
Map ((:!!$$) xs) ns ~ Map ((:!!$$) ys) ms) =>
NDArray xs a -> NDArray ys a -> (Sing ns, Sing ms) ->
NDArray (Map ((:!!$$) xs) (EnumFromTo 0 (Length xs - 1) :\\ ns) :++
Map ((:!!$$) ys) (EnumFromTo 0 (Length ys - 1) :\\ ms)) a
tensordot x@(NDArray xs) y@(NDArray ys) (ns, ms) = result where
range n = sEnumFromTo (sing :: Sing 0) (n %:- (sing :: Sing 1))
notinns = range (sLength xs) %:\\ ns
notinms = range (sLength ys) %:\\ ms
tx = transpose (notinns %:++ ns) x -- 130行
ty = transpose (ms %:++ notinms) y -- 130行
dimsIn xs = sMap (singFun1 (toProxy xs) (xs %:!!)) where
toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape)
toProxy _ = Proxy
(oldxs, oldys) = (dimsIn xs notinns, dimsIn ys notinms) where
rtx = reshape (SCons (sProduct oldxs) $ SCons (sProduct $ dimsIn xs ns) SNil) tx -- 280行
rty = reshape (SCons (sProduct $ dimsIn ys ms) $ SCons (sProduct oldys) SNil) ty -- 280行
result = reshape (oldxs %:++ oldys) (rtx `dot` rty) -- 400行
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment