Skip to content

Instantly share code, notes, and snippets.

@jchapuis
Created May 25, 2021 12:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jchapuis/8d6f60476ebc2b6813ab28daa9216d98 to your computer and use it in GitHub Desktop.
Save jchapuis/8d6f60476ebc2b6813ab28daa9216d98 to your computer and use it in GitHub Desktop.
Generic LCS diffing in Scala (dynamic programming with memoization)
import cats.{ Eq, Show }
import cats.syntax.eq._
import cats.syntax.show._
import cats.instances.int._
import FunctionHelpers._
object Differ {
sealed trait Diff[T]
object Diff {
final case class Insert[T](revision: T) extends Diff[T]
final case class Delete[T](baseline: T) extends Diff[T]
final case class Keep[T](baseline: T, revision: T) extends Diff[T]
implicit def show[T](implicit elementShow: Show[T]): Show[Diff[T]] =
Show.show {
case Diff.Insert(revision) => show"+$revision"
case Diff.Delete(baseline) => show"-$baseline"
case Diff.Keep(_, _) => show"_"
}
}
implicit class RichSeq[T](baseline: Seq[T]) {
def diffWith(revision: Seq[T])(implicit eq: Eq[T]): Seq[Diff[T]] = {
val matrix = computeMatrix(baseline, revision)
generateDiff(matrix, baseline, revision)
}
private type LCSMatrix = (Int, Int) => Int
// baseline is on vertical axis, revision on the horizontal
private def computeMatrix(baseline: Seq[T], revision: Seq[T])(implicit eq: Eq[T]): LCSMatrix = {
lazy val lengthMatrix: LCSMatrix = (lengthFor _).memoize
def lengthFor(baselineIndex: Int, revisionIndex: Int): Int =
if (baselineIndex === 0 || revisionIndex === 0)
0
else if (baseline(baselineIndex - 1) === revision(revisionIndex - 1))
lengthMatrix(baselineIndex - 1, revisionIndex - 1) + 1
else {
val candidateUp = lengthMatrix(baselineIndex - 1, revisionIndex)
val candidateLeft = lengthMatrix(baselineIndex, revisionIndex - 1)
Math.max(candidateUp, candidateLeft)
}
lengthMatrix
}
private def generateDiff(matrix: LCSMatrix, baseline: Seq[T], revision: Seq[T])(implicit
eq: Eq[T]
): Seq[Diff[T]] = {
def generateDiff(baselineIndex: Int, revisionIndex: Int, acc: Seq[Diff[T]]): Seq[Diff[T]] = {
lazy val currentRow = baselineIndex
lazy val currentColumn = revisionIndex
lazy val canMoveUp = currentRow > 0
lazy val canMoveLeft = currentColumn > 0
lazy val rowAbove = currentRow - 1
lazy val columnLeft = currentColumn - 1
lazy val cellAbove = matrix(rowAbove, currentColumn)
lazy val cellLeft = matrix(currentRow, columnLeft)
if (canMoveUp && canMoveLeft && baseline(rowAbove) === revision(columnLeft))
generateDiff(rowAbove, columnLeft, acc :+ Diff.Keep(baseline(rowAbove), revision(columnLeft)))
else if (canMoveLeft && (!canMoveUp || cellLeft >= cellAbove))
generateDiff(currentRow, columnLeft, acc :+ Diff.Insert(revision(columnLeft)))
else if (canMoveUp && (!canMoveLeft || cellLeft < cellAbove))
generateDiff(rowAbove, currentColumn, acc :+ Diff.Delete(baseline(rowAbove)))
else acc
}
generateDiff(baseline.size, revision.size, acc = Nil).reverse
}
}
}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import Differ._
import cats.instances.char._
import cats.syntax.show._
class DifferSpec extends AnyWordSpec with Matchers {
"Differ" should {
"produce diff for simple example" in new SimpleFixture {
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff
}
"produce optimal diff for dna sequence" in new AdnFixture {
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff
}
"reproduce wikipedia example correctly" in new WikipediaFixture {
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff
}
}
trait SimpleFixture {
val baseline = "A B C D".filterNot(_.isWhitespace)
val revision = "A D".filterNot(_.isWhitespace)
val diff = "_-B-C_"
}
trait AdnFixture {
val baseline = "c t c a t g g a g c".filterNot(_.isWhitespace)
val revision = "t c a a t g g a".filterNot(_.isWhitespace)
val diff = "-c__+a_____-g-c"
}
trait WikipediaFixture {
val baseline = "X M J Y A U Z".filterNot(_.isWhitespace)
val revision = "M Z J A W X U".filterNot(_.isWhitespace)
val diff = "-X_+Z_-Y_+W+X_-Z"
}
}
object FunctionHelpers {
@SuppressWarnings(Array("org.wartremover.warts.MutableDataStructures"))
implicit class RichFunction2[A1, A2, R](f: (A1, A2) => R) {
def memoize: (A1, A2) => R =
new ((A1, A2) => R) {
private val cache = scala.collection.mutable.Map[(A1, A2), R]()
override def apply(a1: A1, a2: A2): R = cache.getOrElseUpdate((a1, a2), f(a1, a2))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment