Skip to content

Instantly share code, notes, and snippets.

@kalexmills
Last active December 23, 2019 17:01
Show Gist options
  • Save kalexmills/4e2724ffa5ba7d1f90fdf7f0de9242e0 to your computer and use it in GitHub Desktop.
Save kalexmills/4e2724ffa5ba7d1f90fdf7f0de9242e0 to your computer and use it in GitHub Desktop.
MultiSet datastructure in Scala w/ Cats (Foldable and Monad)
package com.niftysoft.gennit.util
import cats._
import cats.implicits._
import scala.annotation.tailrec
case class MultiSet[V] private (data: Map[V,Int]) {
def filter(f: V => Boolean): MultiSet[V] = MultiSet(data.filter{case(v, mul) => f(v)})
def multiplicity(elem: V): Int = data.getOrElse(elem, 0)
def contains(elem: V): Boolean = data.contains(elem)
def mult(factor: Int): MultiSet[V] =
MultiSet(
data.map{case (x -> count) => (x -> count * factor)}
)
def addMany(elem: V, num: Int): MultiSet[V] =
MultiSet(
data + (elem -> (multiplicity(elem) + num))
)
def excl(elem: V): MultiSet[V] =
MultiSet(
if(multiplicity(elem) == 0) {
data
} else if (multiplicity(elem) == 1) {
data - elem
} else {
data + (elem -> (multiplicity(elem) - 1))
}
)
def exclAll(elem: V): MultiSet[V] =
MultiSet(
data - elem
)
def incl(elem: V): MultiSet[V] = addMany(elem, 1)
def diff(other: Set[V]): MultiSet[V] =
diff(MultiSet(other))
def diff(other: Seq[V]): MultiSet[V] =
diff(MultiSet(other:_*))
def diff(other: MultiSet[V]): MultiSet[V] =
MultiSet(
data.map{case (v,mul) => (v, mul - other.multiplicity(v))}
.filter{case (v, mul) => mul > 0}
)
def sum(other: Set[V]): MultiSet[V] =
sum(MultiSet(other))
def sum(other: Seq[V]): MultiSet[V] =
sum(MultiSet(other:_*))
def sum(other: MultiSet[V]): MultiSet[V] =
MultiSet(
data.map{case (v, mul) => (v, other.multiplicity(v) + mul)} ++
(other.data -- data.keySet)
)
def union(other: Seq[V]): MultiSet[V] =
union(MultiSet(other:_*))
def union(other: Set[V]): MultiSet[V] =
union(MultiSet(other))
def union(other: MultiSet[V]): MultiSet[V] =
MultiSet(
data.map{case (v, mul) => (v, Math.max(other.multiplicity(v), mul))} ++
(other.data -- data.keySet))
def intersect(other: Seq[V]): MultiSet[V] =
intersect(MultiSet(other:_*))
def intersect(other: MultiSet[V]): MultiSet[V] =
MultiSet(
data.map{case (v, mul) => (v, Math.min(other.multiplicity(v), mul))}
.filter{case (v, mul) => mul > 0}
)
def toList: List[V] = iterator.toList
def iterator: Iterator[V] = new Iterator[V] {
private[this] val keys = data.keysIterator
private[this] var curr: Option[V] = if (keys.hasNext) Some(keys.next()) else None
private[this] var valLeft: Int = currMult()
def hasNext: Boolean = keys.hasNext || valLeft > 0
def next(): V =
if (valLeft > 0) {
valLeft -= 1
curr.get
} else {
curr = Some(keys.next()) // throws NoSuchElementException as needed
valLeft = currMult() - 1
curr.get
}
private[this] def currMult(): Int = {
curr.map(data(_)).getOrElse(0)
}
}
override def equals(o: Any): Boolean = {
o match {
case ms @ MultiSet(data) => this.data.equals(data)
case _ => false
}
}
override def toString(): String = {
data.toList.map{case (x, count) =>
List(x.toString)
.replicateA(count)
.flatten
.intercalate(", ")}
.intercalate(", ")
}
}
object MultiSet {
def apply[A](): MultiSet[A] = new MultiSet(Map())
def apply[A](x: A*): MultiSet[A] = new MultiSet(x.groupBy(identity).map{case (v, s) => (v, s.length)})
def apply[A](x: Set[A]): MultiSet[A] = new MultiSet(x.map{x => (x -> 1)}.toMap)
implicit val functorForMultiset = new Functor[MultiSet] {
def map[A, B](fa: MultiSet[A])(f: A => B): MultiSet[B] =
MultiSet(fa.data.map{case(a, mul) => (f(a), mul)})
}
implicit val foldableForMultiset = new Foldable[MultiSet] {
import cats._
import cats.implicits._
def foldLeft[A, B](fa: MultiSet[A], b: B)(f: (B, A) => B): B =
fa.iterator.foldLeft(b)(f)
def foldRight[A, B](fa: MultiSet[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
fa.iterator.toList.foldRight(lb)(f)
}
implicit val monadForMultiset = new Monad[MultiSet] {
def pure[A](x: A): MultiSet[A] = MultiSet(x)
def flatMap[A, B](fa: MultiSet[A])(f: A => MultiSet[B]): MultiSet[B] =
fa.foldLeft(MultiSet[B]())((set, a) => set.union(f(a).mult(fa.multiplicity(a))))
def tailRecM[A, B](a: A)(f: A => MultiSet[Either[A,B]]): MultiSet[B] = {
var buf = MultiSet[B]()
@tailrec
def go(sets: List[MultiSet[Either[A,B]]]): Unit = sets match {
case set :: tail => set.data.toList match {
case (x -> count) :: rest => x match {
case Right(b) => buf.addMany(b, count); go(MultiSet(rest.toMap) :: tail)
case Left(a) => go(f(a) :: MultiSet(rest.toMap) :: tail)
}
case Nil => go(tail)
}
case Nil => ()
}
go(f(a) :: Nil)
buf
}
}
}
package com.niftysoft.gennit.util
import org.scalatest._
import cats.implicits._
class MultiSetSpec extends FlatSpec with Matchers {
"MultiSet" should "work on empty set" in {
val x: MultiSet[Int] = MultiSet()
x.contains(1) shouldEqual (false)
x.iterator.toList shouldEqual (List.empty)
}
it should "sum arguments to apply correctly" in {
val x: MultiSet[Int] = MultiSet(1,1)
x.multiplicity(1) shouldEqual (2)
x.iterator.toList shouldEqual (List(1,1))
}
it should "remove things when excl is called" in {
val x: MultiSet[Int] = MultiSet().addMany(1,2)
x.excl(1).iterator.toList shouldEqual (List(1))
}
it should "admit multiple elements" in {
val x: MultiSet[Int] = MultiSet().addMany(2, 6)
x.multiplicity(2) shouldEqual (6)
x.contains(2) shouldEqual (true)
x.iterator.toList shouldEqual (List(2,2,2,2,2,2))
}
it should "implement difference" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xdiffy = x.diff(y)
val ydiffx = y.diff(x)
xdiffy.iterator.toList.sorted shouldEqual (List(0,0))
ydiffx.iterator.toList.sorted shouldEqual (List(1,1))
}
it should "implement sums" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xplusy = x.sum(y)
xplusy.iterator.toList.sorted shouldEqual (List(0,0,0,0,0,0,1,1,1,1,1,1,1,1))
}
it should "admit unions with sets" in {
val x = MultiSet().addMany(0,4)
val y = Set(1,2,3)
x.union(y).iterator.toList.sorted shouldEqual (List(0,0,0,0,1,2,3))
}
it should "implement unions" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xuy = x.union(y)
xuy.iterator.toList.sorted shouldEqual (List(0,0,0,0,1,1,1,1,1))
}
it should "implement intersections" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xny = x.intersect(y)
xny.iterator.toList.sorted shouldEqual (List(0,0,1,1,1))
}
it should "be creatable from sets" in {
val x = MultiSet(Set(1,2,3))
x.iterator.toList.sorted shouldEqual (List(1,2,3))
}
it should "admit non-deterministic monadic computations" in {
var x = MultiSet(Set(1,2,3))
x.flatMap(x => MultiSet().addMany(x, 2)).iterator.toList.sorted shouldEqual (List(1,1,2,2,3,3))
x = MultiSet(Set(1,3,5))
x.flatMap(x => MultiSet(x, x+1)).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6))
}
it should "multiply existing elements when asked" in {
val x = MultiSet(Set('a','b','c')).mult(3)
x.iterator.toList.sorted shouldEqual (List('a','a','a','b','b','b','c','c','c'))
}
it should "allow unary multplication" in {
val x = MultiSet().addMany(1, 3)
x.flatMap(x => MultiSet(x,x)).iterator.toList.sorted shouldEqual(List(1,1,1,1,1,1))
}
}
@kalexmills
Copy link
Author

kalexmills commented Dec 15, 2019

Note: this is not a high-performance data structure, but it should be useful on slow paths where MultiSet semantics are needed.

It's possible that the explicit Functor implementation is faster than using the auto-generated implementation based on map and pure from Monad, so I'm leaving it. Have yet to see problems with implicit resolution, though I suppose that could happen in theory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment