Skip to content

Instantly share code, notes, and snippets.

@lpetre
Created January 7, 2022 15:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lpetre/05b729a0891427c2a9def8d1eda3153d to your computer and use it in GitHub Desktop.
Save lpetre/05b729a0891427c2a9def8d1eda3153d to your computer and use it in GitHub Desktop.
Codemod for converting libcst.testing.utils.data_provider to parameterized.expand
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import ast
from typing import Generator, List, Optional, Sequence, Set, Tuple
import libcst as cst
import libcst.matchers as m
from libcst.codemod import (
CodemodContext,
ContextAwareTransformer,
ContextAwareVisitor,
VisitorBasedCodemodCommand,
)
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
class ConvertDataProviderCommand(VisitorBasedCodemodCommand):
DESCRIPTION: str = "Converts libcst.testing.utils.data_provider to parameterized.parameterized.expand"
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
if not m.matches(
updated_node, m.ClassDef(bases=[m.Arg(value=m.Name(value="UnitTest"))])
):
return updated_node
AddImportsVisitor.add_needed_import(self.context, "unittest")
RemoveImportsVisitor.remove_unused_import(
self.context, "libcst.testing.utils", "UnitTest"
)
return updated_node.with_changes(
bases=[
cst.Arg(
value=cst.Attribute(
value=cst.Name(value="unittest"),
attr=cst.Name(value="TestCase"),
)
)
]
)
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
if not m.matches(
updated_node,
m.FunctionDef(
decorators=[
m.Decorator(decorator=m.Call(m.Name(value="data_provider")))
]
),
):
return updated_node
params = list(updated_node.params.params)
params.insert(
1,
cst.Param(
name=cst.Name(
value="_name",
),
annotation=cst.Annotation(
annotation=cst.Name(
value="str",
),
),
star="",
),
)
if updated_node.params.kwonly_params is not None:
params += list(updated_node.params.kwonly_params)
decorator = updated_node.decorators[0].decorator
assert len(decorator.args) == 1
old_arg = decorator.args[0]
old_arg_value = cst.ensure_type(old_arg.value, cst.Dict)
new_arg_elems = []
for elem in old_arg_value.elements:
name = cst.ensure_type(elem.key, cst.SimpleString)
if isinstance(elem.value, cst.Dict):
kwargs = cst.ensure_type(elem.value, cst.Dict)
new_arg_elems.append(
cst.Element(
value=cst.Call(
func=cst.Name(value="param"),
args=[
cst.Arg(value=name),
cst.Arg(value=kwargs, star="**"),
],
)
)
)
elif isinstance(elem.value, cst.List):
args = cst.ensure_type(elem.value, cst.List)
new_arg_elems.append(
cst.Element(
value=cst.Tuple(
[cst.Element(value=name)] + list(args.elements),
)
)
)
else:
raise Exception("Unsupported")
AddImportsVisitor.add_needed_import(
self.context, "parameterized", "parameterized"
)
AddImportsVisitor.add_needed_import(self.context, "parameterized", "param")
RemoveImportsVisitor.remove_unused_import(
self.context, "libcst.testing.utils", "data_provider"
)
updated_args = [cst.Arg(value=cst.Tuple(new_arg_elems))]
return updated_node.with_changes(
params=cst.Parameters(params=params),
decorators=[
cst.Decorator(
decorator=cst.Call(
func=cst.Attribute(
value=cst.Name(value="parameterized"),
attr=cst.Name(value="expand"),
),
args=updated_args,
)
)
],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment