Skip to content

Instantly share code, notes, and snippets.

@gwerbin
Last active June 16, 2021 23:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gwerbin/3e095a63169042de88c80410ceca8705 to your computer and use it in GitHub Desktop.
Save gwerbin/3e095a63169042de88c80410ceca8705 to your computer and use it in GitHub Desktop.
Sketch of a solution for aligning matrix columns, intended for aligning columns emitted from Scikit-learn classifiers. See https://github.com/scikit-learn/scikit-learn/issues/12845
import numpy as np
import pandas as pd
def fill_array(fillvalue, shape):
"""Create an array filled with a specific value."""
if fillvalue == 0:
if isinstance(fillvalue, int):
array = np.zeros(shape, dtype='int')
else:
array = np.zeros(shape, dtype='float')
elif np.isscalar(fillvalue):
array = np.resize(fillvalue, shape)
else:
# Fill one element at a time, to prevent numpy from trying to be too
# helpful by broadcasting iterables. This is slow, try to avoid it.
array = np.empty(shape, dtype='O')
array_flat = array.ravel()
for i in range(array.size):
array_flat[i] = fillvalue
return array
def noduplicated1d(x):
"""Unique values of a 1d array-like."""
x = np.asarray(x)
if len(x.shape) != 1:
raise ValueError('x must have exactly 1 dimension.')
return pd.Series(x).drop_duplicates().shape[0] == x.shape[0]
def get_positional_correspondence(
data_values,
given_values,
check_unique=True,
check_unique_data=True,
check_unique_given=True,
check_missing=True
):
"""Compute the positional correspondence between two sets of unique values.
This function returns a mapping between the index of each value in the
"data" values, and the corresponding index of the same value in the "given"
values.
The "data" values should be from some "unknown" source, such as a test set.
The "given" values should be from a "known" set of data, such as a training
set.
Example:
data_values = ['cat', 'dog', 'mouse']
known_values = ['cat', 'mouse', 'dog', 'human]
assert get_positional_correspondence(data_values, known_values) == frozenbidict({
0: 0, # 'cat' is at the 0 index in both "data" and "given" sets
1: 3, # 'dog' is at the 1 index in "data", and at the 3 index in "given"
})
This function does not de-duplicate inputs. The sets of values must already
be unique. This is because the order of deduplication might be
implementation-dependent or otherwise unclear to the user, which would be
counterproductive. Note that checking uniqueness is very slow; use
check_unique=False to bypass this check if you know that it already is
true. WARNING: the result of passing non-unique values is undefined and the
result will probably be wrong.
If any element of the "data" set is not in the "given" set, that element
will be silently dropped from the final mapping. This might be desirable in
some cases, but it might also be unexpected and could produce incorrect
results. If you want this behavior or if you otherwise want to bypass this
check check, use check_missing=False.
"""
data_values = np.asarray(data_values)
given_values = np.asarray(given_values)
if check_unique or check_unique_data:
if not noduplicated1d(data_values):
raise ValueError('Data values are not unique.')
if check_unique or check_unique_given:
if not noduplicated1d(given_values):
raise ValueError('Given values are not unique.')
if check_missing:
data_in_given = np.in1d(data_values, given_values)
if not data_in_given.all():
missing_labs = data_values[~data_in_given].tolist()
raise ValueError('Elements of data_values are missing from given_values:\n {}.'.format('\n '.join(missing_labs)))
del data_in_given
# Find the position of each element of 'given_values' in 'data_values'.
# Note that if any element of 'data_values' is missing from 'given_values',
# it will be silently dropped from the output. The "check_missing" section
# above should ensure that this condition never arises.
value_mapping: MutableMapping[_S, _T] = bidict()
for (i_data, data), (i_given, given) in product(enumerate(data_values), enumerate(given_values)):
if data == given:
value_mapping[i_data] = i_given
return frozenbidict(value_mapping)
def reorder_matrix_columns(matrix, column_mapping, n_output_columns=None, fill_value=0.0):
"""Re-order the columns of a matrix (2-dimensional array).
This operation is more general than permutation, in that columns can be
omitted entirely from the output, or the output can contain more columns
than the input. New columns will be filled with fill_value.
Example:
x = np.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
])
y = np.array([
[0, 1],
[1, 0],
[0, 0]
])
mapping = {
0: 1, # Move column 0 to column 1
1: 0, # Move column 1 to column 0
# Column 2 is missing, so it is omitted from the output
}
align_matrix_columns(x, mapping) == y
"""
if len(matrix.shape) != 2:
raise ValueError('matrix must be a 2-dimensional array.')
if n_output_columns is None:
n_output_columns = max(column_mapping.values())+1
n_row = matrix.shape[0]
if len(matrix.shape) == 1:
matrix = matrix.reshape((-1, 1))
# Skip alignment for binary problems
if matrix.shape[1] == 1:
return matrix
if n_output_columns < matrix.shape[1]:
raise ValueError('Must have at least as many given labels as data labels')
result = fill_array(fill_value, (n_row, n_output_columns))
for j_data, j_given in column_mapping.items():
result[:, j_given] = matrix[:, j_data]
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment