Skip to content

Instantly share code, notes, and snippets.

@jliszka
Created October 31, 2013 03:43
Show Gist options
  • Save jliszka/7244101 to your computer and use it in GitHub Desktop.
Save jliszka/7244101 to your computer and use it in GitHub Desktop.
class Poly(coeffs: Int => Double) {
// Memoizing coefficient accessor. Returns the coefficient for x^n.
def apply(n: Int): Double = memo.getOrElseUpdate(n, this.coeffs(n))
// The memo table
private val memo = scala.collection.mutable.HashMap[Int, Double]()
def +(that: Poly): Poly = new Poly(n => this(n) + that(n))
def -(that: Poly): Poly = new Poly(n => this(n) - that(n))
def unary_-(): Poly = new Poly(n => -this(n))
def *(that: Poly): Poly = new Poly(n =>
kahan((0 to n).map(i => this(i) * that(n-i)))
)
def *(x: Double): Poly = new Poly(n => this(n) * x)
def /(that: Poly): Poly = {
if (this(0) == 0 && that(0) == 0) {
val a = new Poly(n => this(n+1))
val b = new Poly(n => that(n+1))
a / b
} else {
this * that.inv
}
}
def /(x: Double): Poly = new Poly(n => this(n) / x)
private def kahan(xs: Seq[Double]): Double = {
val (sum, carry) = xs.foldLeft((0.0, 0.0)){ case ((sum, carry), x) => {
val x2 = x - carry
val newSum = sum + x2
(newSum, (newSum - sum) - x2)
}}
sum
}
def inv: Poly = {
val a = this(0)
val q = I - this / a
new Poly(n => kahan((0 to n).map(i => (q ** i)(n))) / a)
}
lazy val I: Poly = new Poly(n => if (n == 0) 1 else 0)
private val powMemo = scala.collection.mutable.HashMap[Int, Poly]()
def **(p: Int): Poly = {
powMemo.getOrElseUpdate(p, {
if (p == 0) I
else {
val p2 = this ** (p / 2)
if (p % 2 == 0) p2 * p2 else p2 * p2 * this
}
})
}
def **(r: Double): Poly = {
val a = this(0)
val ar = math.pow(a, r)
val q = this / a - I
def coeff(n: Int) = (0 to n-1).map(i => r - i).product / (1 to n).product
new Poly(n => kahan((0 to n).map(i => coeff(i) * (q ** i)(n))) * ar)
}
def sqrt = this ** 0.5
def exp: Poly = {
val a = this(0)
val q = this - I * a
def fact(n: Int) = (1 to n).product
val eq = new Poly(n => (0 to n).map(i => (q ** i)(n) / fact(i)).sum)
eq * math.exp(a)
}
private def alternatingSign(n: Int): Double = if (n % 2 == 0) 1 else -1
def log: Poly = {
val a = this(0)
val logA = I * math.log(a)
val q = this / a - I
def coeff(n: Int) = if (n % 2 == 0) -1.0 / n else 1.0 / n
new Poly(n => (1 to n).map(i => coeff(i) * (q ** i)(n)).sum) + logA
}
override def toString = {
"{ %s, ... }".format((0 to 10).map(i => df.format(this(i))).mkString(", "))
}
private val df = new java.text.DecimalFormat("#.#######")
def toStream = Stream.from(0).map(this.apply)
def take(n: Int) = (0 to n-1).map(this.apply).toList
}
object one extends Poly(n => if (n == 0) 1 else 0)
object x extends Poly(n => if (n == 1) 1 else 0)
implicit def intToPoly(i: Int): Poly = one * i
implicit def doubleToPoly(x: Double): Poly = one * x
def derivatives(f: Poly => Poly, c: Double): Stream[Double] = {
def fact(n: Int) = (1 to n).product
val fc = f(c + x)
Stream.from(1).map(i => fc(i) * fact(i))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment