Skip to content

Instantly share code, notes, and snippets.

@Ionizing
Last active May 26, 2023 09:08
Show Gist options
  • Save Ionizing/22a9d9a57da8f49f8a2c9dc4505c7f3f to your computer and use it in GitHub Desktop.
Save Ionizing/22a9d9a57da8f49f8a2c9dc4505c7f3f to your computer and use it in GitHub Desktop.
Parse string and return an Fn object
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use nom::{
branch::alt,
bytes::complete::tag,
character::complete::multispace0 as multispace,
number::complete::double,
combinator::map,
multi::many0,
sequence::{delimited, preceded},
IResult,
};
#[derive(Debug)]
enum Expr {
Value(f64),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
Pow(Box<Expr>, Box<Expr>),
Paren(Box<Expr>),
Exp(Box<Expr>),
Sin(Box<Expr>),
Cos(Box<Expr>),
Var,
}
impl Expr {
fn eval(&self, var: f64) -> Result<f64, String> {
Ok(
match self {
Expr::Add(e1, e2) => (*e1).eval(var)? + (*e2).eval(var)?,
Expr::Sub(e1, e2) => (*e1).eval(var)? - (*e2).eval(var)?,
Expr::Mul(e1, e2) => (*e1).eval(var)? * (*e2).eval(var)?,
Expr::Div(e1, e2) => (*e1).eval(var)? / (*e2).eval(var)?,
Expr::Pow(e1, e2) => (*e1).eval(var)?.powf((*e2).eval(var)?),
Expr::Value(f) => *f,
Expr::Paren(e) => (*e).eval(var)?,
Expr::Exp(e) => (*e).eval(var)?.exp(),
Expr::Sin(e) => (*e).eval(var)?.sin(),
Expr::Cos(e) => (*e).eval(var)?.cos(),
Expr::Var => var,
}
)
}
}
impl Display for Expr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
use self::Expr::*;
match &self {
Value(val) => write!(f, "{}", val),
Add(lhs, rhs) => write!(f, "{} + {}", *lhs, *rhs),
Sub(lhs, rhs) => write!(f, "{} - {}", *lhs, *rhs),
Mul(lhs, rhs) => write!(f, "{} * {}", *lhs, *rhs),
Div(lhs, rhs) => write!(f, "{} / {}", *lhs, *rhs),
Pow(lhs, rhs) => write!(f, "{} ^ {}", *lhs, *rhs),
Paren(expr) => write!(f, "({})", *expr),
Exp(expr) => write!(f, "e^{}", *expr),
Sin(expr) => write!(f, "sin({})", *expr),
Cos(expr) => write!(f, "cos({})", *expr),
Var => write!(f, "x")
}
}
}
#[derive(Debug)]
enum Oper {
Add,
Sub,
Mul,
Div,
//Pow,
}
fn parens(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
delimited(tag("("), map(expr, |e| Expr::Paren(Box::new(e))), tag(")")),
multispace
)(i)
}
fn variable(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
map(tag("x"), |_| Expr::Var),
multispace,
)(i)
}
fn sin(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
delimited(
tag("sin("), map(expr, |e| Expr::Sin(Box::new(e))), tag(")")
),
multispace,
)(i)
}
fn cos(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
delimited(
tag("cos("), map(expr, |e| Expr::Cos(Box::new(e))), tag(")")
),
multispace,
)(i)
}
fn exp(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
preceded(
tag("e^"), map(expr, |e| Expr::Exp(Box::new(e)))
),
multispace,
)(i)
}
fn fold_exprs(init: Expr, remainder: Vec<(Oper, Expr)>) -> Expr {
remainder.into_iter().fold(init, |acc, (oper, expr)| {
match oper {
Oper::Add => Expr::Add(Box::new(acc), Box::new(expr)),
Oper::Sub => Expr::Sub(Box::new(acc), Box::new(expr)),
Oper::Mul => Expr::Mul(Box::new(acc), Box::new(expr)),
Oper::Div => Expr::Div(Box::new(acc), Box::new(expr)),
//Oper::Pow => Expr::Pow(Box::new(acc), Box::new(expr)),
}
})
}
fn pow(i: &str) -> IResult<&str, Expr> {
let (i, lhs) = factor(i)?;
let (i, rhs) = preceded(tag("^"), factor)(i)?;
Ok((i, Expr::Pow(Box::new(lhs), Box::new(rhs))))
}
fn number(i: &str) -> IResult<&str, Expr> {
delimited(
multispace,
map(double, |x| Expr::Value(x)),
multispace,
)(i)
}
fn factor(i: &str) -> IResult<&str, Expr> {
alt((
variable,
sin,
cos,
exp,
parens,
number,
))(i)
}
fn pow_or_factor(i: &str) -> IResult<&str, Expr> {
alt((
pow,
factor
))(i)
}
fn term(i: &str) -> IResult<&str, Expr> {
let (i, initial) = pow_or_factor(i)?;
let (i, remainder) = many0(alt((
|i| {
let (i, mul) = preceded(tag("*"), pow_or_factor)(i)?;
Ok((i, (Oper::Mul, mul)))
},
|i| {
let (i, div) = preceded(tag("/"), pow_or_factor)(i)?;
Ok((i, (Oper::Div, div)))
},
)))(i)?;
Ok((i, fold_exprs(initial, remainder)))
}
fn expr(i: &str) -> IResult<&str, Expr> {
let (i, initial) = term(i)?;
let (i, remainder) = many0(alt((
|i| {
let (i, add) = preceded(tag("+"), term)(i)?;
Ok((i, (Oper::Add, add)))
},
|i| {
let (i, sub) = preceded(tag("-"), term)(i)?;
Ok((i, (Oper::Sub, sub)))
},
)))(i)?;
Ok((i, fold_exprs(initial, remainder)))
}
fn str2fn(i: &str) -> impl Fn(f64) -> f64 {
let func = expr(i).unwrap().1;
move |x| func.eval(x).unwrap()
}
fn main() {
let string = "114 + 514 * sin(1919) - cos(810) * 2^8.93 - 810 * x";
let expression = expr(string);
let var = 5.0;
println!("AST of \"{}\" is \"{:?}\"", string, &expression);
println!("Eval of it with var({}) = {:?}", var, expression.unwrap().1.eval(var));
let func = str2fn(string);
println!("Eval from produced function: {}", func(var));
}
use once_cell::sync::Lazy;
use pest_derive::Parser;
use pest::Parser;
use pest::iterators::Pairs;
use pest::pratt_parser::{Assoc, Op, PrattParser};
#[derive(Parser)]
#[grammar_inline = r#"
value = @{ int ~ ("." ~ ASCII_DIGIT*)? ~ (^"e" ~ int)? }
int = { ("+" | "-")? ~ ASCII_DIGIT+ }
infix = _{ add | sub | mul | div | pow }
add = { "+" }
sub = { "-" }
mul = { "*" }
div = { "/" }
pow = { "^" }
prefix = _{ neg | exp }
neg = { "-" }
exp = { "e^" }
function = _{ sin | cos | tan }
sin = { "sin" }
cos = { "cos" }
tan = { "tan" }
variable = @{ "x" }
primary = _{
variable |
value |
function ~ "(" ~ expr ~ ")" |
"(" ~ expr ~ ")"
}
expr = {
prefix? ~ primary ~ (infix ~ prefix? ~ primary)*
}
program = _{ SOI ~ expr ~ EOI }
WHITESPACE = _{ " " }
"#]
struct Calculator;
static PRATT_PARSER: Lazy<PrattParser<Rule>> = Lazy::new(|| {
use Rule::*;
use Assoc::*;
PrattParser::new()
.op(Op::infix(add, Left) | Op::infix(sub, Right))
.op(Op::infix(mul, Left) | Op::infix(div, Right))
.op(Op::prefix(sin) | Op::prefix(cos) | Op::prefix(tan))
.op(Op::prefix(neg))
.op(Op::infix(pow, Right))
.op(Op::prefix(exp))
});
#[derive(Debug)]
pub enum Operation {
// binary op
Add,
Sub,
Mul,
Div,
Pow,
// unary op
Neg,
Exp,
Sin,
Cos,
Tan,
}
#[derive(Debug)]
pub enum Expr {
Variable,
Value(f64),
UnaryOp {
op: Operation,
rhs: Box<Expr>,
},
BinOp {
lhs: Box<Expr>,
op: Operation,
rhs: Box<Expr>,
},
}
pub fn parse_expr(pairs: Pairs<Rule>) -> Expr {
PRATT_PARSER
.map_primary(|primary| match primary.as_rule() {
Rule::variable => Expr::Variable,
Rule::value => Expr::Value(primary.as_str().parse::<f64>().unwrap()),
Rule::expr => parse_expr(primary.into_inner()),
_ => unreachable!("Expr::parse Expected primary expression, found {:?}", primary),
})
.map_prefix(|op, rhs| {
let op = match op.as_rule() {
Rule::neg => Operation::Neg,
Rule::exp => Operation::Exp,
Rule::sin => Operation::Sin,
Rule::cos => Operation::Cos,
Rule::tan => Operation::Tan,
_ => unreachable!("Expr::parse Expected prefix operator, found {:?}", op),
};
Expr::UnaryOp {
op,
rhs: Box::new(rhs),
}
})
.map_infix(|lhs, op, rhs| {
let op = match op.as_rule() {
Rule::add => Operation::Add,
Rule::sub => Operation::Sub,
Rule::mul => Operation::Mul,
Rule::div => Operation::Div,
Rule::pow => Operation::Pow,
_ => unreachable!("Expr::parse Expected prefix operator, found {:?}", op),
};
Expr::BinOp {
lhs: Box::new(lhs),
op,
rhs: Box::new(rhs),
}
})
.parse(pairs)
}
impl Expr {
fn eval(&self, var: f64) -> f64 {
match self {
Expr::Variable => var,
Expr::Value(x) => *x,
Expr::UnaryOp{op, rhs} => {
match op {
Operation::Neg => - rhs.eval(var),
Operation::Exp => rhs.eval(var).exp(),
Operation::Sin => rhs.eval(var).sin(),
Operation::Cos => rhs.eval(var).cos(),
Operation::Tan => rhs.eval(var).tan(),
_ => unreachable!()
}
}
Expr::BinOp{lhs, op, rhs} => {
match op {
Operation::Add => lhs.eval(var) + rhs.eval(var),
Operation::Sub => lhs.eval(var) - rhs.eval(var),
Operation::Mul => lhs.eval(var) * rhs.eval(var),
Operation::Div => lhs.eval(var) / rhs.eval(var),
Operation::Pow => lhs.eval(var).powf(rhs.eval(var)),
_ => unreachable!()
}
}
}
}
}
fn str2fn(i: &str) -> impl Fn(f64) -> f64 {
let mut pairs = Calculator::parse(Rule::program, i).unwrap();
let expr = parse_expr(pairs.next().unwrap().into_inner());
move |x| expr.eval(x)
}
fn main() {
let input = "e^(-x^2)";
let func = str2fn(input);
println!("Input string: {}", input);
println!(" eval with x=1.14514 : {}", func(1.14514));
}
#[cfg(test)]
mod tests {
use super::*;
fn test_fn(input: &str) {
match Calculator::parse(Rule::program, input) {
Ok(mut pairs) => {
println!(
"Parsed: {:#?}",
parse_expr(
pairs.next().unwrap().into_inner()
)
)
},
Err(e) => {
panic!("Parse failed: {:?}", e)
}
}
}
#[test]
fn test1() {
test_fn("1 + 1");
}
#[test]
fn test2() {
test_fn("sin(1 + 1)");
}
#[test]
fn test3() {
test_fn("1+e^5");
}
#[test]
fn test4() {
test_fn("-cos(sin(1+e^5))");
}
#[test]
fn test5() {
test_fn("-cos(sin(e^x))");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment