Skip to content

Instantly share code, notes, and snippets.

@tschuchortdev
Created May 27, 2024 12:45
Show Gist options
  • Save tschuchortdev/3f02c32b4a2ddd3dd2158060b7b3bd6b to your computer and use it in GitHub Desktop.
Save tschuchortdev/3f02c32b4a2ddd3dd2158060b7b3bd6b to your computer and use it in GitHub Desktop.
How to check exhaustivity of a match expression with Scala 3 macros
private def matchExhaustivelyImplImpl[T: Type](
self: Expr[T],
expr: Expr[T => Any],
m: Expr[Mirror.Of[T]]
)(using q: Quotes): Expr[Any] =
import q.reflect.{*, given}
val expectedCases = m match
case '{ $m: Mirror.ProductOf[s] } => Seq(TypeRepr.of[T])
case '{
type elems <: Tuple;
$m: Mirror.SumOf[s] { type MirroredElemTypes = `elems` }
} =>
tupleToTypeReprs[elems]
/*val cases2: Seq[CaseDef] = new TreeAccumulator[Seq[CaseDef]] {
override def foldTree(acc: Seq[CaseDef], tree: Tree)(owner: Symbol): Seq[CaseDef] = tree match
case Match(matchedVar, cases) => cases
case _ => super.foldOverTree(acc, tree)(owner)
}.foldOverTree(Seq.empty, expr.asTerm)(Symbol.spliceOwner)*/
val caseDefs = expr.asTerm match
case Inlined(_,
_,
TypeApply(
Select(
Block(
List(
DefDef(
lambdaName,
List(TermParamClause(List(ValDef(lambdaParamName, lambdaParamType, _)))),
_,
Some(Match(matchVar @ Ident(matchVarName), cases))
)
),
Closure(Ident(closureName), _)
),
"$asInstanceOf$"
),
_
))
if closureName == lambdaName && matchVarName == lambdaParamName =>
cases
case _ => report.errorAndAbort("Must be a lambda with top-level match expression", expr)
def computeMatchedType(caseDefPattern: Tree): Seq[TypeRepr] = caseDefPattern match
case Alternatives(patterns) => patterns.flatMap(computeMatchedType)
case TypedOrTest(_, tpt) =>
assert(tpt.symbol.isType)
List(tpt.tpe)
case Bind(bindName, tr) =>
assert(tr.symbol.isType)
List(tr.symbol.typeRef.widenByName)
case Unapply(fun @ Select(sel @ Apply(TypeApply(_, typeArgs), _), "unapply"), implicits, bindPatterns) =>
fun.tpe.widenTermRefByName match
// A MethodType is a regular method taking term parameters, a PolyType is a method taking type parameters,
// a TypeLambda is a method returning a type and not a value. Unapply's type should be a function with no
// type parameters, with a single value parameter (the match scrutinee) and with an Option[?] return type
// (no curried function), thus it should be a MethodType.
case methodType: MethodType =>
methodType.resType.asType match
// Also matches Some[] and None in an easy way
case '[Option[tpe]] => TypeRepr.of[tpe] match
case AndType(left, right)
if methodType.paramTypes.nonEmpty && left =:= methodType.param(0) => List(right)
case AndType(left, right)
if methodType.paramTypes.nonEmpty && right =:= methodType.param(0) => List(left)
case tpe => List(tpe)
case '[tpe] => List(TypeRepr.of[tpe])
case tpe: TypeRepr => throw AssertionError(
s"Expected type of Unapply function to be MethodType. Was: ${Printer.TypeReprStructure.show(tpe)}"
)
case pattern =>
throw AssertionError(s"Expected pattern of CaseDef to be either Alternative, TypedOrTest, Bind or Unapply. " +
s"Was: ${Printer.TreeStructure.show(pattern)}")
val caseDefTypes = caseDefs.flatMap { caseDef =>
if caseDef.guard.isDefined then List()
else computeMatchedType(caseDef.pattern)
}
val uncoveredCases = expectedCases.map(_.asType).filterNot { case '[expectedCase] =>
caseDefTypes.map(_.asType).exists { case '[caseDefType] =>
(TypeRepr.of[expectedCase] <:< TypeRepr.of[caseDefType])
|| Expr.summon[expectedCase <:< caseDefType].isDefined
}
}
if uncoveredCases.nonEmpty then
val casesString = uncoveredCases.map { t =>
"_: " + Printer.TypeReprCode.show(typeReprOf(t))
}.mkString(", ")
report.warning(
s"Match may not be exhaustive.\n\nIt would fail on case: $casesString",
Position(self.asTerm.pos.sourceFile, start = expr.asTerm.pos.start - 1, end = expr.asTerm.pos.start + 1)
)
'{ $expr($self) }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment