Skip to content

Instantly share code, notes, and snippets.

@ammojamo
Created March 24, 2021 12:23
Show Gist options
  • Save ammojamo/7ed456d26c5128ee5e7c3e0b3a15c329 to your computer and use it in GitHub Desktop.
Save ammojamo/7ed456d26c5128ee5e7c3e0b3a15c329 to your computer and use it in GitHub Desktop.
Diff algorithms in Kotlin
import java.util.*
import kotlin.collections.ArrayList
import kotlin.collections.HashMap
abstract class Diff {
sealed class Op<T> {
data class Insert<T>(val index: Int, val value: T) : Op<T>()
data class Delete<T>(val index: Int) : Op<T>()
}
abstract fun <T> diff(a: List<T>, b: List<T>): List<Op<T>>
}
object MyerDiff: Diff() {
override fun <T> diff(a: List<T>, b: List<T>): List<Op<T>> {
return diff(a, b, 0, 0, a.size, b.size)
}
fun <T> diff(a: List<T>, b: List<T>, a0: Int, b0: Int, n: Int, m: Int): List<Op<T>> {
if (n == 0 && m == 0) {
return emptyList()
} else if (m == 0 && n > 0) {
return (a0 until a0 + n).map { i -> Op.Delete<T>(i) }
} else if (n == 0 && m > 0) {
return b.slice(b0 until b0 + m).mapIndexed { i, x -> Op.Insert(b0 + i, x) }
}
val max = Math.max(n, m) + 2
val delta = n - m
val deltaEven = delta % 2 == 0
var ko = 0 // Overlapping diagonal
var dd = -1
val vf = Array<Int>(max * 2 + 1) { 0 } // Forward
val vr = Array<Int>(max * 2 + 1) { 0 } // Reverse
search@ for (d in 0..max) {
for (dir in 0..1) { // dir: 0 -> forward, 1 -> backward
val v = if (dir == 0) { vf } else { vr }
val vo = if (dir == 0) { vr } else { vf }
for (k in -d..d step 2) {
val x0 = if (k == -d || (k != d && v[max + k - 1] < v[max + k + 1])) {
v[max + k + 1]
} else {
v[max + k - 1] + 1
}
val y0 = x0 - k
var x = x0
var y = y0
if (dir == 0) {
while (x < n && y < m && a[a0 + x] == b[b0 + y]) {
x++
y++
}
} else {
while (x < n && y < m && a[a0 + n - x - 1] == b[b0 + m - y - 1]) {
x++
y++
}
}
v[max + k] = x
if (dir == 0) {
if (!deltaEven && k >= delta - (d - 1) && k <= delta + (d - 1)) {
// Check for overlap
if (vo[max - k + delta] >= n - x0) {
ko = k
dd = d * 2 - 1
break@search
}
}
} else {
if (deltaEven && k >= -d && k <= d) {
if (vo[max - k + delta] >= n - x0) {
ko = -k + delta
dd = d * 2
break@search
}
}
}
}
}
}
val u = vf[max + ko]
val v = u - ko
val x = n - vr[max -ko + delta]
val y = x - ko
// At this point we have found an overlapping snake of path with length dd
// The snake goes from (x, y) -> (u, v)
return when {
dd == 0 -> {
return emptyList() // Identical lists
}
dd > 1 -> {
return diff(a, b, a0, b0, x, y) + diff(a, b, a0 + u, b0 + v, n - u, m - v)
}
dd == 1 -> {
if(m > n) {
return listOf(Op.Insert(b0 + y - 1, b[b0 + y - 1]))
} else {
return listOf(Op.Delete(a0 + x - 1))
}
}
else -> {
error("failed assertion: dd < 0")
}
}
}
}
object ReplaceDiff: Diff() {
override fun <T> diff(a: List<T>, b: List<T>): List<Op<T>> {
return diff(a, b, 0, 0, a.size, b.size)
}
private fun <T> diff(a: List<T>, b: List<T>, a0: Int, b0: Int, n: Int, m: Int): List<Op<T>> {
return (a0 until a0 + n).map { i -> Op.Delete<T>(i) } +
b.slice(b0 until b0 + m).mapIndexed { i, x -> Op.Insert(b0 + i, x) }
}
}
object PatienceDiff: Diff() {
override fun <T> diff(a: List<T>, b: List<T>): List<Op<T>> {
return diff(a, b, 0, 0, a.size, b.size)
}
private fun <T> diff(a: List<T>, b: List<T>, a0: Int, b0: Int, n: Int, m: Int): List<Op<T>> {
if (n == 0 && m == 0) {
return emptyList()
} else if (m == 0 && n > 0) {
return (a0 until a0 + n).map { i -> Op.Delete<T>(i) }
} else if (n == 0 && m > 0) {
return b.slice(b0 until b0 + m).mapIndexed { i, x -> Op.Insert(b0 + i, x) }
}
// Common prefix
run {
var x = a0
var y = b0
while (x < n && y < m && a[x] == b[y]) {
x++
y++
}
if (x > a0) {
return diff(a, b, x, y, a0 + n - x, b0 + m - y)
}
}
// Common suffix
run {
var x = a0 + n - 1
var y = b0 + m - 1
while (x >= a0 && y >= b0 && a[x] == b[y]) {
x--
y--
}
if (x < a0 + n - 1) {
return diff(a, b, a0, b0, x - a0 + 1, y - b0 + 1)
}
}
// Find items that appear only once in both lists, and their positions in each list
data class Item(var x: Int, var y: Int, var ca: Int, var cb: Int, var prev: Item? = null)
val items = HashMap<T, Item>()
for(x in a0 until a0 + n) {
val v = a[x]
val i = items.get(v)
if (i == null) {
items.put(v, Item(x, 0, 1, 0))
} else {
i.ca++
}
}
for(y in b0 until b0 + m) {
val v = b[y]
val i = items.get(v)
if (i != null) {
i.y = y
i.cb++
}
}
val uniques = items.values.filter { it.ca == 1 && it.cb == 1 }.sortedBy { it.y }
if(uniques.size == 0) {
// No common unique lines: fall back to Myer diff
return MyerDiff.diff(a, b, a0, b0, n, m)
}
// Now patience sort uniques to get LCS
val stacks = ArrayList<Stack<Item>>()
for(item in uniques) {
var si = 0
for(stack in stacks) {
if(stack.peek().x > item.x) {
stack.push(item)
break
}
si++
}
if(si == stacks.size) {
stacks.add(Stack<Item>().apply { push(item) })
}
if(si > 0) {
item.prev = stacks[si - 1].peek()
}
}
val lcs = ArrayList<Item>()
run {
var item: Item? = stacks.last().peek()
while (item != null) {
lcs.add(item)
item = item.prev
}
}
lcs.reverse()
// Now we have the lcs, perform diff on all the chunks between those items
var x = 0
var y = 0
val result = ArrayList<Op<T>>()
for(item in lcs) {
result.addAll(diff(a, b, x, y, item.x - x, item.y - y))
x = item.x + 1
y = item.y + 1
}
result.addAll(diff(a, b, x, y, a0 + n - x, b0 + m - y))
return result
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment