Skip to content

Instantly share code, notes, and snippets.

@Pitometsu
Forked from dacr/gist:4642782
Created August 21, 2018 21:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Pitometsu/eef5ae6583339f658a9e9ef38eb933e8 to your computer and use it in GitHub Desktop.
Save Pitometsu/eef5ae6583339f658a9e9ef38eb933e8 to your computer and use it in GitHub Desktop.
Custom scala collection examples
import scala.collection._
import scala.collection.mutable.{ArrayBuffer,ListBuffer, Builder}
import scala.collection.generic._
import scala.collection.immutable.VectorBuilder
// ================================ CustomTraversable ==================================
object CustomTraversable extends TraversableFactory[CustomTraversable] {
implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, CustomTraversable[A]] =
new GenericCanBuildFrom[A]
def newBuilder[A] = new ListBuffer[A] mapResult (x => new CustomTraversable(x:_*))
}
class CustomTraversable[A](seq : A*)
extends Traversable[A]
with GenericTraversableTemplate[A, CustomTraversable]
with TraversableLike[A, CustomTraversable[A]] {
override def companion = CustomTraversable
override def foreach[U](f: A => U) = seq.foreach(f)
}
// ================================ CustomSeq ==================================
object CustomSeq extends SeqFactory[CustomSeq] {
implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, CustomSeq[A]] =
new GenericCanBuildFrom[A]
def newBuilder[A] = new ListBuffer[A] mapResult (x => new CustomSeq(x:_*))
}
class CustomSeq[A](seq : A*)
extends Seq[A]
with GenericTraversableTemplate[A, CustomSeq]
with SeqLike[A, CustomSeq[A]] {
override def companion = CustomSeq
def iterator: Iterator[A] = seq.iterator
def apply(idx: Int): A = {
if (idx < 0 || idx>=length) throw new IndexOutOfBoundsException
seq(idx)
}
def length: Int = seq.size
}
// ================================ MySeq ==================================
object MySeq {
def apply[Base](bases: Base*) = fromSeq(bases)
def fromSeq[Base](buf: Seq[Base]): MySeq[Base] = {
var array = new ArrayBuffer[Base](buf.size)
for (i <- 0 until buf.size) array += buf(i)
new MySeq[Base](array)
}
def newBuilder[Base]: Builder[Base, MySeq[Base]] =
new ArrayBuffer mapResult fromSeq
implicit def canBuildFrom[Base,From]: CanBuildFrom[MySeq[_], Base, MySeq[Base]] =
new CanBuildFrom[MySeq[_], Base, MySeq[Base]] {
def apply(): Builder[Base, MySeq[Base]] = newBuilder
def apply(from: MySeq[_]): Builder[Base, MySeq[Base]] = newBuilder
}
}
class MySeq[Base] protected (buffer: ArrayBuffer[Base])
extends IndexedSeq[Base]
with IndexedSeqLike[Base, MySeq[Base]] {
override protected[this] def newBuilder: Builder[Base, MySeq[Base]] =
MySeq.newBuilder
def apply(idx: Int): Base = {
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException
buffer(idx)
}
def length = buffer.length
}
// ================================ NamedSeq ==================================
object NamedSeq {
def apply[Base](name: String, bases: Base*) = fromSeq(name, bases)
def fromSeq[Base](name: String, buf: Seq[Base]): NamedSeq[Base] = {
var array = new ArrayBuffer[Base](buf.size)
for (i <- 0 until buf.size) array += buf(i)
new NamedSeq[Base](name, array)
}
def newBuilder[Base](name: String): Builder[Base, NamedSeq[Base]] =
new ArrayBuffer mapResult { x: ArrayBuffer[Base] => fromSeq(name, x) }
implicit def canBuildFrom[Base]: CanBuildFrom[NamedSeq[_], Base, NamedSeq[Base]] =
new CanBuildFrom[NamedSeq[_], Base, NamedSeq[Base]] {
def apply(): Builder[Base, NamedSeq[Base]] = newBuilder("default")
def apply(from: NamedSeq[_]): Builder[Base, NamedSeq[Base]] =
newBuilder(from.name)
}
}
class NamedSeq[Base] protected (
val name: String,
buffer: ArrayBuffer[Base])
extends IndexedSeq[Base] with IndexedSeqLike[Base, NamedSeq[Base]] {
override protected[this] def newBuilder: Builder[Base, NamedSeq[Base]] =
NamedSeq.newBuilder(name)
def apply(idx: Int): Base = {
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException
buffer(idx)
}
def length = buffer.length
override def toString() = "NamedSeq("+name+" : "+mkString(", ")+")"
}
// ============================= CustomVector ===================================
object CustomVector {
def apply[Base](bases: Base*) = fromSeq(bases.toVector)
def fromSeq[Base](buf: Vector[Base]): CustomVector[Base] =
new CustomVector[Base](buf)
def newBuilder[Base]: Builder[Base, CustomVector[Base]] =
new VectorBuilder mapResult fromSeq
implicit def canBuildFrom[Base,From]:
CanBuildFrom[CustomVector[_], Base, CustomVector[Base]] =
new CanBuildFrom[CustomVector[_], Base, CustomVector[Base]] {
def apply(): Builder[Base, CustomVector[Base]] = newBuilder
def apply(from: CustomVector[_]): Builder[Base, CustomVector[Base]] =
newBuilder
}
}
class CustomVector[Base] protected (buffer: Vector[Base])
extends IndexedSeq[Base]
with IndexedSeqLike[Base, CustomVector[Base]] {
override protected[this] def newBuilder: Builder[Base, CustomVector[Base]] = CustomVector.newBuilder
def apply(idx: Int): Base = {
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException
buffer(idx)
}
def length = buffer.length
}
// NOW THE TEST CASES :
test("CustomTraversable test") {
val l = CustomTraversable(1, 2, 3, 4)
val c = List(5,6,7)
l should not be equals(List(1,2,3,4))
(l ++ c) should be equals (CustomTraversable(1,2,3,4,5,6,7))
(l.map(_.toString)) should be equals(CustomTraversable("1","2","3","4"))
(l.map(_.toString)).getClass.getName should include ("CustomTraversable")
(l.filter(_ > 2)) should be equals(CustomTraversable(3,4))
(l.filter(_ > 2)).getClass.getName should include("CustomTraversable")
l.reduce(_ + _) should equal(10)
}
test("CustomSeq test") {
val l = CustomSeq(1, 2, 3, 4)
val c = List(5,6,7)
l should not be equals(List(1,2,3,4))
(l :+ 8) should be equals (CustomSeq(1,2,3,4,8))
(l ++ c) should be equals (CustomSeq(1,2,3,4,5,6,7))
(l.map(_.toString)) should be equals (CustomSeq("1","2","3","4"))
(l.map(_.toString)) should not be equals (IndexedSeq("1","2","3","4"))
(l.map(_.toString)).getClass.getName should include("CustomSeq")
(l.filter(_ > 2)) should be equals (CustomSeq(3,4))
(l.filter(_ > 2)).getClass.getName should include("CustomSeq")
l.reduce(_ + _) should equal(10)
}
test("CustomVector test") {
val l = CustomVector(1, 2, 3, 4)
val c = List(5,6,7)
l should not be equals(List(1,2,3,4))
(l :+ 8) should be equals (CustomVector(1,2,3,4,8))
(l ++ c) should be equals (CustomVector(1,2,3,4,5,6,7))
(l.map(_.toString)) should be equals (CustomVector("1","2","3","4"))
(l.map(_.toString)) should not be equals (IndexedSeq("1","2","3","4"))
(l.map(_.toString)).getClass.getName should include("CustomVector")
(l.filter(_ > 2)) should be equals (CustomVector(3,4))
(l.filter(_ > 2)).getClass.getName should include("CustomVector")
l.reduce(_ + _) should equal(10)
}
test("MySeq test") {
val cs = MySeq("1", "2", "3")
info(cs.toString)
cs should be equals (MySeq("1","2","3"))
val scs = cs.map(_.toInt)
info(scs.toString)
scs should be equals (MySeq(1,2,3))
scs.getClass.getName should include("MySeq")
}
test("NamedSeq test") {
val cs = NamedSeq("toto", "1", "2", "3")
info(cs.toString)
cs should be equals (NamedSeq("toto", "1","2","3"))
val scs = cs.map(_.toInt)
info(scs.toString)
scs should be equals (NamedSeq("toto", 1,2,3))
scs should not be equals (NamedSeq("tata", 1,2,3))
scs.getClass.getName should include("NamedSeq")
}
test("NamedSeq && MySeq combined test") {
val cs = MySeq(5,6,7,8)
val scs = cs.filter(_ > 6)
val ncs = NamedSeq("myseq", 1,2,3,4)
val nscs = ncs.filter(_ > 2)
(nscs :+ 10) should be equals(NamedSeq("myseq", 3,4,10))
(nscs ++ scs) should be equals(NamedSeq("myseq", 3,4,7,8))
(nscs ++ scs) should not be equals(NamedSeq("trucmuche", 3,4,7,8))
(nscs ++ scs).getClass.getName should include("NamedSeq")
(scs ++ nscs) should be equals(MySeq(7,8,3,4))
(scs ++ nscs).getClass.getName should include("MySeq")
(nscs.map(_ + 1)) should be equals(NamedSeq("myseq",4,5))
(nscs.map(_.toString)) should be equals(NamedSeq("myseq","3","4"))
(scs.map(_.toString)) should be equals(MySeq("7","8"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment