Skip to content

Instantly share code, notes, and snippets.

@kgilpin
Last active June 25, 2024 14:57
Show Gist options
  • Save kgilpin/32857849619aed2e4d4df88152333909 to your computer and use it in GitHub Desktop.
Save kgilpin/32857849619aed2e4d4df88152333909 to your computer and use it in GitHub Desktop.
scikit-learn-12471

Fixes: scikit-learn/scikit-learn#12470

Title: Fix OneHotEncoder to Safely Handle String Categories for ignore Unknown Strategy

Problem: The OneHotEncoder from scikit-learn raises a ValueError during the transform method when handle_unknown='ignore' is set and the categories are strings. This occurs if the string length of any unknown category being transformed exceeds the length of the strings encountered during fitting. The error arises because OneHotEncoder.categories_[i][0] (the first category) is being used to replace unknown entries, and if it is a longer string than the target array's dtype allows, this string gets truncated, causing subsequent array operations to fail.

Analysis: The root cause of the issue is the discrepancy in memory handling between strings of different lengths when dealing with NumPy arrays. Specifically, when the handle_unknown='ignore' option is used, unknown categories are replaced by a known category from the categories_ array. If this known category string length exceeds that of the array it is replacing, it leads to truncation and eventually raises the ValueError.

Proposed Changes:

  1. Locate and Modify the OneHotEncoder Code:

    • We need to adjust the OneHotEncoder implementation to ensure that the arrays used for transformation are appropriately sized to handle the data being inserted.
  2. Modify the _transform Method in OneHotEncoder:

    • Locate the _transform method in the sklearn/preprocessing/_encoders.py file.
    • Change the handling of unknown categories to first check the size of the elements in the array. If necessary, cast arrays containing string elements to object dtype.

Proposed Changes

File: sklearn/preprocessing/_encoders.py

  1. Import Necessary Utilities:

    • Import np (NumPy) if not already imported.
  2. Modify the _transform Method:

    • Add a check to see if the dtype of the array can sufficiently handle the replacement category.
    • Cast the array to object dtype if necessary.

By making these changes, we ensure that the replacement category string can fit into the transformed array without truncating, thus avoiding the ValueError.

Summary of Changes

  • Update the _transform method to check the size of elements and cast arrays to object dtype if the replacement category string exceeds the array’s allowable string length.

This will prevent errors when unknown string categories are handled during transformation with handle_unknown='ignore'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment