Created
April 9, 2021 23:35
-
-
Save richardliaw/e5276c9578a74e81124c9fa23a0aa9ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset | |
import torch | |
import torchvision | |
from torchvision import transforms | |
import numpy as np | |
import os | |
from PIL import Image | |
import ray | |
from ray.util.sgd.torch import TrainingOperator | |
from ray.util.sgd import TorchTrainer | |
from torch.utils.data import DataLoader | |
class Cifar100TrainingOperator(TrainingOperator): | |
def setup(self, config): | |
model = timm.create_model('resnet50', pretrained=False, num_classes=100).to(self.device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) | |
criterion = torch.nn.CrossEntropyLoss() | |
trans = config["trans"] | |
from filelock import FileLock | |
with FileLock("./data.lock"): | |
train_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=True, download=True, transform=trans) | |
val_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=False, download=True, transform=trans) | |
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) | |
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size) | |
ddp_args = {"output_device": torch.distributed.get_rank()} | |
if criterion is not None and scheduler is not None: | |
self.model, self.optimizer, self.criterion, self.scheduler = self.register( | |
models=model, | |
optimizers=optimizer, | |
criterion=criterion, | |
schedulers=scheduler, | |
ddp_args=ddp_args) | |
elif criterion is not None: | |
self.model, self.optimizer, self.criterion = self.register( | |
models=model, | |
optimizers=optimizer, | |
criterion=criterion, | |
ddp_args=ddp_args) | |
elif scheduler is not None: | |
self.model, self.optimizer, self.scheduler = self.register( | |
models=model, | |
optimizers=optimizer, | |
schedulers=scheduler, | |
ddp_args=ddp_args) | |
else: | |
self.model, self.optimizer = self.register( | |
models=model, | |
optimizers=optimizer, | |
ddp_args=ddp_args) | |
self.criterion = self.criterion.cuda() | |
self.register_data( | |
train_loader=train_loader, | |
validation_loader=val_loader) | |
def train_batch(self, *args, **kwargs): | |
return super().train_batch(*args, **kwargs) | |
if __name__ == '__main__': | |
# for local | |
# for cluster | |
ray.init(address='auto') | |
batch_size = 512 | |
epochs = 200 | |
scheduler = None | |
trans = transforms.Compose([ | |
transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
]) | |
import timm | |
trainer = TorchTrainer( | |
training_operator_cls=Cifar100TrainingOperator, | |
num_workers=4, | |
config={"trans": trans}, | |
use_gpu=True, | |
) | |
train_loss = [] | |
val_loss = [] | |
val_acc = [] | |
import time | |
start_time = time.time() | |
for i in range(epochs): | |
stats = trainer.train(profile=True) | |
prof = stats.pop("profile") | |
stats.update(prof) | |
from tabulate import tabulate | |
formatted = tabulate([stats], headers="keys") | |
if i > 0: # Get the last line of the stats. | |
formatted = formatted.split("\n")[-1] | |
print(formatted) | |
# print("train:", train_info) | |
# train_loss.append(train_info['train_loss']) | |
# eval_info = trainer.validate() | |
# print("eval:", eval_info, '\n') | |
# val_loss.append(eval_info['val_loss']) | |
# val_acc.append(eval_info['val_accuracy']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment