Created
August 21, 2018 08:45
-
-
Save lukauskas/02d68b7c89e21dd6782992f736b61108 to your computer and use it in GitHub Desktop.
Wrapper around R's precrec package
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rpy2.robjects import pandas2ri | |
from rpy2.robjects.packages import importr | |
import rpy2.robjects | |
import fa2 | |
import gc | |
import pandas as pd | |
import numpy as np | |
precrec = importr('precrec') | |
pandas2ri.activate() | |
# Make sure your R is configured appropriately | |
def r_precrec_prc(true, scores): | |
""" | |
Computes PRC curve using R's `precrec` package | |
:param true: true labels | |
:param scores: scores assigned for class 1 for each of the labels. | |
:return: DataFrame of precision-recall curves, area under the curve | |
""" | |
r_msmdat = precrec.mmdata(scores, true) | |
r_mscurves = precrec.evalmod(r_msmdat) | |
r_prcs = r_mscurves[r_mscurves.names.index('prcs')] | |
assert len(r_prcs) == 1 | |
r_prcs = r_prcs[0] | |
prc_df = pd.DataFrame({k: np.array(v) for k, v in r_prcs.items()}) | |
prc_df.rename(columns={'x': 'recall', 'y': 'precision'}, inplace=True) | |
prc_df['orig_points'] = prc_df['orig_points'].astype(bool) | |
r_aucs = precrec.auc(r_mscurves) | |
r_aucs = pd.DataFrame({k: np.asarray(v) for k, v in r_aucs.items()}) | |
r_aucs = r_aucs[r_aucs['curvetypes'] == 'PRC'] | |
assert len(r_aucs) == 1 | |
auc_prc = r_aucs['aucs'].iloc[0] | |
# Compute partial PRC at different points | |
partial_prcs = [] | |
for recall_threshold in np.arange(0.05, 1.01, 0.05): | |
partial_prc = precrec.part(r_mscurves, | |
xlim=rpy2.robjects.r.c(0, recall_threshold)) | |
r_df = precrec.pauc(partial_prc) | |
df = pd.DataFrame(np.asarray(r_df), columns=r_df.rownames, index=r_df.colnames).T | |
prc = df.query('curvetypes == "PRC"') | |
assert len(prc) == 1 | |
prc = prc.iloc[0] | |
row = [recall_threshold, prc['paucs'], prc['spaucs']] | |
partial_prcs.append(row) | |
del r_df | |
partial_prcs = pd.DataFrame(partial_prcs, columns=['recall', 'pAUC', 'spAUC']).set_index( | |
'recall').astype(float) | |
# Cleanup after R | |
del r_msmdat, r_prcs, r_mscurves, r_aucs | |
gc.collect() | |
# Use sklearn to get prc curve with thresholds | |
_precision, _recall, _thresholds = precision_recall_curve(true, scores) | |
_thresholds = np.append(_thresholds, np.nan) | |
prc_df_with_thresholds = pd.DataFrame({'precision': _precision, | |
'recall': _recall, | |
'threshold': _thresholds}) | |
prc_df_with_thresholds = prc_df_with_thresholds.sort_values(by='threshold') | |
return prc_df, auc_prc, partial_prcs, prc_df_with_thresholds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment