Created
March 7, 2022 15:13
-
-
Save koshian2/5dc8d01efe4f6db6eab948a067005ac2 to your computer and use it in GitHub Desktop.
ResNet50 CIFAR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import torchvision | |
import torchmetrics | |
import pytorch_lightning as pl | |
import time | |
class ResNet50(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = torchvision.models.resnet50() | |
self.model.fc = nn.Linear(2048, 10) | |
self.train_acc = torchmetrics.Accuracy() | |
self.val_acc = torchmetrics.Accuracy() | |
def forward(self, inputs): | |
x = self.model(inputs) | |
return x | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
def training_step(self, train_batch, batch_idx): | |
x, y_true = train_batch | |
y_pred = self.forward(x) | |
loss = F.cross_entropy(y_pred, y_true) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
acc = self.train_acc(y_pred_label, y_true) | |
self.log("train_loss", loss, prog_bar=False, logger=True) | |
self.log("train_acc", acc, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, y_true = val_batch | |
y_pred = self.forward(x) | |
loss = F.cross_entropy(y_pred, y_true) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
acc = self.train_acc(y_pred_label, y_true) | |
self.log("val_loss", loss, prog_bar=False, logger=True) | |
self.log("val_acc", acc, prog_bar=True, logger=True) | |
class MyDataModule(pl.LightningDataModule): | |
def __init__(self): | |
super().__init__() | |
def prepare_data(self): | |
self.train_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=True, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
self.val_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=False, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=256, num_workers=4, shuffle=True) | |
def val_dataloader(self): | |
return DataLoader(self.val_dataset, batch_size=256, num_workers=4, shuffle=False) | |
def main(): | |
model = ResNet50() | |
cifar = MyDataModule() | |
trainer = pl.Trainer(gpus=[1], max_epochs=30) | |
start_time = time.time() | |
trainer.fit(model, cifar) | |
print(time.time()-start_time) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import torchvision | |
import torchmetrics | |
import pytorch_lightning as pl | |
import time | |
class ResNet50(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = torchvision.models.resnet50() | |
self.model.fc = nn.Linear(2048, 10) | |
self.train_acc = torchmetrics.Accuracy() | |
self.val_acc = torchmetrics.Accuracy() | |
def forward(self, inputs): | |
x = self.model(inputs) | |
return x | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
def training_step(self, train_batch, batch_idx): | |
x, y_true = train_batch | |
y_pred = self.forward(x) | |
loss = F.cross_entropy(y_pred, y_true) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
acc = self.train_acc(y_pred_label, y_true) | |
self.log("train_loss", loss, prog_bar=False, logger=True) | |
self.log("train_acc", acc, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, y_true = val_batch | |
y_pred = self.forward(x) | |
loss = F.cross_entropy(y_pred, y_true) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
acc = self.train_acc(y_pred_label, y_true) | |
self.log("val_loss", loss, prog_bar=False, logger=True) | |
self.log("val_acc", acc, prog_bar=True, logger=True) | |
class MyDataModule(pl.LightningDataModule): | |
def __init__(self): | |
super().__init__() | |
def prepare_data(self): | |
self.train_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=True, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
self.val_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=False, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=256, num_workers=4, shuffle=True) | |
def val_dataloader(self): | |
return DataLoader(self.val_dataset, batch_size=256, num_workers=4, shuffle=False) | |
def main(): | |
model = ResNet50() | |
cifar = MyDataModule() | |
trainer = pl.Trainer(gpus=[1], max_epochs=30, precision=16) | |
start_time = time.time() | |
trainer.fit(model, cifar) | |
print(time.time()-start_time) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
import torchvision | |
import torchmetrics | |
import time | |
class ResNet50(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = torchvision.models.resnet50() | |
self.model.fc = nn.Linear(2048, 10) | |
def forward(self, inputs): | |
x = self.model(inputs) | |
return x | |
def main(): | |
train_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=True, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
val_dataset = torchvision.datasets.CIFAR10( | |
"./data", train=False, download=True, | |
transform=torchvision.transforms.ToTensor()) | |
train_loader = DataLoader(train_dataset, batch_size=256, num_workers=4, shuffle=True) | |
val_loader = DataLoader(val_dataset, batch_size=256, num_workers=4, shuffle=False) | |
device = "cuda:1" # 2枚目のGPU指定 | |
model = ResNet50() | |
optim = torch.optim.Adam(model.parameters(), 1e-3) | |
model = model.to(device) | |
criterion = torch.nn.CrossEntropyLoss() | |
accuracy = torchmetrics.Accuracy().to(device) | |
start_time = time.time() | |
for epoch in range(30): | |
model.train() | |
accuracy.reset() | |
for x, y_true in train_loader: | |
optim.zero_grad() | |
x, y_true = x.to(device), y_true.to(device) | |
y_pred = model(x) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
loss = criterion(y_pred, y_true) | |
accuracy(y_pred_label, y_true) | |
loss.backward() | |
optim.step() | |
train_accuracy = accuracy.compute() | |
accuracy.reset() | |
model.eval() | |
for x, y_true in val_loader: | |
x, y_true = x.to(device), y_true.to(device) | |
y_pred = model(x) | |
y_pred_label = torch.argmax(y_pred, dim=-1) | |
accuracy(y_pred_label, y_true) | |
val_accuracy = accuracy.compute() | |
print(f"Epoch: {epoch} / train_acc : {train_accuracy:.2%}, val_acc : {val_accuracy:.2%}") | |
print(time.time()-start_time) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment