Skip to content

Instantly share code, notes, and snippets.

@UrosOgrizovic
Created April 21, 2022 10:25
Show Gist options
  • Save UrosOgrizovic/68a10fba8e9741471d765b48a8ca6588 to your computer and use it in GitHub Desktop.
Save UrosOgrizovic/68a10fba8e9741471d765b48a8ca6588 to your computer and use it in GitHub Desktop.
Connected components in Python
import enum
class Connectivity(enum.Enum):
"""
4-connectivity or 8-connectivity
"""
FOUR = 1
EIGHT = 2
def connected_components(matrix, connectivity=Connectivity.FOUR):
"""
Returns the number of connected components in the given matrix.
"""
matrix = pad_matrix(matrix)
equivalent_components = {}
first_pass(matrix, connectivity, equivalent_components)
matrix = matrix[1:-1]
matrix = [row[1:-1] for row in matrix]
print('first pass output', matrix, equivalent_components)
second_pass(matrix, equivalent_components)
return matrix
def pad_matrix(matrix):
"""
Pad matrix with zeroes to avoid if/else
first/last row/col checks in passes
"""
pad_row = [0 for _ in range(len(matrix[0]) + 2)]
matrix = [pad_row] + matrix + [pad_row]
for i in range(1, len(matrix) - 1):
'''pad each row apart from the first and the
last one with two columns of zeroes'''
matrix[i] = [0] + matrix[i] + [0]
return matrix
def first_pass(matrix, connectivity, equivalent_components):
"""
Assign a component to each non-zero cell. The second pass will
remap potential misclassifications to the correct component.
"""
component_idx = 1
# search from (1, 1) to (rows - 1, cols - 1) because of the padding
for i in range(1, len(matrix) - 1):
for j in range(1, len(matrix[i]) - 1):
if matrix[i][j] == 0:
continue
four_conn_comps = [matrix[r][c] for r, c in ((i - 1, j), (i, j - 1), (i + 1, j), (i, j + 1))]
eight_conn_comps = [matrix[r][c] for r, c in ((i - 1, j - 1), (i - 1, j + 1),
(i + 1, j - 1), (i + 1, j + 1))]
if is_island(four_conn_comps, eight_conn_comps, connectivity):
matrix[i][j] = component_idx
component_idx += 1
else:
# only look above and to the left to avoid unnecessary duplicate comparisons
four_conn_comps, eight_conn_comps = four_conn_comps[:2], eight_conn_comps[:2]
comp = determine_component(four_conn_comps, eight_conn_comps, connectivity, equivalent_components)
if comp == 0:
matrix[i][j] = component_idx
component_idx += 1
else:
matrix[i][j] = comp
def is_island(four_conn_comps, eight_conn_comps, connectivity):
"""
Returns True if the given cell is an island.
"""
if sum(four_conn_comps) > 0:
# four-connected
return False
if connectivity == Connectivity.EIGHT:
if sum(eight_conn_comps) > 0:
# eight-connected
return False
return True
def determine_component(four_conn_comps, eight_conn_comps, connectivity, equivalent_components):
"""
Determine component for non-island cell
"""
tgt_components = [c for c in four_conn_comps if c > 0]
if connectivity == Connectivity.EIGHT:
tgt_components += [c for c in eight_conn_comps if c > 0]
tgt_components = list(set(tgt_components))
tgt_components.sort()
for i in range(1, len(tgt_components)):
equivalent_components[tgt_components[i]] = tgt_components[0]
# if possible, return lowest of neighbors' components
return 0 if not tgt_components else tgt_components[0]
def second_pass(matrix, equivalent_components):
max_value_lower_bound = 999999
if equivalent_components:
max_value_lower_bound = max(equivalent_components)
for i in range(len(matrix)):
for j in range(len(matrix[i])):
if matrix[i][j] == 0:
continue
matrix[i][j] = equivalent_components.get(matrix[i][j]) or matrix[i][j]
if max_value_lower_bound < matrix[i][j]:
'''remove gaps between components
e.g. if 2 -> 1 via equivalent_components and there's a 3 left
in the matrix somewhere, then this is used to convert that 3 to a 2'''
matrix[i][j] -= 1
if __name__ == '__main__':
# matrix = [[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0]]
# matrix = [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0]]
# matrix = [[0, 1, 1, 0], [0, 0, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]]
matrix = [[0, 1, 0, 1], [1, 1, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
print('4-conn', connected_components(matrix, Connectivity.FOUR))
print('8-conn', connected_components(matrix, Connectivity.EIGHT))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment