Skip to content

Instantly share code, notes, and snippets.

@btd
Created February 11, 2013 09:27
Show Gist options
  • Save btd/4753465 to your computer and use it in GitHub Desktop.
Save btd/4753465 to your computer and use it in GitHub Desktop.
Symbolic computation of differential
package book
object diff {
trait Expression
case class Const[A: Numeric](value: A) extends Expression {
override def toString = value.toString
}
case class Var(name: String) extends Expression {
override def toString = name
}
case class Addition(left: Expression, right: Expression) extends Expression {
override def toString = "(%s + %s)".format(left, right)
}
case class Multiplication(left: Expression, right: Expression) extends Expression {
override def toString = "(%s * %s)".format(left, right)
}
case class Pow(expr: Expression, p: Int) extends Expression {
override def toString = "%s^%d".format(expr, p)
}
val x = Var("x") //> x : book.diff.Var = x
val y = Var("y") //> y : book.diff.Var = y
val z = Var("z") //> z : book.diff.Var = z
Multiplication(Const(1), x) //> res0: book.diff.Multiplication = (1 * x)
implicit def const2Const[A: Numeric](number: A): Expression = Const(number)
//> const2Const: [A](number: A)(implicit evidence$2: Numeric[A])book.diff.Expres
//| sion
class ExprOps(left: Expression) {
def *(right: Expression): Expression = (left, right) match {
case (Const(0), _) => 0
case (Const(1), r) => r
case (_, Const(0)) => 0
case (l, Const(1)) => l
case (l, r) if l == r => Pow(l, 2)
case (Pow(e1, p1), Pow(e2, p2)) if e1 == e2 => Pow(e1, p1 + p2)
case (Pow(e1, p), e2) if e1 == e2 => Pow(e1, p + 1)
case (e1, Pow(e2, p)) if e1 == e2 => Pow(e1, p + 1)
case _ => Multiplication(left, right)
}
def +(right: Expression): Expression = left match {
case Const(0) => right
case l if l == right => Multiplication(left, 2)
case _ => right match {
case Const(0) => left
case _ => Addition(left, right)
}
}
def unary_- = *(Const(-1))
def ^(p: Int): Expression = {
require(p >= 0, "Power should be positive")
if(p == 0) 1
else if(p == 1) left
else Pow(left, p)
}
}
implicit def expr2exprOps(expr: Expression) = new ExprOps(expr)
//> expr2exprOps: (expr: book.diff.Expression)book.diff.ExprOps
def diff1(expr: Expression, v: Var): Expression = expr match {
case Const(value) => 0
case Var(variable) if variable == v.name => 1
case Var(variable) => 0
case Addition(l, r) => diff1(l, v) + diff1(r, v)
case Pow(_, p) if p == 1 => 0
case Pow(e, p) => Pow(e, p - 1) * diff1(e, v) * p
case Multiplication(l, r) => diff1(l, v) * r + diff1(r, v) * l
} //> diff1: (expr: book.diff.Expression, v: book.diff.Var)book.diff.Expression
val exp = ((x * 5) ^ 3) + y //> exp : book.diff.Expression = ((x * 5)^3 + y)
diff1(exp, x) //> res1: book.diff.Expression = (((x * 5)^2 * 5) * 3)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment