Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created April 8, 2018 10:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devforfu/612fb8ff8c9ab496b60aa06bf4542f92 to your computer and use it in GitHub Desktop.
Save devforfu/612fb8ff8c9ab496b60aa06bf4542f92 to your computer and use it in GitHub Desktop.
SGD training snippet for Medium post
def sgd(x_train, y_train, x_valid, y_valid, variance_threshold=0.1):
threshold = VarianceThreshold(variance_threshold)
sgd_classifier = SGDClassifier(
alpha=1./len(x_train),
class_weight='balanced',
loss='log', penalty='elasticnet',
fit_intercept=False, tol=0.001, n_jobs=-1)
bagging = BaggingClassifier(
base_estimator=sgd_classifier,
bootstrap_features=True,
n_jobs=-1, max_samples=0.5, max_features=0.5)
x_thresh = threshold.fit_transform(x_train)
bagging.fit(x_thresh, y_train)
train_metrics = build_metrics(bagging, x_thresh, y_train)
x_thresh = threshold.transform(x_valid)
valid_metrics = build_metrics(bagging, x_thresh, y_valid)
return bagging, train_metrics, valid_metrics
def build_metrics(model, X, y):
probs = model.predict_proba(X)
preds = np.argmax(probs, axis=1)
metrics = dict(
probs=probs,
preds=preds,
loss=log_loss(y, probs),
accuracy=np.mean(preds == y))
return namedtuple('Predictions', metrics.keys())(**metrics)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment