Skip to content

Instantly share code, notes, and snippets.

@johnynek
Created May 22, 2021 19:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnynek/66e70c39c19902d0079aac318253b4be to your computer and use it in GitHub Desktop.
Save johnynek/66e70c39c19902d0079aac318253b4be to your computer and use it in GitHub Desktop.
An attempt to use match types for a safe apply in scala 3
package sizetypes
import scala.compiletime.ops.int
type LessThan[A <: Int] =
A match
case 0 => Nothing
case _ =>
int.-[A, 1] | LessThan[int.-[A, 1]]
val two: LessThan[3] = 2
def zeroOrDec[A <: Int](idx: LessThan[int.S[A]]): Either[LessThan[A], 0] =
???
/*
How can we implement this method?
if idx == 0 then Right(0)
else
// we know 0 < idx < S[A], so idx - 1 is 0 <= idx < A == LessThan[A]
// how can we get scala to see it?
val prev: LessThan[A] = (idx - 1)
Left(prev)
*/
sealed abstract class SList[Z <: Int, +A] {
def apply(idx: LessThan[Z]): A
def ::[A1 >: A](item: A1): SList.Cons[Z, A1] =
SList.Cons(item, this)
}
object SList {
sealed abstract class IntEv[A <: Int, B <: Int] {
def subst[F[_ <: Int]](fa: F[A]): F[B]
}
object IntEv {
implicit def reflIntEv[A <: Int]: IntEv[A, A] =
new IntEv[A, A] {
def subst[F[_ <: Int]](fa: F[A]) = fa
}
}
case class SNil[Z <: Int](ev: IntEv[Z, 0]) extends SList[Z, Nothing] {
def apply(idx: LessThan[Z]): Nothing = {
val ltz: LessThan[0] = ev.subst[LessThan](idx)
// LessThan[0] == Nothing
ltz
}
}
case class Cons[Z <: Int, +A](head: A, tail: SList[Z, A]) extends SList[int.S[Z], A] {
def apply(idx: LessThan[int.S[Z]]): A =
zeroOrDec(idx) match
case Left(lookTail) => tail(lookTail)
case Right(_) => head
}
def empty: SList[0, Nothing] = SNil(IntEv.reflIntEv[0])
val oneTwoThree = 1 :: 2 :: 3 :: empty
val one = oneTwoThree(0)
}
@note
Copy link

note commented May 22, 2021

From what I understood the thing you're trying to achieve assumes that both list's size and the index you apply are known at compile-time. With such assumptions I came up with simpler encoding:

import scala.compiletime.ops.int.*
import scala.compiletime.{constValue, ops}

sealed abstract class SList[Z <: Int, +A] {
  inline def apply[I <: Int]: A

  def ::[A1 >: A](item: A1): SList.Cons[Z, A1] =
    SList.Cons[Z, A1](item, this)
}

object SList {

  case object SNil extends SList[0, Nothing] {
    inline def apply[I <: Int]: Nothing = compiletime.error("index exceeded list's size")
  }

  case class Cons[Z <: Int, +A](head: A, tail: SList[Z, A]) extends SList[S[Z], A] {
    inline def apply[I <: Int]: A =
      inline if constValue[I] < constValue[S[Z]]
        then inline if constValue[I] == 0
          then head
          else tail.apply[I - 1]
        else compiletime.error("Wrong index: " + constValue[ToString[I]] + " for list of size: " + constValue[ToString[S[Z]]])
  }

  def empty: SList[0, Nothing] = SNil

  val oneTwoThree = 1 :: 2 :: 3 :: empty

  val one = oneTwoThree.apply[0]
  val two = oneTwoThree.apply[1]
  val three = oneTwoThree.apply[2]
  // val doesNotCompile = oneTwoThree.apply[3] // fails compilation with:
  // Wrong index: 3 for list of size 3
}

I am not 100% sure if that captures the essense of what you were striving for

@note
Copy link

note commented May 23, 2021

Another attempt:

import scala.annotation.tailrec
import scala.compiletime.ops.int.*
import scala.compiletime.{constValue, ops}

// In eventual form we'll have to restrict constructing LessThan only to lessThanFromInt conversion
trait LessThan[Z <: Int] {
  def value: Int
}
inline implicit def lessThanFromInt[I <: Int & Singleton, Z <: Int](v: I): LessThan[Z] =
  inline if constValue[I] < constValue[Z] && constValue[I] >= 0
    then new LessThan[Z] {
        def value = v
      }
    else compiletime.error("wrong")

sealed abstract class SList[Z <: Int, +A] {
  def apply[V <: Int](index: LessThan[Z]): A

  protected def _apply(index: Int): A

  def ::[A1 >: A](item: A1): SList.Cons[Z, A1] =
    SList.Cons[Z, A1](item, this)
}

object SList {

  case object SNil extends SList[0, Nothing] {
    // It's here just to satisfy the interface, LessThan[0] type is uninhabited
    def apply[V <: Int](index: LessThan[0]): Nothing = ???
    // Consequently it should not ever happen
    protected def _apply(index: Int): Nothing = ???
  }

  case class Cons[Z <: Int, +A](head: A, tail: SList[Z, A]) extends SList[S[Z], A] {
    def apply[V <: Int](index: LessThan[S[Z]]): A =
      if index.value == 0
        then head
        else tail._apply(index.value - 1)

    @tailrec
    protected final def _apply(index: Int): A =
      if index == 0
        then head
        // we have to _apply specifically on Cons to be able to annotate _apply with tailrec
        else tail match
          case _: Cons[_, _] => _apply(index - 1)
          case SNil => ???
  }

  val oneTwoThree: SList[3, Int] = 1 :: 2 :: 3 :: SNil

  val one = oneTwoThree.apply(0)
  val two = oneTwoThree.apply(1)
  val three = oneTwoThree.apply(2)
  // val doesNotCompile = oneTwoThree.apply[3]

  SListExample.genericFun(oneTwoThree, 2)
  // val doesNotCompile = SListExample.genericFun(oneTwoThree, 3)

  // None of those compile:
//  SNil.apply(-1)
//  SNil.apply(0)
//  SNil.apply(1)
}

object SListExample {
  def genericFun[Z <: Int](ls: SList[Z, Int], index: LessThan[Z]) =
    ls.apply(index)
}

@note
Copy link

note commented May 23, 2021

I managed to make your encoding (with some changes) to work too:

package original

import scala.annotation.tailrec
import scala.compiletime.ops.int.*
import scala.compiletime.constValue

trait LessThan[Z <: Int] { self =>
  def value: Int
  def pred: LessThan[Z - 1] = new LessThan[Z - 1] {
    override def value: Int = self.value - 1
  }
}

inline implicit def intToLessThanValue[I <: Int & Singleton, Z <: Int](i: I): LessThan[Z] =
  inline if (constValue[I] < constValue[Z]){
    new LessThan[Z] {
      def value: Int = i
    }
  } else {
    scala.compiletime.error("wrong")
  }

implicit def inference[Z <: Int](i: LessThan[S[Z] - 1]): LessThan[Z] =
  new LessThan[Z] {
    override def value: Int = i.value
  }

def zeroOrDec[Z <: Int](idx: LessThan[Z]): Either[LessThan[Z - 1], 0] =
  if idx.value == 0
    then Right(0)
    else
      // without `def inference` it was returning:
      // Found:    (lookTail : original.LessThan[compiletime.ops.int.S[Z] - (1 : Int)])
      // [error]      Required: original.LessThan[Z]
      // We know that: S[Z] - 1 =:= Z but compiler cannot deduce it (I guess in principle it could?)
      // So we manually add inference rule
      val prev: LessThan[Z - 1] = idx.pred
      Left(prev)

sealed abstract class SList[Z <: Int, +A] {
  def apply(idx: LessThan[Z]): A

  def ::[A1 >: A](item: A1): SList.Cons[Z, A1] =
    SList.Cons(item, this)
}

object SList {

  sealed abstract class IntEv[A <: Int, B <: Int] {
    def subst[F[_ <: Int]](fa: F[A]): F[B]
  }

  object IntEv {
    implicit def reflIntEv[A <: Int]: IntEv[A, A] =
      new IntEv[A, A] {
        def subst[F[_ <: Int]](fa: F[A]) = fa
      }
  }

  case class SNil[Z <: Int](ev: IntEv[Z, 0]) extends SList[Z, Nothing] {
    def apply(idx: LessThan[Z]): Nothing = {
      val ltz: LessThan[0] = ev.subst[LessThan](idx)
      // LessThan[0] == Nothing
      ???
    }
  }

  case class Cons[Z <: Int, +A](head: A, tail: SList[Z, A]) extends SList[S[Z], A] {
    def apply(idx: LessThan[S[Z]]): A =
      zeroOrDec(idx) match
        case Left(lookTail) => tail(lookTail)
        case Right(_) => head
  }

}

object OriginalExample:
  @main def main(): Unit =
    def empty: SList[0, Nothing] = SList.SNil(SList.IntEv.reflIntEv[0])
    val oneTwoThree = 1 :: 2 :: 3 :: empty

    val one = oneTwoThree(0)
    val three = oneTwoThree(2)
    // val abc = oneTwoThree(3) // does not compile
    println("one: " + one)
    println("three: " + three)

The crucial change is adding def pred: LessThan[Z - 1] and to do that I had to turn LessThan into trait. Another tricky point was adding inference.

I don't think your original def zeroOrDec[A <: Int](idx: LessThan[int.S[A]]): Either[LessThan[A], 0] is implementable. It's because idx - 1, even if you somehow managed minus working, is of type Int. You can see that in much simpler example:

val two: 0 | 1 | 2 = 2
val a: 0 | 1 | 2 = two - 1 // does not compile as right side is of type Int

You're not able to define any methods on LessThan as long as it's a type and I think defining predecessor method with proper return type is a key in this exercise

@johnynek
Copy link
Author

Thanks for thinking through this.

I think making something like the original work would be nice, since the LessThan[A] type could be erased to Int.

I added this discussion:
https://contributors.scala-lang.org/t/math-on-union-numeric-union-types/5095/3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment