Last active
January 21, 2023 18:29
-
-
Save HandcartCactus/327dd5a61b77ddb2a5d1bde93555670b to your computer and use it in GitHub Desktop.
Computes Edit Distance on Sequences, Allows For Custom Definitions Of "Equality" (Used For Experimental Process Mining)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Union, Sequence, Protocol, List, Tuple, Any | |
from functools import partial | |
from collections import namedtuple | |
from dataclasses import dataclass | |
import numpy as np | |
class EqualityFunction(Protocol): | |
"""Defines call for a custom equality function""" | |
def __call__(self, s1: Sequence, s2: Sequence) -> bool: | |
... | |
@dataclass | |
class EditDistanceResult: | |
"""What calling the class on two sequences returns.""" | |
s1: Sequence | |
s2: Sequence | |
edit_distance: int | |
edit_matrix: np.array | |
source_matrix: np.array | |
source_trail: List[Tuple[int]] | |
joined_sources: List[Union[List, Any]] | |
class CustomEditDistance: | |
"""Basic implementation of edit distance plus a custom equality method""" | |
def __init__(self, equality_fn: Union[EqualityFunction, None] = None): | |
""" | |
Params: | |
------- | |
equality_fn: Defaults to basic equality. Can be a method that takes two sequences and returns bools. | |
""" | |
self.equality_fn = equality_fn | |
def _is_equal(self, a, b): | |
is_equal = False | |
if self.equality_fn is None: | |
is_equal = (a == b) | |
else: | |
is_equal = self.equality_fn(a, b) | |
return is_equal | |
def _initial_edit_matrix(self, len_s1: int, len_s2: int): | |
m = np.zeros((len_s1, len_s2), dtype=np.int32) | |
m[:, 0] = np.arange(len_s1) | |
m[0, :] = np.arange(len_s2) | |
return m | |
def _initial_source_matrix(self, len_s1: int, len_s2: int): | |
m = np.ones((len_s1, len_s2, 2), dtype=np.int32) * -1 | |
return m | |
def _calculate_source(self, s1_idx:int, s2_idx:int, prev_s1_s2:int, prev_s1:int, prev_s2:int): | |
"""Tracks sources of values, so we can see what index match the algorithm decided was optimal""" | |
source = [-1, -1] | |
if prev_s1_s2 <= prev_s1 and prev_s1_s2 <= prev_s2: | |
source = [s1_idx-1, s2_idx-1] | |
elif prev_s1 <= prev_s1_s2 and prev_s1 <= prev_s2: | |
source = [s1_idx-1, s2_idx] | |
elif prev_s2 <= prev_s1_s2 and prev_s2 <= prev_s1: | |
source = [s1_idx, s2_idx-1] | |
return source | |
def _process_sequences(self, s1: Sequence, s2: Sequence): | |
len_s1, len_s2 = len(s1), len(s2) | |
edit_matrix = self._initial_edit_matrix(len_s1=len_s1, len_s2=len_s2) | |
source_matrix = self._initial_source_matrix(len_s1=len_s1, len_s2=len_s2) | |
for s1_idx in range(1, len_s1): | |
for s2_idx in range(1, len_s2): | |
is_equal = self._is_equal(s1[s1_idx], s2[s2_idx]) | |
# Possible changes the DP EditDistance Algo tracks | |
# Matched indices. Combines "no change" and "element was replaced". | |
prev_s1_s2 = edit_matrix[s1_idx-1, s2_idx-1] + (0 if is_equal else 1) | |
# Element inserted in s2. | |
prev_s1 = edit_matrix[s1_idx-1, s2_idx] + 1 | |
# Element inserted in s1. | |
prev_s2 = edit_matrix[s1_idx, s2_idx-1] + 1 | |
edit_matrix[s1_idx, s2_idx] = min(prev_s1_s2, prev_s1, prev_s2) | |
source_matrix[s1_idx, s2_idx] = self._calculate_source(s1_idx, s2_idx, prev_s1_s2, prev_s1, prev_s2) | |
Matrices = namedtuple('Results', ('edit_matrix', 'source_matrix')) | |
return Matrices(edit_matrix, source_matrix) | |
def _source_trail_from_matrix(self, source_matrix:np.array): | |
"""What element indices from each sequence were matched up together. Compare/reconstruct what the DP algorithm 'saw'.""" | |
i, j, _ = source_matrix.shape | |
source_trail = [(i-1, j-1)] | |
while source_trail[-1] != (-1, -1): | |
i, j = source_trail[-1] | |
new_source = tuple(source_matrix[i,j]) | |
source_trail.append(new_source) | |
return source_trail[-2::-1] | |
def _source_comparison(self, s1: Sequence, s2: Sequence, source_trail: List[Tuple[int]]): | |
"""Grab the elements of the sequences that were matched up together. Insert None if an element was inserted in the other sequence.""" | |
s1_idx_prev, s2_idx_prev = None, None | |
joined_sources = [] | |
for s1_idx, s2_idx in source_trail: | |
new_link = [None, None] | |
if s1_idx != s1_idx_prev: | |
new_link[0] = s1[s1_idx] | |
if s2_idx != s2_idx_prev: | |
new_link[1] = s2[s2_idx] | |
if all(el is not None for el in new_link) and self._is_equal(new_link[0], new_link[1]): | |
new_link = new_link[0] | |
joined_sources.append(new_link) | |
s1_idx_prev, s2_idx_prev = s1_idx, s2_idx | |
return joined_sources | |
def __call__(self, s1: Sequence, s2: Sequence) -> EditDistanceResult: | |
"""Return an EditDistanceResult with the edit_distance and relevant artefacts""" | |
assert len(s1) > 0 | |
assert len(s2) > 0 | |
edit_matrix, source_matrix = self._process_sequences(s1, s2) | |
source_trail = self._source_trail_from_matrix(source_matrix) | |
joined_sources = self._source_comparison(s1=s1, s2=s2, source_trail=source_trail) | |
i, j = edit_matrix.shape | |
edit_distance = edit_matrix[i-1, j-1] | |
return EditDistanceResult( | |
s1=s1, s2=s2, edit_distance=edit_distance, edit_matrix=edit_matrix, | |
source_matrix=source_matrix, source_trail=source_trail, joined_sources=joined_sources | |
) | |
def edit_distance_less_than(s1: str, s2: str, thresh: int): | |
return CustomEditDistance()(s1, s2).edit_distance < thresh | |
class ApproxStringSeqEditDistance(CustomEditDistance): | |
def __init__(self, thresh: int = 0): | |
self.equality_fn = partial(edit_distance_less_than, thresh=thresh) if thresh > 0 else None | |
def __call__(self, s1: Sequence[str], s2: Sequence[str]): | |
return super().__call__(s1, s2) | |
def test_approx_str_edit_distance(): | |
orig = ['hello', 'it', 'is', 'nice', 'to', 'meet', 'you'] | |
one_typo_one_diff = ['hwllo', 'it', 'is', 'nice', 'to', 'meet', 'them'] | |
return ApproxStringSeqEditDistance(3)(orig, one_typo_one_diff).edit_distance == 1 | |
def joined_source_threshholded_equality(s1:Union[List[str], str], s2:str, thresh:int): | |
if isinstance(s1, list): | |
has_equal = any(edit_distance_less_than(s1=s1_str, s2=s2, thresh=thresh) for s1_str in s1 if s1_str is not None) | |
else: | |
has_equal = edit_distance_less_than(s1=s1, s2=s2, thresh=thresh) | |
return has_equal | |
class JoinedSourceMerger(CustomEditDistance): | |
def __init__(self, thresh: int = 0): | |
self.equality_fn = partial(joined_source_threshholded_equality, thresh=thresh) if thresh > 0 else None | |
def __call__(self, s1: List[Union[List[str], str]], s2: List[str]): | |
return super().__call__(s1, s2).joined_sources |
Author
HandcartCactus
commented
Jan 21, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment