Created
November 13, 2020 09:51
-
-
Save fsarradin/6d72a66fb3e3d850fa2161b3ccaf61b7 to your computer and use it in GitHub Desktop.
Arithmetic expression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Arithmetic expression representation. | |
* | |
* In this expressions you can only have a unique variable. | |
* eg. Mult(Var, Add(Const(2), Var)) represents X * (2 + X) | |
*/ | |
sealed trait Expression | |
object Expression { | |
// represent a variable | |
case object Var extends Expression | |
// represent a constant | |
case class Const(value: Double) extends Expression | |
case class Add(left: Expression, right: Expression) extends Expression | |
case class Mult(left: Expression, right: Expression) extends Expression | |
/** | |
* Evaluate an expression for a given value for the variable. | |
* | |
* eg. eval(Mult(Var, Add(Const(2), Var)))(3) => 15 | |
*/ | |
def eval(expression: Expression)(variable: Double): Double = ??? | |
/** | |
* Convert an expression to its postfix representation. | |
* | |
* eg. Mult(Var, Add(Const(2), Var)) => X 2 X + * | |
*/ | |
def toPostfix(expression: Expression): String = ??? | |
/** | |
* Convert an expression to its prefix representation. | |
* | |
* eg. Mult(Var, Add(Const(2), Var)) => (* X (+ 2 X)) | |
*/ | |
def toPrefix(expression: Expression): String = ??? | |
/** | |
* Convert an expression to its infix representation. | |
* | |
* eg. Mult(Var, Add(Const(2), Var)) => X * (2 + X) | |
*/ | |
def toInfix(expression: Expression): String = ??? | |
/** | |
* Deriv an expression according to the variable. | |
* | |
* eg. deriv(Mult(Var, Add(Const(2), Var))) => | |
* Add(Mult(Const(1.0), Add(Const(2.0), Var)), | |
* Mult(Var, Add(Const(0.0), Const(1.0)))) | |
*/ | |
def deriv(expression: Expression): Expression = ??? | |
/** | |
* Bonus : | |
* Get the polynomial coeficients from the expression | |
* eg. | |
* val exp1 = Plus(Var,Const(3)) | |
* polyCoef(Mult(exp1,exp1)) => Vector(9,6,1) | |
*/ | |
def polyCoef(expression: Expression): Vector[Double] = ??? | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object simplify { | |
def simplifyOneShot(expression: Expression): Expression = | |
expression match { | |
case Add(Zero, right) => simplifyOneShot(right) | |
case Add(left, Zero) => simplifyOneShot(left) | |
case Add(Const(x), Const(y)) => Const(x + y) | |
case Add(e1, e2) if e1 == e2 => Mult(Two, e1) | |
case Add(e1, Mult(Const(c), e2)) if e1 == e2 => Mult(Const(c + 1.0), e1) | |
case Add(Mult(Const(c), e1), e2) if e1 == e2 => Mult(Const(c + 1.0), e1) | |
case Add(e1, Mult(e2, Const(c))) if e1 == e2 => Mult(Const(c + 1.0), e1) | |
case Add(Mult(e1, Const(c)), e2) if e1 == e2 => Mult(Const(c + 1.0), e1) | |
case Add(Mult(e1, Const(c1)), Mult(e2, Const(c2))) if e1 == e2 => Mult(Const(c1 + c2), e1) | |
case Add(Mult(e1, Const(c1)), Mult(Const(c2), e2)) if e1 == e2 => Mult(Const(c1 + c2), e1) | |
case Add(Mult(Const(c1), e1), Mult(e2, Const(c2))) if e1 == e2 => Mult(Const(c1 + c2), e1) | |
case Add(Mult(Const(c1), e1), Mult(Const(c2), e2)) if e1 == e2 => Mult(Const(c1 + c2), e1) | |
case Add(Mult(Const(c1), e1), Add(Mult(Const(c2), e2), e3)) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3) | |
case Add(Mult(Const(c1), e1), Add(e3, Mult(Const(c2), e2))) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3) | |
case Add(Add(Mult(e1, Const(c1)), e3), Mult(e2, Const(c2))) if e1 == e2 => Add(Mult(Const(c1 + c2), e1), e3) | |
case Add(Add(e3, Mult(e1, Const(c1))), Mult(e2, Const(c2))) if e1 == e2 => Add(e3, Mult(Const(c1 + c2), e1)) | |
case Add(e1, Add(e2, e3)) if e1 == e2 => Add(Mult(Two, e1), e3) | |
case Add(e1, Add(e3, e2)) if e1 == e2 => Add(Mult(Two, e1), e3) | |
case Add(Add(e1, e3), e2) if e1 == e2 => Add(Mult(Two, e1), e3) | |
case Add(Add(e3, e1), e2) if e1 == e2 => Add(Mult(Two, e1), e3) | |
case Mult(Zero, _) | Mult(_, Zero) => Zero | |
case Mult(One, right) => simplifyOneShot(right) | |
case Mult(left, One) => simplifyOneShot(left) | |
case Mult(Const(x), Const(y)) => Const(x * y) | |
case Add(left @ Var, right @ Mult(Var, Var)) => Add(right, left) | |
case Add(left @ Mult(Const(c), Var), right @ Mult(Var, Var)) => Add(right, left) | |
case Add(left @ Mult(Const(c1), Var), right @ Mult(Const(c2), Mult(Var, Var))) => Add(right, left) | |
case Add(left @ Var, right @ Mult(Const(c2), Mult(Var, Var))) => Add(right, left) | |
case Add(left, right) => Add(simplifyOneShot(left), simplifyOneShot(right)) | |
case Mult(left, right) => Mult(simplifyOneShot(left), simplifyOneShot(right)) | |
case _ => expression | |
} | |
@tailrec | |
def fixPoint[A](f: A => A)(x: A): A = | |
if (x == f(x)) x | |
else fixPoint(f)(f(x)) | |
val simplify: Expression => Expression = fixPoint(simplifyOneShot) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment