Skip to content

Instantly share code, notes, and snippets.

@amaarora
Created February 22, 2022 10:38
Show Gist options
  • Save amaarora/e4346adde3225645f96ccf22a3267cc1 to your computer and use it in GitHub Desktop.
Save amaarora/e4346adde3225645f96ccf22a3267cc1 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"id": "3344492e",
"metadata": {},
"outputs": [],
"source": [
"# get dataset\n",
"# !mkdir data && cd data \n",
"# !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz\n",
"# !tar -xvf imagenette2-160.tgz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "047564f0",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torchvision\n",
"import timm\n",
"import torch.nn as nn\n",
"from tqdm.notebook import tqdm\n",
"import albumentations\n",
"from torchvision import transforms\n",
"import numpy as np \n",
"import os\n",
"\n",
"# set logging\n",
"import logging\n",
"logging.getLogger().setLevel(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dc1f748e",
"metadata": {},
"outputs": [],
"source": [
"IMG_SIZE = 160 \n",
"MODEL_NAME = \"resnet34\"\n",
"LR = 1e-4\n",
"EPOCHS = 5"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cf6dc7d2",
"metadata": {},
"outputs": [],
"source": [
"train_aug = transforms.Compose(\n",
" [\n",
" transforms.RandomCrop(IMG_SIZE),\n",
" transforms.RandomHorizontalFlip(p=0.5),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ac158b53",
"metadata": {},
"outputs": [],
"source": [
"val_aug = transforms.Compose(\n",
" [\n",
" transforms.CenterCrop(IMG_SIZE),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d7860098",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Compose(\n",
" RandomCrop(size=(160, 160), padding=None)\n",
" RandomHorizontalFlip(p=0.5)\n",
" ToTensor()\n",
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
" ),\n",
" Compose(\n",
" CenterCrop(size=(160, 160))\n",
" ToTensor()\n",
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
" ))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_aug, val_aug"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7817c33f",
"metadata": {},
"outputs": [],
"source": [
"class CheckpointSaver:\n",
" def __init__(self, dirpath, decreasing=True, top_n=5):\n",
" \"\"\"\n",
" dirpath: Directory path where to store all model weights \n",
" decreasing: If decreasing is `True`, then lower metric is better\n",
" top_n: Total number of models to track based on validation metric value\n",
" \"\"\"\n",
" if not os.path.exists(dirpath): os.makedirs(dirpath)\n",
" self.dirpath = dirpath\n",
" self.top_n = top_n \n",
" self.decreasing = decreasing\n",
" self.top_model_paths = []\n",
" self.best_metric_val = np.Inf if decreasing else -np.Inf\n",
" \n",
" def __call__(self, model, epoch, metric_val):\n",
" model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')\n",
" save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val\n",
" if save: \n",
" logging.info(f\"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}\")\n",
" self.best_metric_val = metric_val\n",
" torch.save(model.state_dict(), model_path)\n",
" self.top_model_paths.append({'path': model_path, 'score': metric_val})\n",
" self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)\n",
" if len(self.top_model_paths)>self.top_n: \n",
" self.cleanup()\n",
" \n",
" def cleanup(self):\n",
" to_remove = self.top_model_paths[self.top_n:]\n",
" logging.info(f\"Removing extra models.. {to_remove}\")\n",
" for o in to_remove:\n",
" os.remove(o['path'])\n",
" self.top_model_paths = self.top_model_paths[:self.top_n]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e1caf342",
"metadata": {},
"outputs": [],
"source": [
"def train_fn(model, train_data_loader, optimizer, epoch, device='cuda'):\n",
" model.train()\n",
" fin_loss = 0.0\n",
" tk = tqdm(train_data_loader, desc=\"Epoch\" + \" [TRAIN] \" + str(epoch + 1))\n",
"\n",
" for t, data in enumerate(tk):\n",
" data[0] = data[0].to(device)\n",
" data[1] = data[1].to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" out = model(data[0])\n",
" loss = nn.CrossEntropyLoss()(out, data[1])\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" fin_loss += loss.item()\n",
" tk.set_postfix(\n",
" {\n",
" \"loss\": \"%.6f\" % float(fin_loss / (t + 1)),\n",
" \"LR\": optimizer.param_groups[0][\"lr\"],\n",
" }\n",
" )\n",
" return fin_loss / len(train_data_loader), optimizer.param_groups[0][\"lr\"]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "565dfa4d",
"metadata": {},
"outputs": [],
"source": [
"def eval_fn(model, eval_data_loader, epoch, device='cuda'):\n",
" model.eval()\n",
" fin_loss = 0.0\n",
" tk = tqdm(eval_data_loader, desc=\"Epoch\" + \" [VALID] \" + str(epoch + 1))\n",
"\n",
" with torch.no_grad():\n",
" for t, data in enumerate(tk):\n",
" data[0] = data[0].to(device)\n",
" data[1] = data[1].to(device)\n",
" out = model(data[0])\n",
" loss = nn.CrossEntropyLoss()(out, data[1])\n",
" fin_loss += loss.item()\n",
" tk.set_postfix({\"loss\": \"%.6f\" % float(fin_loss / (t + 1))})\n",
" return fin_loss / len(eval_data_loader)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3aa6a0ac",
"metadata": {},
"outputs": [],
"source": [
"def train(train_dir, test_dir):\n",
" train_dataset = torchvision.datasets.ImageFolder(\n",
" train_dir, transform=train_aug\n",
" )\n",
" eval_dataset = torchvision.datasets.ImageFolder(\n",
" test_dir, transform=val_aug\n",
" )\n",
" train_dataloader = torch.utils.data.DataLoader(\n",
" train_dataset,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" num_workers=4\n",
" )\n",
" eval_dataloader = torch.utils.data.DataLoader(\n",
" eval_dataset, batch_size=64, num_workers=4\n",
" )\n",
"\n",
" # model\n",
" model = timm.create_model(MODEL_NAME, pretrained=True)\n",
" model = model.cuda()\n",
"\n",
" # optimizer\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
"\n",
" # checkpoint saver\n",
" checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=1)\n",
" for epoch in range(EPOCHS):\n",
" avg_loss_train, lr = train_fn(\n",
" model, train_dataloader, optimizer, epoch, device='cuda'\n",
" )\n",
" avg_loss_eval = eval_fn(model, eval_dataloader, epoch, device='cuda')\n",
" checkpoint_saver(model, epoch, avg_loss_eval)\n",
" print(\n",
" f\"EPOCH = {epoch} | TRAIN_LOSS = {avg_loss_train} | EVAL_LOSS = {avg_loss_eval}\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b13a5dd7",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39c36eaaa13041d98fbb66848b68ec36",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [TRAIN] 1: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f8ea52598bb04371967d1454b288f9b5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [VALID] 1: 0%| | 0/62 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Current metric value better than 0.1954631515958857 better than best inf, saving model at ./model_weights/ResNet_epoch0.pt\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH = 0 | TRAIN_LOSS = 1.3544871056502736 | EVAL_LOSS = 0.1954631515958857\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1667f9b528524a368043fbdabfad2abf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [TRAIN] 2: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1c24b86d8aeb4e15a64f4d2c89efc551",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [VALID] 2: 0%| | 0/62 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Current metric value better than 0.15000865837529062 better than best 0.1954631515958857, saving model at ./model_weights/ResNet_epoch1.pt\n",
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch0.pt', 'score': 0.1954631515958857}]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH = 1 | TRAIN_LOSS = 0.11298705174310787 | EVAL_LOSS = 0.15000865837529062\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e34713e9f7344440a4686310476d985f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [TRAIN] 3: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2d52768e16114ae79f9204e6fd07731e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [VALID] 3: 0%| | 0/62 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Current metric value better than 0.1338667555208949 better than best 0.15000865837529062, saving model at ./model_weights/ResNet_epoch2.pt\n",
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch1.pt', 'score': 0.15000865837529062}]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH = 2 | TRAIN_LOSS = 0.053521369734930026 | EVAL_LOSS = 0.1338667555208949\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "33f3744f46234b1190d63e05eac2f638",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [TRAIN] 4: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6f17c6f08ac94326b17e699b42d359a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [VALID] 4: 0%| | 0/62 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Current metric value better than 0.12575743053551583 better than best 0.1338667555208949, saving model at ./model_weights/ResNet_epoch3.pt\n",
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch2.pt', 'score': 0.1338667555208949}]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH = 3 | TRAIN_LOSS = 0.04176709730165532 | EVAL_LOSS = 0.12575743053551583\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39e1ee19189b41d3b2390cda491f9399",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [TRAIN] 5: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "48f96c8f0e864b2f9f8d65efcda42ce3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch [VALID] 5: 0%| | 0/62 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH = 4 | TRAIN_LOSS = 0.027766215419900174 | EVAL_LOSS = 0.12607890976810707\n"
]
}
],
"source": [
"train(train_dir='./data/imagenette2-160/train/', test_dir='./data/imagenette2-160/val/')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2d2b4bd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"gist": {
"data": {
"description": "reports/How to save all your trained model weights locally after every epoch.ipynb",
"public": true
},
"id": ""
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment