Skip to content

Instantly share code, notes, and snippets.

@tikurahul
Last active August 30, 2020 20:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tikurahul/70277664e7e03a822baa1c131a60b0a9 to your computer and use it in GitHub Desktop.
Save tikurahul/70277664e7e03a822baa1c131a60b0a9 to your computer and use it in GitHub Desktop.
import numpy as np
from fastai2.vision.all import *
from fastai2.distributed import *
def train():
path = untar_data(URLs.CAMVID_TINY)
def label_func(fn):
return path/"labels"/f"{fn.stem}_P{fn.suffix}"
codes = np.loadtxt(path/'codes.txt', dtype=str)
fnames = get_image_files(path/"images")
dls = SegmentationDataLoaders.from_label_func(
path, bs=8, fnames = fnames, label_func = label_func, codes = codes
)
print('Creating Learner')
learner = unet_learner(dls, resnet34).to_fp16()
callbacks = [
EarlyStoppingCallback(min_delta=0.001, patience=5)
]
with learner.parallel_ctx(device_ids=[0]):
learner.fine_tune(20, freeze_epochs=2, wd=0.01, base_lr=0.0006, cbs=callbacks)
print('Done')
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment