Skip to content

Instantly share code, notes, and snippets.

@douglaspsteen
Created August 29, 2020 22:24
Show Gist options
  • Select an option

  • Save douglaspsteen/62e86b20b6bb730f0cea31a1bcaedbe2 to your computer and use it in GitHub Desktop.

Select an option

Save douglaspsteen/62e86b20b6bb730f0cea31a1bcaedbe2 to your computer and use it in GitHub Desktop.
# Initiate iteration counter
iterations = 0
# Containers to hold f1_scores and # of pseudo-labels
train_f1s = []
test_f1s = []
pseudo_labels = []
# Assign value to initiate while loop
high_prob = [1]
# Loop will run until there are no more high-probability pseudo-labels
while len(high_prob) > 0:
# Fit classifier and make train/test predictions
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)
y_hat_train = clf.predict(X_train)
y_hat_test = clf.predict(X_test)
# Calculate and print iteration # and f1 scores, and store f1 scores
train_f1 = f1_score(y_train, y_hat_train)
test_f1 = f1_score(y_test, y_hat_test)
print(f"Iteration {iterations}")
print(f"Train f1: {train_f1}")
print(f"Test f1: {test_f1}")
train_f1s.append(train_f1)
test_f1s.append(test_f1)
# Generate predictions and probabilities for unlabeled data
print(f"Now predicting labels for unlabeled data...")
pred_probs = clf.predict_proba(X_unlabeled)
preds = clf.predict(X_unlabeled)
prob_0 = pred_probs[:,0]
prob_1 = pred_probs[:,1]
# Store predictions and probabilities in dataframe
df_pred_prob = pd.DataFrame([])
df_pred_prob['preds'] = preds
df_pred_prob['prob_0'] = prob_0
df_pred_prob['prob_1'] = prob_1
df_pred_prob.index = X_unlabeled.index
# Separate predictions with > 99% probability
high_prob = pd.concat([df_pred_prob.loc[df_pred_prob['prob_0'] > 0.99],
df_pred_prob.loc[df_pred_prob['prob_1'] > 0.99]],
axis=0)
print(f"{len(high_prob)} high-probability predictions added to training data.")
pseudo_labels.append(len(high_prob))
# Add pseudo-labeled data to training data
X_train = pd.concat([X_train, X_unlabeled.loc[high_prob.index]], axis=0)
y_train = pd.concat([y_train, high_prob.preds])
# Drop pseudo-labeled instances from unlabeled data
X_unlabeled = X_unlabeled.drop(index=high_prob.index)
print(f"{len(X_unlabeled)} unlabeled instances remaining.\n")
# Update iteration counter
iterations += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment