Skip to content

Instantly share code, notes, and snippets.

@EtienneT
Created February 24, 2019 03:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EtienneT/c07994bc96e9fad7a30a89cb9f20bc6b to your computer and use it in GitHub Desktop.
Save EtienneT/c07994bc96e9fad7a30a89cb9f20bc6b to your computer and use it in GitHub Desktop.
#%%
from fastai.tabular import *
from fastai.vision import *
from fastai.metrics import *
def _maybe_add_crop_pad(tfms):
assert is_listy(tfms) and len(tfms) == 2, "Please pass a list of two lists of transforms (train and valid)."
tfm_names = [[tfm.__name__ for tfm in o] for o in tfms]
return [([crop_pad()] + o if 'crop_pad' not in n else o) for o,n in zip(tfms, tfm_names)]
def _prep_tfm_kwargs(tfms, size, resize_method:ResizeMethod=None):
tfms = ifnone(tfms, [[],[]])
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP
resize_method = ifnone(resize_method, default_rsz)
if resize_method <= 2: tfms = _maybe_add_crop_pad(tfms)
return tfms, resize_method
class ImageTabularDataBunch(ImageDataBunch):
@classmethod
def from_df(cls, path:PathOrStr, df:pd.DataFrame, valid_idx, size, ds_tfms:Optional[TfmList]=None, dl_tfms:Optional[Collection[Callable]]=None, procs=None,
cat_names:OptStrList=None, cont_names:OptStrList=None, test_df=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus, label_delim:str=None,
fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **kwargs:Any)->'ImageTabularDataBunch':
cat_names = ifnone(cat_names, []).copy()
cont_names = ifnone(cont_names, list(set(df)-set(cat_names)-{dep_var}))
procs = listify(procs)
src = (ImageTabularList.from_df(df, path=path, cols=fn_col, cat_names=cat_names, cont_names=cont_names, procs=procs)
.split_by_idx(valid_idx)
.label_from_df(label_delim=label_delim, cols=label_col))
# if test_df is not None: src.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names,
# processor = src.train.x.processor))
ds_tfms, resize_method = _prep_tfm_kwargs(ds_tfms, size)
src = src.transform(tfms=ds_tfms, size=size, resize_method=resize_method)
return src.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn,
device=device, no_check=no_check)
class ImageTabular(ItemBase):
def __init__(self, img, tabular):
self.img = img
self.tabular = tabular
# self.data = [img.data,tabular.data]
self.data = img.data
def apply_tfms(self, tfms, **kwargs):
self.img = self.img.apply_tfms(tfms, **kwargs)
# self.data = [self.img.data,self.tabular.data]
self.data = self.img.data
return self
class ImageTabularList(ItemList):
_item_cls=ImageTabular
_bunch=ImageTabularDataBunch
def __init__(self, items, path, cols, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, *args, **kwargs):
super().__init__(items, path, *args, **kwargs)
self.tabularList = TabularList.from_df(df=self.xtra, cat_names=cat_names, cont_names=cont_names, procs=procs)
self.imageList = ImageItemList.from_df(df=self.xtra, path=path, cols=cols)
self.cols = cols
self.copy_new += ['cols']
def get(self, i):
img = self.imageList.get(i)
line = self.tabularList.get(i)
return self._item_cls(img, line)
@classmethod
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList':
"Get the list of inputs in the `col` of `path/csv_name`."
return cls(items=range(len(df)), path=path, cols=cols, cat_names=cat_names, cont_names=cont_names, procs=procs, xtra=df.copy(), **kwargs)
#%%
dep_var = 'Passed'
cat_names = ['StoreID', 'AnsweredBy']
cat_names += [col for col in pictures if col.startswith('Performed')]
cat_names.remove('PerformedElapsed')
cont_names = ['MaximumScore', 'PerformedElapsed', 'MinimumPassingScore']
procs = [FillMissing, Categorify, Normalize]
np.random.seed(42)
valid_idx = np.random.choice(len(pictures), round(len(pictures)*.2))
valid_idx
#%%
bs = 64
np.random.seed(42)
# data = ImageDataBunch.from_df(path=path, df=pictures, fn_col='PicturePath', label_col='Passed', ds_tfms=get_transforms(), size=224, bs=bs).normalize(imagenet_stats)
data = ImageTabularDataBunch.from_df(path=path, df=pictures, valid_idx=valid_idx, fn_col='PicturePath', label_col='Passed', ds_tfms=get_transforms(), size=224, bs=bs, cat_names=cat_names, cont_names=cont_names, procs=procs).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet34, metrics=accuracy)
#%%
learn.lr_find()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment