Skip to content

Instantly share code, notes, and snippets.

@HandcartCactus
Last active January 21, 2023 18:29
Show Gist options
  • Save HandcartCactus/327dd5a61b77ddb2a5d1bde93555670b to your computer and use it in GitHub Desktop.
Save HandcartCactus/327dd5a61b77ddb2a5d1bde93555670b to your computer and use it in GitHub Desktop.
Computes Edit Distance on Sequences, Allows For Custom Definitions Of "Equality" (Used For Experimental Process Mining)
from typing import Callable, Union, Sequence, Protocol, List, Tuple, Any
from functools import partial
from collections import namedtuple
from dataclasses import dataclass
import numpy as np
class EqualityFunction(Protocol):
"""Defines call for a custom equality function"""
def __call__(self, s1: Sequence, s2: Sequence) -> bool:
...
@dataclass
class EditDistanceResult:
"""What calling the class on two sequences returns."""
s1: Sequence
s2: Sequence
edit_distance: int
edit_matrix: np.array
source_matrix: np.array
source_trail: List[Tuple[int]]
joined_sources: List[Union[List, Any]]
class CustomEditDistance:
"""Basic implementation of edit distance plus a custom equality method"""
def __init__(self, equality_fn: Union[EqualityFunction, None] = None):
"""
Params:
-------
equality_fn: Defaults to basic equality. Can be a method that takes two sequences and returns bools.
"""
self.equality_fn = equality_fn
def _is_equal(self, a, b):
is_equal = False
if self.equality_fn is None:
is_equal = (a == b)
else:
is_equal = self.equality_fn(a, b)
return is_equal
def _initial_edit_matrix(self, len_s1: int, len_s2: int):
m = np.zeros((len_s1, len_s2), dtype=np.int32)
m[:, 0] = np.arange(len_s1)
m[0, :] = np.arange(len_s2)
return m
def _initial_source_matrix(self, len_s1: int, len_s2: int):
m = np.ones((len_s1, len_s2, 2), dtype=np.int32) * -1
return m
def _calculate_source(self, s1_idx:int, s2_idx:int, prev_s1_s2:int, prev_s1:int, prev_s2:int):
"""Tracks sources of values, so we can see what index match the algorithm decided was optimal"""
source = [-1, -1]
if prev_s1_s2 <= prev_s1 and prev_s1_s2 <= prev_s2:
source = [s1_idx-1, s2_idx-1]
elif prev_s1 <= prev_s1_s2 and prev_s1 <= prev_s2:
source = [s1_idx-1, s2_idx]
elif prev_s2 <= prev_s1_s2 and prev_s2 <= prev_s1:
source = [s1_idx, s2_idx-1]
return source
def _process_sequences(self, s1: Sequence, s2: Sequence):
len_s1, len_s2 = len(s1), len(s2)
edit_matrix = self._initial_edit_matrix(len_s1=len_s1, len_s2=len_s2)
source_matrix = self._initial_source_matrix(len_s1=len_s1, len_s2=len_s2)
for s1_idx in range(1, len_s1):
for s2_idx in range(1, len_s2):
is_equal = self._is_equal(s1[s1_idx], s2[s2_idx])
# Possible changes the DP EditDistance Algo tracks
# Matched indices. Combines "no change" and "element was replaced".
prev_s1_s2 = edit_matrix[s1_idx-1, s2_idx-1] + (0 if is_equal else 1)
# Element inserted in s2.
prev_s1 = edit_matrix[s1_idx-1, s2_idx] + 1
# Element inserted in s1.
prev_s2 = edit_matrix[s1_idx, s2_idx-1] + 1
edit_matrix[s1_idx, s2_idx] = min(prev_s1_s2, prev_s1, prev_s2)
source_matrix[s1_idx, s2_idx] = self._calculate_source(s1_idx, s2_idx, prev_s1_s2, prev_s1, prev_s2)
Matrices = namedtuple('Results', ('edit_matrix', 'source_matrix'))
return Matrices(edit_matrix, source_matrix)
def _source_trail_from_matrix(self, source_matrix:np.array):
"""What element indices from each sequence were matched up together. Compare/reconstruct what the DP algorithm 'saw'."""
i, j, _ = source_matrix.shape
source_trail = [(i-1, j-1)]
while source_trail[-1] != (-1, -1):
i, j = source_trail[-1]
new_source = tuple(source_matrix[i,j])
source_trail.append(new_source)
return source_trail[-2::-1]
def _source_comparison(self, s1: Sequence, s2: Sequence, source_trail: List[Tuple[int]]):
"""Grab the elements of the sequences that were matched up together. Insert None if an element was inserted in the other sequence."""
s1_idx_prev, s2_idx_prev = None, None
joined_sources = []
for s1_idx, s2_idx in source_trail:
new_link = [None, None]
if s1_idx != s1_idx_prev:
new_link[0] = s1[s1_idx]
if s2_idx != s2_idx_prev:
new_link[1] = s2[s2_idx]
if all(el is not None for el in new_link) and self._is_equal(new_link[0], new_link[1]):
new_link = new_link[0]
joined_sources.append(new_link)
s1_idx_prev, s2_idx_prev = s1_idx, s2_idx
return joined_sources
def __call__(self, s1: Sequence, s2: Sequence) -> EditDistanceResult:
"""Return an EditDistanceResult with the edit_distance and relevant artefacts"""
assert len(s1) > 0
assert len(s2) > 0
edit_matrix, source_matrix = self._process_sequences(s1, s2)
source_trail = self._source_trail_from_matrix(source_matrix)
joined_sources = self._source_comparison(s1=s1, s2=s2, source_trail=source_trail)
i, j = edit_matrix.shape
edit_distance = edit_matrix[i-1, j-1]
return EditDistanceResult(
s1=s1, s2=s2, edit_distance=edit_distance, edit_matrix=edit_matrix,
source_matrix=source_matrix, source_trail=source_trail, joined_sources=joined_sources
)
def edit_distance_less_than(s1: str, s2: str, thresh: int):
return CustomEditDistance()(s1, s2).edit_distance < thresh
class ApproxStringSeqEditDistance(CustomEditDistance):
def __init__(self, thresh: int = 0):
self.equality_fn = partial(edit_distance_less_than, thresh=thresh) if thresh > 0 else None
def __call__(self, s1: Sequence[str], s2: Sequence[str]):
return super().__call__(s1, s2)
def test_approx_str_edit_distance():
orig = ['hello', 'it', 'is', 'nice', 'to', 'meet', 'you']
one_typo_one_diff = ['hwllo', 'it', 'is', 'nice', 'to', 'meet', 'them']
return ApproxStringSeqEditDistance(3)(orig, one_typo_one_diff).edit_distance == 1
def joined_source_threshholded_equality(s1:Union[List[str], str], s2:str, thresh:int):
if isinstance(s1, list):
has_equal = any(edit_distance_less_than(s1=s1_str, s2=s2, thresh=thresh) for s1_str in s1 if s1_str is not None)
else:
has_equal = edit_distance_less_than(s1=s1, s2=s2, thresh=thresh)
return has_equal
class JoinedSourceMerger(CustomEditDistance):
def __init__(self, thresh: int = 0):
self.equality_fn = partial(joined_source_threshholded_equality, thresh=thresh) if thresh > 0 else None
def __call__(self, s1: List[Union[List[str], str]], s2: List[str]):
return super().__call__(s1, s2).joined_sources
@HandcartCactus
Copy link
Author

from custom_edit_dist import CustomEditDistance

def equal_if_casefold(s1, s2) -> bool:
    return s1.casefold() == s2.casefold()

uncased_edit_distance = CustomEditDistance(equal_if_casefold)
assert uncased_edit_distance("HELLO", "hello.").edit_distance == 1
from custom_edit_dist import ApproxStringSeqEditDistance

orig = ['hello', 'it', 'is', 'nice', 'to', 'meet', 'you']
one_typo_one_diff = ['hwllo', 'it', 'is', 'nice', 'to', 'meet', 'them']

edit_distance_strseq_typo = ApproxStringSeqEditDistance(3)
assert edit_distance_strseq_typo(orig, one_typo_one_diff).edit_distance == 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment