Last active
June 16, 2021 23:45
-
-
Save gwerbin/3e095a63169042de88c80410ceca8705 to your computer and use it in GitHub Desktop.
Sketch of a solution for aligning matrix columns, intended for aligning columns emitted from Scikit-learn classifiers. See https://github.com/scikit-learn/scikit-learn/issues/12845
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
def fill_array(fillvalue, shape): | |
"""Create an array filled with a specific value.""" | |
if fillvalue == 0: | |
if isinstance(fillvalue, int): | |
array = np.zeros(shape, dtype='int') | |
else: | |
array = np.zeros(shape, dtype='float') | |
elif np.isscalar(fillvalue): | |
array = np.resize(fillvalue, shape) | |
else: | |
# Fill one element at a time, to prevent numpy from trying to be too | |
# helpful by broadcasting iterables. This is slow, try to avoid it. | |
array = np.empty(shape, dtype='O') | |
array_flat = array.ravel() | |
for i in range(array.size): | |
array_flat[i] = fillvalue | |
return array | |
def noduplicated1d(x): | |
"""Unique values of a 1d array-like.""" | |
x = np.asarray(x) | |
if len(x.shape) != 1: | |
raise ValueError('x must have exactly 1 dimension.') | |
return pd.Series(x).drop_duplicates().shape[0] == x.shape[0] | |
def get_positional_correspondence( | |
data_values, | |
given_values, | |
check_unique=True, | |
check_unique_data=True, | |
check_unique_given=True, | |
check_missing=True | |
): | |
"""Compute the positional correspondence between two sets of unique values. | |
This function returns a mapping between the index of each value in the | |
"data" values, and the corresponding index of the same value in the "given" | |
values. | |
The "data" values should be from some "unknown" source, such as a test set. | |
The "given" values should be from a "known" set of data, such as a training | |
set. | |
Example: | |
data_values = ['cat', 'dog', 'mouse'] | |
known_values = ['cat', 'mouse', 'dog', 'human] | |
assert get_positional_correspondence(data_values, known_values) == frozenbidict({ | |
0: 0, # 'cat' is at the 0 index in both "data" and "given" sets | |
1: 3, # 'dog' is at the 1 index in "data", and at the 3 index in "given" | |
}) | |
This function does not de-duplicate inputs. The sets of values must already | |
be unique. This is because the order of deduplication might be | |
implementation-dependent or otherwise unclear to the user, which would be | |
counterproductive. Note that checking uniqueness is very slow; use | |
check_unique=False to bypass this check if you know that it already is | |
true. WARNING: the result of passing non-unique values is undefined and the | |
result will probably be wrong. | |
If any element of the "data" set is not in the "given" set, that element | |
will be silently dropped from the final mapping. This might be desirable in | |
some cases, but it might also be unexpected and could produce incorrect | |
results. If you want this behavior or if you otherwise want to bypass this | |
check check, use check_missing=False. | |
""" | |
data_values = np.asarray(data_values) | |
given_values = np.asarray(given_values) | |
if check_unique or check_unique_data: | |
if not noduplicated1d(data_values): | |
raise ValueError('Data values are not unique.') | |
if check_unique or check_unique_given: | |
if not noduplicated1d(given_values): | |
raise ValueError('Given values are not unique.') | |
if check_missing: | |
data_in_given = np.in1d(data_values, given_values) | |
if not data_in_given.all(): | |
missing_labs = data_values[~data_in_given].tolist() | |
raise ValueError('Elements of data_values are missing from given_values:\n {}.'.format('\n '.join(missing_labs))) | |
del data_in_given | |
# Find the position of each element of 'given_values' in 'data_values'. | |
# Note that if any element of 'data_values' is missing from 'given_values', | |
# it will be silently dropped from the output. The "check_missing" section | |
# above should ensure that this condition never arises. | |
value_mapping: MutableMapping[_S, _T] = bidict() | |
for (i_data, data), (i_given, given) in product(enumerate(data_values), enumerate(given_values)): | |
if data == given: | |
value_mapping[i_data] = i_given | |
return frozenbidict(value_mapping) | |
def reorder_matrix_columns(matrix, column_mapping, n_output_columns=None, fill_value=0.0): | |
"""Re-order the columns of a matrix (2-dimensional array). | |
This operation is more general than permutation, in that columns can be | |
omitted entirely from the output, or the output can contain more columns | |
than the input. New columns will be filled with fill_value. | |
Example: | |
x = np.array([ | |
[1, 0, 0], | |
[0, 1, 0], | |
[0, 0, 1] | |
]) | |
y = np.array([ | |
[0, 1], | |
[1, 0], | |
[0, 0] | |
]) | |
mapping = { | |
0: 1, # Move column 0 to column 1 | |
1: 0, # Move column 1 to column 0 | |
# Column 2 is missing, so it is omitted from the output | |
} | |
align_matrix_columns(x, mapping) == y | |
""" | |
if len(matrix.shape) != 2: | |
raise ValueError('matrix must be a 2-dimensional array.') | |
if n_output_columns is None: | |
n_output_columns = max(column_mapping.values())+1 | |
n_row = matrix.shape[0] | |
if len(matrix.shape) == 1: | |
matrix = matrix.reshape((-1, 1)) | |
# Skip alignment for binary problems | |
if matrix.shape[1] == 1: | |
return matrix | |
if n_output_columns < matrix.shape[1]: | |
raise ValueError('Must have at least as many given labels as data labels') | |
result = fill_array(fill_value, (n_row, n_output_columns)) | |
for j_data, j_given in column_mapping.items(): | |
result[:, j_given] = matrix[:, j_data] | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment