Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created November 25, 2018 14:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devforfu/0503f044ac31abf05cba1cdcc5a29fbe to your computer and use it in GitHub Desktop.
Save devforfu/0503f044ac31abf05cba1cdcc5a29fbe to your computer and use it in GitHub Desktop.
fastai/pytorch parallel workers
"""
Training ResNet18 model on 50000 samples per category.
"""
import sys
from fastai import defaults
from fastai.vision import create_cnn, get_transforms
from fastai.metrics import accuracy
from fastai.callbacks import EarlyStoppingCallback, SaveModelCallback, CSVLogger
from fastai.vision.data import ImageItemList, imagenet_stats
from multiprocessing import cpu_count
import torch
from torchvision.models import resnet18
from metrics import map3
from utils import Timer
def main():
defaults.device = torch.device('cuda:1')
dataset = 'doodles_50000'
print('Loading the dataset...')
print(f'Batch size: {bs}')
print(f'Image size: {img_sz}')
with Timer() as timer:
data = (ImageItemList.
from_folder(Path(f'~/data/{dataset}').expanduser()).
random_split_by_pct(0.01).
label_from_folder().
transform(get_transforms(), size=224).
databunch(bs=300, num_workers=cpu_count()//2).
normalize(imagenet_stats))
print(f'Time to read the data: {timer.verbose()}')
print('Creating model...')
learn = create_cnn(data, resnet18, path=Path.home())
learn.metrics = [accuracy, map3]
learn.callbacks = [EarlyStoppingCallback(learn), SaveModelCallback(learn), CSVLogger(learn)]
print('The model is ready! Start training process...')
with Timer() as timer:
try:
learn.fit_one_cycle(5)
learn.unfreeze()
learn.fit_one_cycle(50)
learn.save('resnet18_50000')
except RuntimeError as e:
print('The model training was interrupted')
print('Error dump saved into error.txt')
with open('error.txt', 'a') as file:
file.write(str(e))
sys.exit(1)
print(f'Total training time: {timer.verbose()}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment