Skip to content

Instantly share code, notes, and snippets.

@danfunk
Created January 20, 2022 13:59
Show Gist options
  • Save danfunk/020ce41f5848fc8280e87dc14918a2a5 to your computer and use it in GitHub Desktop.
Save danfunk/020ce41f5848fc8280e87dc14918a2a5 to your computer and use it in GitHub Desktop.
import ast
import logging
from RestrictedPython import safe_globals, compile_restricted
from SpiffWorkflow.bpmn.PythonScriptEngine import PythonScriptEngine
from SpiffWorkflow.bpmn.specs.ScriptTask import ScriptTask
from SpiffWorkflow.exceptions import WorkflowTaskExecException
logger = logging.getLogger(__name__)
class ScriptParser(ast.NodeVisitor):
"""This class walks the tree and records imports and function & method calls,
It can validate against a potential execution enviromnent.
"""
def __init__(self):
self.imports = [ ]
self.functions = [ ]
self.methods = [ ]
self.nodes = [ ]
def visit_Import(self, node):
for alias in node.names:
self.imports.append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
for alias in node.names:
self.imports.append(alias.name)
self.generic_visit(node)
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.functions.append(node.func.id)
elif isinstance(node.func, ast.Attribute):
self.methods.append(node.func.attr)
self.generic_visit(node)
def validate(self, node, globals, locals):
self.visit(node)
invalid = [ call for call in self.functions if locals.get(call, globals.get(call)) is None ]
if len(invalid) > 0:
raise Exception(f"The following functions are not available: {', '.join(invalid)}")
def clear(self):
self.imports = [ ]
self.functions = [ ]
self.methods = [ ]
self.nodes = [ ]
class ScriptExecutor(ast.NodeTransformer):
"""This class walks a tree, replaces constants with their values and calls functions.
If there are nodes that aren't handled, they're left alone (which means this probably won't
work if the script contains anything but functions calls and references to constants).
"""
def __init__(self, globals, locals):
self.globals = globals
self.locals = locals
def visit(self, node):
rv = super().visit(node)
logger.debug(f"{node.__class__.__name__} (fields={', '.join(node._fields)}) -> {rv}")
return rv if rv is not None else node
def visit_Call(self, node):
func = self.visit(node.func)
args = [ self.visit(arg) for arg in node.args ]
kwargs = [ self.visit(kw) for kw in node.keywords ]
if args and kwargs:
return func(*args, **dict(kwargs))
elif node.args:
return func(*args)
else:
return func()
def visit_keyword(self, node):
self.generic_visit(node)
return (node.arg, node.value)
def visit_Attribute(self, node):
self.generic_visit(node)
return node.value.__getattribute__(node.attr)
def visit_Name(self, node):
return self.locals.get(node.id, self.globals.get(node.id))
def visit_Constant(self, node):
return node.value
class PythonFormatAwareScriptEngine(PythonScriptEngine):
def __init__(self, globals=None):
super().__init__()
self.globals = globals if globals is not None else safe_globals
def execute(self, task, script, data, external_methods=None):
if isinstance(task.task_spec, ScriptTask) and task.task_spec.script_format == 'text/python;type=function':
self.execute_function_call(task, script, data)
else:
self.execute_in_container(task, script)
def execute_function_call(self, task, script, data):
tree = ast.parse(script)
# Check format
if len(tree.body) > 1 or not self.is_assignment_from_function(tree.body[0]):
raise WorkflowTaskExecException(task,
f"This must come in the format of an a assignment from a function, "
"ie x = y(...), only one expression is allowed.")
# Verify that functions are available
parser = ScriptParser()
parser.visit(tree)
try:
parser.validate(tree, self.globals, data)
except Exception as exc:
raise WorkflowTaskExecException(task, exc)
# This could be a tuple; need to either forbid or handle that case
target, call = tree.body[0].targets, tree.body[0].value
executor = ScriptExecutor(self.globals, data)
# Works, but very limited
try:
data[target[0].id] = executor.visit(call)
except Exception as err:
if len(err.args) > 0:
detail = err.args[0]
else:
detail = err.__class__.__name__
raise WorkflowTaskExecException(task, detail, err)
def execute_in_container(self, task, script, data):
# This would be replaced with an actual container or something
super().execute(task, script, data)
def is_assignment_from_function(self, node):
return isinstance(node, ast.Assign) and isinstance(node.value, ast.Call)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment