Skip to content

Instantly share code, notes, and snippets.

@fsarradin
Created November 13, 2020 09:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fsarradin/6d72a66fb3e3d850fa2161b3ccaf61b7 to your computer and use it in GitHub Desktop.
Save fsarradin/6d72a66fb3e3d850fa2161b3ccaf61b7 to your computer and use it in GitHub Desktop.
Arithmetic expression
/**
* Arithmetic expression representation.
*
* In this expressions you can only have a unique variable.
* eg. Mult(Var, Add(Const(2), Var)) represents X * (2 + X)
*/
sealed trait Expression
object Expression {
// represent a variable
case object Var extends Expression
// represent a constant
case class Const(value: Double) extends Expression
case class Add(left: Expression, right: Expression) extends Expression
case class Mult(left: Expression, right: Expression) extends Expression
/**
* Evaluate an expression for a given value for the variable.
*
* eg. eval(Mult(Var, Add(Const(2), Var)))(3) => 15
*/
def eval(expression: Expression)(variable: Double): Double = ???
/**
* Convert an expression to its postfix representation.
*
* eg. Mult(Var, Add(Const(2), Var)) => X 2 X + *
*/
def toPostfix(expression: Expression): String = ???
/**
* Convert an expression to its prefix representation.
*
* eg. Mult(Var, Add(Const(2), Var)) => (* X (+ 2 X))
*/
def toPrefix(expression: Expression): String = ???
/**
* Convert an expression to its infix representation.
*
* eg. Mult(Var, Add(Const(2), Var)) => X * (2 + X)
*/
def toInfix(expression: Expression): String = ???
/**
* Deriv an expression according to the variable.
*
* eg. deriv(Mult(Var, Add(Const(2), Var))) =>
* Add(Mult(Const(1.0), Add(Const(2.0), Var)),
* Mult(Var, Add(Const(0.0), Const(1.0))))
*/
def deriv(expression: Expression): Expression = ???
/**
* Bonus :
* Get the polynomial coeficients from the expression
* eg.
* val exp1 = Plus(Var,Const(3))
* polyCoef(Mult(exp1,exp1)) => Vector(9,6,1)
*/
def polyCoef(expression: Expression): Vector[Double] = ???
}
object simplify {
def simplifyOneShot(expression: Expression): Expression =
expression match {
case Add(Zero, right) => simplifyOneShot(right)
case Add(left, Zero) => simplifyOneShot(left)
case Add(Const(x), Const(y)) => Const(x + y)
case Add(e1, e2) if e1 == e2 => Mult(Two, e1)
case Add(e1, Mult(Const(c), e2)) if e1 == e2 => Mult(Const(c + 1.0), e1)
case Add(Mult(Const(c), e1), e2) if e1 == e2 => Mult(Const(c + 1.0), e1)
case Add(e1, Mult(e2, Const(c))) if e1 == e2 => Mult(Const(c + 1.0), e1)
case Add(Mult(e1, Const(c)), e2) if e1 == e2 => Mult(Const(c + 1.0), e1)
case Add(Mult(e1, Const(c1)), Mult(e2, Const(c2))) if e1 == e2 => Mult(Const(c1 + c2), e1)
case Add(Mult(e1, Const(c1)), Mult(Const(c2), e2)) if e1 == e2 => Mult(Const(c1 + c2), e1)
case Add(Mult(Const(c1), e1), Mult(e2, Const(c2))) if e1 == e2 => Mult(Const(c1 + c2), e1)
case Add(Mult(Const(c1), e1), Mult(Const(c2), e2)) if e1 == e2 => Mult(Const(c1 + c2), e1)
case Add(Mult(Const(c1), e1), Add(Mult(Const(c2), e2), e3)) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3)
case Add(Mult(Const(c1), e1), Add(e3, Mult(Const(c2), e2))) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3)
case Add(Add(Mult(e1, Const(c1)), e3), Mult(e2, Const(c2))) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3)
case Add(Add(e3, Mult(e1, Const(c1))), Mult(e2, Const(c2))) if e1 == e2 => Add(e3, Mult(Const(c1 + c2), e1))
case Add(e1, Add(e2, e3)) if e1 == e2 => Add(Mult(Two, e1), e3)
case Add(e1, Add(e3, e2)) if e1 == e2 => Add(Mult(Two, e1), e3)
case Add(Add(e1, e3), e2) if e1 == e2 => Add(Mult(Two, e1), e3)
case Add(Add(e3, e1), e2) if e1 == e2 => Add(Mult(Two, e1), e3)
case Mult(Zero, _) | Mult(_, Zero) => Zero
case Mult(One, right) => simplifyOneShot(right)
case Mult(left, One) => simplifyOneShot(left)
case Mult(Const(x), Const(y)) => Const(x * y)
case Add(left @ Var, right @ Mult(Var, Var)) => Add(right, left)
case Add(left @ Mult(Const(c), Var), right @ Mult(Var, Var)) => Add(right, left)
case Add(left @ Mult(Const(c1), Var), right @ Mult(Const(c2), Mult(Var, Var))) => Add(right, left)
case Add(left @ Var, right @ Mult(Const(c2), Mult(Var, Var))) => Add(right, left)
case Add(left, right) => Add(simplifyOneShot(left), simplifyOneShot(right))
case Mult(left, right) => Mult(simplifyOneShot(left), simplifyOneShot(right))
case _ => expression
}
@tailrec
def fixPoint[A](f: A => A)(x: A): A =
if (x == f(x)) x
else fixPoint(f)(f(x))
val simplify: Expression => Expression = fixPoint(simplifyOneShot)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment