Skip to content

Instantly share code, notes, and snippets.

@agucova
Last active June 13, 2023 23:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agucova/cae477b3e7d913487b9849e0c9560075 to your computer and use it in GitHub Desktop.
Save agucova/cae477b3e7d913487b9849e0c9560075 to your computer and use it in GitHub Desktop.
Automated libCST refactor for squigglepy
import libcst as cst
from libcst.codemod import CodemodTest, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
DISTS: dict[str, str] = {
"base": "BaseDistribution",
"complex": "ComplexDistribution",
"const": "ConstantDistribution",
"uniform": "UniformDistribution",
"norm": "NormalDistribution",
"lognorm": "LognormalDistribution",
"binomial": "BinomialDistribution",
"beta": "BetaDistribution",
"bernoulli": "BernoulliDistribution",
"discrete": "DiscreteDistribution",
"tdist": "TDistribution",
"log_tdist": "LogTDistribution",
"triangular": "TriangularDistribution",
"poisson": "PoissonDistribution",
"chisquare": "ChiSquareDistribution",
"exponential": "ExponentialDistribution",
"gamma": "GammaDistribution",
"pareto": "ParetoDistribution",
"mixture": "MixtureDistribution",
}
class DistributionTypeRefactorCommand(VisitorBasedCodemodCommand):
DESCRIPTION = """
Refactor expressions of the form 'prior.type == "beta"' to 'isinstance(prior, BetaDistribution)'
"""
def leave_Comparison(
self, original_node: cst.Comparison, updated_node: cst.Comparison
) -> cst.Comparison:
if len(updated_node.comparisons) != 1:
# We're only interested in comparisons with one comparator.
return updated_node
comparison_base = updated_node.left
comparison_target = updated_node.comparisons[0]
# Check we're comparing against an attribute
if not isinstance(comparison_base, cst.Attribute):
return updated_node
# Check the operator is ==, !=, or is
if not isinstance(comparison_target.operator, (cst.Is, cst.Equal, cst.NotEqual)):
return updated_node
attribute = comparison_base
# Check the attribute is a type attribute
# Explore nested attribute for a "type" value
explored_attribute = attribute
print(explored_attribute)
# while isinstance(explored_attribute.value, cst.Attribute):
# explored_attribute = attribute.value
if not isinstance(explored_attribute.attr, cst.Name) or explored_attribute.attr.value != "type":
return updated_node
# Check we're comparing against a string
if not isinstance(comparison_target.comparator, cst.SimpleString):
return updated_node
# Check if it's a supported distribution
if comparison_target.comparator.evaluated_value not in DISTS:
print(f"Unsupported distribution: {comparison_target.comparator.evaluated_value}")
return updated_node
# If we've gotten this far, we have a match! Replace the comparison with a call to isinstance.
dist = DISTS[comparison_target.comparator.evaluated_value]
new_left = attribute.value
new_node = cst.Call(
func=cst.Name("isinstance"),
args=[
cst.Arg(new_left),
cst.Arg(cst.Name(dist)),
],
)
# Ensure import is added
AddImportsVisitor.add_needed_import(self.context, ".distributions", dist)
return new_node
class TestDistributionTypeRefactorCommand(CodemodTest):
TRANSFORM = DistributionTypeRefactorCommand
def test_noop(self) -> None:
before = """
from squigglepy import BaseDistribution
prior = BaseDistribution()
"""
after = """
from squigglepy import BaseDistribution
prior = BaseDistribution()
"""
self.assertCodemod(before, after)
def test_simple(self) -> None:
before = """
from .distributions import BaseDistribution
prior = BaseDistribution()
prior.type == "base"
"""
after = """
from .distributions import BaseDistribution
from .distributions import BaseDistribution
prior = BaseDistribution()
isinstance(prior, BaseDistribution)
"""
self.assertCodemod(before, after)
def test_if(self) -> None:
before = """
from .distributions import BaseDistribution
prior = BaseDistribution()
if prior.type == "beta":
print("hello")
"""
after = """
from .distributions import BaseDistribution
from .distributions import BetaDistribution
prior = BaseDistribution()
if isinstance(prior, BetaDistribution):
print("hello")
"""
self.assertCodemod(before, after)
def test_auto_import(self) -> None:
before = """
import squigglepy as sq
posterior = sq.norm(0, 1)
if posterior.type == "norm":
print("hello")
"""
after = """
import squigglepy as sq
from .distributions import NormalDistribution
posterior = sq.norm(0, 1)
if isinstance(posterior, NormalDistribution):
print("hello")
"""
self.assertCodemod(before, after)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment