Skip to content

Instantly share code, notes, and snippets.

@7shi
Forked from shigemk2/AlgebraCalculation.scala
Last active August 29, 2015 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 7shi/7078516eaa5980334271 to your computer and use it in GitHub Desktop.
Save 7shi/7078516eaa5980334271 to your computer and use it in GitHub Desktop.
import scala.language.implicitConversions
sealed trait Expr {
def +(that: Expr): Expr = Add(List(this, that))
def *(that: Expr): Expr = Mul(List(this, that))
def +(that: Int ): Expr = this + N( that)
def -(that: Int ): Expr = this + N(-that)
def *(that: Int ): Expr = this * N( that)
override def toString: String = this match {
case n: N => n.toString
case Var(x, N( 1,1), N(1,1)) => x
case Var(x, N(-1,1), N(1,1)) => "-" + x
case Var(x, a , N(1,1)) => a.rstr(" ") + x
case Var(x, a , n ) => Var(x, a) + "^" + n
case Add(Nil) => ""
case Add(List(Add(xs))) => "(" + Add(xs) + ")"
case Add(List(x)) => x.toString
case Add(x::xs)
if xs.head.isNeg => Add(List(x)).toString + Add(xs)
case Add(x::xs) => Add(List(x)) + "+" + Add(xs)
case Mul(Nil) => ""
case Mul(List(Add(xs))) => "(" + Add(xs) + ")"
case Mul(List(Mul(xs))) => "(" + Mul(xs) + ")"
case Mul(List(x)) => x.toString
case Mul(x::xs) => Mul(List(x)) + "*" + Mul(xs)
}
def eval: N = (this: @unchecked) match {
case n: N => n
case Add(xs) => xs.map(_.eval).reduceLeft(_ + _)
case Mul(xs) => xs.map(_.eval).reduceLeft(_ * _)
}
def isNeg: Boolean = this match {
case n: N if n < 0 => true
case Var(_, a, _) if a < 0 => true
case _ => false
}
def <(that: Expr): Boolean = (this, that) match {
case (Var("x", _, n1), Var("x", _, n2)) => n1 < n2
case (Var("x", _, _), _) => false
case (_, Var("x", _, _)) => true
case _ => true
}
def >=(that: Expr): Boolean = !(this < that)
def sort: Expr = this match {
case Add(xs) => {
def f(list: List[Expr]): List[Expr] = list match {
case List() => List()
case x::xs => {
val xs1 = for (x1 <- xs if x1 >= x) yield x1.sort
val xs2 = for (x2 <- xs if x2 < x) yield x2.sort
f(xs1) ++ List(x) ++ f(xs2)
}
}
Add(f(xs))
}
case Mul(xs) => Mul(xs.map(_.sort))
case _ => this
}
private def flatten(list: List[Expr]): List[Expr] = list match {
case List() => List()
case Add(xs1)::xs2 => flatten(xs1 ++ xs2)
case x::xs => x :: flatten(xs)
}
private def add(list: List[Expr]): Expr = list match {
case List() => N(0)
case List(x) => x
case _ => Add(list)
}
def simplify: Expr = this match {
case Add(xs) => {
def getxs(e: Expr) = (e: @unchecked) match {
case Add(xs) => xs
}
def f(list: List[Expr]): List[Expr] = list match {
case List() => List()
case N(0,_)::xs => f(xs)
case Var(_,N(0,_),_)::xs => f(xs)
case List(x) => List(x.simplify)
case (a1: N)::(a2: N)::zs => f((a1 + a2)::zs)
case Var("x",a1,n1)::Var("x",a2,n2)::zs if n1 == n2 => f(x(a1 + a2, n1)::zs)
case x::xs => x.simplify::f(xs)
}
add(f(getxs(Add(flatten(xs)).sort)))
}
case Mul(xs) => Mul(xs.map(_.simplify))
case _ => this
}
def multiply(that: Expr): Expr = (this, that) match {
case (n1: N, n2: N) => n1 * n2
case (n1: N, Var(x, a2, n2)) => Var(x, n1 * a2, n2)
case (Var(x, a1, n1), n2: N) => Var(x, a1 * n2, n1)
case (Var(x, a1, n1), Var(y, a2, n2)) if x == y => Var(x, a1 * a2, n1 + n2)
case (Var(x, a1, n1), Var(y, a2, n2)) if x != y => Var(x, a1, n1) * Var(y, a2, n2)
case (Add(xs1), Add(xs2)) => Add((for(x1 <- xs1; x2 <- xs2) yield x1.multiply(x2)))
case (Add(xs1), x2) => Add((for(x1 <- xs1) yield x1.multiply(x2)))
case (x1, Add(xs2)) => Add((for(x2 <- xs2) yield x1.multiply(x2)))
case (Mul(xs1), Mul(xs2)) => Mul(xs1 ++ xs2)
case (Mul(xs1), x2) => Mul(xs1 :+ x2)
case (x1, Mul(xs2)) => Mul(x1 :: xs2)
}
private def mul(list: List[Expr]): Expr = list match {
case List() => N(1)
case List(x) => x
case _ => Mul(list)
}
def expand: Expr = this match {
case Mul(xs) => {
def f(list: List[Expr]): List[Expr] = list match {
case List() => List()
case List(x) => List(x.expand)
case (x::y::xs) => x.multiply(y) :: xs
}
mul(f(xs))
}
case Add(xs) => {
def f(list: List[Expr]): List[Expr] = list match {
case List() => List()
case x::xs => {
val x2 = x.expand
if (x != x2) x2::xs else x::f(xs)
}
}
add(f(xs))
}
case _ => this
}
def expandAll: Expr = {
val x2 = expand
if (this != x2) x2.expandAll else this
}
def differentiate(x: String): Expr = this match {
case Add(ys) => Add((for(y <- ys) yield y.differentiate(x)))
case Var(y, a, N(1,1)) if x == y => a
case Var(y, a, n) if x == y => Var(x, a * n, n - 1)
case Var(_, _, _) => N(0)
case _: N => N(0)
}
def integrate(x: String): Expr = this match {
case Add(ys) => Add((for(y <- ys) yield y.integrate(x)) ++ List(Var("C")))
case Var(y, a, n) if x == y => Var(x, a / (n + 1), n + 1)
case Var(y, a, n) if x != y => Var(y, a, n) * Var(x)
case n: N => Var(x, n)
}
}
// N is derived from Rational with some fixes.
// https://sites.google.com/site/scalajp/home/documentation/scala-by-example/chapter6
case class N(numer: Int, denom: Int = 1) extends Expr with Ordered[N] {
def reduce: N = {
def gcd(x: Int, y: Int): Int = {
if (x == 0) y
else if (x < 0) gcd(-x, y)
else if (y < 0) -gcd( x, -y)
else gcd(y % x, x)
}
val g = gcd(numer, denom)
N(numer / g, denom / g)
}
def +(that: N) = N(
numer * that.denom + that.numer * denom,
denom * that.denom).reduce
def -(that: N) = N(
numer * that.denom - that.numer * denom,
denom * that.denom).reduce
def *(that: N) = N(
numer * that.numer,
denom * that.denom).reduce
def /(that: N) = N(
numer * that.denom,
denom * that.numer).reduce
def compare(that: N): Int = (this - that).numer
override def equals(other: Any): Boolean = other match {
case that: N => numer == that.numer && denom == that.denom
case that: Int => numer == that && denom == 1
case _ => false
}
override def toString: String = denom match {
case 1 => numer.toString
case _ => numer + "/" + denom
}
def rstr(s: String): String = denom match {
case 1 => numer.toString
case _ => this + s
}
}
implicit def NToInt(n: Int): N = N(n)
case class Var(x: String, a: N = 1, n: N = 1) extends Expr
def x(a: N = 1, n: N = 1): Var = Var("x", a, n)
case class Add(xs: List[Expr]) extends Expr {
override def +(that: Expr): Expr = Add(xs :+ that)
}
case class Mul(xs: List[Expr]) extends Expr {
override def *(that: Expr): Expr = Mul(xs :+ that)
}
def test(tag: String, v: Any, e: Any) = {
if (v == e) {
println("[OK] " + tag)
} else {
println("[NG] " + tag)
println(" value : " + v)
println(" expected: " + e)
}
}
test("eval 1", (N(1)+1).eval, 1+1)
test("eval 2", (N(2)+3).eval, 2+3)
test("eval 3", (N(5)-3).eval, 5-3)
test("eval 4", (N(3)*4).eval, 3*4)
test("eval 5", (N(1)+N(2)*3).eval, 1+2*3)
test("eval 6", ((N(1)+2)*3).eval, (1+2)*3)
test("str 1", (N(1)+2+3).toString, "1+2+3")
test("str 2", (N(1)-2-3).toString, "1-2-3")
test("str 3", (N(1)*2*3).toString, "1*2*3")
test("str 4", (N(1)+N(2)*3).toString, "1+2*3")
test("str 5", (Add(List(N(1)+2,N(3)))).toString, "(1+2)+3")
test("str 6", ((N(1)+2)*3).toString, "(1+2)*3")
test("str 7", (Mul(List(N(1)*2,N(3)))).toString, "(1*2)*3")
test("equal", N(1)+2, N(1)+2)
test("x 1", (x()+1).toString, "x+1")
test("x 2", (x(1,3)+x(-1,2)+x(-2)+1).toString, "x^3-x^2-2x+1")
test("xlt 1", x() < x(1,2), true)
test("xlt 2", N(1) < x(), true)
test("xsort 1", {
val f = x()+1+x(1,2)
(f.toString, f.sort.toString)
},("x+1+x^2", "x^2+x+1"))
test("xsort 2", {
val f = (N(5)+x(2))*(x()+1+x(1,2))
(f.toString, f.sort.toString)
},("(5+2x)*(x+1+x^2)", "(2x+5)*(x^2+x+1)"))
test("xsimplify 1", {
val f = x(2)+3+x(4,2)+x()+1+x(1,2)
(f.toString, f.simplify.toString)
},("2x+3+4x^2+x+1+x^2", "5x^2+3x+4"))
test("xsimplify 2", {
val f = (x()+0+x(2))*Add(List(x(1,2),N(1)+x(2,2),N(2)))
(f.toString, f.simplify.toString)
},("(x+0+2x)*(x^2+(1+2x^2)+2)", "3x*(3x^2+3)"))
test("xsimplify 3", {
val f = x()+1+x(0,2)+x()+1+x(-2)-2
(f.toString, f.simplify.toString)
},("x+1+0x^2+x+1-2x-2", "0"))
test("multiply 1", {
val f1 = N(2)
val f2 = N(3)
(f1.toString, f2.toString, f1.multiply(f2).toString)
},("2", "3", "6"))
test("multiply 2", {
val f1 = N(2)
val f2 = x(3,2)
(f1.toString, f2.toString, f1.multiply(f2).toString)
},("2", "3x^2", "6x^2"))
test("multiply 3", {
val f1 = x(2,3)
val f2 = x(3,4)
(f1.toString, f2.toString, f1.multiply(f2).toString)
},("2x^3", "3x^4", "6x^7"))
test("multiply 4", {
val f1 = N(2)
val f2 = x()+x(2,2)+3
(f1.toString, f2.toString, f1.multiply(f2).toString)
},("2", "x+2x^2+3", "2x+4x^2+6"))
test("multiply 5", {
val f1 = x()+1
val f2 = x(2)+3
val f3 = f1.multiply(f2)
val f4 = f3.simplify
(f1.toString, f2.toString, f3.toString, f4.toString)
},("x+1", "2x+3", "2x^2+3x+2x+3", "2x^2+5x+3"))
test("expand 1", {
val f = (x()+1)*(x()+2)*(x()+3)
(f.toString, f.expand.toString)
},("(x+1)*(x+2)*(x+3)", "(x^2+2x+x+2)*(x+3)"))
test("expand 2", {
val f = N(1)+(x()+1)*(x()+2)*(x()+3)
(f.toString, f.expand.toString)
},("1+(x+1)*(x+2)*(x+3)", "1+(x^2+2x+x+2)*(x+3)"))
test("expandAll", {
val f1 = N(1)+(x()+1)*(x()+2)*(x()+3)
val f2 = f1.expandAll
val f3 = f2.simplify
(f1.toString, f2.toString, f3.toString)
},("1+(x+1)*(x+2)*(x+3)",
"1+(x^3+3x^2+2x^2+6x+x^2+3x+2x+6)",
"x^3+6x^2+11x+7"))
test("differentiate", {
val f = x(1,3)+x(1,2)+x()+1
(f.toString, f.differentiate("x").toString)
},("x^3+x^2+x+1", "3x^2+2x+1+0"))
test("integrate", {
val f = x(1,2)+x(2)+1
(f.toString, f.integrate("x").toString)
},("x^2+2x+1", "1/3 x^3+x^2+x+C"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment