Skip to content

Instantly share code, notes, and snippets.

@alexandrebarachant
Created May 19, 2020 16:20
Show Gist options
  • Save alexandrebarachant/f1e90aaeeb105d4f8697ae3ff20bc0fc to your computer and use it in GitHub Desktop.
Save alexandrebarachant/f1e90aaeeb105d4f8697ae3ff20bc0fc to your computer and use it in GitHub Desktop.
confusion mat and predictions
from pybmi.modeling.loading.benchmarks import EventsBenchmarkLoader
from cogdata_service.service import simple_api
from workflow.experiments.utils import partition
from pybmi.modeling.events.metrics import events_confusion
import json
import matplotlib.pyplot as plt
def plot_preds(preds, time, alpha=1, ax=None, c='k', events=None):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=[15, 4])
fig.patch.set_facecolor('white')
offsets = 1.1*np.arange(preds.shape[-1])
pr = preds- offsets[np.newaxis]
ax.plot(time, pr, c=c, alpha=alpha, lw=1)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.set_yticks(-offsets + 0.5)
if events is not None:
ax.set_yticklabels(events)
ax.set_xlim(time[0], time[-1])
ax.set_xlabel('Time (s)')
return ax
api = simple_api.API()
model_id = '61e19360-942d-47c5-8c79-f761b56cde1d'
loader = EventsBenchmarkLoader(config={'model_predictions': {'model_id': model_id}})
m = api.get_model(model_id)
config = m['config']
datas = partition(loader, config['data']['validation'][0])
preds = np.concatenate([data.main['model predictions'] for data in datas], axis=0)
labels = np.concatenate([data.main['events'] for data in datas], axis=0)
time = np.concatenate([data.main['time'] for data in datas], axis=0)
scores = events_confusion(preds, labels, threshold=0.4)
ev_names = json.loads(datas[0].metadata['events_names'])
fig = plt.figure(figsize=[4, 4])
fig.patch.set_facecolor('white')
sns.heatmap(scores, annot=True, fmt='.0f',
xticklabels=ev_names + ['null'], yticklabels=ev_names + ['null'],
cbar=False, square=True, cmap='Greys')
plt.ylabel('True Event')
plt.xlabel('Predicted Event');
ax = plot_preds(preds, time, events=ev_names, c='r')
plot_preds(labels, time, events=ev_names, c='k', ax=ax, alpha=0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment