-
-
Save arthur-thuy/9b79d13e0a85185c328eb12b84e1cf87 to your computer and use it in GitHub Desktop.
File for checking Baal and torchmetrics metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""File for checking Baal metrics.""" | |
import os | |
import argparse | |
import random | |
import warnings | |
import time | |
from dataclasses import dataclass | |
from pprint import pprint | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
from torch import optim | |
from torchvision import datasets | |
from torchvision.transforms import transforms | |
from torchmetrics import Accuracy as torchAccuracy # NOTE: torchmetrics | |
from baal.modelwrapper import ModelWrapper | |
from baal.bayesian.dropout import MCDropoutModule | |
from baal.utils.metrics import Accuracy | |
from baal.active import ActiveLearningDataset | |
parser = argparse.ArgumentParser(description="Baal check metrics") | |
parser.add_argument( | |
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | |
) | |
parser.add_argument( | |
"--no-test", action="store_true", default=False, help="evaluate on test set" | |
) | |
parser.add_argument( | |
"--torchmetrics", action="store_true", default=False, help="use torchmetrics accuracy" | |
) | |
def main(): | |
args = parser.parse_args() | |
# config | |
config = ExperimentConfig() | |
# seed | |
set_seed(config.seed) | |
# device | |
use_cuda = not args.no_cuda and torch.cuda.is_available() | |
if use_cuda: | |
device = torch.device("cuda") | |
print(f"Use GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
else: | |
device = torch.device("cpu") | |
print("using CPU, this will be slow") | |
# data loading | |
train_transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |
) | |
eval_transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |
) | |
train_ds = datasets.MNIST( | |
"./tmp", train=True, download=True, transform=train_transform | |
) | |
test_ds = datasets.MNIST( | |
"../tmp", train=False, download=True, transform=eval_transform | |
) | |
# active learning dataset | |
active_set = ActiveLearningDataset( | |
train_ds, pool_specifics={"transform": eval_transform} | |
) | |
print(f"labelling {config.init_active_size} observations") | |
active_set.label_randomly(config.init_active_size) | |
# model | |
model = LeNet5Dropout() | |
model = model.to(device) | |
optimizer = optim.Adam(model.parameters(), lr=config.lr) | |
# If not done already, you can wrap your model with our MCDropoutModule | |
model = MCDropoutModule(model) | |
wrapper = ModelWrapper(model, criterion=nn.CrossEntropyLoss()) | |
wrapper.add_metric(name="accuracy", initializer=lambda: Accuracy()) | |
if args.torchmetrics: | |
wrapper.add_metric( | |
name="torch_accuracy", initializer=lambda: torchAccuracy().to(device) | |
) | |
# train the model on the currently labelled dataset | |
start_training_time = time.time() | |
_ = wrapper.train_on_dataset( | |
active_set, | |
optimizer=optimizer, | |
batch_size=config.train_batch_size, | |
epoch=config.train_epochs, | |
use_cuda=use_cuda, | |
) | |
train_time = time.time() - start_training_time | |
print( | |
"Elapsed training time: " | |
f"{int(train_time//3600)}:{int(train_time%3600//60)}:{int(train_time%60)}" | |
) | |
if args.no_test is False: | |
# predict on test set | |
_ = wrapper.test_on_dataset( | |
test_ds, | |
batch_size=config.eval_batch_size, | |
use_cuda=use_cuda, | |
average_predictions=config.mc_samples, | |
) | |
pprint(wrapper.get_metrics()) | |
@dataclass | |
class ExperimentConfig: | |
# neural network - train | |
lr: float = 0.001 | |
train_batch_size: int = 128 | |
train_epochs: int = 1 | |
num_classes: int = 10 | |
# neural network - eval | |
mc_samples: int = 100 | |
eval_batch_size: int = 512 | |
# al | |
seed: int = 42 | |
init_active_size: int = 100 | |
class LeNet5Dropout(nn.Module): | |
def __init__( | |
self, | |
num_filters: int = 32, | |
kernel_size: int = 4, | |
dense_layer: int = 128, | |
img_rows: int = 28, | |
img_cols: int = 28, | |
maxpool: int = 2, | |
): | |
""" | |
Basic Architecture of CNN | |
Attributes: | |
num_filters: Number of filters, out channel for 1st and 2nd conv layers, | |
kernel_size: Kernel size of convolution, | |
dense_layer: Dense layer units, | |
img_rows: Height of input image, | |
img_cols: Width of input image, | |
maxpool: Max pooling size | |
""" | |
super(LeNet5Dropout, self).__init__() | |
self.conv1 = nn.Conv2d(1, num_filters, kernel_size, 1) | |
self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size, 1) | |
self.max_pool2d = nn.MaxPool2d(maxpool) | |
self.dropout1 = nn.Dropout(0.25) | |
self.dropout2 = nn.Dropout(0.5) | |
self.fc1 = nn.Linear( | |
num_filters | |
* ((img_rows - 2 * kernel_size + 2) // 2) | |
* ((img_cols - 2 * kernel_size + 2) // 2), | |
dense_layer, | |
) | |
self.fc2 = nn.Linear(dense_layer, 10) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = self.max_pool2d(x) | |
x = self.dropout1(x) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout2(x) | |
out = self.fc2(x) | |
return out | |
def set_seed(seed: int) -> None: | |
"""Set seed for reproducibility. | |
Attributes: | |
seed: seed value. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
cudnn.deterministic = True | |
cudnn.benchmark = False | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
warnings.warn( | |
"You have chosen to seed training. " | |
"This will turn on the CUDNN deterministic setting, " | |
"which can slow down your training considerably!", | |
stacklevel=2, | |
) | |
if __name__ == "__main__": | |
start_total_time = time.time() | |
main() | |
total_time = time.time() - start_total_time | |
print( | |
"Elapsed total time: " | |
f"{int(total_time//3600)}:{int(total_time%3600//60)}:{int(total_time%60)}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment