Skip to content

Instantly share code, notes, and snippets.

@pshirshov
Last active October 18, 2017 16:27
Show Gist options
  • Save pshirshov/a58581e5910e33471b6eac38e52b69d8 to your computer and use it in GitHub Desktop.
Save pshirshov/a58581e5910e33471b6eac38e52b69d8 to your computer and use it in GitHub Desktop.
HLIst and implicit based specializations
// hlist definition
sealed trait HList
case class HPair[A, B <: HList](head: A, tail: B) extends HList {
override def toString = Seq(head, tail).mkString(" ::: ")
}
case object HNil extends HList
// instantiation
object HListFactory {
implicit class ToHListEnd[H](v: H) {
def :::[T](b: T): HPair[T, HPair[H, HNil.type]] = HPair(b, HPair(v, HNil))
}
implicit class ToHList[H <: HList](v: H) {
def :::[T](b: T): HPair[T, H] = HPair(b, v)
}
}
// Loop over the list based on specialization simulated through implicit resolution mechanism
trait IsEnumerable[L <: HList] {
def length(l: L): Int
}
object IsEnumerable {
implicit def hlistIsEnumerable0: IsEnumerable[HNil.type] = new IsEnumerable[HNil.type] {
override def length(l: HNil.type) = 0
}
implicit def hlistIsEnumerableN[H0, T0 <: HList : IsEnumerable]: IsEnumerable[H0 HPair T0] = new IsEnumerable[H0 HPair T0] {
override def length(l: HPair[H0, T0]) = 1 + implicitly[IsEnumerable[T0]].length(l.tail)
}
def length[H <: HList : IsEnumerable](hl: H): Int = {
implicitly[IsEnumerable[H]].length(hl)
}
}
// basic list operations, one-way recursion
trait IsHPair[Lst <: HList] {
type H
type T <: HList
type L
type I <: HList
def head(l : Lst): H
def tail(l : Lst): T
def last(l : Lst): L
def init(l : Lst): I
def ==(l: Lst, other: Lst): Boolean
}
object IsHPair {
type Aux[Lst <: HList, H0, T0 <: HList] = IsHPair[Lst] {type H = H0; type T = T0}
implicit def hlistIsPair0[H0]: Aux[H0 HPair HNil.type, H0, HNil.type] = new IsHPair[H0 HPair HNil.type] {
override type H = H0
override type T = HNil.type
override type L = H0
override type I = HNil.type
override def init(l: HPair[H0, HNil.type]): I = HNil
override def head(l: HPair[H0, T]): H = l.head
override def tail(l: HPair[H0, T]): T = HNil
override def last(l: HPair[H0, T]): H = l.head
override def ==(l: HPair[H0, T], other: HPair[H0, HNil.type]): Boolean = other.head == l.head
}
implicit def hlistIsPairN[H0, T0 <: HList : IsHPair]: Aux[H0 HPair T0, H0, T0] = new IsHPair[H0 HPair T0] {
override type H = H0
override type T = T0
override type L = IsHPair[T0]#L
override type I = H0 HPair IsHPair[T0]#I
override def head(l: HPair[H0, T0]) = l.head
override def tail(l: HPair[H0, T0]) = l.tail
override def last(l: HPair[H0, T0]): L = implicitly[IsHPair[T0]].last(l.tail)
override def init(l: HPair[H0, T0]) = HPair(l.head, implicitly[IsHPair[T0]].init(l.tail))
override def ==(l: HPair[H0, T0], other: HPair[H0, T0]): Boolean = l.head == other.head && implicitly[IsHPair[T0]].==(l.tail, other.tail)
}
implicit class HListPairOps[H <: HList : IsHPair](hl: H) {
type IP = IsHPair[H]
def hhead: IP#H = {
implicitly[IP].head(hl)
}
def htail: IP#T = {
implicitly[IP].tail(hl)
}
def hlast: IP#L = {
implicitly[IP].last(hl)
}
def hinit: IP#I = {
implicitly[IP].init(hl)
}
def ==(o: H): Boolean = {
implicitly[IP].==(hl, o)
}
}
}
// joining two lists by replacing HNil in Lst1 with Lst2
trait IsJoinable[Lst1 <: HList, Lst2 <: HList] {
type J <: HList
def join(l: Lst1, other: Lst2): J
}
object IsJoinable {
type Aux[H0, T0 <: HList, Lst2 <: HList] = IsJoinable[HPair[H0, T0], Lst2]
implicit def joinableFromHNil[Lst2 <: HList]: IsJoinable[HNil.type, Lst2] = new IsJoinable[HNil.type, Lst2] {
override type J = Lst2
override def join(l: HNil.type, other: Lst2): J = other
}
implicit def joinAux[H0, T0 <: HList, Lst2 <: HList](implicit ev: IsJoinable[T0, Lst2]): Aux[H0, T0, Lst2] = new IsJoinable[HPair[H0, T0], Lst2] {
override type J = HPair[H0, IsJoinable[T0, Lst2]#J]
override def join(l: HPair[H0, T0], other: Lst2): J = HPair(l.head, ev.join(l.tail, other))
}
implicit class HListJoinOps[Lst1 <: HList](hl: Lst1) {
def hjoin[Lst2 <: HList](o: Lst2)(implicit ev: IsJoinable[Lst1, Lst2]): IsJoinable[Lst1, Lst2]#J = {
ev.join(hl, o)
}
}
}
// TODO: zip, hfold
import HListFactory._
val l0 = 1 ::: HNil
val l1 = 1 ::: true
val l2 = 1 ::: true ::: 1.0
val l2_1 = 2 ::: true ::: 1.0
val l2_2 = 2 ::: true ::: 1.0 ::: HNil
assert(l2_1 == l2_2)
assert(l2 == l2)
assert(l2 != l2_1)
import IsEnumerable._
assert(length(l0) == 1)
assert(length(l1) == 2)
assert(length(l2) == 3)
import IsHPair._
assert(l0.hhead == 1)
assert(l1.hhead == 1)
assert(l2.hhead == 1)
assert(l0.htail == HNil)
assert(l1.htail == true ::: HNil)
assert(l2.htail == true ::: 1.0 ::: HNil)
assert(l0.hlast == 1)
assert(l1.hlast == true)
assert(l2.hlast == 1.0)
assert(l0.hinit == HNil)
assert(l1.hinit == 1 ::: HNil)
assert(l2.hinit == 1 ::: true ::: HNil)
import IsJoinable._
assert(l2.hjoin(HNil) == 1 ::: true ::: 1.0)
assert(l2.hjoin(l2_2) == 1 ::: true ::: 1.0 ::: 2 ::: true ::: 1.0)
assert(HNil.hjoin(l2_2) == 2 ::: true ::: 1.0)
val f1: Int => String = _.toString
val f2: String => Double = _.toDouble
val f3: Double => String = d => s"=$d"
val lf1 = f1 ::: HNil
val lf2 = f1 ::: f2 ::: f3 ::: HNil
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment