Created
November 25, 2018 14:45
-
-
Save devforfu/0503f044ac31abf05cba1cdcc5a29fbe to your computer and use it in GitHub Desktop.
fastai/pytorch parallel workers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Training ResNet18 model on 50000 samples per category. | |
""" | |
import sys | |
from fastai import defaults | |
from fastai.vision import create_cnn, get_transforms | |
from fastai.metrics import accuracy | |
from fastai.callbacks import EarlyStoppingCallback, SaveModelCallback, CSVLogger | |
from fastai.vision.data import ImageItemList, imagenet_stats | |
from multiprocessing import cpu_count | |
import torch | |
from torchvision.models import resnet18 | |
from metrics import map3 | |
from utils import Timer | |
def main(): | |
defaults.device = torch.device('cuda:1') | |
dataset = 'doodles_50000' | |
print('Loading the dataset...') | |
print(f'Batch size: {bs}') | |
print(f'Image size: {img_sz}') | |
with Timer() as timer: | |
data = (ImageItemList. | |
from_folder(Path(f'~/data/{dataset}').expanduser()). | |
random_split_by_pct(0.01). | |
label_from_folder(). | |
transform(get_transforms(), size=224). | |
databunch(bs=300, num_workers=cpu_count()//2). | |
normalize(imagenet_stats)) | |
print(f'Time to read the data: {timer.verbose()}') | |
print('Creating model...') | |
learn = create_cnn(data, resnet18, path=Path.home()) | |
learn.metrics = [accuracy, map3] | |
learn.callbacks = [EarlyStoppingCallback(learn), SaveModelCallback(learn), CSVLogger(learn)] | |
print('The model is ready! Start training process...') | |
with Timer() as timer: | |
try: | |
learn.fit_one_cycle(5) | |
learn.unfreeze() | |
learn.fit_one_cycle(50) | |
learn.save('resnet18_50000') | |
except RuntimeError as e: | |
print('The model training was interrupted') | |
print('Error dump saved into error.txt') | |
with open('error.txt', 'a') as file: | |
file.write(str(e)) | |
sys.exit(1) | |
print(f'Total training time: {timer.verbose()}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment