Created
August 6, 2012 16:19
-
-
Save mgood/3276107 to your computer and use it in GitHub Desktop.
Python expression evaluator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An attempt to make a safe evaluator of a subset of Python expressions. | |
This is mostly a proof-of-concept for getting feedback, it has not been | |
thoroughly checked for safety, use at your own risk :) | |
It uses the Python ast module to parse the expression, but all evaluation is | |
done by walking the ast, it is not directly executed by the Python runtime. | |
Nosetests are provided below including coverage of supported and unsupported | |
operations. | |
Known security considerations: | |
The variables are expected to be simple primitive types. Providing functions | |
with unsafe effects, or variables where the operator implementations can have | |
unsafe effects is obviously unsafe. | |
Some operations may also take a lot of time or memory and DOS the process. | |
""" | |
# we use floating-point division by default | |
from __future__ import division | |
import ast | |
import operator | |
_standard_context = { | |
'True': True, | |
'False': False, | |
} | |
def eval_expression(expr, vars=None, funcs=None): | |
if vars is None: | |
vars = {} | |
if funcs is None: | |
funcs = {} | |
vars = dict(vars, **_standard_context) | |
tree = ast.parse(expr, mode='eval') | |
print ast.dump(tree) | |
return AstEvaluator(vars, funcs).visit(tree) | |
class AstEvaluator(ast.NodeTransformer): | |
def __init__(self, variables, funcs): | |
self.variables = variables | |
self.funcs = funcs | |
binary_ops = { | |
ast.Add: operator.add, | |
ast.Sub: operator.sub, | |
ast.Mult: operator.mul, | |
ast.Div: operator.truediv, | |
ast.Mod: operator.mod, | |
ast.Pow: operator.pow, | |
ast.LShift: operator.lshift, | |
ast.RShift: operator.rshift, | |
ast.BitOr: operator.or_, | |
ast.BitXor: operator.xor, | |
ast.BitAnd: operator.and_, | |
ast.FloorDiv: operator.floordiv, | |
} | |
unary_ops = { | |
ast.Invert: operator.invert, | |
ast.Not: operator.not_, | |
ast.UAdd: operator.pos, | |
ast.USub: operator.neg, | |
} | |
bool_ops = { | |
ast.And: all, | |
ast.Or: any, | |
} | |
compare_ops = { | |
ast.Eq: operator.eq, | |
ast.NotEq: operator.ne, | |
ast.Lt: operator.lt, | |
ast.LtE: operator.le, | |
ast.Gt: operator.gt, | |
ast.GtE: operator.ge, | |
# include Is and IsNot? | |
ast.In: lambda a, b: a in b, | |
ast.NotIn: lambda a, b: a not in b, | |
} | |
def find_operator(self, op_map, op): | |
for op_type, op_func in op_map.iteritems(): | |
if isinstance(op, op_type): | |
return op_func | |
else: | |
raise ValueError('Unknown operator: %s' % op) | |
def visit_Expression(self, node): | |
return self.visit(node.body) | |
def visit_BinOp(self, node): | |
op_func = self.find_operator(self.binary_ops, node.op) | |
left = self.visit(node.left) | |
right = self.visit(node.right) | |
return op_func(left, right) | |
def visit_UnaryOp(self, node): | |
op_func = self.find_operator(self.unary_ops, node.op) | |
return op_func(self.visit(node.operand)) | |
def visit_Compare(self, node): | |
left = self.visit(node.left) | |
for op_node, comp_node in zip(node.ops, node.comparators): | |
op_func = self.find_operator(self.compare_ops, op_node) | |
right = self.visit(comp_node) | |
if not op_func(left, right): | |
return False | |
left = right | |
return True | |
def visit_Name(self, node): | |
if not isinstance(node.ctx, ast.Load): | |
raise ValueError('Can only read variables') | |
try: | |
return self.variables[node.id] | |
except KeyError: | |
raise ValueError('Unknown variable: %s' % node.id) | |
def visit_BoolOp(self, node): | |
op_func = self.find_operator(self.bool_ops, node.op) | |
return op_func(self.visit(v) for v in node.values) | |
def visit_Call(self, node): | |
if not isinstance(node.func, ast.Name): | |
raise ValueError() | |
if not isinstance(node.func.ctx, ast.Load): | |
raise ValueError('Can only read variables') | |
try: | |
func = self.funcs[node.func.id] | |
except KeyError: | |
raise ValueError() | |
args = [self.visit(x) for x in node.args] | |
return func(*args) | |
def visit_IfExp(self, node): | |
if self.visit(node.test): | |
return self.visit(node.body) | |
else: | |
return self.visit(node.orelse) | |
def visit_List(self, node): | |
if not isinstance(node.ctx, ast.Load): | |
raise ValueError('Can only read variables') | |
return [self.visit(x) for x in node.elts] | |
def visit_Tuple(self, node): | |
if not isinstance(node.ctx, ast.Load): | |
raise ValueError('Can only read variables') | |
return tuple(self.visit(x) for x in node.elts) | |
def visit_Num(self, node): | |
return node.n | |
def visit_Str(self, node): | |
return node.s | |
def generic_visit(self, node): | |
raise ValueError('Unknown node type: %s' % node.__class__.__name__) | |
# nosetests | |
from nose.tools import eq_, assert_raises | |
supported_expressions = [ | |
'1', | |
'-1', | |
'2.5', | |
'-(-1)', | |
'+(1)', | |
'~1', | |
# strings | |
'"text"', | |
'u"foo"', | |
# math | |
'2 + 2', | |
'10 - 5', | |
'2 * 3', | |
'20 / 2', | |
'10 % 2', | |
'3 ** 2', | |
'8 << 1', | |
'1 >> 8', | |
'0b11 & 0b10', | |
'0b10 | 0b01', | |
'0b11 ^ 0b01', | |
# uses "true" division by default | |
'5 / 2', | |
# use "//" for "floor" division | |
'5 // 2', | |
# function calls for whitelisted funcs | |
'abs(-1)', | |
'max([1, 2])', | |
# variables | |
'x * 2', | |
'3 - -x', | |
'x + 10 > y / 2', | |
'1 == 1', | |
'1 != 1', | |
'1 < 2', | |
'1 <= 2', | |
'1 > 2', | |
'1 >= 2', | |
'x in (1, 2, 3)', | |
'x not in (1, 2, 3)', | |
'x in [1, 2, 3]', | |
'x not in [1, 2, 3]', | |
'x == y / 10', | |
'True', | |
'False', | |
'not True', | |
'not False', | |
'True or False', | |
'True and False', | |
'1 if True else 2', | |
'1 if False else 2', | |
'x < 4 < 3', | |
'(x < 4) < 3', | |
] | |
context = {'x': 2, 'y': 20, 'min': min} | |
funcs = {'abs': abs, 'max': max} | |
unsupported_expressions = [ | |
'import foo', | |
# built-ins are not available | |
'str', | |
# foo is not in the context | |
'foo', | |
# min is a normal variable, not in the list of callable funcs | |
'min([1, 2])', | |
# don't allow attribute access | |
'x.foo', | |
'y.bar', | |
# both sides of the comparison should be validated | |
'x > y.foo', | |
'x.foo > y', | |
# both sides of bool op should be validated | |
'False or bar', | |
'foo or bar', | |
'foo or True', | |
'lambda: False', | |
'if True: pass', | |
'while True: pass', | |
'for x in y: pass', | |
# all 3 sections of if/else should be validated | |
'foo if True else 1', | |
'1 if False else foo', | |
'1 if foo else 2', | |
] | |
def test_expressions(): | |
for expr in supported_expressions: | |
yield check_supported, expr | |
for expr in unsupported_expressions: | |
yield check_unsupported, expr | |
def check_supported(expr): | |
expected = eval(expr, funcs, context) | |
eq_(expected, eval_expression(expr, vars=context, funcs=funcs)) | |
def check_unsupported(expr): | |
# calls _validate_expression to make sure validation errors are raised | |
# before trying to eval it | |
assert_raises((ValueError, SyntaxError), eval_expression, expr, vars=context, funcs=funcs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment