Last active
July 31, 2023 09:25
-
-
Save imvladikon/7a61aac848ac1a04763e61503e4a63d3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Simple fuzzy grouping of the list of the dictionaries using any string field and string similarities functions | |
Dependencies: | |
- similarities | |
`pip install rapidfuzz jarowinkler -q` | |
- scipy for connected components (DisjointSet) | |
""" | |
from itertools import combinations, tee | |
from typing import Dict, Optional | |
from jarowinkler import jarowinkler_similarity | |
from rapidfuzz.distance import metrics_cpp | |
from scipy._lib._disjoint_set import DisjointSet | |
from scipy.special import comb | |
from tqdm import tqdm | |
STRING_FUNCTIONS = [ | |
metrics_cpp.hamming_normalized_similarity, | |
metrics_cpp.osa_normalized_similarity, | |
metrics_cpp.jaro_normalized_similarity, | |
metrics_cpp.indel_normalized_similarity, | |
metrics_cpp.damerau_levenshtein_normalized_similarity, | |
metrics_cpp.jaro_winkler_normalized_similarity, | |
metrics_cpp.lcs_seq_normalized_similarity, | |
jarowinkler_similarity] | |
def pair_similarity(name_1, name_2): | |
return max(func(name_1, name_2) for func in STRING_FUNCTIONS) | |
def get_prefix(item, column_name: Optional[str] = None): | |
if column_name is None: | |
return "".join(sorted(t[:3] for t in item.split())) | |
else: | |
return "".join(sorted(t[:3] for t in item[column_name].split())) | |
def get_sorting_neighborhood(items, column_name, window_size=50, step=40): | |
list_items = list(items.items()) | |
for order_id, (id, item) in enumerate(list_items): | |
if order_id % step == 0: | |
indices = range(max(0, order_id - window_size), | |
min(len(items), order_id + window_size)) | |
yield [list_items[j] for j in indices] | |
def get_prefix_blocks(items, column_name: Optional[str] = None): | |
block = [] | |
prefix = get_prefix(items[0], column_name) | |
for i, item in items.items(): | |
if get_prefix(item, column_name) == prefix: | |
block.append((i, item)) | |
else: | |
yield block | |
block = [(i, item)] | |
prefix = get_prefix(item, column_name) | |
yield block | |
def get_cartesian_blocks(items, column_name: Optional[str] = None): | |
yield list(items.items()) | |
class FuzzyStringGrouper: | |
def __init__(self, | |
blocking="greedy", | |
sim_f=pair_similarity, | |
threshold=0.8, | |
max_connections_per_item=100): | |
""" | |
Group items by similarity using disjoint set data structure. | |
Args: | |
items: List of items to be grouped. | |
sim_f: Similarity function to compare two items. | |
threshold: Similarity threshold above which items are considered similar. | |
max_connections_per_item: Maximum number of connections each item can have. | |
simple heuristic to avoid due to connected components, connecting everything in one cluster | |
Returns: | |
List of groups containing similar items. | |
""" # noqa | |
self.blocking = blocking | |
self.sim_f = sim_f | |
self.threshold = threshold | |
self.max_connections_per_item = max_connections_per_item | |
self.items: Optional[Dict] = None | |
def iter_items_pairs(self, items, blocking, column_name, **kwargs): | |
blocking_factory = { | |
"sorting": get_sorting_neighborhood, | |
"prefix": get_prefix_blocks, | |
"greedy": get_cartesian_blocks, | |
"cartesian": get_cartesian_blocks | |
} | |
iter_blocks = blocking_factory[blocking](items, column_name=column_name, **kwargs) | |
items_pairs = (pair for block in iter_blocks for pair in combinations(block, 2)) | |
items_pairs, countable = tee(items_pairs) | |
total = sum(1 for _ in countable) | |
pbar = tqdm(items_pairs, total=total) | |
return pbar | |
def by(self, column_name, return_only_index=False, **kwargs): | |
sorted_items = dict(sorted(self.items.items(), | |
key=lambda item: get_prefix(item[1][column_name]))) | |
union_find = DisjointSet(range(len(sorted_items))) | |
connections_count = {i: 0 for i in range(len(sorted_items))} | |
similarity_cache = {} | |
pbar = self.iter_items_pairs(sorted_items, self.blocking, column_name) | |
for (i, item_1), (j, item_2) in pbar: | |
if union_find.connected(i, j): | |
continue | |
if connections_count[i] >= self.max_connections_per_item or connections_count[ | |
j] >= self.max_connections_per_item: | |
continue | |
# TODO: add self-similarity matrix values into cache using cdist from rapidfuzz ? | |
pair_key = (i, j) | |
if pair_key not in similarity_cache: | |
similarity_cache[pair_key] = self.sim_f(item_1[column_name], | |
item_2[column_name]) | |
if similarity_cache[pair_key] > self.threshold: | |
if connections_count[i] == self.max_connections_per_item: | |
# If item_1 reached the maximum connections, remove the least similar connection | |
connected_items = union_find.items[i] | |
least_similar_item = min(connected_items, | |
key=lambda x: similarity_cache.get((i, x), | |
0)) | |
union_find.disconnect(i, least_similar_item) | |
connections_count[least_similar_item] -= 1 | |
if connections_count[j] == self.max_connections_per_item: | |
# If item_2 reached the maximum connections, remove the least similar connection | |
connected_items = union_find.items[j] | |
least_similar_item = min(connected_items, | |
key=lambda x: similarity_cache.get((j, x), | |
0)) | |
union_find.disconnect(j, least_similar_item) | |
connections_count[least_similar_item] -= 1 | |
union_find.merge(i, j) | |
connections_count[i] += 1 | |
connections_count[j] += 1 | |
for sb in union_find.subsets(): | |
yield [self.items[i] for i in sb] | |
def for_items(self, items): | |
self.items = dict(enumerate(items)) | |
return self | |
if __name__ == '__main__': | |
data = [ | |
{"id": 1, "name": "New York"}, | |
{"id": 2, "name": "New York City"}, | |
{"id": 3, "name": "Los Angeles"}, | |
{"id": 4, "name": "Los Angeles, CA"}, | |
{"id": 5, "name": "Los Angeles, California"}, | |
{"id": 6, "name": "Los Angeles, California, USA"}, | |
{"id": 7, "name": "New York, United States"}, | |
] | |
fuzzy_group = FuzzyStringGrouper(blocking="sorting", | |
threshold=0.8, | |
max_connections_per_item=2) | |
groups = fuzzy_group.for_items(data).by("name") | |
for group in groups: | |
if len(group) > 1: | |
print(group) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment