Skip to content

Instantly share code, notes, and snippets.

@abhishekkrthakur
Created June 11, 2019 20:37
Show Gist options
  • Save abhishekkrthakur/e0fb34f05148483fb68e15b52b70669b to your computer and use it in GitHub Desktop.
Save abhishekkrthakur/e0fb34f05148483fb68e15b52b70669b to your computer and use it in GitHub Desktop.
test_transform=transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(IMG_MEAN,IMG_STD)
])
test_dataset = CollectionsDatasetTest(csv_file='../input/sample_submission.csv',
root_dir='../input/test/',
image_size=IMAGE_SIZE,
transform=test_transform)
test_dataset_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=TEST_BATCH_SIZE,
shuffle=False,
num_workers=4)
model_ft.load_state_dict(torch.load("model.bin"))
model_ft = model_ft.to(device)
for param in model_ft.parameters():
param.requires_grad = False
model_ft.eval()
test_preds = np.zeros((len(test_dataset), NUM_CLASSES))
tk0 = tqdm(test_dataset_loader)
for i, x_batch in enumerate(tk0):
x_batch = x_batch["image"]
pred = model_ft(x_batch.to(device))
test_preds[i * TEST_BATCH_SIZE:(i + 1) * TEST_BATCH_SIZE, :] = pred.detach().cpu().squeeze().numpy()
test_preds = torch.from_numpy(test_preds).float().to(device).sigmoid()
test_preds = test_preds.detach().cpu().squeeze().numpy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment