Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created March 7, 2022 15:13
Show Gist options
  • Save koshian2/5dc8d01efe4f6db6eab948a067005ac2 to your computer and use it in GitHub Desktop.
Save koshian2/5dc8d01efe4f6db6eab948a067005ac2 to your computer and use it in GitHub Desktop.
ResNet50 CIFAR
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchmetrics
import pytorch_lightning as pl
import time
class ResNet50(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet50()
self.model.fc = nn.Linear(2048, 10)
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
def forward(self, inputs):
x = self.model(inputs)
return x
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y_true = train_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
acc = self.train_acc(y_pred_label, y_true)
self.log("train_loss", loss, prog_bar=False, logger=True)
self.log("train_acc", acc, prog_bar=True, logger=True)
return loss
def validation_step(self, val_batch, batch_idx):
x, y_true = val_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
acc = self.train_acc(y_pred_label, y_true)
self.log("val_loss", loss, prog_bar=False, logger=True)
self.log("val_acc", acc, prog_bar=True, logger=True)
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
self.train_dataset = torchvision.datasets.CIFAR10(
"./data", train=True, download=True,
transform=torchvision.transforms.ToTensor())
self.val_dataset = torchvision.datasets.CIFAR10(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=256, num_workers=4, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=256, num_workers=4, shuffle=False)
def main():
model = ResNet50()
cifar = MyDataModule()
trainer = pl.Trainer(gpus=[1], max_epochs=30)
start_time = time.time()
trainer.fit(model, cifar)
print(time.time()-start_time)
if __name__ == "__main__":
main()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchmetrics
import pytorch_lightning as pl
import time
class ResNet50(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet50()
self.model.fc = nn.Linear(2048, 10)
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
def forward(self, inputs):
x = self.model(inputs)
return x
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y_true = train_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
acc = self.train_acc(y_pred_label, y_true)
self.log("train_loss", loss, prog_bar=False, logger=True)
self.log("train_acc", acc, prog_bar=True, logger=True)
return loss
def validation_step(self, val_batch, batch_idx):
x, y_true = val_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
acc = self.train_acc(y_pred_label, y_true)
self.log("val_loss", loss, prog_bar=False, logger=True)
self.log("val_acc", acc, prog_bar=True, logger=True)
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
self.train_dataset = torchvision.datasets.CIFAR10(
"./data", train=True, download=True,
transform=torchvision.transforms.ToTensor())
self.val_dataset = torchvision.datasets.CIFAR10(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=256, num_workers=4, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=256, num_workers=4, shuffle=False)
def main():
model = ResNet50()
cifar = MyDataModule()
trainer = pl.Trainer(gpus=[1], max_epochs=30, precision=16)
start_time = time.time()
trainer.fit(model, cifar)
print(time.time()-start_time)
if __name__ == "__main__":
main()
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchmetrics
import time
class ResNet50(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet50()
self.model.fc = nn.Linear(2048, 10)
def forward(self, inputs):
x = self.model(inputs)
return x
def main():
train_dataset = torchvision.datasets.CIFAR10(
"./data", train=True, download=True,
transform=torchvision.transforms.ToTensor())
val_dataset = torchvision.datasets.CIFAR10(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=256, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, num_workers=4, shuffle=False)
device = "cuda:1" # 2枚目のGPU指定
model = ResNet50()
optim = torch.optim.Adam(model.parameters(), 1e-3)
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
accuracy = torchmetrics.Accuracy().to(device)
start_time = time.time()
for epoch in range(30):
model.train()
accuracy.reset()
for x, y_true in train_loader:
optim.zero_grad()
x, y_true = x.to(device), y_true.to(device)
y_pred = model(x)
y_pred_label = torch.argmax(y_pred, dim=-1)
loss = criterion(y_pred, y_true)
accuracy(y_pred_label, y_true)
loss.backward()
optim.step()
train_accuracy = accuracy.compute()
accuracy.reset()
model.eval()
for x, y_true in val_loader:
x, y_true = x.to(device), y_true.to(device)
y_pred = model(x)
y_pred_label = torch.argmax(y_pred, dim=-1)
accuracy(y_pred_label, y_true)
val_accuracy = accuracy.compute()
print(f"Epoch: {epoch} / train_acc : {train_accuracy:.2%}, val_acc : {val_accuracy:.2%}")
print(time.time()-start_time)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment