Skip to content

Instantly share code, notes, and snippets.

@isaaccorley
Created May 31, 2023 19:22
Show Gist options
  • Save isaaccorley/92d32c1cd818251f70996ea04ba83d1b to your computer and use it in GitHub Desktop.
Save isaaccorley/92d32c1cd818251f70996ea04ba83d1b to your computer and use it in GitHub Desktop.
PyTorch Lightning KNN Classifier Evaluation Callback
# pip install torch lightning scikit-learn numpy tqdm faissknn
import lightning.pytorch as pl
import numpy as np
import torch
from faissknn import FaissKNNClassifier
from lightning.pytorch.utilities import rank_zero_only
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tqdm import tqdm
class KNNEval(pl.callbacks.Callback):
def __init__(self, datamodule, k=5, check_every_n_epochs=1, device: str = "cuda:0"):
self.datamodule = datamodule
self.datamodule.setup()
self.k = k
self.check_every_n_epochs = check_every_n_epochs
self.device = device
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) % self.check_every_n_epochs != 0:
return
# Get image encoder
device = pl_module.device
backbone = pl_module.image_encoder
backbone.eval()
# Get train set embeddings
x_train, y_train = [], []
dataloader = self.datamodule.train_dataloader()
for batch in tqdm(dataloader, total=len(dataloader)):
x, y = batch["image"].to(device), batch["label"]
with torch.no_grad():
with torch.inference_mode():
with torch.cuda.amp.autocast():
emb = backbone(x).detach()
x_train.append(emb.cpu().numpy())
y_train.append(y.detach().cpu().numpy())
x_train = np.concatenate(x_train, axis=0)
y_train = np.concatenate(y_train, axis=0)
# Get val set embeddings
x_val, y_val = [], []
dataloader = self.datamodule.val_dataloader()
for batch in tqdm(dataloader, total=len(dataloader)):
x, y = batch["image"].to(device), batch["label"]
with torch.no_grad():
with torch.inference_mode():
with torch.cuda.amp.autocast():
emb = backbone(x).detach()
x_val.append(emb.cpu().numpy())
y_val.append(y.detach().cpu().numpy())
x_val = np.concatenate(x_val, axis=0)
y_val = np.concatenate(y_val, axis=0)
# Fit knn model
knn = FaissKNNClassifier(n_neighbors=self.k, device=self.device)
knn.fit(X=x_train, y=y_train)
y_pred = knn.predict(x_val)
# Compute metrics
metrics = {
"val_f1_weighted": f1_score(y_val, y_pred, average="weighted"),
"val_f1_macro": f1_score(y_val, y_pred, average="macro"),
"val_f1_micro": f1_score(y_val, y_pred, average="micro"),
"val_precision_micro": precision_score(y_val, y_pred, average="micro"),
"val_precision_macro": precision_score(y_val, y_pred, average="macro"),
"val_precision_weighted": precision_score(
y_val, y_pred, average="weighted"
),
"val_recall_micro": recall_score(y_val, y_pred, average="micro"),
"val_recall_macro": recall_score(y_val, y_pred, average="macro"),
"val_recall_weighted": recall_score(y_val, y_pred, average="weighted"),
"val_accuracy": accuracy_score(y_val, y_pred),
}
# Log metrics
pl_module.log_dict(metrics, rank_zero_only=True, on_epoch=True, sync_dist=True)
backbone.train()
@isaaccorley
Copy link
Author

You will have to change line 26 depending on to link to the backbone you're using. In this case I used timm.create_model("resnet50", pretrained=True, num_classes=0) as my image_encoder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment