Skip to content

Instantly share code, notes, and snippets.

@geblanco
Last active February 9, 2021 12:50
Show Gist options
  • Save geblanco/a906dd1f3ccaaa123f9ca28793a30d64 to your computer and use it in GitHub Desktop.
Save geblanco/a906dd1f3ccaaa123f9ca28793a30d64 to your computer and use it in GitHub Desktop.
5a6,8
> import json
>
> import sklearn.metrics as metrics
6a10,11
> from sklearn.metrics import precision_recall_curve
> from dvc.api import make_checkpoint
10c15
< if len(sys.argv) != 3:
---
> if len(sys.argv) != 4:
16a22,24
> matrix_file = os.path.join(input, 'test.pkl')
> scores_file = sys.argv[3]
>
18a27
> max_n_estimators = n_estimators * 4
25a35,40
> with open(matrix_file, 'rb') as fd:
> test_matrix = pickle.load(fd)
>
> test_labels = test_matrix[:, 1].toarray()
> test_x = test_matrix[:, 2:]
>
30,34c45,62
< clf = RandomForestClassifier(
< n_estimators=n_estimators,
< n_jobs=2,
< random_state=seed
< )
---
> for n_est in range(n_estimators, max_n_estimators+1, 10):
> clf = RandomForestClassifier(
> n_estimators=n_est,
> n_jobs=2,
> random_state=seed
> )
>
> clf.fit(x, labels)
>
> with open(output, 'wb') as fd:
> pickle.dump(clf, fd)
>
> predictions_by_class = clf.predict_proba(test_x)
> predictions = predictions_by_class[:, 1]
>
> precision, recall, thresholds = precision_recall_curve(test_labels, predictions)
>
> auc = metrics.auc(recall, precision)
36c64,65
< clf.fit(x, labels)
---
> with open(scores_file, 'w') as fd:
> yaml.dump({'auc': float(auc)}, fd)
38,39c67
< with open(output, 'wb') as fd:
< pickle.dump(clf, fd)
---
> make_checkpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment