Skip to content

Instantly share code, notes, and snippets.

@tov

tov/deriv.rs Secret

Created November 12, 2018 00:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tov/af73f345710e937ec39a4dbaca4504fe to your computer and use it in GitHub Desktop.
Save tov/af73f345710e937ec39a4dbaca4504fe to your computer and use it in GitHub Desktop.
Rust code to compute the nth derivative of x^x
use std::env;
use std::fmt;
use std::ops::{Add, Mul};
use std::rc::Rc;
#[derive(Clone, Debug)]
pub enum Expr {
Int(i64),
Var(Rc<str>),
Add(Rc<Expr>, Rc<Expr>),
Mul(Rc<Expr>, Rc<Expr>),
Pow(Rc<Expr>, Rc<Expr>),
Ln(Rc<Expr>),
}
fn pown(base: i64, expt: i64) -> i64 {
match expt {
0 => 1,
1 => base,
_ => {
let b = pown(base, expt / 2);
b * b * if expt % 2 == 0 {1} else {base}
}
}
}
fn decompose_add(expr: Expr) -> (i64, Option<Expr>) {
match expr {
Expr::Int(n) => (n, None),
Expr::Add(f, g) => {
match *f {
Expr::Int(n) => (n, Some(g.clone_expr())),
_ => (0, Some(Expr::Add(f, g)))
}
}
_ => (0, Some(expr)),
}
}
fn decompose_mul(expr: Expr) -> (i64, Option<Expr>) {
match expr {
Expr::Int(n) => (n, None),
Expr::Mul(f, g) => {
match *f {
Expr::Int(n) => (n, Some(g.clone_expr())),
_ => (1, Some(Expr::Mul(f, g))),
}
}
_ => (1, Some(expr)),
}
}
impl Expr {
fn clone_expr(&self) -> Self {
Expr::clone(self)
}
}
impl Add<Expr> for Expr {
type Output = Expr;
fn add(self, other: Expr) -> Expr {
use Expr::*;
let (m, fo) = decompose_add(self);
let (n, go) = decompose_add(other);
let mn = m + n;
let fg = match (fo, go) {
(None, None) => return Int(mn),
(Some(f), None) => f,
(None, Some(g)) => g,
(Some(f), Some(g)) => Add(Rc::new(f), Rc::new(g)),
};
if mn == 0 {
fg
} else {
Add(Rc::new(Int(mn)), Rc::new(fg))
}
}
}
impl Mul<Expr> for Expr {
type Output = Expr;
fn mul(self, other: Expr) -> Expr {
use Expr::*;
let (m, fo) = decompose_mul(self);
let (n, go) = decompose_mul(other);
let mn = m * n;
if mn == 0 {
return Int(0);
}
let fg = match (fo, go) {
(None, None) => return Int(mn),
(Some(f), None) => f,
(None, Some(g)) => g,
(Some(f), Some(g)) => Mul(Rc::new(f), Rc::new(g)),
};
if mn == 1 {
fg
} else {
Mul(Rc::new(Int(mn)), Rc::new(fg))
}
}
}
impl Expr {
pub fn pow(self, other: Expr) -> Expr {
use Expr::*;
match (self, other) {
(Int(m), Int(n)) => Int(pown(m, n)),
(_, Int(0)) => Int(1),
(a, Int(1)) => a,
(Int(0), _) => Int(0),
(a, b) => Pow(Rc::new(a), Rc::new(b)),
}
}
pub fn ln(self) -> Expr {
use Expr::*;
match self {
Int(1) => Int(0),
a => Ln(Rc::new(a)),
}
}
pub fn d(&self, x: &str) -> Expr {
use Expr::*;
match *self {
Var(ref y) =>
if *x == **y {
Int(1)
} else {
Int(0)
},
Int(_) => Int(0),
Add(ref f, ref g) => f.d(x) + g.d(x),
Mul(ref f, ref g) =>
f.clone_expr() * g.d(x) + g.clone_expr() * f.d(x),
Pow(ref f, ref g) =>
self.clone_expr() *
(g.clone_expr() * f.d(x) * f.clone_expr().pow(Int(-1)) +
f.clone_expr().ln() * g.d(x)),
Ln(ref f) =>
f.d(x) * f.clone_expr().pow(Int(-1)),
}
}
pub fn count(&self) -> usize {
use Expr::*;
match *self {
Var(_) => 1,
Int(_) => 1,
Add(ref f, ref g) => f.count() + g.count(),
Mul(ref f, ref g) => f.count() + g.count(),
Pow(ref f, ref g) => f.count() + g.count(),
Ln(ref f) => f.count(),
}
}
fn format(&self, f: &mut fmt::Formatter, old_prec: usize) -> fmt::Result {
use Expr::*;
fn bracket<F>(f: &mut fmt::Formatter,
old_prec: usize, new_prec: usize, body: F)
-> fmt::Result
where F: FnOnce(&mut fmt::Formatter) -> fmt::Result {
if new_prec < old_prec { f.write_str("(")?; }
body(f)?;
if new_prec < old_prec { f.write_str(")")?; }
Ok(())
}
match *self {
Var(ref name) => f.write_str(&**name),
Int(n) => write!(f, "{}", n),
Add(ref g, ref h) => bracket(f, old_prec, 1, |f| {
g.format(f, 1)?;
f.write_str(" + ")?;
h.format(f, 1)
}),
Mul(ref g, ref h) => bracket(f, old_prec, 2, |f| {
g.format(f, 2)?;
f.write_str("*")?;
h.format(f, 2)
}),
Pow(ref g, ref h) => bracket(f, old_prec, 3, |f| {
g.format(f, 2)?;
f.write_str("^")?;
h.format(f, 3)
}),
Ln(ref g) => {
f.write_str("ln(")?;
g.format(f, 1)?;
f.write_str(")")
}
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let n = self.count();
if n > 100 {
write!(f, "<<{}>>", n)
} else {
self.format(f, 1)
}
}
}
fn nest<A, F: Fn(&A) -> A>(n: usize, f: F, mut x: A) -> A {
for _ in 0 .. n {
x = f(&x);
}
x
}
fn deriv(f: &Expr) -> Expr {
let d = f.d("x");
println!("D({}) = {}", f, d);
d
}
fn main() {
let x = Expr::Var(Rc::from("x"));
let f = x.clone().pow(x);
let n = env::args().nth(1).expect("n").parse().unwrap();
nest(n, deriv, f);
}
@TeXitoi
Copy link

TeXitoi commented Jan 23, 2019

https://github.com/TeXitoi/deriv-rs

Using typed-arena, I go from 2.3s to 0.9s with 10 as parameter.

@jdh30
Copy link

jdh30 commented Feb 11, 2019

Excellent work, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment