Skip to content

Instantly share code, notes, and snippets.

@jaircastruita
Created March 5, 2021 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaircastruita/fe16af157112375043a4115d5aaa88f3 to your computer and use it in GitHub Desktop.
Save jaircastruita/fe16af157112375043a4115d5aaa88f3 to your computer and use it in GitHub Desktop.
model = MildNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.TripletMarginLoss(reduction="none", margin=0.1)
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 161
print_every = 20
eval_losses = []
train_losses = []
for iter in range(n_epochs):
model.train()
running_loss = []
for anchor, positive in tqdm(train_dl):
anchor, positive = anchor.to(device), positive.to(device)
anchor_embs = model(anchor)
positive_embs = model(positive)
indices = np.array(np.meshgrid(list(range(len(anchor))), list(range(len(positive))))).T.reshape(-1, 2)
valid_pairs = indices[np.where(indices[:,0] != indices[:,1]),:].squeeze(0)
a_p_idx = valid_pairs[:, 0]
n_idx = valid_pairs[:, 1]
anchor_samples = anchor_embs[a_p_idx]
positive_samples = positive_embs[a_p_idx]
negative_samples = positive_embs[n_idx]
loss = criterion(anchor_samples, positive_samples, negative_samples)
loss = loss[loss > 0].mean()
if not loss.isnan().item():
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss.append(loss.item())
print(f"NaN's percentage in train set: {np.isnan(running_loss).sum() / len(running_loss)}", flush=True)
train_losses.append(np.nanmean(running_loss))
model.eval()
with torch.no_grad():
running_loss = []
for anchor, positive in tqdm(val_dl):
anchor, positive = anchor.to(device), positive.to(device)
anchor_embs = model(anchor)
positive_embs = model(positive)
indices = np.array(np.meshgrid(list(range(len(anchor))), list(range(len(positive))))).T.reshape(-1, 2)
valid_pairs = indices[np.where(indices[:,0] != indices[:,1]),:].squeeze(0)
a_p_idx = valid_pairs[:, 0]
n_idx = valid_pairs[:, 1]
anchor_samples = anchor_embs[a_p_idx]
positive_samples = positive_embs[a_p_idx]
negative_samples = positive_embs[n_idx]
loss = criterion(anchor_samples, positive_samples, negative_samples)
loss = loss[loss > 0].mean()
running_loss.append(loss.item())
print(f"NaN's percentage in validation set: {np.isnan(running_loss).sum() / len(running_loss)}", flush=True)
eval_losses.append(np.nanmean(running_loss))
if iter % print_every == 0:
print("Calculating embeddings...", flush=True)
embs = load_embeddings(test_dl, model, device)
print("Random image retrieval example:", flush=True)
query = random.randint(0, len(embs)-1)
image_query = test_ds[query][0]
knn = retrieve_ktop(model, image_query, embs, k=5)
plot_k_retrievals(knn, image_query, test_ds, query, k=5)
print("Calculating retrieval success percentage (hit rate)...", flush=True)
p_accurate = retrieve_hitrate(model, embs, test_ds)
print(f"accurate retrievals: {p_accurate * 100}%", flush=True)
torch.save(model.state_dict(), f"mildnet_224_bal_epoch_{iter}.pt")
print(f"epoch iteration: {iter}/{n_epochs}, train loss: {train_losses[-1]}, evaluation loss: {eval_losses[-1]}", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment