Skip to content

Instantly share code, notes, and snippets.

@harupy
Last active November 2, 2022 16:21
Show Gist options
  • Save harupy/7fdf544b78d62720ee67473b1b5a6576 to your computer and use it in GitHub Desktop.
Save harupy/7fdf544b78d62720ee67473b1b5a6576 to your computer and use it in GitHub Desktop.
import libcst as cst
import pathlib
import argparse
import difflib
class AssertMethodTransformer(cst.CSTTransformer):
@staticmethod
def is_unittest_assert_method(node: cst.Call, name: str) -> bool:
return (
isinstance(node.func, cst.Attribute)
and isinstance(node.func.value, cst.Name)
and node.func.value.value == "self"
and isinstance(node.func.attr, cst.Name)
and node.func.attr.value == name
)
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Assert:
if (
AssertMethodTransformer.is_unittest_assert_method(original_node, "assertEqual")
or AssertMethodTransformer.is_unittest_assert_method(original_node, "assertListEqual")
or AssertMethodTransformer.is_unittest_assert_method(
original_node, "assertSequenceEqual"
)
):
left, right = original_node.args
left = left.value
right = right.value
assert_node = cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=right,
)
],
)
)
return assert_node
if AssertMethodTransformer.is_unittest_assert_method(original_node, "assertCountEqual"):
# `self.assertCountEqual({1, 2, 3}, x)` can converted to `assert {1, 2, 3} == x`
if any(isinstance(a.value, cst.Set) for a in original_node.args):
print(original_node.args[0].value.elements)
left = original_node.args[0].value
right = original_node.args[1].value
assert_node = cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=right,
)
],
)
)
return assert_node
# `self.assertCountEqual([], x)` or `self.assertCountEqual([1], x)` can be converted to
# assert [] == x or assert [1] == x
if any(
isinstance(a.value, (cst.List, cst.Tuple)) and len(a.value.elements) < 2
for a in original_node.args
):
left = original_node.args[0].value
right = original_node.args[1].value
assert_node = cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=right,
)
],
)
)
return assert_node
# `self.assertCountEqual([1, 2], x)` can't be converted to `assert [1, 2] == x`.
# because x can be [2, 1]. We can convert it to either
# `assert sorted([1, 2]) == sorted(x)` or `assert {1, 2} == set(x)`.
# We choose the former because the latter allows x to contain duplicates.
# For example, if x is the return value of `search_experiments`, it should not contain duplicates.
left, right = original_node.args
assert_node = cst.Assert(
test=cst.Comparison(
left=cst.Call(
func=cst.Name("sorted"),
args=[left],
),
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=cst.Call(
func=cst.Name("sorted"),
args=[right],
),
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertNotEqual"):
left, right = original_node.args
left = left.value
right = right.value
assert_node = cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=cst.NotEqual(),
comparator=right,
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertIsInstance"):
left, right = original_node.args
assert_node = cst.Assert(test=cst.Call(func=cst.Name("isinstance"), args=[left, right]))
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertTrue"):
assert_node = cst.Assert(test=original_node.args[0].value)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertIsNone"):
(arg,) = original_node.args
assert_node = cst.Assert(
test=cst.Comparison(
left=arg.value,
comparisons=[
cst.ComparisonTarget(
operator=cst.Is(),
comparator=cst.Name("None"),
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertIsNotNone"):
(arg,) = original_node.args
assert_node = cst.Assert(
test=cst.Comparison(
left=arg.value,
comparisons=[
cst.ComparisonTarget(
operator=cst.IsNot(),
comparator=cst.Name("None"),
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertIn"):
left, right = original_node.args
assert_node = cst.Assert(
test=cst.Comparison(
left=left.value,
comparisons=[
cst.ComparisonTarget(
operator=cst.In(),
comparator=right.value,
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertLessEqual"):
left, right = original_node.args
assert_node = cst.Assert(
test=cst.Comparison(
left=left.value,
comparisons=[
cst.ComparisonTarget(
operator=cst.LessThanEqual(),
comparator=right.value,
)
],
)
)
return assert_node
elif AssertMethodTransformer.is_unittest_assert_method(original_node, "assertAlmostEqual"):
left, right = original_node.args[:2]
assert_node = cst.Assert(
test=cst.Comparison(
left=left.value,
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=cst.Call(
cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("approx")),
args=[right],
),
)
],
)
)
return assert_node
return original_node
def transform_file(path: pathlib.Path) -> None:
src = path.read_text()
source_tree = cst.parse_module(src)
modified_tree = source_tree.visit(AssertMethodTransformer())
if not modified_tree.deep_equals(source_tree):
print("".join(difflib.unified_diff(src.splitlines(1), modified_tree.code.splitlines(1))))
path.write_text(modified_tree.code)
print(f"Transformed {path}")
if __name__ == "__main__":
# python a.py $(git ls-files "tests/**/*.py")
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("files", metavar="N", nargs="+", help="an integer for the accumulator")
args = parser.parse_args()
for f in args.files:
print("Processing", f)
transform_file(pathlib.Path(f))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment