Created
May 16, 2024 16:54
-
-
Save dsaint31x/f18e4f3ee942f00da19ba32b7a8beb75 to your computer and use it in GitHub Desktop.
dl_ex_simple_checkpnt.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMq7ocsn62ZyQTnBQooLG0A", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/dsaint31x/f18e4f3ee942f00da19ba32b7a8beb75/dl_ex_simple_checkpnt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "A1KUdMU6TakH", | |
"outputId": "aada3f24-a416-4316-a5a6-d01759665420" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1/20, Loss: 0.9961\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 2/20, Loss: 0.9796\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 3/20, Loss: 0.9651\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 4/20, Loss: 0.9526\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 5/20, Loss: 0.9400\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 6/20, Loss: 0.9310\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 7/20, Loss: 0.9175\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 8/20, Loss: 0.9088\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 9/20, Loss: 0.9000\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 10/20, Loss: 0.8908\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 11/20, Loss: 0.8837\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 12/20, Loss: 0.8762\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 13/20, Loss: 0.8696\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 14/20, Loss: 0.8637\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 15/20, Loss: 0.8582\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 16/20, Loss: 0.8517\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 17/20, Loss: 0.8462\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 18/20, Loss: 0.8412\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 19/20, Loss: 0.8368\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Epoch 20/20, Loss: 0.8309\n", | |
"Model saved to ./checkpoints/best_model.pth\n", | |
"Checkpoint loaded. Training will resume from epoch 20 with loss 0.8309\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"from torch.utils.data import DataLoader, TensorDataset\n", | |
"import os\n", | |
"\n", | |
"# 예제 모델 정의\n", | |
"class SimpleModel(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(SimpleModel, self).__init__()\n", | |
" self.fc1 = nn.Linear(10, 20)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.fc2 = nn.Linear(20, 1)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.fc1(x)\n", | |
" x = self.relu(x)\n", | |
" x = self.fc2(x)\n", | |
" return x\n", | |
"\n", | |
"# 데이터셋 생성\n", | |
"x_train = torch.randn(100, 10)\n", | |
"y_train = torch.randn(100, 1)\n", | |
"dataset = TensorDataset(x_train, y_train)\n", | |
"dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n", | |
"\n", | |
"# 모델, 손실 함수 및 옵티마이저 초기화\n", | |
"model = SimpleModel()\n", | |
"criterion = nn.MSELoss()\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n", | |
"\n", | |
"# 체크포인트 디렉토리 생성\n", | |
"checkpoint_dir = './checkpoints'\n", | |
"os.makedirs(checkpoint_dir, exist_ok=True)\n", | |
"\n", | |
"# 최상의 손실 값 초기화\n", | |
"best_loss = float('inf')\n", | |
"\n", | |
"# 훈련 루프\n", | |
"num_epochs = 5\n", | |
"\n", | |
"for epoch in range(num_epochs):\n", | |
" model.train()\n", | |
" running_loss = 0.0\n", | |
" for inputs, targets in dataloader:\n", | |
" optimizer.zero_grad()\n", | |
" outputs = model(inputs)\n", | |
" loss = criterion(outputs, targets)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" running_loss += loss.item() * inputs.size(0)\n", | |
"\n", | |
" epoch_loss = running_loss / len(dataloader.dataset)\n", | |
" print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n", | |
"\n", | |
" # 손실이 낮아질 때마다 모델 저장\n", | |
" if epoch_loss < best_loss:\n", | |
" best_loss = epoch_loss\n", | |
" checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')\n", | |
" torch.save({\n", | |
" 'epoch': epoch + 1,\n", | |
" 'model_state_dict': model.state_dict(),\n", | |
" 'optimizer_state_dict': optimizer.state_dict(),\n", | |
" 'loss': best_loss,\n", | |
" }, checkpoint_path)\n", | |
" print(f'Model saved to {checkpoint_path}')\n", | |
"\n", | |
"# 저장된 모델 로드\n", | |
"checkpoint = torch.load(checkpoint_path)\n", | |
"model.load_state_dict(checkpoint['model_state_dict'])\n", | |
"optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", | |
"\n", | |
"start_epoch = checkpoint['epoch']\n", | |
"best_loss = checkpoint['loss']\n", | |
"\n", | |
"print(f'Checkpoint loaded. Training will resume from epoch {start_epoch} with loss {best_loss:.4f}')\n" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment