Skip to content

Instantly share code, notes, and snippets.

@nahkd123
Created April 2, 2024 17:48
Show Gist options
  • Save nahkd123/5a4226f5965e4beb5d58a4fd73907a52 to your computer and use it in GitHub Desktop.
Save nahkd123/5a4226f5965e4beb5d58a4fd73907a52 to your computer and use it in GitHub Desktop.
450 lines math expression engine
/*
* (c) Tran Huu An 2024. Licensed under MIT license.
*/
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* <p>
* A "simple" mathematical expression engine. Only supports {@code +},
* {@code -}, {@code *} and {@code /} but it can be extended for your needs.
* </p>
*
* @author nahkd123
* @see #compile(String)
* @see #evaluate(Context)
* @see #evaluate()
*/
public interface Expression {
public static interface Context {
default Double getVariable(String name) {
return null;
}
default DoubleUnaryOperator getFunction1(String name) {
return null;
}
default DoubleBinaryOperator getFunction2(String name) {
return null;
}
}
public static class MergedContext implements Context {
private Context[] contexts;
public MergedContext(Context... contexts) {
this.contexts = contexts;
}
@Override
public Double getVariable(String name) {
for (Context context : contexts) {
Double variable = context.getVariable(name);
if (variable != null) return variable;
}
return null;
}
@Override
public DoubleUnaryOperator getFunction1(String name) {
for (Context context : contexts) {
DoubleUnaryOperator function1 = context.getFunction1(name);
if (function1 != null) return function1;
}
return null;
}
@Override
public DoubleBinaryOperator getFunction2(String name) {
for (Context context : contexts) {
DoubleBinaryOperator function2 = context.getFunction2(name);
if (function2 != null) return function2;
}
return null;
}
}
public static class UniverseContext implements Context {
@Override
public Double getVariable(String name) {
return switch (name) {
case "pi", "\u03C0" -> Math.PI;
case "e" -> Math.E;
case "true" -> 1d;
case "false" -> 0d;
case "random" -> Math.random();
default -> null;
};
}
@Override
public DoubleUnaryOperator getFunction1(String name) {
return switch (name) {
case "sin" -> Math::sin;
case "cos" -> Math::cos;
case "tan" -> Math::tan;
case "asin" -> Math::asin;
case "acos" -> Math::acos;
case "atan" -> Math::atan;
case "exp" -> Math::exp;
case "log" -> Math::log;
case "log10" -> Math::log10;
case "signum" -> Math::signum;
case "floor" -> Math::floor;
case "ceil" -> Math::ceil;
case "round" -> Math::round;
case "sqrt" -> Math::sqrt;
case "cbrt" -> Math::cbrt;
case "deg" -> Math::toDegrees;
case "rad" -> Math::toRadians;
default -> null;
};
}
@Override
public DoubleBinaryOperator getFunction2(String name) {
return switch (name) {
case "atan2" -> Math::atan2;
case "pow" -> Math::pow;
default -> null;
};
}
}
public static final UniverseContext UNIVERSE = new UniverseContext();
public static class EvalException extends RuntimeException {
@Serial
private static final long serialVersionUID = 3019515149401799088L;
public EvalException(String message) {
super(message);
}
public EvalException(String message, Throwable cause) {
super(message, cause);
}
}
/**
* <p>
* Evaluate the expression, using your own context. You can provide your own
* functions and variables by implementing the context. Please note that
* functions like {@code sin()} will not be available unless you provided it
* yourself.
* </p>
*
* @param context The evaluation context.
* @return The evaluation result.
* @throws EvalException if the expression can't be evaluated because of unknown
* variable or function.
*/
public double evaluate(Context context) throws EvalException;
/**
* <p>
* Evaluate the expression, using {@link #UNIVERSE} context. The
* {@link #UNIVERSE} context contains all commonly used functions and variables,
* like {@code sin()}, {@code random} or {@code pi} (note that {@code random} is
* a variable, not a function).
* </p>
*
* @return The evaluation result.
* @throws EvalException if the expression can't be evaluated because of unknown
* variable or function.
*/
default double evaluate() throws EvalException {
return evaluate(UNIVERSE);
}
public static record Const(double value) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
return value;
}
}
public static record Variable(String name) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
Double value = context.getVariable(name);
if (value == null) throw new EvalException("Unknown variable: " + name);
return value;
}
}
public static enum Operator {
ADD("+") {
@Override
public double apply(double a, double b) {
return a + b;
}
},
SUBTRACT("-") {
@Override
public double apply(double a, double b) {
return a - b;
}
},
MULTIPLY("*") {
@Override
public double apply(double a, double b) {
return a * b;
}
},
DIVIDE("/") {
@Override
public double apply(double a, double b) {
return a / b;
}
};
private String symbol;
private Operator(String symbol) {
this.symbol = symbol;
}
public String getSymbol() { return symbol; }
public abstract double apply(double a, double b);
}
public static record Operate(Expression a, Expression b, Operator operator) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
return operator.apply(a.evaluate(context), b.evaluate(context));
}
}
public static record Call1(String name, Expression param1) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
DoubleUnaryOperator function1 = context.getFunction1(name);
if (function1 == null) throw new EvalException("Unknown function: " + name + "(x)");
return function1.applyAsDouble(param1.evaluate(context));
}
}
public static record Call2(String name, Expression param1, Expression param2) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
DoubleBinaryOperator function2 = context.getFunction2(name);
if (function2 == null) throw new EvalException("Unknown function: " + name + "(x, y)");
return function2.applyAsDouble(param1.evaluate(context), param2.evaluate(context));
}
}
public static final Pattern PATTERN = Pattern.compile("([\\w.]+|[+*/(),-])");
@SuppressWarnings("rawtypes")
public static final Set[] ORDER_OF_OPERATIONS = {
Set.of(Operator.MULTIPLY, Operator.DIVIDE),
Set.of(Operator.ADD, Operator.SUBTRACT)
};
/**
* <p>
* Compile math expression from string to {@link Expression}, which you can
* {@link #evaluate()} at any time.
* </p>
*
* @param source The math expression in string form.
* @return The compiled expression.
* @throws IllegalArgumentException if the expression can't be compiled.
*/
public static Expression compile(String source) throws IllegalArgumentException {
Matcher matcher = PATTERN.matcher(source);
List<Expression> expressions = new ArrayList<>();
while (matcher.find()) {
String result = matcher.group(1);
expressions.add(isNumber(result) ? new Const(Double.parseDouble(result)) : new $$$Token(result));
}
return validate(reduce(expressions));
}
public static record $$$Token(String token) implements Expression {
@Override
public double evaluate(Context context) throws EvalException {
throw new EvalException("Not parsed");
}
}
private static boolean isSymbol(String token) {
if (isNumber(token)) return false;
if (token.equals("(") || token.equals(")") || token.equals(",")) return false;
for (Operator operator : Operator.values()) if (token.equals(operator.symbol)) return false;
return true;
}
private static boolean isNumber(String token) {
for (int i = 0; i < token.length(); i++)
if (token.charAt(i) != '.' && (token.charAt(i) < '0' || token.charAt(i) > '9')) return false;
return true;
}
@SuppressWarnings("rawtypes")
private static Expression reduce(List<Expression> expressions) {
while (scan(expressions));
while (expressions.size() > 1) {
int lastSize = expressions.size();
for (Set set : ORDER_OF_OPERATIONS) {
boolean applied;
do {
applied = false;
for (int i = 1; i < expressions.size() - 1; i++) {
Expression left = expressions.get(i - 1);
Operator operator = asOperator(expressions.get(i));
Expression right = expressions.get(i + 1);
if (set.contains(operator)) {
expressions.remove(i + 1);
expressions.remove(i);
expressions.set(i - 1, new Operate(left, right, operator));
applied = true;
break;
} else {
i++; // Step by 2
}
}
} while (applied);
}
if (lastSize == expressions.size())
throw new IllegalArgumentException("Parse error: Stuck in infinite loop");
lastSize = expressions.size();
}
return expressions.get(0);
}
private static boolean scan(List<Expression> expressions) {
boolean mark = false;
for (int i = 0; i < expressions.size(); i++) {
Expression expr = expressions.get(i);
if (expr instanceof $$$Token token) {
if (isSymbol(token.token)) {
// Last
if (i == expressions.size() - 1) {
expressions.set(i, new Variable(token.token));
return true;
}
if (expressions.get(i + 1) instanceof $$$Token nextToken) {
// Function
if (nextToken.token.equals("(")) {
List<Expression> subExpression = new ArrayList<>();
List<Expression> params = new ArrayList<>();
int depth = 0;
int removes = 3;
for (int j = i + 2; j < expressions.size(); j++) {
if (expressions.get(j) instanceof $$$Token altToken) {
if (altToken.token.equals(")")) {
depth--;
if (depth == -1) {
params.add(reduce(subExpression));
break;
}
}
if (altToken.token.equals("(")) depth++;
if (altToken.token.equals(",") && depth == 0) {
params.add(reduce(subExpression));
subExpression = new ArrayList<>();
removes++;
continue;
}
}
subExpression.add(expressions.get(j));
removes++;
}
if (depth != -1) throw new IllegalArgumentException("Missing ')'");
if (params.size() < 1 || params.size() > 2)
throw new IllegalArgumentException("Unsupported number of function parameters: "
+ params.size());
while (removes > 0) {
expressions.remove(i);
removes--;
}
if (params.size() == 1)
expressions.add(i, new Call1(token.token, params.get(0)));
if (params.size() == 2)
expressions.add(i, new Call2(token.token, params.get(0), params.get(1)));
mark = true;
continue;
}
}
// Regular variable otherwise
expressions.set(i, new Variable(token.token));
mark = true;
}
// Group
if (token.token.equals("(")) {
List<Expression> subExpression = new ArrayList<>();
int depth = 0;
int removes = 2;
for (int j = i + 1; j < expressions.size(); j++) {
if (expressions.get(j) instanceof $$$Token altToken) {
if (altToken.token.equals(")")) {
depth--;
if (depth == -1) break;
}
if (altToken.token.equals("(")) depth++;
if (altToken.token.equals(",") && depth == 0)
throw new IllegalArgumentException("Unexpected ','");
}
subExpression.add(expressions.get(j));
removes++;
}
while (removes > 0) {
expressions.remove(i);
removes--;
}
expressions.add(i, reduce(subExpression));
mark = true;
}
}
}
return mark;
}
private static Operator asOperator(Expression expr) {
if (!(expr instanceof $$$Token token)) throw new IllegalArgumentException("Unexpected expression: " + expr);
return switch (token.token) {
case "+" -> Operator.ADD;
case "-" -> Operator.SUBTRACT;
case "*" -> Operator.MULTIPLY;
case "/" -> Operator.DIVIDE;
default -> throw new IllegalArgumentException("Unexpected token: " + token.token);
};
}
private static Expression validate(Expression expr) {
if (expr instanceof $$$Token token) throw new IllegalArgumentException("Unexpected token: " + token.token);
if (expr instanceof Call1 call1) validate(call1.param1);
if (expr instanceof Call2 call2) {
validate(call2.param1);
validate(call2.param2);
}
if (expr instanceof Operate operate) {
validate(operate.a);
validate(operate.b);
}
return expr;
}
}
public class ExpressionTest {
public static void main(String[] args) {
System.out.println(Expression.compile("1 + 2 * 32 / (4 + 5 * 6) + sin(pi) + e").evaluate());
System.out.println(1d + 2d * 32d / (4d + 5d * 6d) + Math.sin(Math.PI) + Math.E);
// => 5.600634769635516
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment