Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created April 9, 2021 23:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richardliaw/e5276c9578a74e81124c9fa23a0aa9ce to your computer and use it in GitHub Desktop.
Save richardliaw/e5276c9578a74e81124c9fa23a0aa9ce to your computer and use it in GitHub Desktop.
from torch.utils.data import Dataset
import torch
import torchvision
from torchvision import transforms
import numpy as np
import os
from PIL import Image
import ray
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd import TorchTrainer
from torch.utils.data import DataLoader
class Cifar100TrainingOperator(TrainingOperator):
def setup(self, config):
model = timm.create_model('resnet50', pretrained=False, num_classes=100).to(self.device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
trans = config["trans"]
from filelock import FileLock
with FileLock("./data.lock"):
train_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=True, download=True, transform=trans)
val_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=False, download=True, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size)
ddp_args = {"output_device": torch.distributed.get_rank()}
if criterion is not None and scheduler is not None:
self.model, self.optimizer, self.criterion, self.scheduler = self.register(
models=model,
optimizers=optimizer,
criterion=criterion,
schedulers=scheduler,
ddp_args=ddp_args)
elif criterion is not None:
self.model, self.optimizer, self.criterion = self.register(
models=model,
optimizers=optimizer,
criterion=criterion,
ddp_args=ddp_args)
elif scheduler is not None:
self.model, self.optimizer, self.scheduler = self.register(
models=model,
optimizers=optimizer,
schedulers=scheduler,
ddp_args=ddp_args)
else:
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
ddp_args=ddp_args)
self.criterion = self.criterion.cuda()
self.register_data(
train_loader=train_loader,
validation_loader=val_loader)
def train_batch(self, *args, **kwargs):
return super().train_batch(*args, **kwargs)
if __name__ == '__main__':
# for local
# for cluster
ray.init(address='auto')
batch_size = 512
epochs = 200
scheduler = None
trans = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
import timm
trainer = TorchTrainer(
training_operator_cls=Cifar100TrainingOperator,
num_workers=4,
config={"trans": trans},
use_gpu=True,
)
train_loss = []
val_loss = []
val_acc = []
import time
start_time = time.time()
for i in range(epochs):
stats = trainer.train(profile=True)
prof = stats.pop("profile")
stats.update(prof)
from tabulate import tabulate
formatted = tabulate([stats], headers="keys")
if i > 0: # Get the last line of the stats.
formatted = formatted.split("\n")[-1]
print(formatted)
# print("train:", train_info)
# train_loss.append(train_info['train_loss'])
# eval_info = trainer.validate()
# print("eval:", eval_info, '\n')
# val_loss.append(eval_info['val_loss'])
# val_acc.append(eval_info['val_accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment