Last active
May 26, 2023 09:08
-
-
Save Ionizing/22a9d9a57da8f49f8a2c9dc4505c7f3f to your computer and use it in GitHub Desktop.
Parse string and return an Fn object
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fmt; | |
use std::fmt::{Debug, Display, Formatter}; | |
use nom::{ | |
branch::alt, | |
bytes::complete::tag, | |
character::complete::multispace0 as multispace, | |
number::complete::double, | |
combinator::map, | |
multi::many0, | |
sequence::{delimited, preceded}, | |
IResult, | |
}; | |
#[derive(Debug)] | |
enum Expr { | |
Value(f64), | |
Add(Box<Expr>, Box<Expr>), | |
Sub(Box<Expr>, Box<Expr>), | |
Mul(Box<Expr>, Box<Expr>), | |
Div(Box<Expr>, Box<Expr>), | |
Pow(Box<Expr>, Box<Expr>), | |
Paren(Box<Expr>), | |
Exp(Box<Expr>), | |
Sin(Box<Expr>), | |
Cos(Box<Expr>), | |
Var, | |
} | |
impl Expr { | |
fn eval(&self, var: f64) -> Result<f64, String> { | |
Ok( | |
match self { | |
Expr::Add(e1, e2) => (*e1).eval(var)? + (*e2).eval(var)?, | |
Expr::Sub(e1, e2) => (*e1).eval(var)? - (*e2).eval(var)?, | |
Expr::Mul(e1, e2) => (*e1).eval(var)? * (*e2).eval(var)?, | |
Expr::Div(e1, e2) => (*e1).eval(var)? / (*e2).eval(var)?, | |
Expr::Pow(e1, e2) => (*e1).eval(var)?.powf((*e2).eval(var)?), | |
Expr::Value(f) => *f, | |
Expr::Paren(e) => (*e).eval(var)?, | |
Expr::Exp(e) => (*e).eval(var)?.exp(), | |
Expr::Sin(e) => (*e).eval(var)?.sin(), | |
Expr::Cos(e) => (*e).eval(var)?.cos(), | |
Expr::Var => var, | |
} | |
) | |
} | |
} | |
impl Display for Expr { | |
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { | |
use self::Expr::*; | |
match &self { | |
Value(val) => write!(f, "{}", val), | |
Add(lhs, rhs) => write!(f, "{} + {}", *lhs, *rhs), | |
Sub(lhs, rhs) => write!(f, "{} - {}", *lhs, *rhs), | |
Mul(lhs, rhs) => write!(f, "{} * {}", *lhs, *rhs), | |
Div(lhs, rhs) => write!(f, "{} / {}", *lhs, *rhs), | |
Pow(lhs, rhs) => write!(f, "{} ^ {}", *lhs, *rhs), | |
Paren(expr) => write!(f, "({})", *expr), | |
Exp(expr) => write!(f, "e^{}", *expr), | |
Sin(expr) => write!(f, "sin({})", *expr), | |
Cos(expr) => write!(f, "cos({})", *expr), | |
Var => write!(f, "x") | |
} | |
} | |
} | |
#[derive(Debug)] | |
enum Oper { | |
Add, | |
Sub, | |
Mul, | |
Div, | |
//Pow, | |
} | |
fn parens(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
delimited(tag("("), map(expr, |e| Expr::Paren(Box::new(e))), tag(")")), | |
multispace | |
)(i) | |
} | |
fn variable(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
map(tag("x"), |_| Expr::Var), | |
multispace, | |
)(i) | |
} | |
fn sin(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
delimited( | |
tag("sin("), map(expr, |e| Expr::Sin(Box::new(e))), tag(")") | |
), | |
multispace, | |
)(i) | |
} | |
fn cos(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
delimited( | |
tag("cos("), map(expr, |e| Expr::Cos(Box::new(e))), tag(")") | |
), | |
multispace, | |
)(i) | |
} | |
fn exp(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
preceded( | |
tag("e^"), map(expr, |e| Expr::Exp(Box::new(e))) | |
), | |
multispace, | |
)(i) | |
} | |
fn fold_exprs(init: Expr, remainder: Vec<(Oper, Expr)>) -> Expr { | |
remainder.into_iter().fold(init, |acc, (oper, expr)| { | |
match oper { | |
Oper::Add => Expr::Add(Box::new(acc), Box::new(expr)), | |
Oper::Sub => Expr::Sub(Box::new(acc), Box::new(expr)), | |
Oper::Mul => Expr::Mul(Box::new(acc), Box::new(expr)), | |
Oper::Div => Expr::Div(Box::new(acc), Box::new(expr)), | |
//Oper::Pow => Expr::Pow(Box::new(acc), Box::new(expr)), | |
} | |
}) | |
} | |
fn pow(i: &str) -> IResult<&str, Expr> { | |
let (i, lhs) = factor(i)?; | |
let (i, rhs) = preceded(tag("^"), factor)(i)?; | |
Ok((i, Expr::Pow(Box::new(lhs), Box::new(rhs)))) | |
} | |
fn number(i: &str) -> IResult<&str, Expr> { | |
delimited( | |
multispace, | |
map(double, |x| Expr::Value(x)), | |
multispace, | |
)(i) | |
} | |
fn factor(i: &str) -> IResult<&str, Expr> { | |
alt(( | |
variable, | |
sin, | |
cos, | |
exp, | |
parens, | |
number, | |
))(i) | |
} | |
fn pow_or_factor(i: &str) -> IResult<&str, Expr> { | |
alt(( | |
pow, | |
factor | |
))(i) | |
} | |
fn term(i: &str) -> IResult<&str, Expr> { | |
let (i, initial) = pow_or_factor(i)?; | |
let (i, remainder) = many0(alt(( | |
|i| { | |
let (i, mul) = preceded(tag("*"), pow_or_factor)(i)?; | |
Ok((i, (Oper::Mul, mul))) | |
}, | |
|i| { | |
let (i, div) = preceded(tag("/"), pow_or_factor)(i)?; | |
Ok((i, (Oper::Div, div))) | |
}, | |
)))(i)?; | |
Ok((i, fold_exprs(initial, remainder))) | |
} | |
fn expr(i: &str) -> IResult<&str, Expr> { | |
let (i, initial) = term(i)?; | |
let (i, remainder) = many0(alt(( | |
|i| { | |
let (i, add) = preceded(tag("+"), term)(i)?; | |
Ok((i, (Oper::Add, add))) | |
}, | |
|i| { | |
let (i, sub) = preceded(tag("-"), term)(i)?; | |
Ok((i, (Oper::Sub, sub))) | |
}, | |
)))(i)?; | |
Ok((i, fold_exprs(initial, remainder))) | |
} | |
fn str2fn(i: &str) -> impl Fn(f64) -> f64 { | |
let func = expr(i).unwrap().1; | |
move |x| func.eval(x).unwrap() | |
} | |
fn main() { | |
let string = "114 + 514 * sin(1919) - cos(810) * 2^8.93 - 810 * x"; | |
let expression = expr(string); | |
let var = 5.0; | |
println!("AST of \"{}\" is \"{:?}\"", string, &expression); | |
println!("Eval of it with var({}) = {:?}", var, expression.unwrap().1.eval(var)); | |
let func = str2fn(string); | |
println!("Eval from produced function: {}", func(var)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use once_cell::sync::Lazy; | |
use pest_derive::Parser; | |
use pest::Parser; | |
use pest::iterators::Pairs; | |
use pest::pratt_parser::{Assoc, Op, PrattParser}; | |
#[derive(Parser)] | |
#[grammar_inline = r#" | |
value = @{ int ~ ("." ~ ASCII_DIGIT*)? ~ (^"e" ~ int)? } | |
int = { ("+" | "-")? ~ ASCII_DIGIT+ } | |
infix = _{ add | sub | mul | div | pow } | |
add = { "+" } | |
sub = { "-" } | |
mul = { "*" } | |
div = { "/" } | |
pow = { "^" } | |
prefix = _{ neg | exp } | |
neg = { "-" } | |
exp = { "e^" } | |
function = _{ sin | cos | tan } | |
sin = { "sin" } | |
cos = { "cos" } | |
tan = { "tan" } | |
variable = @{ "x" } | |
primary = _{ | |
variable | | |
value | | |
function ~ "(" ~ expr ~ ")" | | |
"(" ~ expr ~ ")" | |
} | |
expr = { | |
prefix? ~ primary ~ (infix ~ prefix? ~ primary)* | |
} | |
program = _{ SOI ~ expr ~ EOI } | |
WHITESPACE = _{ " " } | |
"#] | |
struct Calculator; | |
static PRATT_PARSER: Lazy<PrattParser<Rule>> = Lazy::new(|| { | |
use Rule::*; | |
use Assoc::*; | |
PrattParser::new() | |
.op(Op::infix(add, Left) | Op::infix(sub, Right)) | |
.op(Op::infix(mul, Left) | Op::infix(div, Right)) | |
.op(Op::prefix(sin) | Op::prefix(cos) | Op::prefix(tan)) | |
.op(Op::prefix(neg)) | |
.op(Op::infix(pow, Right)) | |
.op(Op::prefix(exp)) | |
}); | |
#[derive(Debug)] | |
pub enum Operation { | |
// binary op | |
Add, | |
Sub, | |
Mul, | |
Div, | |
Pow, | |
// unary op | |
Neg, | |
Exp, | |
Sin, | |
Cos, | |
Tan, | |
} | |
#[derive(Debug)] | |
pub enum Expr { | |
Variable, | |
Value(f64), | |
UnaryOp { | |
op: Operation, | |
rhs: Box<Expr>, | |
}, | |
BinOp { | |
lhs: Box<Expr>, | |
op: Operation, | |
rhs: Box<Expr>, | |
}, | |
} | |
pub fn parse_expr(pairs: Pairs<Rule>) -> Expr { | |
PRATT_PARSER | |
.map_primary(|primary| match primary.as_rule() { | |
Rule::variable => Expr::Variable, | |
Rule::value => Expr::Value(primary.as_str().parse::<f64>().unwrap()), | |
Rule::expr => parse_expr(primary.into_inner()), | |
_ => unreachable!("Expr::parse Expected primary expression, found {:?}", primary), | |
}) | |
.map_prefix(|op, rhs| { | |
let op = match op.as_rule() { | |
Rule::neg => Operation::Neg, | |
Rule::exp => Operation::Exp, | |
Rule::sin => Operation::Sin, | |
Rule::cos => Operation::Cos, | |
Rule::tan => Operation::Tan, | |
_ => unreachable!("Expr::parse Expected prefix operator, found {:?}", op), | |
}; | |
Expr::UnaryOp { | |
op, | |
rhs: Box::new(rhs), | |
} | |
}) | |
.map_infix(|lhs, op, rhs| { | |
let op = match op.as_rule() { | |
Rule::add => Operation::Add, | |
Rule::sub => Operation::Sub, | |
Rule::mul => Operation::Mul, | |
Rule::div => Operation::Div, | |
Rule::pow => Operation::Pow, | |
_ => unreachable!("Expr::parse Expected prefix operator, found {:?}", op), | |
}; | |
Expr::BinOp { | |
lhs: Box::new(lhs), | |
op, | |
rhs: Box::new(rhs), | |
} | |
}) | |
.parse(pairs) | |
} | |
impl Expr { | |
fn eval(&self, var: f64) -> f64 { | |
match self { | |
Expr::Variable => var, | |
Expr::Value(x) => *x, | |
Expr::UnaryOp{op, rhs} => { | |
match op { | |
Operation::Neg => - rhs.eval(var), | |
Operation::Exp => rhs.eval(var).exp(), | |
Operation::Sin => rhs.eval(var).sin(), | |
Operation::Cos => rhs.eval(var).cos(), | |
Operation::Tan => rhs.eval(var).tan(), | |
_ => unreachable!() | |
} | |
} | |
Expr::BinOp{lhs, op, rhs} => { | |
match op { | |
Operation::Add => lhs.eval(var) + rhs.eval(var), | |
Operation::Sub => lhs.eval(var) - rhs.eval(var), | |
Operation::Mul => lhs.eval(var) * rhs.eval(var), | |
Operation::Div => lhs.eval(var) / rhs.eval(var), | |
Operation::Pow => lhs.eval(var).powf(rhs.eval(var)), | |
_ => unreachable!() | |
} | |
} | |
} | |
} | |
} | |
fn str2fn(i: &str) -> impl Fn(f64) -> f64 { | |
let mut pairs = Calculator::parse(Rule::program, i).unwrap(); | |
let expr = parse_expr(pairs.next().unwrap().into_inner()); | |
move |x| expr.eval(x) | |
} | |
fn main() { | |
let input = "e^(-x^2)"; | |
let func = str2fn(input); | |
println!("Input string: {}", input); | |
println!(" eval with x=1.14514 : {}", func(1.14514)); | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
fn test_fn(input: &str) { | |
match Calculator::parse(Rule::program, input) { | |
Ok(mut pairs) => { | |
println!( | |
"Parsed: {:#?}", | |
parse_expr( | |
pairs.next().unwrap().into_inner() | |
) | |
) | |
}, | |
Err(e) => { | |
panic!("Parse failed: {:?}", e) | |
} | |
} | |
} | |
#[test] | |
fn test1() { | |
test_fn("1 + 1"); | |
} | |
#[test] | |
fn test2() { | |
test_fn("sin(1 + 1)"); | |
} | |
#[test] | |
fn test3() { | |
test_fn("1+e^5"); | |
} | |
#[test] | |
fn test4() { | |
test_fn("-cos(sin(1+e^5))"); | |
} | |
#[test] | |
fn test5() { | |
test_fn("-cos(sin(e^x))"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment