Skip to content

Instantly share code, notes, and snippets.

@harupy
Created November 3, 2022 23:22
Show Gist options
  • Save harupy/e7656c4c2b8bb5018c7674ddac73a312 to your computer and use it in GitHub Desktop.
Save harupy/e7656c4c2b8bb5018c7674ddac73a312 to your computer and use it in GitHub Desktop.
import libcst as cst
import pathlib
import argparse
import difflib
class AssertMethodTransformer(cst.CSTTransformer):
@staticmethod
def is_unittest_assert_method(node: cst.Call, name: str) -> bool:
return (
isinstance(node.func, cst.Attribute)
and isinstance(node.func.value, cst.Name)
and node.func.value.value == "self"
and isinstance(node.func.attr, cst.Name)
and node.func.attr.value == name
)
def leave_Assert(self, original_node: cst.Assert, updated_node: cst.Assert) -> cst.Assert:
# print(original_node)
if isinstance(original_node.test, cst.Comparison):
if (
len(original_node.test.comparisons) == 1
and isinstance(original_node.test.comparisons[0].operator, cst.Equal)
and (
isinstance(original_node.test.left, cst.SimpleString)
or isinstance(original_node.test.left, cst.Integer)
or isinstance(original_node.test.left, cst.Float)
or isinstance(original_node.test.left, cst.List)
or isinstance(original_node.test.left, cst.Set)
or isinstance(original_node.test.left, cst.Tuple)
or isinstance(original_node.test.left, cst.Dict)
or (
isinstance(original_node.test.left, cst.Name)
and original_node.test.left.value == "None"
)
or (
isinstance(original_node.test.left, cst.Name)
and original_node.test.left.value == "False"
)
or (
isinstance(original_node.test.left, cst.Name)
and original_node.test.left.value == "True"
)
)
):
return cst.Assert(
test=cst.Comparison(
left=original_node.test.comparisons[0].comparator,
comparisons=[
cst.ComparisonTarget(
operator=cst.Equal(),
comparator=original_node.test.left,
)
],
),
msg=original_node.msg,
)
return original_node
def transform_file(path: pathlib.Path) -> None:
src = path.read_text()
source_tree = cst.parse_module(src)
modified_tree = source_tree.visit(AssertMethodTransformer())
if not modified_tree.deep_equals(source_tree):
print("".join(difflib.unified_diff(src.splitlines(1), modified_tree.code.splitlines(1))))
path.write_text(modified_tree.code)
print(f"Transformed {path}")
if __name__ == "__main__":
# python a.py $(git ls-files "tests/**/*.py")
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("files", metavar="N", nargs="+", help="an integer for the accumulator")
args = parser.parse_args()
for f in args.files:
print("Processing", f)
transform_file(pathlib.Path(f))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment