Skip to content

Instantly share code, notes, and snippets.

@tonyday567
Created May 24, 2018 04:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tonyday567/0a72d9769e7fc5450784b6539dd95fa3 to your computer and use it in GitHub Desktop.
Save tonyday567/0a72d9769e7fc5450784b6539dd95fa3 to your computer and use it in GitHub Desktop.
numhask-backprop
#!/usr/bin/env stack
-- stack --install-ghc runghc --resolver lts-11.9 --package backprop-0.2.2.0 --package numhask-prelude-0.0.4.1 --package numhask-0.2.1.0 -- -Wall -O2
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module NumHask.Backprop where
import NumHask.Prelude as NH
import qualified Numeric.Backprop as IBP
import Numeric.Backprop.Explicit as BP
newtype NH a = NH { unnh :: a} deriving (Eq, Ord, AdditiveMagma, AdditiveAssociative, AdditiveCommutative, AdditiveUnital, AdditiveIdempotent, Additive, AdditiveInvertible, AdditiveGroup, MultiplicativeMagma, MultiplicativeUnital, MultiplicativeAssociative, MultiplicativeCommutative, MultiplicativeIdempotent, Multiplicative, MultiplicativeInvertible, MultiplicativeGroup, Distribution, Semiring, Ring, CRing, StarSemiring, KleeneAlgebra, InvolutiveRing, Semifield, Field, ExpField, QuotientField, UpperBoundedField, LowerBoundedField, TrigField, Signed, Integral, ToInteger, FromInteger)
-- Normed, Metric, Epsilon
newtype BVarNH s a = BVarNH { unNH :: (Additive a, MultiplicativeUnital a, Reifies s W) => BVar s (NH a)} deriving (AdditiveAssociative, AdditiveCommutative, AdditiveIdempotent, MultiplicativeAssociative, MultiplicativeCommutative, MultiplicativeIdempotent, Distribution, Semiring, Ring, CRing, Semifield, Field, KleeneAlgebra)
instance (Eq a, Additive a, MultiplicativeUnital a, Reifies s W) => Eq (BVarNH s a) where
(==) (BVarNH a) (BVarNH b) = a == b
instance (Ord a, Additive a, MultiplicativeUnital a, Reifies s W) => Ord (BVarNH s a) where
(>=) (BVarNH a) (BVarNH b) = a >= b
(<=) (BVarNH a) (BVarNH b) = a <= b
-- QuotientField, UpperBoundedField, LowerBoundedField, Integral, ToInteger, FromInteger
-- * Backprop instance for a NH wrapped number
instance (Additive a, MultiplicativeUnital a) => Backprop (NH a) where
zero _ = NH.zero
one _ = NH.one
add = (NH.+)
-- * operators
plusOp :: AdditiveMagma a => Op '[a, a] a
plusOp = op2 $ \x y -> (x `plus` y, \g -> (g, g))
negateOp :: (AdditiveInvertible a) => Op '[a] a
negateOp = op1 $ \x -> (negate x, negate)
timesOp :: MultiplicativeMagma a => Op '[a, a] a
timesOp = op2 $ \x y -> (x `times` y, \g -> (y `times` g, x `times` g))
recipOp :: (AdditiveInvertible a, MultiplicativeGroup a) => Op '[a] a
recipOp = op1 $ \x -> (recip x, (/(x*x)) . negate)
signOp :: (Signed a, AdditiveUnital a) => Op '[a] a
signOp = op1 $ \x -> (sign x, const NH.zero)
absOp :: (Signed a) => Op '[a] a
absOp = op1 $ \x -> (abs x, (`times` sign x))
starOp :: (StarSemiring a) => Op '[a] a
starOp = op1 $ \x -> (star x, plus')
plus'Op :: (StarSemiring a) => Op '[a] a
plus'Op = op1 $ \x -> (plus' x, (`times` star x))
adjOp :: (InvolutiveRing a) => Op '[a] a
adjOp = op1 $ \x -> (adj x, adj)
expOp :: ExpField a => Op '[a] a
expOp = op1 $ \x -> (exp x, (exp x *))
logOp :: ExpField a => Op '[a] a
logOp = op1 $ \x -> (log x, (/x))
sinOp :: TrigField a => Op '[a] a
sinOp = op1 $ \x -> (sin x, (* cos x))
cosOp :: TrigField a => Op '[a] a
cosOp = op1 $ \x -> (cos x, (* (negate (sin x))))
asinOp :: (ExpField a, TrigField a) => Op '[a] a
asinOp = op1 $ \x -> (asin x, (/ sqrt(NH.one - x*x)))
acosOp :: (ExpField a, TrigField a) => Op '[a] a
acosOp = op1 $ \x -> (acos x, (/ sqrt (NH.one - x*x)) . negate)
atanOp :: TrigField a => Op '[a] a
atanOp = op1 $ \x -> (atan x, (/ (x*x + NH.one)))
sinhOp :: TrigField a => Op '[a] a
sinhOp = op1 $ \x -> (sinh x, (* cosh x))
coshOp :: TrigField a => Op '[a] a
coshOp = op1 $ \x -> (cosh x, (* sinh x))
tanhOp :: (TrigField a, ExpField a) => Op '[a] a
tanhOp = op1 $ \x -> (tanh x, (/ cosh x ** (NH.one + NH.one)))
asinhOp :: (TrigField a, ExpField a) => Op '[a] a
asinhOp = op1 $ \x -> (asinh x, (/ sqrt (x*x + NH.one)))
acoshOp :: (TrigField a, ExpField a) => Op '[a] a
acoshOp = op1 $ \x -> (acosh x, (/ sqrt (x*x - NH.one)))
atanhOp :: (TrigField a) => Op '[a] a
atanhOp = op1 $ \x -> (atanh x, (/ (NH.one - x*x)))
type family HKD f a where
HKD Identity a = a
HKD f a = f a
data FracType' a f = FT { ftI :: HKD f Integer, ftR :: HKD f a} deriving Generic
type FracType a = FracType' a Identity
instance (Backprop a) => Backprop (FracType' a Identity)
properFractionOp :: (UpperBoundedField a, FromInteger a, QuotientField a) => Op '[a] (FracType a)
properFractionOp = op1 $ \x ->
(,) (let (i,r) = properFraction x in FT i r) (\(FT i r) -> if r == NH.zero then nan else fromInteger i)
-- | fixme:
extractInteger :: (Reifies s W, Backprop a) => BVar s (FracType a) -> (Integer, BVar s a)
extractInteger (IBP.splitBV -> FT i r) = undefined -- (i,r)
-- numhask classes
instance ( ) => AdditiveMagma (BVarNH s a) where
plus (BVarNH a) (BVarNH b) = BVarNH $
(liftOp2 addFunc addFunc zeroFunc plusOp) a b
instance ( ) => AdditiveUnital (BVarNH s a) where
zero = NH.zero
instance ( ) => Additive (BVarNH s a)
instance (AdditiveInvertible a) => AdditiveInvertible (BVarNH s a) where
negate (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc negateOp) a
instance (AdditiveInvertible a) => AdditiveGroup (BVarNH s a)
instance ( ) => MultiplicativeMagma (BVarNH s a) where
times (BVarNH a) (BVarNH b) = BVarNH $
(liftOp2 addFunc addFunc zeroFunc timesOp) a b
instance ( ) => MultiplicativeUnital (BVarNH s a) where
one = NH.one
instance ( ) => Multiplicative (BVarNH s a)
instance (AdditiveInvertible a, MultiplicativeGroup a) =>
MultiplicativeInvertible (BVarNH s a) where
recip (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc recipOp) a
instance (AdditiveInvertible a, MultiplicativeGroup a) =>
MultiplicativeGroup (BVarNH s a)
instance (Signed a) => Signed (BVarNH s a) where
sign (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc signOp) a
abs (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc absOp) a
instance (StarSemiring a) => StarSemiring (BVarNH s a) where
plus' (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc plus'Op) a
star (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc starOp) a
instance (InvolutiveRing a) => InvolutiveRing (BVarNH s a) where
adj (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc adjOp) a
instance (ExpField a, AdditiveInvertible a, MultiplicativeGroup a) =>
ExpField (BVarNH s a) where
exp (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc expOp) a
log (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc logOp) a
instance
( Reifies s W -- fixme: why was this needed here?
, ExpField a
, TrigField a
, AdditiveInvertible a
, MultiplicativeGroup a) =>
TrigField (BVarNH s a) where
pi = NH.pi
sin (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc sinOp) a
cos (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc cosOp) a
asin (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc asinOp) a
acos (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc acosOp) a
atan (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc atanOp) a
sinh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc sinhOp) a
cosh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc coshOp) a
asinh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc asinhOp) a
acosh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc acoshOp) a
atanh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc atanhOp) a
instance (UpperBoundedField a, FromInteger a, Reifies s W, QuotientField a, AdditiveInvertible a, MultiplicativeGroup a) =>
QuotientField (BVarNH s a) where
properFraction (BVarNH a) = (\(x,y) -> (x, BVarNH y)) $ extractInteger $ (liftOp1 addFunc zeroFunc properFractionOp) a
main :: IO ()
main = pure ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment