Skip to content

Instantly share code, notes, and snippets.

@cinjon
Created June 11, 2020 14:02
Show Gist options
  • Save cinjon/2efb13ade36defc324292156d12e2c95 to your computer and use it in GitHub Desktop.
Save cinjon/2efb13ade36defc324292156d12e2c95 to your computer and use it in GitHub Desktop.
def test(model, dataloader, config):
model.eval()
num_batches = config['num_test_batches']
running_loss = 0.0
running_accuracy = 0.0
with torch.no_grad():
with tqdm(dataloader, total=num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device)
train_targets = train_targets.to(device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device)
test_targets = test_targets.to(device)
train_embeddings = model(train_inputs, 'neighbor')
test_embeddings = model(test_inputs, 'input')
# Get prototypes
batch_size, embedding_size = train_embeddings.size(0), train_embeddings.size(-1)
with torch.no_grad():
ones = torch.ones_like(train_targets, dtype=train_embeddings.dtype)
num_samples = ones.new_zeros((train_targets.size(0), model.module.num_classes_per_task))
num_samples.scatter_add_(1, train_targets, ones)
num_samples.unsqueeze_(-1)
num_samples = torch.max(num_samples, torch.ones_like(num_samples))
prototypes = train_embeddings.new_zeros(
(batch_size, model.module.num_classes_per_task, embedding_size))
indices = train_targets.unsqueeze(-1).expand_as(train_embeddings)
prototypes.scatter_add_(1, indices, train_embeddings).div_(num_samples)
# Get the loss
squared_distances = torch.sum(
(prototypes.unsqueeze(2) - test_embeddings.unsqueeze(1))**2, dim=-1)
loss = F.cross_entropy(-squared_distances, test_targets)
# Get the accuracy
sq_distances = torch.sum(
(prototypes.unsqueeze(1) - test_embeddings.unsqueeze(2))**2, dim=-1)
_, predictions = torch.min(sq_distances, dim=-1)
accuracy = torch.mean(predictions.eq(test_targets).float())
running_loss += loss.item()
accuracy = accuracy.item()
running_accuracy += accuracy
avg_loss = running_loss / (batch_idx + 1)
avg_accuracy = running_accuracy / (batch_idx + 1)
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy),
avg_loss='{0:.4f}'.format(avg_loss))
if batch_idx == num_batches - 1:
break
avg_loss = running_loss / num_batches
avg_accuracy = running_accuracy / num_batches
return avg_loss, avg_accuracy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment