Skip to content

Instantly share code, notes, and snippets.

@dsaint31x
Created May 16, 2024 16:54
Show Gist options
  • Save dsaint31x/f18e4f3ee942f00da19ba32b7a8beb75 to your computer and use it in GitHub Desktop.
Save dsaint31x/f18e4f3ee942f00da19ba32b7a8beb75 to your computer and use it in GitHub Desktop.
dl_ex_simple_checkpnt.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMq7ocsn62ZyQTnBQooLG0A",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dsaint31x/f18e4f3ee942f00da19ba32b7a8beb75/dl_ex_simple_checkpnt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A1KUdMU6TakH",
"outputId": "aada3f24-a416-4316-a5a6-d01759665420"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/20, Loss: 0.9961\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 2/20, Loss: 0.9796\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 3/20, Loss: 0.9651\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 4/20, Loss: 0.9526\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 5/20, Loss: 0.9400\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 6/20, Loss: 0.9310\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 7/20, Loss: 0.9175\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 8/20, Loss: 0.9088\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 9/20, Loss: 0.9000\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 10/20, Loss: 0.8908\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 11/20, Loss: 0.8837\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 12/20, Loss: 0.8762\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 13/20, Loss: 0.8696\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 14/20, Loss: 0.8637\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 15/20, Loss: 0.8582\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 16/20, Loss: 0.8517\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 17/20, Loss: 0.8462\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 18/20, Loss: 0.8412\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 19/20, Loss: 0.8368\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Epoch 20/20, Loss: 0.8309\n",
"Model saved to ./checkpoints/best_model.pth\n",
"Checkpoint loaded. Training will resume from epoch 20 with loss 0.8309\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import os\n",
"\n",
"# 예제 모델 정의\n",
"class SimpleModel(nn.Module):\n",
" def __init__(self):\n",
" super(SimpleModel, self).__init__()\n",
" self.fc1 = nn.Linear(10, 20)\n",
" self.relu = nn.ReLU()\n",
" self.fc2 = nn.Linear(20, 1)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.relu(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"# 데이터셋 생성\n",
"x_train = torch.randn(100, 10)\n",
"y_train = torch.randn(100, 1)\n",
"dataset = TensorDataset(x_train, y_train)\n",
"dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n",
"\n",
"# 모델, 손실 함수 및 옵티마이저 초기화\n",
"model = SimpleModel()\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"# 체크포인트 디렉토리 생성\n",
"checkpoint_dir = './checkpoints'\n",
"os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
"# 최상의 손실 값 초기화\n",
"best_loss = float('inf')\n",
"\n",
"# 훈련 루프\n",
"num_epochs = 5\n",
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for inputs, targets in dataloader:\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, targets)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item() * inputs.size(0)\n",
"\n",
" epoch_loss = running_loss / len(dataloader.dataset)\n",
" print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n",
"\n",
" # 손실이 낮아질 때마다 모델 저장\n",
" if epoch_loss < best_loss:\n",
" best_loss = epoch_loss\n",
" checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')\n",
" torch.save({\n",
" 'epoch': epoch + 1,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'loss': best_loss,\n",
" }, checkpoint_path)\n",
" print(f'Model saved to {checkpoint_path}')\n",
"\n",
"# 저장된 모델 로드\n",
"checkpoint = torch.load(checkpoint_path)\n",
"model.load_state_dict(checkpoint['model_state_dict'])\n",
"optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
"\n",
"start_epoch = checkpoint['epoch']\n",
"best_loss = checkpoint['loss']\n",
"\n",
"print(f'Checkpoint loaded. Training will resume from epoch {start_epoch} with loss {best_loss:.4f}')\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment