-
-
Save sjrd/fdb7c413e04db4d86d78eb654a09c5b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//> using scala "3.3.0" | |
//> using dep "ch.epfl.scala::tasty-query:0.8.1" | |
//> using dep "io.get-coursier:coursier_2.13:2.1.4" | |
import scala.collection.mutable | |
import coursier._ | |
import java.nio.file._ | |
import tastyquery.Contexts | |
import tastyquery.Contexts.* | |
import tastyquery.Exceptions.* | |
import tastyquery.Flags | |
import tastyquery.Flags.Covariant | |
import tastyquery.Symbols.* | |
import tastyquery.Trees.* | |
import tastyquery.Types.{Type, *} | |
import tastyquery.jdk.ClasspathLoaders | |
import findmatchtypes.classSymbolOf | |
object findmatchtypes { | |
enum CompatCategory: | |
case FullyDefined, CatchAll, SingleLevelClass, CovariantNested, TypeMemberExtractor, TypeMemberStructuralExtractor, HKTypeParamTycon, Incompatible | |
val FullyCompatibleCats: Set[CompatCategory] = | |
import CompatCategory.* | |
Set(FullyDefined, CatchAll, SingleLevelClass) | |
val stats = mutable.Map(CompatCategory.values.map(_ -> 0)*) | |
final class Magics()(using Context): | |
private def getMagicType(fullName: String): Option[TypeSymbol] = | |
try Some(ctx.findStaticType(fullName)) | |
catch case _: MemberNotFoundException => None | |
private def getMagics[A](op: String => Option[A])(fullNames: String*): Set[A] = | |
Set(fullNames*).flatMap(op) | |
val compileTimeOpsMatchables = | |
getMagics(getMagicType(_))( | |
"scala.compiletime.ops.int.S", | |
) | |
end Magics | |
def magics(using m: Magics): Magics = m | |
def main(args: Array[String]): Unit = { | |
require(args.length == 1) | |
val (organization, artifact, version) = args(0).trim() match | |
case s"$org::$artifact:$ver" => (org, artifact + "_3", ver) | |
case s"$org:$artifact:$ver" => (org, artifact, ver) | |
case arg => throw IllegalArgumentException(s"Cannot parse module ID '$arg'") | |
println(s"$organization : $artifact : $version") | |
val rtJar = Paths.get("/usr/lib/jvm/temurin-8-jdk-amd64/jre/lib/rt.jar") | |
val classpath = resolve(organization, artifact, version) :+ rtJar | |
println(classpath) | |
println("") | |
process(classpath) | |
println("") | |
for cat <- CompatCategory.values do | |
println(s"$cat: ${stats(cat)}") | |
val report = new java.lang.StringBuilder | |
report.append("report,").append(s"$organization:$artifact:$version") | |
for cat <- CompatCategory.values do | |
report.append(",").append(stats(cat)) | |
println(report) | |
} | |
def resolve(organization: String, artifact: String, version: String): List[Path] = { | |
val resolution = Fetch() | |
.addDependencies(Dependency(Module(Organization(organization), ModuleName(artifact)), version)) | |
.run() | |
resolution.toList.map(_.toPath) | |
} | |
def process(classpathPaths: List[Path]): Unit = { | |
val classpath = ClasspathLoaders.read(classpathPaths) | |
val entry = classpath.entries(0) | |
val context = Contexts.init(classpath) | |
given Context = context | |
val magics = new Magics() | |
given Magics = magics | |
val topLevelSymbols = ctx.findSymbolsByClasspathEntry(entry) | |
for topLevelSymbol <- topLevelSymbols do | |
topLevelSymbol match | |
case cls: ClassSymbol => | |
println(cls.fullName.toString()) | |
for tree <- cls.tree do | |
processTree(tree) | |
case _ => | |
} | |
def processTree(root: Tree)(using Context, Magics): Unit = { | |
root.walkTree { tree => | |
tree match | |
case tree: MatchTypeTree => | |
processMatchTypeTree(tree) | |
case _ => | |
} | |
} | |
def processMatchTypeTree(tree: MatchTypeTree)(using Context, Magics): Unit = { | |
val tpe = tree.toType.asInstanceOf[MatchType] | |
for caze <- tpe.cases do | |
val cat = computeCategory(caze) | |
stats(cat) = stats(cat) + 1 | |
if !FullyCompatibleCats.contains(cat) then | |
println(s"$cat: ${caze.paramNames.mkString(", ")} -> ${showDetails(caze.pattern)}") | |
} | |
type SymDetailsSet = mutable.ArrayBuffer[TypeSymbol] | |
def showDetails(tpe: Type)(using Context): String = | |
val symDetails: SymDetailsSet = mutable.ArrayBuffer.empty[TypeSymbol] | |
given SymDetailsSet = symDetails | |
var result = show(tpe) | |
var idx = 0 | |
while idx < symDetails.size do | |
val sym = symDetails(idx) | |
if !(sym.isClass && sym.owner == defn.scalaPackage) then | |
result += "\n " + show(sym) | |
idx += 1 | |
result | |
end showDetails | |
def show(sym: TypeSymbol)(using Context, SymDetailsSet): String = sym match | |
case sym: ClassSymbol => | |
val kw = if sym.is(Flags.Trait) then "trait" else "class" | |
val typeParams = | |
if sym.typeParams.isEmpty then "" | |
else sym.typeParams.map(p => s"${variance(p)}${p.name}${show(p.bounds)}").mkString("[", ", ", "]") | |
s"$kw ${sym.name}$typeParams" | |
case sym: TypeMemberSymbol => | |
s"type ${sym.name}${show(sym.bounds)}" | |
case sym: TypeParamSymbol => | |
s"${variance(sym)}${sym.name}${show(sym.bounds)}" | |
end show | |
def variance(sym: TypeParamSymbol)(using Context): String = | |
if sym.is(Covariant) then "+" | |
else if sym.is(Flags.Contravariant) then "-" | |
else "" | |
def show(tpe: Type)(using Context, SymDetailsSet): String = tpe match | |
case tpe: NamedType => | |
tpe.optSymbol match | |
case Some(sym: TypeSymbol) if !summon[SymDetailsSet].contains(sym) => | |
summon[SymDetailsSet] += sym | |
case _ => | |
() | |
tpe.prefix match | |
case NoPrefix => tpe.name.toString() | |
case prefix: Type => show(prefix) + "." + tpe.name | |
case tpe: PackageRef => | |
tpe.fullyQualifiedName.toString() | |
case tpe: ThisType => | |
s"${tpe.cls.name}.this" | |
case tpe: SuperType => | |
"super" | |
case tpe: ConstantType => | |
tpe.value.toString() | |
case tpe: AppliedType => | |
show(tpe.tycon) + tpe.args.map(show(_)).mkString("[", ", ", "]") | |
case tpe: ByNameType => | |
"=> " + show(tpe.resultType) | |
case tpe: TypeLambda => | |
val argStrs = tpe.paramNames.lazyZip(tpe.paramTypeBounds).map { (name, bounds) => | |
s"$name${show(bounds)}" | |
} | |
argStrs.mkString("[", ", ", "] =>> ") + show(tpe.resultType) | |
case tpe: TypeParamRef => | |
tpe.paramName.toString() | |
case tpe: TermParamRef => | |
tpe.paramName.toString() | |
case tpe: AnnotatedType => | |
s"(${show(tpe.typ)} @${tpe.annotation.tree})" | |
case tpe: TypeRefinement => | |
s"(${show(tpe.parent)} { type ${tpe.refinedName}${show(tpe.refinedBounds)}})" | |
case tpe: TermRefinement => | |
val kw = if tpe.isStable then "val" else "def" | |
s"(${show(tpe.parent)} { $kw ${tpe.refinedName}: ${show(tpe.refinedType)}})" | |
case tpe: RecType => | |
s"{ this#${tpe.debugID} => ${show(tpe.parent)} }" | |
case tpe: RecThis => | |
s"this#${tpe.binders.debugID}" | |
case tpe: WildcardTypeBounds => | |
"_" + show(tpe.bounds) | |
case tpe: OrType => | |
s"(${show(tpe.first)} | ${show(tpe.second)})" | |
case tpe: AndType => | |
s"(${show(tpe.first)} & ${show(tpe.second)})" | |
case _: MatchType | _: SkolemType | _: MethodicType | _: CustomTransientGroundType => | |
tpe.toString() | |
end show | |
def show(bounds: TypeBounds)(using Context, SymDetailsSet): String = bounds match | |
case RealTypeBounds(low, high) => | |
val lowStr = if isExactlyScalaDot(low, defn.NothingClass) then "" else " >: " + show(low) | |
val highStr = if isExactlyScalaDot(high, defn.AnyClass) then "" else " <: " + show(high) | |
lowStr + highStr | |
case TypeAlias(alias) => | |
" = " + show(alias) | |
end show | |
private def isExactlyScalaDot(tpe: Type, cls: ClassSymbol)(using Context): Boolean = tpe match | |
case tpe: TypeRef if tpe.name == cls.name => | |
tpe.prefix match | |
case prefix: PackageRef => | |
prefix.symbol == defn.scalaPackage && tpe.optSymbol.contains(cls) | |
case _ => | |
false | |
case _ => | |
false | |
end isExactlyScalaDot | |
def computeCategory(caze: MatchTypeCase)(using Context, Magics): CompatCategory = | |
if caze.paramNames.isEmpty then | |
CompatCategory.FullyDefined | |
else | |
caze.pattern match | |
case pattern: AppliedType => | |
computeCatForCapturing(caze) | |
case pattern: TypeParamRef if pattern.binders == caze => | |
if caze.paramNames.sizeIs == 1 then | |
CompatCategory.CatchAll | |
else | |
CompatCategory.Incompatible | |
case pattern => | |
CompatCategory.Incompatible | |
end computeCategory | |
def computeCatForCapturing(caze: MatchTypeCase)(using Context, Magics): CompatCategory = | |
var bestCompat: CompatCategory = CompatCategory.SingleLevelClass | |
def problem(compat: CompatCategory): Unit = | |
if compat.ordinal > bestCompat.ordinal then | |
bestCompat = compat | |
val usedTypeParamRefs = mutable.Set.empty[TypeParamRef] | |
def walk(pattern: Type, level: Int, enclosingCovariant: Boolean, covariant: Boolean): Unit = pattern match | |
case pattern: AppliedType => | |
val tyconKind = checkTycon(pattern.tycon) | |
tyconKind match | |
case TyconKind.Class(paramsIsCovariant) => | |
for (arg, argCovariant) <- pattern.args.zip(paramsIsCovariant) do | |
walk(arg, level + 1, enclosingCovariant = covariant, covariant && argCovariant) | |
case TyconKind.TypeMemberExtractor(structural) => | |
if structural then | |
problem(CompatCategory.TypeMemberStructuralExtractor) | |
else | |
problem(CompatCategory.TypeMemberExtractor) | |
for arg <- pattern.args do | |
walk(arg, level + 1, enclosingCovariant = covariant, covariant = false) | |
case TyconKind.HKTypeParam => | |
problem(CompatCategory.HKTypeParamTycon) | |
for arg <- pattern.args do | |
walk(arg, level + 1, enclosingCovariant = covariant, covariant = false) | |
case TyconKind.Unknown => | |
problem(CompatCategory.Incompatible) | |
case pattern: TypeParamRef if pattern.binders == caze => | |
if !usedTypeParamRefs.add(pattern) then | |
problem(CompatCategory.Incompatible) | |
else if level > 1 then | |
if enclosingCovariant then | |
problem(CompatCategory.CovariantNested) | |
else | |
problem(CompatCategory.Incompatible) | |
case pattern => | |
() | |
end walk | |
walk(caze.pattern, level = 0, enclosingCovariant = true, covariant = true) | |
if usedTypeParamRefs != caze.paramRefs.toSet then | |
problem(CompatCategory.Incompatible) | |
bestCompat | |
end computeCatForCapturing | |
enum TyconKind: | |
case Class(paramsIsCovariant: List[Boolean]) | |
case TypeMemberExtractor(structural: Boolean) | |
case HKTypeParam | |
case Unknown | |
def checkTycon(tycon: Type)(using Context, Magics): TyconKind = { | |
tycon match | |
case tycon: TypeRef => | |
tycon.optAliasedType match | |
case Some(alias) => | |
checkTycon(alias) | |
case None => | |
tycon.optSymbol match | |
case Some(cls: ClassSymbol) => | |
TyconKind.Class(tycon.optSymbol.get.asClass.typeParams.map(_.is(Covariant))) | |
case Some(sym: TypeMemberSymbol) if magics.compileTimeOpsMatchables.contains(sym) => | |
// for the purposes of our analysis, this behaves like an invariant class | |
val typeParamCount = sym.upperBound match | |
case tpe: TypeLambda => tpe.paramNames.size | |
case _ => 0 | |
TyconKind.Class(List.fill(typeParamCount)(false)) | |
case Some(_: LocalTypeParamSymbol) => | |
TyconKind.HKTypeParam | |
case _ => | |
TyconKind.Unknown | |
case tycon: TypeLambda => | |
tycon.resultType match | |
case result: AppliedType => | |
if allSame(result.args, tycon.paramRefs) then | |
checkTycon(result.tycon) | |
else | |
TyconKind.Unknown | |
case result: TypeRefinement if tycon.paramNames.sizeIs == 1 => | |
result.refinedBounds match | |
case TypeAlias(alias) if alias eq tycon.paramRefs.head => | |
classSymbolOf(result.parent) match | |
case Some(cls) => | |
cls.getMember(result.refinedName) match | |
case Some(_: TypeMemberSymbol) => TyconKind.TypeMemberExtractor(structural = false) | |
case _ => TyconKind.TypeMemberExtractor(structural = true) | |
case None => | |
TyconKind.Unknown | |
case _ => | |
TyconKind.Unknown | |
case _ => | |
TyconKind.Unknown | |
case _ => | |
TyconKind.Unknown | |
} | |
def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean = | |
xs.sizeCompare(ys) == 0 && xs.lazyZip(ys).forall(_ eq _) | |
def classSymbolOf(tpe: Type)(using Context): Option[ClassSymbol] = | |
tpe.dealias match | |
case tpe: TypeRef if tpe.isClass => Some(tpe.optSymbol.get.asClass) | |
case _ => None | |
end classSymbolOf | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment