Skip to content

Instantly share code, notes, and snippets.

@shigemk2
Created July 12, 2015 04:05
Show Gist options
  • Save shigemk2/817e6c2267ec08929cf6 to your computer and use it in GitHub Desktop.
Save shigemk2/817e6c2267ec08929cf6 to your computer and use it in GitHub Desktop.
sealed trait Expr
case class N(n: Int) extends Expr
case class Var(x: String, a: Int, n: Int) extends Expr
case class Add(n: Expr*) extends Expr
case class Mul(n: Expr*) extends Expr
def x(a: Int, n: Int): Var = {
Var("x", a, n)
}
def eval(e: Expr): Int = (e: @unchecked) match {
case N(x) => x
case Add(xs @_*) => xs.map(x => eval(x)).sum
case Mul(xs @_*) => xs.map(x => eval(x)).product
}
def str(e: Expr): String = e match {
case N(x) => x.toString
case Var(x, 1, 1) => x
case Var(x, -1, 1) => "-" ++ x
case Var(x, a, 1) => a.toString ++ x
case Var(x, a, n) => str(Var(x, a, 1)) ++ "^" ++ n.toString
case Add() => ""
case Add(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")"
case Add(x) => str(x)
case Add(x, xs@_*)
if isneg(xs.head) => str(Add(x)) ++ str(Add(xs: _*))
case Add(x, xs@_*) => str(Add(x)) ++ "+" ++ str(Add(xs: _*))
case Mul() => ""
case Mul(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")"
case Mul(Mul(xs@_*)) => "(" ++ str(Mul(xs: _*)) ++ ")"
case Mul(x) => str(x)
case Mul(x, xs@_*) => str(Mul(x)) ++ "*" ++ str(Mul(xs: _*))
}
def isneg(e: Expr): Boolean = e match {
case N(n) if n < 0 => true
case Var(_, a, _) if a < 0 => true
case _ => false
}
def xlt(x: Expr, y: Expr): Boolean = (x,y) match {
case (Var("x", _, n1), Var("x", _, n2)) => (n1 < n2)
case (Var("x", _, _), _) => false
case (_, Var("x", _, _)) => true
case (_, _) => true
}
def xsort(xs: Expr): Expr = xs match {
case Add(xs@_*) => {
def f(xs: List[Expr]): List[Expr] = xs match {
case List() => List()
case (x::xs) => {
val xs1 = for (x1 <- xs if ! xlt(x1, x)) yield xsort(x1)
val xs2 = for (x2 <- xs if xlt(x2, x)) yield xsort(x2)
f(xs1) ++ List(x) ++ f(xs2)
}
}
Add(f(xs.toList): _*)
}
case Mul(xs@_*) => Mul(xs.map(x => xsort(x)): _*)
case xs => xs
}
def flatten(xs: List[Expr]): List[Expr] = xs match {
case List() => List()
case (Add(xs1@_*)::xs2) => flatten(xs1.toList ++ xs2)
case (x::xs) => x :: flatten(xs)
}
def add(xs: List[Expr]): Expr = xs match {
case List() => N(0)
case List(xs) => xs
case xs => Add(xs: _*)
}
def xsimplify(xs: Expr): Expr = xs match {
case Add(xs@_*) => {
def getxs(xs: Expr) = (xs: @unchecked) match {
case Add(xs @_*) => xs
}
def f(xs: List[Expr]): List[Expr] = xs match {
case List() => List()
case (N(0)::xs) => f(xs)
case (Var(_,0,_)::xs) => f(xs)
case List(x) => List(xsimplify(x))
case (N(a1)::N(a2)::zs) => f(N(a1 + a2)::zs)
case (Var("x",a1,n1)::Var("x",a2,n2)::zs) if n1 == n2 => f(x(a1 + a2, n1)::zs)
case (x::xs) => xsimplify(x)::f(xs)
}
add(f(getxs(xsort(Add(flatten(xs.toList): _*))).toList))
}
case Mul(xs@_*) => Mul(xs.map(x => xsimplify(x)): _*)
case xs => xs
}
def multiply(xs1: Expr, xs2: Expr): Expr = (xs1, xs2) match {
case (N(n1), N(n2)) => N(n1 * n2)
case (N(n1), Var(x, a2, n2)) => Var(x, (n1 * a2), n2)
case (Var(x, a1, n1), N(n2)) => Var(x, (a1 * n2), n1)
case (Var(x, a1, n1), Var(y, a2, n2)) if x == y => Var(x, (a1 * a2), (n1 + n2))
case (Var(x, a1, n1), Var(y, a2, n2)) if x != y => Mul(Var(x, a1, n1), Var(y, a2, n2))
case (Add(xs1@_*), Add(xs2@_*)) => Add((for(x1 <- xs1; x2 <- xs2) yield multiply(x1, x2)): _*)
case (Add(xs1@_*), x2) => Add((for(x1 <- xs1) yield multiply(x1, x2)): _*)
case (x1, Add(xs2@_*)) => Add((for(x2 <- xs2) yield multiply(x1, x2)): _*)
case (Mul(xs1@_*), Mul(xs2@_*)) => Mul(xs1.toList ++ xs2.toList:_*)
case (Mul(xs1@_*), xs2) => Mul(xs1.toList :+ xs2:_*)
case (xs1, Mul(xs2@_*)) => Mul(xs1 :: xs2.toList:_*)
}
def mul(xs: List[Expr]): Expr = xs match {
case List() => N(1)
case List(xs) => xs
case xs => Mul(xs: _*)
}
def expand(xs: Expr): Expr = xs match {
case Mul(xs@_*) => {
def f(xs: List[Expr]): List[Expr] = xs match {
case List() => List()
case List(x) => List(expand(x))
case (x::y::xs) => multiply(x,y) :: xs
}
Mul(f(xs.toList): _*)
}
case Add(xs@_*) => {
def f(xs: List[Expr]): List[Expr] = xs match {
case List() => List()
case (x::xs) if x != expand(x) => expand(x) :: xs
case (x::xs) if x == expand(x) => x :: f(xs)
}
Add(f(xs.toList): _*)
}
case xs => xs
}
def expandAll(x: Expr): Expr = x match {
case x if x != expand(x) => expandAll(expand(x))
case x if x == expand(x) => x
}
println(eval(Add(N(1),N(2))) == 1+2)
println(eval(Add(N(2),N(3))) == 2+3)
println(eval(Add(N(5),N(-3))) == 5-3)
println(eval(Mul(N(3),N(4))) == 3*4)
println(eval(Add(N(1),Mul(N(2),N(3)))) == 1+2*3)
println(eval(Mul(Add(N(1),N(2)),N(3))) == (1+2)*3)
println(str(Add(N(1),N(2),N(3))) == "1+2+3")
println(str(Add(N(1),N(-2),N(-3))) == "1-2-3")
println(str(Mul(N(1),N(2),N(3))) == "1*2*3")
println(str(Add(N(1),Mul(N(2),N(3)))) == "1+2*3")
println(str(Mul(N(1),N(2),N(3))) == "1*2*3")
println(str(Add(Add(N(1),N(2)),N(3))) == "(1+2)+3")
println(str(Mul(Add(N(1),N(2)),N(3))) == "(1+2)*3")
println(str(Mul(Mul(N(1),N(2)),N(3))) == "(1*2)*3")
println(Add(N(1),N(2)) == Add(N(1),N(2)))
println(str(Add(x(1,1),N(1))) == "x+1")
println(str(Add(x(1,3),x(-1,2),x(-2,1),N(1))) == "x^3-x^2-2x+1")
val f = Mul(Add(N(5),x(2,1)),Add(x(1,2),x(1,1),N(1),x(3,3)))
println(str(f) == "(5+2x)*(x^2+x+1+3x^3)")
println(str(xsort(f)) == "(2x+5)*(3x^3+x^2+x+1)")
val g1 = Add(x(2,1),N(3),x(4,2),x(1,1),N(1),x(1,2))
println(str(g1) == "2x+3+4x^2+x+1+x^2")
println(str(xsimplify(g1)) == "5x^2+3x+4")
val g2 = Mul(Add(x(1,1),N(0),x(2,1)),Add(x(1,2),Add(N(1),x(2,2)),N(2)))
println(str(g2) == "(x+0+2x)*(x^2+(1+2x^2)+2)")
println(str(xsimplify(g2)) == "3x*(3x^2+3)")
val g3 = Add(x(1,1),N(1),x(0,2),x(1,1),N(1),x(-2,1),N(-2))
println(str(g3) == "x+1+0x^2+x+1-2x-2")
println(str(xsimplify(g3)) == "0")
println(str(N(2)) == "2")
println(str(N(3)) == "3")
println(str(multiply(N(2), N(3))) == "6")
println(str(N(2)) == "2")
println(str(x(3,2)) == "3x^2")
println(str(multiply(N(2), x(3,2))) == "6x^2")
println(str(x(2,3)) == "2x^3")
println(str(x(3,4)) == "3x^4")
println(str(multiply(x(2,3), x(3,4))) == "6x^7")
println(str(N(2)) == "2")
println(str(Add(x(1,1),x(2,2),N(3))) == "x+2x^2+3")
println(str(multiply(N(2), Add(x(1,1),x(2,2),N(3)))) == "2x+4x^2+6")
println(str(Add(x(1,1),N(1))) == "x+1")
println(str(Add(x(2,1),N(3))) == "2x+3")
println(str(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3)))) == "2x^2+3x+2x+3")
println(str(xsimplify(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3))))) == "2x^2+5x+3")
println(str(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))) == "(x+1)*(x+2)*(x+3)")
println(str(expand(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "(x^2+2x+x+2)*(x+3)")
println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)")
println(str(expand((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))))) == "1+(x^2+2x+x+2)*(x+3)")
println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)")
println(str(expandAll(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))) == "1+(x^3+3x^2+2x^2+6x+x^2+3x+2x+6)")
println(str(xsimplify(expandAll((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))))) == "x^3+6x^2+11x+7")
println(str(xsimplify(Add(N(1),Add(x(1,3),x(3,2),x(2,2),x(6,1),x(1,2),x(3,1),x(2,1),N(6))))) == "x^3+6x^2+11x+7")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment