Skip to content

Instantly share code, notes, and snippets.

@sjrd
Created June 12, 2023 11:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sjrd/fdb7c413e04db4d86d78eb654a09c5b8 to your computer and use it in GitHub Desktop.
Save sjrd/fdb7c413e04db4d86d78eb654a09c5b8 to your computer and use it in GitHub Desktop.
//> using scala "3.3.0"
//> using dep "ch.epfl.scala::tasty-query:0.8.1"
//> using dep "io.get-coursier:coursier_2.13:2.1.4"
import scala.collection.mutable
import coursier._
import java.nio.file._
import tastyquery.Contexts
import tastyquery.Contexts.*
import tastyquery.Exceptions.*
import tastyquery.Flags
import tastyquery.Flags.Covariant
import tastyquery.Symbols.*
import tastyquery.Trees.*
import tastyquery.Types.{Type, *}
import tastyquery.jdk.ClasspathLoaders
import findmatchtypes.classSymbolOf
object findmatchtypes {
enum CompatCategory:
case FullyDefined, CatchAll, SingleLevelClass, CovariantNested, TypeMemberExtractor, TypeMemberStructuralExtractor, HKTypeParamTycon, Incompatible
val FullyCompatibleCats: Set[CompatCategory] =
import CompatCategory.*
Set(FullyDefined, CatchAll, SingleLevelClass)
val stats = mutable.Map(CompatCategory.values.map(_ -> 0)*)
final class Magics()(using Context):
private def getMagicType(fullName: String): Option[TypeSymbol] =
try Some(ctx.findStaticType(fullName))
catch case _: MemberNotFoundException => None
private def getMagics[A](op: String => Option[A])(fullNames: String*): Set[A] =
Set(fullNames*).flatMap(op)
val compileTimeOpsMatchables =
getMagics(getMagicType(_))(
"scala.compiletime.ops.int.S",
)
end Magics
def magics(using m: Magics): Magics = m
def main(args: Array[String]): Unit = {
require(args.length == 1)
val (organization, artifact, version) = args(0).trim() match
case s"$org::$artifact:$ver" => (org, artifact + "_3", ver)
case s"$org:$artifact:$ver" => (org, artifact, ver)
case arg => throw IllegalArgumentException(s"Cannot parse module ID '$arg'")
println(s"$organization : $artifact : $version")
val rtJar = Paths.get("/usr/lib/jvm/temurin-8-jdk-amd64/jre/lib/rt.jar")
val classpath = resolve(organization, artifact, version) :+ rtJar
println(classpath)
println("")
process(classpath)
println("")
for cat <- CompatCategory.values do
println(s"$cat: ${stats(cat)}")
val report = new java.lang.StringBuilder
report.append("report,").append(s"$organization:$artifact:$version")
for cat <- CompatCategory.values do
report.append(",").append(stats(cat))
println(report)
}
def resolve(organization: String, artifact: String, version: String): List[Path] = {
val resolution = Fetch()
.addDependencies(Dependency(Module(Organization(organization), ModuleName(artifact)), version))
.run()
resolution.toList.map(_.toPath)
}
def process(classpathPaths: List[Path]): Unit = {
val classpath = ClasspathLoaders.read(classpathPaths)
val entry = classpath.entries(0)
val context = Contexts.init(classpath)
given Context = context
val magics = new Magics()
given Magics = magics
val topLevelSymbols = ctx.findSymbolsByClasspathEntry(entry)
for topLevelSymbol <- topLevelSymbols do
topLevelSymbol match
case cls: ClassSymbol =>
println(cls.fullName.toString())
for tree <- cls.tree do
processTree(tree)
case _ =>
}
def processTree(root: Tree)(using Context, Magics): Unit = {
root.walkTree { tree =>
tree match
case tree: MatchTypeTree =>
processMatchTypeTree(tree)
case _ =>
}
}
def processMatchTypeTree(tree: MatchTypeTree)(using Context, Magics): Unit = {
val tpe = tree.toType.asInstanceOf[MatchType]
for caze <- tpe.cases do
val cat = computeCategory(caze)
stats(cat) = stats(cat) + 1
if !FullyCompatibleCats.contains(cat) then
println(s"$cat: ${caze.paramNames.mkString(", ")} -> ${showDetails(caze.pattern)}")
}
type SymDetailsSet = mutable.ArrayBuffer[TypeSymbol]
def showDetails(tpe: Type)(using Context): String =
val symDetails: SymDetailsSet = mutable.ArrayBuffer.empty[TypeSymbol]
given SymDetailsSet = symDetails
var result = show(tpe)
var idx = 0
while idx < symDetails.size do
val sym = symDetails(idx)
if !(sym.isClass && sym.owner == defn.scalaPackage) then
result += "\n " + show(sym)
idx += 1
result
end showDetails
def show(sym: TypeSymbol)(using Context, SymDetailsSet): String = sym match
case sym: ClassSymbol =>
val kw = if sym.is(Flags.Trait) then "trait" else "class"
val typeParams =
if sym.typeParams.isEmpty then ""
else sym.typeParams.map(p => s"${variance(p)}${p.name}${show(p.bounds)}").mkString("[", ", ", "]")
s"$kw ${sym.name}$typeParams"
case sym: TypeMemberSymbol =>
s"type ${sym.name}${show(sym.bounds)}"
case sym: TypeParamSymbol =>
s"${variance(sym)}${sym.name}${show(sym.bounds)}"
end show
def variance(sym: TypeParamSymbol)(using Context): String =
if sym.is(Covariant) then "+"
else if sym.is(Flags.Contravariant) then "-"
else ""
def show(tpe: Type)(using Context, SymDetailsSet): String = tpe match
case tpe: NamedType =>
tpe.optSymbol match
case Some(sym: TypeSymbol) if !summon[SymDetailsSet].contains(sym) =>
summon[SymDetailsSet] += sym
case _ =>
()
tpe.prefix match
case NoPrefix => tpe.name.toString()
case prefix: Type => show(prefix) + "." + tpe.name
case tpe: PackageRef =>
tpe.fullyQualifiedName.toString()
case tpe: ThisType =>
s"${tpe.cls.name}.this"
case tpe: SuperType =>
"super"
case tpe: ConstantType =>
tpe.value.toString()
case tpe: AppliedType =>
show(tpe.tycon) + tpe.args.map(show(_)).mkString("[", ", ", "]")
case tpe: ByNameType =>
"=> " + show(tpe.resultType)
case tpe: TypeLambda =>
val argStrs = tpe.paramNames.lazyZip(tpe.paramTypeBounds).map { (name, bounds) =>
s"$name${show(bounds)}"
}
argStrs.mkString("[", ", ", "] =>> ") + show(tpe.resultType)
case tpe: TypeParamRef =>
tpe.paramName.toString()
case tpe: TermParamRef =>
tpe.paramName.toString()
case tpe: AnnotatedType =>
s"(${show(tpe.typ)} @${tpe.annotation.tree})"
case tpe: TypeRefinement =>
s"(${show(tpe.parent)} { type ${tpe.refinedName}${show(tpe.refinedBounds)}})"
case tpe: TermRefinement =>
val kw = if tpe.isStable then "val" else "def"
s"(${show(tpe.parent)} { $kw ${tpe.refinedName}: ${show(tpe.refinedType)}})"
case tpe: RecType =>
s"{ this#${tpe.debugID} => ${show(tpe.parent)} }"
case tpe: RecThis =>
s"this#${tpe.binders.debugID}"
case tpe: WildcardTypeBounds =>
"_" + show(tpe.bounds)
case tpe: OrType =>
s"(${show(tpe.first)} | ${show(tpe.second)})"
case tpe: AndType =>
s"(${show(tpe.first)} & ${show(tpe.second)})"
case _: MatchType | _: SkolemType | _: MethodicType | _: CustomTransientGroundType =>
tpe.toString()
end show
def show(bounds: TypeBounds)(using Context, SymDetailsSet): String = bounds match
case RealTypeBounds(low, high) =>
val lowStr = if isExactlyScalaDot(low, defn.NothingClass) then "" else " >: " + show(low)
val highStr = if isExactlyScalaDot(high, defn.AnyClass) then "" else " <: " + show(high)
lowStr + highStr
case TypeAlias(alias) =>
" = " + show(alias)
end show
private def isExactlyScalaDot(tpe: Type, cls: ClassSymbol)(using Context): Boolean = tpe match
case tpe: TypeRef if tpe.name == cls.name =>
tpe.prefix match
case prefix: PackageRef =>
prefix.symbol == defn.scalaPackage && tpe.optSymbol.contains(cls)
case _ =>
false
case _ =>
false
end isExactlyScalaDot
def computeCategory(caze: MatchTypeCase)(using Context, Magics): CompatCategory =
if caze.paramNames.isEmpty then
CompatCategory.FullyDefined
else
caze.pattern match
case pattern: AppliedType =>
computeCatForCapturing(caze)
case pattern: TypeParamRef if pattern.binders == caze =>
if caze.paramNames.sizeIs == 1 then
CompatCategory.CatchAll
else
CompatCategory.Incompatible
case pattern =>
CompatCategory.Incompatible
end computeCategory
def computeCatForCapturing(caze: MatchTypeCase)(using Context, Magics): CompatCategory =
var bestCompat: CompatCategory = CompatCategory.SingleLevelClass
def problem(compat: CompatCategory): Unit =
if compat.ordinal > bestCompat.ordinal then
bestCompat = compat
val usedTypeParamRefs = mutable.Set.empty[TypeParamRef]
def walk(pattern: Type, level: Int, enclosingCovariant: Boolean, covariant: Boolean): Unit = pattern match
case pattern: AppliedType =>
val tyconKind = checkTycon(pattern.tycon)
tyconKind match
case TyconKind.Class(paramsIsCovariant) =>
for (arg, argCovariant) <- pattern.args.zip(paramsIsCovariant) do
walk(arg, level + 1, enclosingCovariant = covariant, covariant && argCovariant)
case TyconKind.TypeMemberExtractor(structural) =>
if structural then
problem(CompatCategory.TypeMemberStructuralExtractor)
else
problem(CompatCategory.TypeMemberExtractor)
for arg <- pattern.args do
walk(arg, level + 1, enclosingCovariant = covariant, covariant = false)
case TyconKind.HKTypeParam =>
problem(CompatCategory.HKTypeParamTycon)
for arg <- pattern.args do
walk(arg, level + 1, enclosingCovariant = covariant, covariant = false)
case TyconKind.Unknown =>
problem(CompatCategory.Incompatible)
case pattern: TypeParamRef if pattern.binders == caze =>
if !usedTypeParamRefs.add(pattern) then
problem(CompatCategory.Incompatible)
else if level > 1 then
if enclosingCovariant then
problem(CompatCategory.CovariantNested)
else
problem(CompatCategory.Incompatible)
case pattern =>
()
end walk
walk(caze.pattern, level = 0, enclosingCovariant = true, covariant = true)
if usedTypeParamRefs != caze.paramRefs.toSet then
problem(CompatCategory.Incompatible)
bestCompat
end computeCatForCapturing
enum TyconKind:
case Class(paramsIsCovariant: List[Boolean])
case TypeMemberExtractor(structural: Boolean)
case HKTypeParam
case Unknown
def checkTycon(tycon: Type)(using Context, Magics): TyconKind = {
tycon match
case tycon: TypeRef =>
tycon.optAliasedType match
case Some(alias) =>
checkTycon(alias)
case None =>
tycon.optSymbol match
case Some(cls: ClassSymbol) =>
TyconKind.Class(tycon.optSymbol.get.asClass.typeParams.map(_.is(Covariant)))
case Some(sym: TypeMemberSymbol) if magics.compileTimeOpsMatchables.contains(sym) =>
// for the purposes of our analysis, this behaves like an invariant class
val typeParamCount = sym.upperBound match
case tpe: TypeLambda => tpe.paramNames.size
case _ => 0
TyconKind.Class(List.fill(typeParamCount)(false))
case Some(_: LocalTypeParamSymbol) =>
TyconKind.HKTypeParam
case _ =>
TyconKind.Unknown
case tycon: TypeLambda =>
tycon.resultType match
case result: AppliedType =>
if allSame(result.args, tycon.paramRefs) then
checkTycon(result.tycon)
else
TyconKind.Unknown
case result: TypeRefinement if tycon.paramNames.sizeIs == 1 =>
result.refinedBounds match
case TypeAlias(alias) if alias eq tycon.paramRefs.head =>
classSymbolOf(result.parent) match
case Some(cls) =>
cls.getMember(result.refinedName) match
case Some(_: TypeMemberSymbol) => TyconKind.TypeMemberExtractor(structural = false)
case _ => TyconKind.TypeMemberExtractor(structural = true)
case None =>
TyconKind.Unknown
case _ =>
TyconKind.Unknown
case _ =>
TyconKind.Unknown
case _ =>
TyconKind.Unknown
}
def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean =
xs.sizeCompare(ys) == 0 && xs.lazyZip(ys).forall(_ eq _)
def classSymbolOf(tpe: Type)(using Context): Option[ClassSymbol] =
tpe.dealias match
case tpe: TypeRef if tpe.isClass => Some(tpe.optSymbol.get.asClass)
case _ => None
end classSymbolOf
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment