Skip to content

Instantly share code, notes, and snippets.

@yatsuta
Created June 1, 2010 01:07
Show Gist options
  • Save yatsuta/420442 to your computer and use it in GitHub Desktop.
Save yatsuta/420442 to your computer and use it in GitHub Desktop.
derivative system of Gaussian function and its log likelyhood
import Data.List
import Data.Maybe
-- ****************************************************************
-- * data type definition
-- ****************************************************************
type VarName = String
data Expr
= Num Double
| Var VarName
| Add [Expr]
| Mult [Expr]
| Minus Expr
| Recip Expr
| Pow Expr Expr
| Exp Expr
| Log Expr
| Sqrt Expr
| Sum VarName [VarName] Expr Expr Expr
| Prod VarName [VarName] Expr Expr Expr
deriving (Eq, Ord, Show)
-- ****************************************************************
-- * hasVar
-- ****************************************************************
Num _ `hasVar` varName = False
Var x `hasVar` varName = x == varName
Add exprs `hasVar` varName = any (`hasVar` varName) exprs
Mult exprs `hasVar` varName = any (`hasVar` varName) exprs
Pow x n `hasVar` varName = (x `hasVar` varName) || (n `hasVar` varName)
Exp x `hasVar` varName = x `hasVar` varName
Log x `hasVar` varName = x `hasVar` varName
Sum n xs _ _ body `hasVar` varName = (body `hasVar` varName) && (all (/= varName) (n:xs))
-- ****************************************************************
-- * split...
-- ****************************************************************
isVar (Var _) = True
isVar _ = False
splitVar exprs = partition isVar exprs
splitNum exprs = partition isNum exprs
where isNum (Num _) = True
isNum _ = False
splitAdd exprs = partition isAdd exprs
where isAdd (Add _) = True
isAdd _ = False
splitMultVar exprs = partition isMultVar exprs
where isMultVar (Mult exprs') | any isVar exprs' = True
isMultVar _ = False
splitMult exprs = partition isMult exprs
where isMult (Mult _) = True
isMult _ = False
splitPowVar exprs = partition isPowVar exprs
where isPowVar (Pow (Var _) _) = True
isPowVar _ = False
splitHasVars vars exprs = partition f exprs
where f expr = any (expr `hasVar`) vars
-- ****************************************************************
-- * addNums, multNums
-- ****************************************************************
addNums nums = foldl' addNum2 (Num 0) nums
where addNum2 (Num n1) (Num n2) = Num (n1 + n2)
multNums nums = foldl' multNum2 (Num 1) nums
where multNum2 (Num n1) (Num n2) = Num (n1 * n2)
-- ****************************************************************
-- * count...
-- ****************************************************************
countVar vars = [(var, Num 1) | var <- vars]
countMultVar multVars b = [countEach factors | factors <- factorsList]
where factorsList = [exprs | Mult exprs <- multVars]
countEach factors =
let firstVar = fromJust $ find isVar factors
in (firstVar,
eval (Mult $ delete firstVar factors) b)
countPowVar powVars = [(x, n) | Pow x n <- powVars]
sumUpVars counts b = addExprs
where countsGroup = groupBy (\(v, _) (v', _) -> v == v') $ sort counts
varSummedUp = map sumUp countsGroup
where sumUp counts = (fst $ head counts,
eval (Add $ map snd counts) b)
addExprs = map makeVarExpr varSummedUp
where makeVarExpr (var, count) = eval (Mult [var, count]) b
sumUpVars' counts b = multExprs
where countsGroup = groupBy (\(v, _) (v', _) -> v == v') $ sort counts
varSummedUp = map sumUp countsGroup
where sumUp counts = (fst $ head counts,
eval (Add $ map snd counts) b)
multExprs = map makeVarExpr varSummedUp
where makeVarExpr (var, count) = eval (Pow var count) b
indexedVar v i = Var (v ++ (show . round) i)
-- ****************************************************************
-- * eval...
-- ****************************************************************
evalAdd (Add exprs) b = expr
where exprs1 = [eval expr b | expr <- exprs]
(adds, nonAdds) = splitAdd exprs1
exprs2 = concat [exprs | Add exprs <- adds] ++ nonAdds
(vars, nonVars) = splitVar exprs2
varCount = countVar vars
(multVars, nonMultVars) = splitMultVar nonVars
multVarCount = countMultVar multVars b
vars' = sumUpVars (varCount ++ multVarCount) b
(nums, nonNums) = splitNum $ vars' ++ nonMultVars
num = addNums nums
exprs3 = case num of
Num 0 -> nonNums
_ -> num : nonNums
expr = case exprs3 of
[] -> Num 0
[expr'] -> expr'
_ -> Add $ sort exprs3
evalMult (Mult exprs) b = expr
where exprs1 = [eval expr b | expr <- exprs]
(mults, nonMults) = splitMult exprs1
exprs2 = concat [exprs | Mult exprs <- mults] ++ nonMults
(vars, nonVars) = splitVar exprs2
varCount = countVar vars
(powVars, nonPowVars) = splitPowVar nonVars
powVarCount = countPowVar powVars
vars' = sumUpVars' (varCount ++ powVarCount) b
(nums, nonNums) = splitNum $ vars' ++ nonPowVars
num = multNums nums
exprs3 = case num of
Num 0 -> [Num 0]
Num 1 -> nonNums
_ -> num : nonNums
expr = case exprs3 of
[] -> Num 1
[expr'] -> expr'
_ -> Mult $ sort exprs3
evalSumMult i vars begin end exprs b =
eval (Mult ((Sum i vars begin end $ Mult hasVars) : notHasVars)) b
where (hasVars, notHasVars) = splitHasVars (i:vars) exprs
evalSum (Sum i vars begin end body) b =
case (eval begin b, eval end b, eval body b) of
(begin', end', body') | not (any (body' `hasVar`) (i:vars)) ->
eval (Mult [Add [end', Minus begin', Num 1], body']) b
(begin', end', Mult exprs) ->
evalSumMult i vars begin' end' exprs b
(begin', end', Add exprs) ->
eval (Add [Sum i vars begin' end' expr | expr <- exprs]) b
(Num n1, Num n2, expr') -> eval (Add $ map f [n1 .. n2]) b
where f n = let b' = (i, Num n) :
zip vars [indexedVar var n | var <- vars]
in eval expr' (b ++ b')
(begin', end', body') -> Sum i vars begin' end' body'
evalProd (Prod i vars begin end body) b =
case (eval begin b, eval end b, eval body b) of
(begin', end', Log body') ->
eval (Sum i vars begin' end' (Log body')) b
(Num n1, Num n2, expr') -> eval (Mult $ map f [n1 .. n2]) b
where f n = let b' = (i, Num n) :
zip vars [indexedVar var n | var <- vars]
in eval expr' (b ++ b')
(begin', end', body') -> Prod i vars begin' end' body'
-- ****************************************************************
-- * eval
-- ****************************************************************
eval :: Expr -> [(VarName, Expr)] -> Expr
eval (Num n) _ = Num n
eval (Var x) b = case lookup x b of
Just e -> simplify e
Nothing -> Var x
eval a@(Add _) b = evalAdd a b
eval m@(Mult _) b = evalMult m b
eval (Minus expr) b = eval (Mult [Num (-1), expr]) b
eval (Recip expr) b = eval (Pow expr (Num (-1))) b
eval (Pow expr1 expr2) b =
case (eval expr1 b, eval expr2 b) of
(_ , Num 0) -> Num 1
(expr1', Num 1) -> expr1'
(Pow expr1_1' expr1_2', expr2') ->
eval (Pow expr1_1' (Mult [expr1_2', expr2'])) b
(Mult exprs, expr2') ->
eval (Mult [Pow expr expr2' | expr <- exprs]) b
(Num n1, Num n2) -> Num (n1 ** n2)
(expr1', expr2') -> Pow expr1' expr2'
eval (Exp expr) b = case eval expr b of
Num n -> Num (exp n)
Log expr' -> expr'
expr' -> Exp expr'
eval (Log expr) b = case eval expr b of
Num n -> Num (log n)
Pow expr1' expr2' ->
eval (Mult [expr2', Log expr1']) b
Mult exprs -> eval (Add [Log e | e <- exprs]) b
Exp expr' -> expr'
Prod i vars begin end body ->
eval (Sum i vars begin end (Log body)) b
expr' -> Log expr'
eval (Sqrt expr) b = eval (Pow expr (Num 0.5)) b
eval s@(Sum _ _ _ _ _) b = evalSum s b
eval p@(Prod _ _ _ _ _) b = evalProd p b
-- ****************************************************************
-- * gaussian, bern
-- ****************************************************************
gaussian = Mult [Recip (Sqrt (Mult [Num 2,
Var "PI",
Var "sigma^2"])),
Exp (Minus (Mult [Recip (Mult [Num 2,
Var "sigma^2"]),
(Pow (Add [Var "x",
Minus (Var "mu")] )
(Num 2))]))]
standardGaussian = eval gaussian [("PI", Num 3.141592),
("mu", Num 0),
("sigma^2", Num 1)]
logLikelyhood = Log (Prod "n" ["x"] (Num 1) (Var "N") gaussian)
bern = Mult [Pow (Var "mu") (Var "x"),
Pow (Add [Num 1, Minus (Var "mu")])
(Add [Num 1, Minus (Var "x")])]
logLikelyhoodBern = Log (Prod "n" ["x"] (Num 1) (Var "N") bern)
-- ****************************************************************
-- * simplify
-- ****************************************************************
simplify expr = eval expr []
-- ****************************************************************
-- * deriv
-- ****************************************************************
deriv' :: Expr -> VarName -> Expr
deriv' (Num _) _ = Num 0
deriv' (Var x) v | x == v = Num 1
| otherwise = Num 0
deriv' (Add exprs) v = Add [deriv' expr v | expr <- exprs]
deriv' (Mult [expr]) v = deriv' expr v
deriv' (Mult (e:exprs)) v = Add [Mult [e, deriv' (Mult exprs) v],
Mult [deriv' e v, Mult exprs]]
deriv' (Pow x n) v = Mult [Mult [n,
Pow x (Add [n,
Minus (Num 1)])],
deriv' x v]
deriv' (Exp x) v = Mult [Exp x, deriv' x v]
deriv' (Log x) v = Mult [Recip x, deriv' x v]
deriv' (Sum i vars begin end body) v = Sum i vars begin end (deriv' body v)
deriv expr v = simplify $ deriv' (simplify expr) v
-- ****************************************************************
-- * solve
-- ****************************************************************
solve' :: Expr -> Expr -> VarName -> Expr
solve' (Var x) rhs v | x == v = rhs
solve' (Add exprs) rhs v = solve' lhs (Add [rhs, Minus rhs']) v
where lhs = fromJust $ find (`hasVar` v) exprs
rhs' = Add (delete lhs exprs)
solve' (Mult exprs) rhs v = solve' lhs (Mult [rhs, Recip rhs']) v
where lhs = fromJust $ find (`hasVar` v) exprs
rhs' = Mult (delete lhs exprs)
solve' (Pow x n) rhs v = solve' x (Pow rhs (Recip n)) v
solve' (Exp x) rhs v = solve' x (Log rhs) v
solve' (Log x) rhs v = solve' x (Exp rhs) v
solve lhs rhs v = simplify $ solve' (simplify lhs) (simplify rhs) v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment