Skip to content

Instantly share code, notes, and snippets.

@AdityaKane2001
Created July 18, 2022 14:18
Show Gist options
  • Save AdityaKane2001/60e39e5ace9906004152d649dffe24ee to your computer and use it in GitHub Desktop.
Save AdityaKane2001/60e39e5ace9906004152d649dffe24ee to your computer and use it in GitHub Desktop.
from scipy.stats import norm as dist_model
import numpy as np
import torch
cfg = None # training config
device = None # training device object
seen_classes = list(range(5)) # list of seen classes
OOD_CLASS_NUMBER = -1
def fit(prob_pos_X):
prob_pos = [p for p in prob_pos_X]+[2-p for p in prob_pos_X]
pos_mu, pos_std = dist_model.fit(prob_pos)
return pos_mu, pos_std
for epoch in range(cfg.epochs):
model = None
# Train model
seen_train_X_predictions = []
model.eval()
seen_train_y = []
for batch in train_dl:
batch = [elem.to(device) for elem in batch]
outputs = model(batch[0])
seen_train_X_predictions.append(outputs.detach())
seen_train_y.append(batch[1].cpu().numpy())
seen_train_X_predictions = torch.concat(seen_train_X_predictions, dim=0).detach().cpu().numpy()
seen_train_y = np.concatenate(seen_train_y, axis=0)
mu_stds = []
for i in range(len(seen_classes)):
pos_mu, pos_std = fit(seen_train_X_predictions[seen_train_y==i, i])
mu_stds.append([pos_mu, pos_std])
# print(mu_stds)
test_X_pred = []
test_y_gt = []
model.eval()
for batch in test_dl: # included ood samples
batch = [elem.to(device) for elem in batch]
outputs = model(batch[0])
test_X_pred.append(outputs.detach())
test_y_gt.append(batch[1].cpu().numpy())
if len(test_X_pred[-1].shape) == 1:
test_X_pred[-1] = test_X_pred[-1].unsqueeze(0)
test_X_pred = torch.concat(test_X_pred, dim=0).detach().cpu().numpy()
test_y_gt = np.concatenate(test_y_gt, axis = 0)
# print(test_X_pred.shape, test_y_gt.shape)
test_y_pred = [] # our final model predictions
for p in test_X_pred:# loop every test prediction
max_class = np.argmax(p)# predicted class
max_value = np.max(p)# predicted probability
threshold = max(0.5, 1. - cfg.scale * mu_stds[max_class][1])#find threshold for the predicted class
if max_value > threshold:
test_y_pred.append(max_class) #predicted probability is greater than threshold, accept
else:
test_y_pred.append(OOD_CLASS_NUMBER) #otherwise, reject
accuracy, fscore = calculate_metrics(test_y_gt, test_y_pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment