Skip to content

Instantly share code, notes, and snippets.

@steveash
Last active May 2, 2022 15:16
Show Gist options
  • Save steveash/5426191 to your computer and use it in GitHub Desktop.
Save steveash/5426191 to your computer and use it in GitHub Desktop.
This is an implementation of Optimal String Alignment in Java with some tricks and optimizations. OSA is similar to Damerau–Levenshtein edit distance in that insertions, deletions, substitutions, and transpositions of adjacent are all treated as one edit operation. This OSA implementation also takes a "max edit" threshold, which allows you to sk…
package com.github.steveash.util;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.primitives.Shorts.checkedCast;
import static java.lang.Math.abs;
import static java.lang.Math.max;
import java.util.Arrays;
import com.google.common.annotations.VisibleForTesting;
/**
* Implementation of the OSA which is similar to the Damerau-Levenshtein in that it allows for transpositions to
* count as a single edit distance, but is not a true metric and can over-estimate the cost because it disallows
* substrings to edited more than once. See wikipedia for more discussion on OSA vs DL
* <p/>
* See Algorithms on Strings, Trees and Sequences by Dan Gusfield for more information.
* <p/>
* This also has a set of local buffer implementations to avoid allocating new buffers each time, which might be
* a premature optimization
* <p/>
* @author Steve Ash
*/
public class OptimalStringAlignment {
private static final int threadLocalBufferSize = 64;
private static final ThreadLocal<short[]> costLocal = new ThreadLocal<short[]>() {
@Override
protected short[] initialValue() {
return new short[threadLocalBufferSize];
}
};
private static final ThreadLocal<short[]> back1Local = new ThreadLocal<short[]>() {
@Override
protected short[] initialValue() {
return new short[threadLocalBufferSize];
}
};
private static final ThreadLocal<short[]> back2Local = new ThreadLocal<short[]>() {
@Override
protected short[] initialValue() {
return new short[threadLocalBufferSize];
}
};
public static int editDistance(CharSequence s, CharSequence t, int threshold) {
checkNotNull(s, "cannot measure null strings");
checkNotNull(t, "cannot measure null strings");
checkArgument(threshold >= 0, "Threshold must not be negative");
checkArgument(s.length() < Short.MAX_VALUE, "Cannot take edit distance of strings longer than 32k chars");
checkArgument(t.length() < Short.MAX_VALUE, "Cannot take edit distance of strings longer than 32k chars");
if (s.length() + 1 > threadLocalBufferSize || t.length() + 1 > threadLocalBufferSize)
return editDistanceWithNewBuffers(s, t, checkedCast(threshold));
short[] cost = costLocal.get();
short[] back1 = back1Local.get();
short[] back2 = back2Local.get();
return editDistanceWithBuffers(s, t, checkedCast(threshold), back2, back1, cost);
}
@VisibleForTesting
static int editDistanceWithNewBuffers(CharSequence s, CharSequence t, short threshold) {
int slen = s.length();
short[] back1 = new short[slen + 1]; // "up 1" row in table
short[] back2 = new short[slen + 1]; // "up 2" row in table
short[] cost = new short[slen + 1]; // "current cost"
return editDistanceWithBuffers(s, t, threshold, back2, back1, cost);
}
private static int editDistanceWithBuffers(CharSequence s, CharSequence t, short threshold,
short[] back2, short[] back1, short[] cost) {
short slen = (short) s.length();
short tlen = (short) t.length();
// if one string is empty, the edit distance is necessarily the length of the other
if (slen == 0) {
return tlen <= threshold ? tlen : -1;
} else if (tlen == 0) {
return slen <= threshold ? slen : -1;
}
// if lengths are different > k, then can't be within edit distance
if (abs(slen - tlen) > threshold)
return -1;
if (slen > tlen) {
// swap the two strings to consume less memory
CharSequence tmp = s;
s = t;
t = tmp;
slen = tlen;
tlen = (short) t.length();
}
initMemoiseTables(threshold, back2, back1, cost, slen);
for (short j = 1; j <= tlen; j++) {
cost[0] = j; // j is the cost of inserting this many characters
// stripe bounds
int min = max(1, j - threshold);
int max = min(slen, (short) (j + threshold));
// at this iteration the left most entry is "too much" so reset it
if (min > 1) {
cost[min - 1] = Short.MAX_VALUE;
}
iterateOverStripe(s, t, j, cost, back1, back2, min, max);
// swap our cost arrays to move on to the next "row"
short[] tempCost = back2;
back2 = back1;
back1 = cost;
cost = tempCost;
}
// after exit, the current cost is in back1
// if back1[slen] > k then we exceeded, so return -1
if (back1[slen] > threshold) {
return -1;
}
return back1[slen];
}
private static void iterateOverStripe(CharSequence s, CharSequence t, short j,
short[] cost, short[] back1, short[] back2, int min, int max) {
// iterates over the stripe
for (int i = min; i <= max; i++) {
if (s.charAt(i - 1) == t.charAt(j - 1)) {
cost[i] = back1[i - 1];
} else {
cost[i] = (short) (1 + min(cost[i - 1], back1[i], back1[i - 1]));
}
if (i >= 2 && j >= 2) {
// possible transposition to check for
if ((s.charAt(i - 2) == t.charAt(j - 1)) &&
s.charAt(i - 1) == t.charAt(j - 2)) {
cost[i] = min(cost[i], (short) (back2[i - 2] + 1));
}
}
}
}
private static void initMemoiseTables(short threshold, short[] back2, short[] back1, short[] cost, short slen) {
// initial "starting" values for inserting all the letters
short boundary = (short) (min(slen, threshold) + 1);
for (short i = 0; i < boundary; i++) {
back1[i] = i;
back2[i] = i;
}
// need to make sure that we don't read a default value when looking "up"
Arrays.fill(back1, boundary, slen + 1, Short.MAX_VALUE);
Arrays.fill(back2, boundary, slen + 1, Short.MAX_VALUE);
Arrays.fill(cost, 0, slen + 1, Short.MAX_VALUE);
}
private static short min(short a, short b) {
return (a <= b ? a : b);
}
private static short min(short a, short b, short c) {
return min(a, min(b, c));
}
}
import org.junit.Test
import static com.github.steveash.util.OptimalStringAlignment.editDistance
/**
* @author Steve Ash
*/
class OptimalStringAlignmentTest {
@Test
public void shouldBeZeroForEqualStrings() throws Exception {
assert 0 == editDistance("steve", "steve", 1)
assert 0 == editDistance("steve", "steve", 0)
assert 0 == editDistance("steve", "steve", 2)
assert 0 == editDistance("steve", "steve", 100)
assert 0 == editDistance("s", "s", 1)
assert 0 == editDistance("s", "s", 0)
assert 0 == editDistance("s", "s", 2)
assert 0 == editDistance("s", "s", 100)
assert 0 == editDistance("", "", 0)
assert 0 == editDistance("", "", 1)
assert 0 == editDistance("", "", 100)
}
@Test
public void shouldBeOneForSingleOperation() throws Exception {
def a = "steve";
for (int i = 0; i < 5; i++) {
assertOneOp(new StringBuilder(a).insert(i, 'f'), a)
assertOneOp(new StringBuilder(a).deleteCharAt(i), a)
def sb = new StringBuilder(a)
sb.setCharAt(i, 'x' as char);
assertOneOp(sb, a)
if (i > 1) {
sb = new StringBuilder(a)
char t = sb.charAt(i - 1)
sb.setCharAt(i - 1, sb.charAt(i))
sb.setCharAt(i, t)
println "comparing " + sb.toString() + " -> " + a
assertOneOp(sb, a)
}
}
}
@Test
public void shouldCountTransposeAsOne() throws Exception {
assert 3 == editDistance("xxsteve", "steev", 4)
assert 3 == editDistance("xxsteve", "steev", 3)
assert 3 == editDistance("steev", "xxsteve", 4)
assert 3 == editDistance("steev", "xxsteve", 3)
assert -1 == editDistance("steev", "xxsteve", 2)
assert 4 == editDistance("xxtseve", "steev", 4)
assert 5 == editDistance("xxtsevezx", "steevxz", 5)
assert 6 == editDistance("xxtsevezx", "steevxzpp", 6)
assert 7 == editDistance("xxtsfevezx", "steevxzpp", 7)
assert 4 == editDistance("xxtsf", "st", 7)
assert 4 == editDistance("evezx", "eevxzpp", 7)
assert 7 == editDistance("xxtsfevezx", "steevxzpp", 7)
}
@Test
public void shouldCountLeadingCharacterTranspositionsAsOne() throws Exception {
assert 1 == editDistance("rosa", "orsa", 2)
}
private void assertOneOp(CharSequence a, CharSequence b) {
assert 1 == editDistance(a, b, 1)
assert 1 == editDistance(b, a, 1)
assert 1 == editDistance(a, b, 2)
assert 1 == editDistance(b, a, 2)
}
@Test
public void shouldShortCutWhenSpecialCase() throws Exception {
assert 1 == editDistance("s", "", 1)
assert 1 == editDistance("", "s", 1)
assert -1 == editDistance("s", "", 0)
assert -1 == editDistance("", "s", 0)
assert -1 == editDistance("st", "", 1)
assert -1 == editDistance("", "st", 1)
assert -1 == editDistance("steve", "ste", 0)
assert -1 == editDistance("ste", "steve", 0)
assert -1 == editDistance("stev", "steve", 0)
assert -1 == editDistance("ste", "steve", 1)
assert -1 == editDistance("steve", "ste", 1)
assert 1 == editDistance("steve", "stev", 1)
assert 1 == editDistance("stev", "steve", 1)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment