Skip to content

Instantly share code, notes, and snippets.

@sanchezg
Forked from jmansilla/onehot_transformer.py
Created August 26, 2016 16:03
Show Gist options
  • Save sanchezg/0013d9cec441e1c35ffd8837725f7167 to your computer and use it in GitHub Desktop.
Save sanchezg/0013d9cec441e1c35ffd8837725f7167 to your computer and use it in GitHub Desktop.
import numpy as np
class OneHotTransformer:
def __init__(self, func):
self.f = func
def fit(self, X, y=None):
unseen = object()
seen = set()
for x in X:
seen.add(self.f(x))
self.seen = list(sorted(seen)) + [unseen]
return self
def transform(self, X):
return np.array([self.transform_one(x) for x in X])
def transform_one(self, x):
result = [0] * len(self.seen)
value = self.f(x)
if value in self.seen:
result[self.seen.index(value)] = 1
else:
result[-1] = 1
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment