Skip to content

Instantly share code, notes, and snippets.

@mgood
Created August 6, 2012 16:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mgood/3276107 to your computer and use it in GitHub Desktop.
Save mgood/3276107 to your computer and use it in GitHub Desktop.
Python expression evaluator
"""
An attempt to make a safe evaluator of a subset of Python expressions.
This is mostly a proof-of-concept for getting feedback, it has not been
thoroughly checked for safety, use at your own risk :)
It uses the Python ast module to parse the expression, but all evaluation is
done by walking the ast, it is not directly executed by the Python runtime.
Nosetests are provided below including coverage of supported and unsupported
operations.
Known security considerations:
The variables are expected to be simple primitive types. Providing functions
with unsafe effects, or variables where the operator implementations can have
unsafe effects is obviously unsafe.
Some operations may also take a lot of time or memory and DOS the process.
"""
# we use floating-point division by default
from __future__ import division
import ast
import operator
_standard_context = {
'True': True,
'False': False,
}
def eval_expression(expr, vars=None, funcs=None):
if vars is None:
vars = {}
if funcs is None:
funcs = {}
vars = dict(vars, **_standard_context)
tree = ast.parse(expr, mode='eval')
print ast.dump(tree)
return AstEvaluator(vars, funcs).visit(tree)
class AstEvaluator(ast.NodeTransformer):
def __init__(self, variables, funcs):
self.variables = variables
self.funcs = funcs
binary_ops = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
ast.BitXor: operator.xor,
ast.BitAnd: operator.and_,
ast.FloorDiv: operator.floordiv,
}
unary_ops = {
ast.Invert: operator.invert,
ast.Not: operator.not_,
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
bool_ops = {
ast.And: all,
ast.Or: any,
}
compare_ops = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
# include Is and IsNot?
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
}
def find_operator(self, op_map, op):
for op_type, op_func in op_map.iteritems():
if isinstance(op, op_type):
return op_func
else:
raise ValueError('Unknown operator: %s' % op)
def visit_Expression(self, node):
return self.visit(node.body)
def visit_BinOp(self, node):
op_func = self.find_operator(self.binary_ops, node.op)
left = self.visit(node.left)
right = self.visit(node.right)
return op_func(left, right)
def visit_UnaryOp(self, node):
op_func = self.find_operator(self.unary_ops, node.op)
return op_func(self.visit(node.operand))
def visit_Compare(self, node):
left = self.visit(node.left)
for op_node, comp_node in zip(node.ops, node.comparators):
op_func = self.find_operator(self.compare_ops, op_node)
right = self.visit(comp_node)
if not op_func(left, right):
return False
left = right
return True
def visit_Name(self, node):
if not isinstance(node.ctx, ast.Load):
raise ValueError('Can only read variables')
try:
return self.variables[node.id]
except KeyError:
raise ValueError('Unknown variable: %s' % node.id)
def visit_BoolOp(self, node):
op_func = self.find_operator(self.bool_ops, node.op)
return op_func(self.visit(v) for v in node.values)
def visit_Call(self, node):
if not isinstance(node.func, ast.Name):
raise ValueError()
if not isinstance(node.func.ctx, ast.Load):
raise ValueError('Can only read variables')
try:
func = self.funcs[node.func.id]
except KeyError:
raise ValueError()
args = [self.visit(x) for x in node.args]
return func(*args)
def visit_IfExp(self, node):
if self.visit(node.test):
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_List(self, node):
if not isinstance(node.ctx, ast.Load):
raise ValueError('Can only read variables')
return [self.visit(x) for x in node.elts]
def visit_Tuple(self, node):
if not isinstance(node.ctx, ast.Load):
raise ValueError('Can only read variables')
return tuple(self.visit(x) for x in node.elts)
def visit_Num(self, node):
return node.n
def visit_Str(self, node):
return node.s
def generic_visit(self, node):
raise ValueError('Unknown node type: %s' % node.__class__.__name__)
# nosetests
from nose.tools import eq_, assert_raises
supported_expressions = [
'1',
'-1',
'2.5',
'-(-1)',
'+(1)',
'~1',
# strings
'"text"',
'u"foo"',
# math
'2 + 2',
'10 - 5',
'2 * 3',
'20 / 2',
'10 % 2',
'3 ** 2',
'8 << 1',
'1 >> 8',
'0b11 & 0b10',
'0b10 | 0b01',
'0b11 ^ 0b01',
# uses "true" division by default
'5 / 2',
# use "//" for "floor" division
'5 // 2',
# function calls for whitelisted funcs
'abs(-1)',
'max([1, 2])',
# variables
'x * 2',
'3 - -x',
'x + 10 > y / 2',
'1 == 1',
'1 != 1',
'1 < 2',
'1 <= 2',
'1 > 2',
'1 >= 2',
'x in (1, 2, 3)',
'x not in (1, 2, 3)',
'x in [1, 2, 3]',
'x not in [1, 2, 3]',
'x == y / 10',
'True',
'False',
'not True',
'not False',
'True or False',
'True and False',
'1 if True else 2',
'1 if False else 2',
'x < 4 < 3',
'(x < 4) < 3',
]
context = {'x': 2, 'y': 20, 'min': min}
funcs = {'abs': abs, 'max': max}
unsupported_expressions = [
'import foo',
# built-ins are not available
'str',
# foo is not in the context
'foo',
# min is a normal variable, not in the list of callable funcs
'min([1, 2])',
# don't allow attribute access
'x.foo',
'y.bar',
# both sides of the comparison should be validated
'x > y.foo',
'x.foo > y',
# both sides of bool op should be validated
'False or bar',
'foo or bar',
'foo or True',
'lambda: False',
'if True: pass',
'while True: pass',
'for x in y: pass',
# all 3 sections of if/else should be validated
'foo if True else 1',
'1 if False else foo',
'1 if foo else 2',
]
def test_expressions():
for expr in supported_expressions:
yield check_supported, expr
for expr in unsupported_expressions:
yield check_unsupported, expr
def check_supported(expr):
expected = eval(expr, funcs, context)
eq_(expected, eval_expression(expr, vars=context, funcs=funcs))
def check_unsupported(expr):
# calls _validate_expression to make sure validation errors are raised
# before trying to eval it
assert_raises((ValueError, SyntaxError), eval_expression, expr, vars=context, funcs=funcs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment