Skip to content

Instantly share code, notes, and snippets.

@lukauskas
Created August 21, 2018 08:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukauskas/02d68b7c89e21dd6782992f736b61108 to your computer and use it in GitHub Desktop.
Save lukauskas/02d68b7c89e21dd6782992f736b61108 to your computer and use it in GitHub Desktop.
Wrapper around R's precrec package
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr
import rpy2.robjects
import fa2
import gc
import pandas as pd
import numpy as np
precrec = importr('precrec')
pandas2ri.activate()
# Make sure your R is configured appropriately
def r_precrec_prc(true, scores):
"""
Computes PRC curve using R's `precrec` package
:param true: true labels
:param scores: scores assigned for class 1 for each of the labels.
:return: DataFrame of precision-recall curves, area under the curve
"""
r_msmdat = precrec.mmdata(scores, true)
r_mscurves = precrec.evalmod(r_msmdat)
r_prcs = r_mscurves[r_mscurves.names.index('prcs')]
assert len(r_prcs) == 1
r_prcs = r_prcs[0]
prc_df = pd.DataFrame({k: np.array(v) for k, v in r_prcs.items()})
prc_df.rename(columns={'x': 'recall', 'y': 'precision'}, inplace=True)
prc_df['orig_points'] = prc_df['orig_points'].astype(bool)
r_aucs = precrec.auc(r_mscurves)
r_aucs = pd.DataFrame({k: np.asarray(v) for k, v in r_aucs.items()})
r_aucs = r_aucs[r_aucs['curvetypes'] == 'PRC']
assert len(r_aucs) == 1
auc_prc = r_aucs['aucs'].iloc[0]
# Compute partial PRC at different points
partial_prcs = []
for recall_threshold in np.arange(0.05, 1.01, 0.05):
partial_prc = precrec.part(r_mscurves,
xlim=rpy2.robjects.r.c(0, recall_threshold))
r_df = precrec.pauc(partial_prc)
df = pd.DataFrame(np.asarray(r_df), columns=r_df.rownames, index=r_df.colnames).T
prc = df.query('curvetypes == "PRC"')
assert len(prc) == 1
prc = prc.iloc[0]
row = [recall_threshold, prc['paucs'], prc['spaucs']]
partial_prcs.append(row)
del r_df
partial_prcs = pd.DataFrame(partial_prcs, columns=['recall', 'pAUC', 'spAUC']).set_index(
'recall').astype(float)
# Cleanup after R
del r_msmdat, r_prcs, r_mscurves, r_aucs
gc.collect()
# Use sklearn to get prc curve with thresholds
_precision, _recall, _thresholds = precision_recall_curve(true, scores)
_thresholds = np.append(_thresholds, np.nan)
prc_df_with_thresholds = pd.DataFrame({'precision': _precision,
'recall': _recall,
'threshold': _thresholds})
prc_df_with_thresholds = prc_df_with_thresholds.sort_values(by='threshold')
return prc_df, auc_prc, partial_prcs, prc_df_with_thresholds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment