Created
January 7, 2022 15:38
-
-
Save lpetre/05b729a0891427c2a9def8d1eda3153d to your computer and use it in GitHub Desktop.
Codemod for converting libcst.testing.utils.data_provider to parameterized.expand
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
import argparse | |
import ast | |
from typing import Generator, List, Optional, Sequence, Set, Tuple | |
import libcst as cst | |
import libcst.matchers as m | |
from libcst.codemod import ( | |
CodemodContext, | |
ContextAwareTransformer, | |
ContextAwareVisitor, | |
VisitorBasedCodemodCommand, | |
) | |
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | |
class ConvertDataProviderCommand(VisitorBasedCodemodCommand): | |
DESCRIPTION: str = "Converts libcst.testing.utils.data_provider to parameterized.parameterized.expand" | |
def leave_ClassDef( | |
self, original_node: cst.ClassDef, updated_node: cst.ClassDef | |
) -> cst.ClassDef: | |
if not m.matches( | |
updated_node, m.ClassDef(bases=[m.Arg(value=m.Name(value="UnitTest"))]) | |
): | |
return updated_node | |
AddImportsVisitor.add_needed_import(self.context, "unittest") | |
RemoveImportsVisitor.remove_unused_import( | |
self.context, "libcst.testing.utils", "UnitTest" | |
) | |
return updated_node.with_changes( | |
bases=[ | |
cst.Arg( | |
value=cst.Attribute( | |
value=cst.Name(value="unittest"), | |
attr=cst.Name(value="TestCase"), | |
) | |
) | |
] | |
) | |
def leave_FunctionDef( | |
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef | |
) -> cst.FunctionDef: | |
if not m.matches( | |
updated_node, | |
m.FunctionDef( | |
decorators=[ | |
m.Decorator(decorator=m.Call(m.Name(value="data_provider"))) | |
] | |
), | |
): | |
return updated_node | |
params = list(updated_node.params.params) | |
params.insert( | |
1, | |
cst.Param( | |
name=cst.Name( | |
value="_name", | |
), | |
annotation=cst.Annotation( | |
annotation=cst.Name( | |
value="str", | |
), | |
), | |
star="", | |
), | |
) | |
if updated_node.params.kwonly_params is not None: | |
params += list(updated_node.params.kwonly_params) | |
decorator = updated_node.decorators[0].decorator | |
assert len(decorator.args) == 1 | |
old_arg = decorator.args[0] | |
old_arg_value = cst.ensure_type(old_arg.value, cst.Dict) | |
new_arg_elems = [] | |
for elem in old_arg_value.elements: | |
name = cst.ensure_type(elem.key, cst.SimpleString) | |
if isinstance(elem.value, cst.Dict): | |
kwargs = cst.ensure_type(elem.value, cst.Dict) | |
new_arg_elems.append( | |
cst.Element( | |
value=cst.Call( | |
func=cst.Name(value="param"), | |
args=[ | |
cst.Arg(value=name), | |
cst.Arg(value=kwargs, star="**"), | |
], | |
) | |
) | |
) | |
elif isinstance(elem.value, cst.List): | |
args = cst.ensure_type(elem.value, cst.List) | |
new_arg_elems.append( | |
cst.Element( | |
value=cst.Tuple( | |
[cst.Element(value=name)] + list(args.elements), | |
) | |
) | |
) | |
else: | |
raise Exception("Unsupported") | |
AddImportsVisitor.add_needed_import( | |
self.context, "parameterized", "parameterized" | |
) | |
AddImportsVisitor.add_needed_import(self.context, "parameterized", "param") | |
RemoveImportsVisitor.remove_unused_import( | |
self.context, "libcst.testing.utils", "data_provider" | |
) | |
updated_args = [cst.Arg(value=cst.Tuple(new_arg_elems))] | |
return updated_node.with_changes( | |
params=cst.Parameters(params=params), | |
decorators=[ | |
cst.Decorator( | |
decorator=cst.Call( | |
func=cst.Attribute( | |
value=cst.Name(value="parameterized"), | |
attr=cst.Name(value="expand"), | |
), | |
args=updated_args, | |
) | |
) | |
], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment