Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
type-level broadcast and dimshuffle ops
{-# LANGUAGE DataKinds, PolyKinds, ScopedTypeVariables,
FlexibleInstances, FlexibleContexts, GADTs, TypeFamilies,
RankNTypes, LambdaCase, TypeOperators, TemplateHaskell,
NoImplicitPrelude, ConstraintKinds, PolyKinds, UndecidableInstances
module TH (DimShuffle, EvalBroadcast) where
import qualified Prelude as P
import Data.Singletons.TH (promote)
import Data.Singletons.Prelude.List
import Data.Singletons.TypeLits
import Data.Singletons.Prelude.Base
import Data.Singletons.Prelude.Bool
import Data.Singletons.Prelude.Eq
import Data.Singletons.Prelude.Maybe
import Data.Singletons.Prelude.Tuple
import Data.Singletons.Prelude.Num
import Data.Singletons.Prelude.Ord
import Data.Maybe
instance P.Eq Nat where
x == y = P.not (x P./= y)
x /= y = P.not (x P.== y)
instance P.Ord Nat where
x > y = ((P.not (x P./= y) ) P.&& (P.not (x P.< y)))
x < y = ((P.not (x P./=y) ) P.&& (P.not (x P.> y)))
x <= y = (((x P.== y) ) P.|| (P.not (x P.> y)))
x >= y = (((x P.== y) ) P.|| (P.not (x P.< y)))
(!!) :: [a] -> Nat -> a
[] !! _ = P.error "Data.Singletons.List.!!: index too large"
(x:xs) !! n = if n P.== 0 then x else xs !! (n P.- 1)
evalBroadcast :: [Maybe Nat] -> [Maybe Nat] -> [Maybe Nat]
evalBroadcast a b
= let la = P.length a
lb = P.length b
if la P.== lb then P.zipWith maxIfSome a b else
if la P.> lb then
P.zipWith maxIfSome
(P.concat [(P.replicate (la P.- lb) (Just 1 :: Maybe Nat)), b])
P.zipWith maxIfSome
(P.concat [(P.replicate (lb P.- la) (Just 1 :: Maybe Nat)), a])
where maxIfSome x y
= case (x, y) of
(Just a, Just b) -> Just P.$ P.max a b
_ -> Nothing
dimShuffle :: P.Eq a => [a] -> [Nat] -> [a]
dimShuffle _ [] = []
dimShuffle [] _ = []
dimShuffle (x : xs) (b : bs)
= if b P.== 0 then x : dimShuffle xs bs else
(xs !! (b P.- 1)) : dimShuffle xs bs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment