Last active December 23, 2019 17:01
MultiSet datastructure in Scala w/ Cats (Foldable and Monad)
package com.niftysoft.gennit.util
import cats._
import cats.implicits._
import scala.annotation.tailrec
case class MultiSet[V] private (data: Map[V,Int]) {
def filter(f: V => Boolean): MultiSet[V] = MultiSet(data.filter{case(v, mul) => f(v)})
def multiplicity(elem: V): Int = data.getOrElse(elem, 0)
def contains(elem: V): Boolean = data.contains(elem)
def mult(factor: Int): MultiSet[V] =
MultiSet({case (x -> count) => (x -> count * factor)}
def addMany(elem: V, num: Int): MultiSet[V] =
data + (elem -> (multiplicity(elem) + num))
def excl(elem: V): MultiSet[V] =
if(multiplicity(elem) == 0) {
} else if (multiplicity(elem) == 1) {
data - elem
} else {
data + (elem -> (multiplicity(elem) - 1))
def exclAll(elem: V): MultiSet[V] =
data - elem
def incl(elem: V): MultiSet[V] = addMany(elem, 1)
def diff(other: Set[V]): MultiSet[V] =
def diff(other: Seq[V]): MultiSet[V] =
def diff(other: MultiSet[V]): MultiSet[V] =
MultiSet({case (v,mul) => (v, mul - other.multiplicity(v))}
.filter{case (v, mul) => mul > 0}
def sum(other: Set[V]): MultiSet[V] =
def sum(other: Seq[V]): MultiSet[V] =
def sum(other: MultiSet[V]): MultiSet[V] =
MultiSet({case (v, mul) => (v, other.multiplicity(v) + mul)} ++
( -- data.keySet)
def union(other: Seq[V]): MultiSet[V] =
def union(other: Set[V]): MultiSet[V] =
def union(other: MultiSet[V]): MultiSet[V] =
MultiSet({case (v, mul) => (v, Math.max(other.multiplicity(v), mul))} ++
( -- data.keySet))
def intersect(other: Seq[V]): MultiSet[V] =
def intersect(other: MultiSet[V]): MultiSet[V] =
MultiSet({case (v, mul) => (v, Math.min(other.multiplicity(v), mul))}
.filter{case (v, mul) => mul > 0}
def toList: List[V] = iterator.toList
def iterator: Iterator[V] = new Iterator[V] {
private[this] val keys = data.keysIterator
private[this] var curr: Option[V] = if (keys.hasNext) Some( else None
private[this] var valLeft: Int = currMult()
def hasNext: Boolean = keys.hasNext || valLeft > 0
def next(): V =
if (valLeft > 0) {
valLeft -= 1
} else {
curr = Some( // throws NoSuchElementException as needed
valLeft = currMult() - 1
private[this] def currMult(): Int = {
override def equals(o: Any): Boolean = {
o match {
case ms @ MultiSet(data) =>
case _ => false
override def toString(): String = {{case (x, count) =>
.intercalate(", ")}
.intercalate(", ")
object MultiSet {
def apply[A](): MultiSet[A] = new MultiSet(Map())
def apply[A](x: A*): MultiSet[A] = new MultiSet(x.groupBy(identity).map{case (v, s) => (v, s.length)})
def apply[A](x: Set[A]): MultiSet[A] = new MultiSet({x => (x -> 1)}.toMap)
implicit val functorForMultiset = new Functor[MultiSet] {
def map[A, B](fa: MultiSet[A])(f: A => B): MultiSet[B] =
MultiSet({case(a, mul) => (f(a), mul)})
implicit val foldableForMultiset = new Foldable[MultiSet] {
import cats._
import cats.implicits._
def foldLeft[A, B](fa: MultiSet[A], b: B)(f: (B, A) => B): B =
def foldRight[A, B](fa: MultiSet[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
implicit val monadForMultiset = new Monad[MultiSet] {
def pure[A](x: A): MultiSet[A] = MultiSet(x)
def flatMap[A, B](fa: MultiSet[A])(f: A => MultiSet[B]): MultiSet[B] =
fa.foldLeft(MultiSet[B]())((set, a) => set.union(f(a).mult(fa.multiplicity(a))))
def tailRecM[A, B](a: A)(f: A => MultiSet[Either[A,B]]): MultiSet[B] = {
var buf = MultiSet[B]()
def go(sets: List[MultiSet[Either[A,B]]]): Unit = sets match {
case set :: tail => match {
case (x -> count) :: rest => x match {
case Right(b) => buf.addMany(b, count); go(MultiSet(rest.toMap) :: tail)
case Left(a) => go(f(a) :: MultiSet(rest.toMap) :: tail)
case Nil => go(tail)
case Nil => ()
go(f(a) :: Nil)
package com.niftysoft.gennit.util
import org.scalatest._
import cats.implicits._
class MultiSetSpec extends FlatSpec with Matchers {
"MultiSet" should "work on empty set" in {
val x: MultiSet[Int] = MultiSet()
x.contains(1) shouldEqual (false)
x.iterator.toList shouldEqual (List.empty)
it should "sum arguments to apply correctly" in {
val x: MultiSet[Int] = MultiSet(1,1)
x.multiplicity(1) shouldEqual (2)
x.iterator.toList shouldEqual (List(1,1))
it should "remove things when excl is called" in {
val x: MultiSet[Int] = MultiSet().addMany(1,2)
x.excl(1).iterator.toList shouldEqual (List(1))
it should "admit multiple elements" in {
val x: MultiSet[Int] = MultiSet().addMany(2, 6)
x.multiplicity(2) shouldEqual (6)
x.contains(2) shouldEqual (true)
x.iterator.toList shouldEqual (List(2,2,2,2,2,2))
it should "implement difference" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xdiffy = x.diff(y)
val ydiffx = y.diff(x)
xdiffy.iterator.toList.sorted shouldEqual (List(0,0))
ydiffx.iterator.toList.sorted shouldEqual (List(1,1))
it should "implement sums" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xplusy = x.sum(y)
xplusy.iterator.toList.sorted shouldEqual (List(0,0,0,0,0,0,1,1,1,1,1,1,1,1))
it should "admit unions with sets" in {
val x = MultiSet().addMany(0,4)
val y = Set(1,2,3)
x.union(y).iterator.toList.sorted shouldEqual (List(0,0,0,0,1,2,3))
it should "implement unions" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xuy = x.union(y)
xuy.iterator.toList.sorted shouldEqual (List(0,0,0,0,1,1,1,1,1))
it should "implement intersections" in {
val x = MultiSet().addMany(0, 4).addMany(1,3)
val y = MultiSet().addMany(0, 2).addMany(1,5)
val xny = x.intersect(y)
xny.iterator.toList.sorted shouldEqual (List(0,0,1,1,1))
it should "be creatable from sets" in {
val x = MultiSet(Set(1,2,3))
x.iterator.toList.sorted shouldEqual (List(1,2,3))
it should "admit non-deterministic monadic computations" in {
var x = MultiSet(Set(1,2,3))
x.flatMap(x => MultiSet().addMany(x, 2)).iterator.toList.sorted shouldEqual (List(1,1,2,2,3,3))
x = MultiSet(Set(1,3,5))
x.flatMap(x => MultiSet(x, x+1)).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6))
it should "multiply existing elements when asked" in {
val x = MultiSet(Set('a','b','c')).mult(3)
x.iterator.toList.sorted shouldEqual (List('a','a','a','b','b','b','c','c','c'))
it should "allow unary multplication" in {
val x = MultiSet().addMany(1, 3)
x.flatMap(x => MultiSet(x,x)).iterator.toList.sorted shouldEqual(List(1,1,1,1,1,1))
kalexmills commented Dec 15, 2019

Note: this is not a high-performance data structure, but it should be useful on slow paths where MultiSet semantics are needed.

It's possible that the explicit Functor implementation is faster than using the auto-generated implementation based on map and pure from Monad, so I'm leaving it. Have yet to see problems with implicit resolution, though I suppose that could happen in theory.

