Skip to content

Instantly share code, notes, and snippets.

@jgarte
Forked from thrau/convert_assert.py
Created June 14, 2022 04:02
Show Gist options
  • Save jgarte/28518c2f8e5cd96e05a058911895fb7f to your computer and use it in GitHub Desktop.
Save jgarte/28518c2f8e5cd96e05a058911895fb7f to your computer and use it in GitHub Desktop.
script to covert unittest asserts to plain asserts
"""
Script to convert unittest asserts into plain asserts.
Either reads the file from the path passed as first parameter, or reads from stdin if no parameter is given.
"""
import functools
import sys
import libcst as cst
def to_assert(left, op=None, right=None, msg=None) -> cst.Assert:
if right is None:
if op is None:
return cst.Assert(test=left, msg=msg)
return cst.Assert(test=cst.UnaryOperation(operator=op, expression=left), msg=msg)
return cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=op,
comparator=right,
),
],
),
msg=msg,
)
class AssertTransformer(cst.CSTTransformer):
unary_comps = {
"assertIsNone": lambda arg, msg: to_assert(arg, cst.Is(), cst.Name("None"), msg=msg),
"assertIsNotNone": lambda arg, msg: to_assert(arg, cst.IsNot(), cst.Name("None"), msg=msg),
"assertTrue": lambda arg, msg: to_assert(arg, msg=msg),
"assertFalse": lambda arg, msg: to_assert(arg, cst.Not(), msg=msg),
}
binary_comps = {
"assertEqual": functools.partial(to_assert, op=cst.Equal()),
"assertNotEqual": functools.partial(to_assert, op=cst.NotEqual()),
"assertIn": functools.partial(to_assert, op=cst.In()),
"assertNotIn": functools.partial(to_assert, op=cst.NotIn()),
"assertIs": functools.partial(to_assert, op=cst.Is()),
"assertIsNot": functools.partial(to_assert, op=cst.IsNot()),
"assertLess": functools.partial(to_assert, op=cst.LessThan()),
"assertLessEqual": functools.partial(to_assert, op=cst.LessThanEqual()),
"assertGreater": functools.partial(to_assert, op=cst.GreaterThan()),
"assertGreaterEqual": functools.partial(to_assert, op=cst.GreaterThanEqual()),
}
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if isinstance(updated_node.func, cst.Attribute):
attr = updated_node.func
comp = attr.attr.value.strip()
if comp in self.unary_comps:
left = updated_node.args[0].value
try:
msg = updated_node.args[1].value
except IndexError:
msg = None
return self.unary_comps[comp](arg=left, msg=msg)
if comp in self.binary_comps:
left = updated_node.args[0].value
right = updated_node.args[1].value
try:
msg = updated_node.args[2].value
except IndexError:
msg = None
return self.binary_comps[comp](left=left, right=right, msg=msg)
return updated_node
def main():
if len(sys.argv) > 1:
with open(sys.argv[1]) as fd:
code = fd.read()
else:
code = sys.stdin.read()
tree = cst.parse_module(code)
new_code = tree.visit(AssertTransformer()).code
print(new_code)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment