Skip to content

Instantly share code, notes, and snippets.

@yeiichi
Created July 25, 2024 05:13
Show Gist options
  • Save yeiichi/663c39c96574cd8e0d552fd238a89336 to your computer and use it in GitHub Desktop.
Save yeiichi/663c39c96574cd8e0d552fd238a89336 to your computer and use it in GitHub Desktop.
Multi-hot-encode the target column data and inner join the resultant table with the original table.
#!/usr/bin/env python3
import re
from pathlib import Path
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
def add_multi_hot_to_df(df_in, target_col: str):
"""Multi-hot-encode the target column data
and inner join the resultant table with the original table.
Reference:
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html
"""
# Safeguard for the MultiLabelBinarizer process
df_in = df_in.fillna('void')
# Convert data type of the target for MLB.
pattern = re.compile(r'[\[\]\'\"]')
# ['foo', 'bar'] :str -> 'foo, bar': str -> ['foo', 'bar']: list
df_in[target_col] = df_in[target_col].apply(lambda x: pattern.sub('', x).split())
mlb = MultiLabelBinarizer()
mat = mlb.fit_transform(df_in[target_col]) # Multi-hot MATrix
col = mlb.classes_ # Corresponds to the sorted set of classes
dex = pd.DataFrame(mat, columns=col) # Add a Header -> dexter table.
# noinspection PyArgumentEqualDefault
return pd.merge(df_in, dex, how='inner', left_index=True, right_index=True
).replace('void', '')
if __name__ == '__main__':
# Load the source.
src_csv = Path(input('\033[93mSource CSV? >> \033[0m'))
df = pd.read_csv(src_csv)
target_col_ = input('\033[93mTarget column? >> \033[0m')
# Main function call
df = add_multi_hot_to_df(df, target_col_)
df.to_csv(f'{src_csv.stem}_multi_hot.csv', index=False)
print('\033[93mDONE!\033[0m')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment