Skip to content

Instantly share code, notes, and snippets.

@talhaanwarch
Last active December 15, 2021 09:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save talhaanwarch/52e2bc8118e2fb411277d81c04970ae7 to your computer and use it in GitHub Desktop.
Save talhaanwarch/52e2bc8118e2fb411277d81c04970ae7 to your computer and use it in GitHub Desktop.
one cycle PL
import torchvision
import torch
import torch.nn as nn
from time import time
from torch.optim.lr_scheduler import OneCycleLR
from pytorch_lightning import seed_everything, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor
from torch.utils.data.dataloader import DataLoader
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
class OurModel(LightningModule):
def __init__(self):
super(OurModel,self).__init__()
#architecute
self.model = Net()
#parameters
self.lr=1e-6
self.batch_size=72
self.numworker=12
self.criterion=nn.CrossEntropyLoss()
self.train_set=torchvision.datasets.MNIST('files/', train=True, download=True,transform=transform)
self.val_set=torchvision.datasets.MNIST('files/', train=False, download=True,transform=transform)
def forward(self,x):
x= self.model(x)
return x
def configure_optimizers(self):
opt=torch.optim.AdamW(params=self.parameters(),lr=self.lr )
scheduler=OneCycleLR(opt,max_lr=1e-2,epochs=10,steps_per_epoch=len(self.train_set)//self.batch_size//8)
lr_scheduler = {'scheduler': scheduler, 'interval': 'step'}
return {'optimizer': opt,'lr_scheduler':scheduler}
def train_dataloader(self):
return DataLoader(self.train_set,batch_size=self.batch_size,shuffle=True)
def training_step(self,batch,batch_idx):
image,label=batch
out=self(image)
loss=self.criterion(out,label)
return {'loss':loss}
def training_epoch_end(self, outputs):
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
self.log('train_loss', loss)
def val_dataloader(self):
ds=DataLoader(self.val_set,batch_size=self.batch_size,shuffle=True)
return ds
def validation_step(self,batch,batch_idx):
image,label=batch
out=self(image)
loss=self.criterion(out,label)
return {'loss':loss}
def validation_epoch_end(self, outputs):
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
self.log('val_loss', loss)
def test_dataloader(self):
return DataLoader(DataReader(df_test,aug), batch_size = self.batch_size,
num_workers=self.numworker,pin_memory=True,shuffle=False)
save_name='onecycle'
model_name='resnest50d'
model=OurModel()
from pytorch_lightning.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key=api_token,
project="mrtictac96/eye",
name=model_name,
tags=[model_name, save_name],
log_model_checkpoints=False,
)
seed_everything(0)
earlystop=EarlyStopping(monitor="val_loss",patience=10, verbose=True)
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='checkpoints',
filename='file',save_last=True)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = Trainer(max_epochs=10,
deterministic=True,
gpus=-1,precision=16,
accumulate_grad_batches=8,
enable_progress_bar = False,
callbacks=[checkpoint_callback,lr_monitor],
logger=neptune_logger,
num_sanity_val_steps=0
)
start=time()
trainer.fit(model)
train_time=time()-start
print('training time',train_time)
neptune_logger.run['train_time'].log(train_time)
neptune_logger.experiment.stop()
@talhaanwarch
Copy link
Author

This is how learning rate look like
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment