Skip to content

Instantly share code, notes, and snippets.

@bryant1410
Forked from kljensen/onehot_pandas_scikit.py
Last active August 29, 2015 14:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bryant1410/4d81cd1e325302bd9f08 to your computer and use it in GitHub Desktop.
Save bryant1410/4d81cd1e325302bd9f08 to your computer and use it in GitHub Desktop.
This function helps to do a one hot encoding of a pandas' dataframe instead of a features numpy matrix. This has some advantages, for instance the fact of knowing which new columns have been created (identifying them easily).
# -*- coding: utf-8 -*-
""" Small script that shows hot to do one hot encoding
of categorical columns in a pandas DataFrame.
See:
http://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.OneHotEncoder.html#sklearn.preprocessing.OneHotEncoder
http://scikit-learn.org/dev/modules/generated/sklearn.feature_extraction.DictVectorizer.html
"""
import pandas
import random
import numpy
from sklearn.feature_extraction import DictVectorizer
def one_hot_dataframe(data, cols, replace=False):
""" Takes a dataframe and a list of columns that need to be encoded.
Returns a 3-tuple comprising the data, the vectorized data,
and the fitted vectorizor.
"""
vec = DictVectorizer()
vecData = pd.DataFrame(vec.fit_transform(data[cols].to_dict(outtype='records')).toarray())
vecData.columns = vec.get_feature_names()
vecData.index = data.index
if replace is True:
data = data.drop(cols, axis=1)
data = data.join(vecData)
return (data, vecData, vec)
def main():
# Get a random DataFrame
df = pandas.DataFrame(numpy.random.randn(25, 3), columns=['a', 'b', 'c'])
# Make some random categorical columns
df['e'] = [random.choice(('Chicago', 'Boston', 'New York')) for i in range(df.shape[0])]
df['f'] = [random.choice(('Chrome', 'Firefox', 'Opera', "Safari")) for i in range(df.shape[0])]
print df
# Vectorize the categorical columns: e & f
df, _, _ = one_hot_dataframe(df, ['e', 'f'], replace=True)
print df
if __name__ == '__main__':
main()
Example output
Original DataFrame
------------------
a b c e f
0 -0.219222 -0.368154 0.388479 New York Opera
1 1.879536 -0.033210 -0.099437 New York Firefox
2 0.909419 -0.498084 0.084163 New York Safari
3 -0.002199 -0.692806 -0.844436 New York Opera
4 -0.109549 -0.367305 -0.520999 Chicago Firefox
5 -0.400515 -1.202466 -1.664337 New York Chrome
6 -2.241892 -0.888160 -0.332380 New York Chrome
7 -0.432767 -1.794931 0.975878 Chicago Chrome
8 -1.401193 -0.478224 0.112729 Chicago Safari
9 -1.493518 0.584824 0.652820 New York Opera
10 0.525359 -0.885912 0.474492 Boston Firefox
11 0.671226 -0.733788 0.272915 Boston Chrome
12 0.775901 -0.163745 0.628414 Boston Opera
13 -1.158007 -0.495240 1.183522 New York Chrome
14 -1.200085 1.083380 -0.692171 Boston Safari
15 0.872763 -2.119172 -0.169185 Boston Chrome
16 1.423514 -1.802891 -2.947628 Boston Safari
17 -0.547940 -0.788654 -1.065005 Boston Safari
18 -0.380440 2.050783 1.548453 New York Firefox
19 -0.095913 1.260104 0.196552 Boston Opera
20 -1.558961 1.240931 -0.165927 Boston Safari
21 1.111618 -0.309371 -0.803404 Chicago Chrome
22 0.348182 -1.200900 0.307754 New York Firefox
23 -0.834901 0.188590 -1.115227 New York Chrome
24 1.463240 -1.559017 0.954684 New York Chrome
Encoded DataFrame
-----------------
a b c e=Boston e=Chicago e=New York f=Chrome f=Firefox f=Opera f=Safari
0 -0.219222 -0.368154 0.388479 0 0 1 0 0 1 0
1 1.879536 -0.033210 -0.099437 0 0 1 0 1 0 0
2 0.909419 -0.498084 0.084163 0 0 1 0 0 0 1
3 -0.002199 -0.692806 -0.844436 0 0 1 0 0 1 0
4 -0.109549 -0.367305 -0.520999 0 1 0 0 1 0 0
5 -0.400515 -1.202466 -1.664337 0 0 1 1 0 0 0
6 -2.241892 -0.888160 -0.332380 0 0 1 1 0 0 0
7 -0.432767 -1.794931 0.975878 0 1 0 1 0 0 0
8 -1.401193 -0.478224 0.112729 0 1 0 0 0 0 1
9 -1.493518 0.584824 0.652820 0 0 1 0 0 1 0
10 0.525359 -0.885912 0.474492 1 0 0 0 1 0 0
11 0.671226 -0.733788 0.272915 1 0 0 1 0 0 0
12 0.775901 -0.163745 0.628414 1 0 0 0 0 1 0
13 -1.158007 -0.495240 1.183522 0 0 1 1 0 0 0
14 -1.200085 1.083380 -0.692171 1 0 0 0 0 0 1
15 0.872763 -2.119172 -0.169185 1 0 0 1 0 0 0
16 1.423514 -1.802891 -2.947628 1 0 0 0 0 0 1
17 -0.547940 -0.788654 -1.065005 1 0 0 0 0 0 1
18 -0.380440 2.050783 1.548453 0 0 1 0 1 0 0
19 -0.095913 1.260104 0.196552 1 0 0 0 0 1 0
20 -1.558961 1.240931 -0.165927 1 0 0 0 0 0 1
21 1.111618 -0.309371 -0.803404 0 1 0 1 0 0 0
22 0.348182 -1.200900 0.307754 0 0 1 0 1 0 0
23 -0.834901 0.188590 -1.115227 0 0 1 1 0 0 0
24 1.463240 -1.559017 0.954684 0 0 1 1 0 0 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment