Skip to content

Instantly share code, notes, and snippets.

@TomasMikula
Created May 19, 2015 05:23
Show Gist options
  • Save TomasMikula/9a1d8d2c18262d013c59 to your computer and use it in GitHub Desktop.
Save TomasMikula/9a1d8d2c18262d013c59 to your computer and use it in GitHub Desktop.
Map a single type in a Coproduct. Lift Function1 to work on Coproducts.
import shapeless._
/** In coproduct `C`, replace `A` with `B`. */
trait Replace[C <: Coproduct, A, B] extends DepFn2[C, A => B] {
type Out <: Coproduct
def lift(f: A => B): C => Out =
c => apply(c, f)
}
object Replace {
type Aux[C <: Coproduct, A, B, D <: Coproduct] = Replace[C, A, B] { type Out = D }
implicit def inl[H1, T <: Coproduct, H2]: Aux[H1 :+: T, H1, H2, H2 :+: T] =
new Replace[H1 :+: T, H1, H2] {
type Out = H2 :+: T
def apply(c: H1 :+: T, f: H1 => H2): Out = c match {
case Inl(h) => Inl(f(h))
case Inr(t) => Inr(t)
}
}
implicit def inr[H, T <: Coproduct, A, B](
implicit replace: Replace[T, A, B]
): Aux[H :+: T, A, B, H :+: replace.Out] = new Replace[H :+: T, A, B] {
type Out = H :+: replace.Out
def apply(c: H :+: T, f: A => B): Out = c match {
case Inl(h) => Inl(h)
case Inr(t) => Inr(replace(t, f))
}
}
}
case class CoproductExt[C <: Coproduct](c: C) extends AnyVal {
def replace[A, B](f: A => B)(implicit replace: Replace[C, A, B]) = replace(c, f)
}
object CoproductExt {
import scala.language.implicitConversions
implicit def coproductExt[C <: Coproduct](c: C) = CoproductExt(c)
}
case class Function1Ext[A, B](f: A => B) extends AnyVal {
def lift[C <: Coproduct, D <: Coproduct](implicit replace: Replace.Aux[C, A, B, D]): C => D = replace.lift(f)
def liftTo[C <: Coproduct](implicit replace: Replace[C, A, B]): C => replace.Out = replace.lift(f)
}
object Function1Ext {
import scala.language.implicitConversions
implicit def function1Ext[A, B](f: A => B) = Function1Ext(f)
}
object Test extends App {
import CoproductExt._
import Function1Ext._
type ISB = Int :+: String :+: Boolean :+: CNil
type IDB = Int :+: Double :+: Boolean :+: CNil
val isb = Coproduct[ISB]("foo")
val f: String => Double = x => 0.42
val idb = isb replace f
println(idb) // 0.42
val g: ISB => IDB = f.lift
println(g(isb)) // 0.42
val h = f.liftTo[ISB]
println(h(isb)) // 0.42
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment