Skip to content

Instantly share code, notes, and snippets.

@sorindragan
Created December 6, 2018 17:10
Show Gist options
  • Save sorindragan/2b4e3ec007feb233a8b05f91749ca65f to your computer and use it in GitHub Desktop.
Save sorindragan/2b4e3ec007feb233a8b05f91749ca65f to your computer and use it in GitHub Desktop.
feature_computer: Function that computes features before training a model using sklearn pipeline. train_model: Function that tains a sklearn model and writes it in a pickle. train_utils: Helper function
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
class FeatureComputer(BaseEstimator, TransformerMixin):
def __init__(self, train_data):
self.data = train_data
def transform(self, X, y=None):
return self.process(X)
def fit(self, X, y=None):
return self
def process(self, data_lines):
X = []
for data_line in data_lines:
X.append(self.get_feature_vect(data_line))
return csr_matrix(np.array(X, dtype=np.float32))
def get_feature_vect(self, data_line):
feature_vec = []
for d in self.data:
# calculate features and fill up feature vector
break
return feature_vec
import pickle
import os.path
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from train import FeatureComputer
from train_utils import data_extraction
def create_models():
json_path = 'data.json'
# create x_train, y_train
# if only one label(section) of data, meaning only one model to train
# --> use test train split on data
X_train, y_train = data_extraction(json_path, label)
ppl = Pipeline([
('features', FeatureComputer(NGRAMS[section])),
('clf', RandomForestClassifier(n_estimators=100,
class_weight="balanced",
max_features=None,
)
)
])
ppl.fit(X_train, y_train)
file_name = section + "_model.pkl"
with open(file_name, 'wb') as file_obj:
pickle.dump(ppl, file_obj)
import numpy as np
import json
def label_data(path, label_name):
X = []
y = []
data = json.load(open(path, 'rb'))
for element in data:
# parse data and create element
data_elem = 'something'
if label_name in element:
y.append(1)
else:
y.append(0)
X.append(data_elem)
X = np.array(X)
y = np.array(y)
y.reshape(-1, 1)
return X, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment