Created
February 24, 2019 03:39
-
-
Save EtienneT/c07994bc96e9fad7a30a89cb9f20bc6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
from fastai.tabular import * | |
from fastai.vision import * | |
from fastai.metrics import * | |
def _maybe_add_crop_pad(tfms): | |
assert is_listy(tfms) and len(tfms) == 2, "Please pass a list of two lists of transforms (train and valid)." | |
tfm_names = [[tfm.__name__ for tfm in o] for o in tfms] | |
return [([crop_pad()] + o if 'crop_pad' not in n else o) for o,n in zip(tfms, tfm_names)] | |
def _prep_tfm_kwargs(tfms, size, resize_method:ResizeMethod=None): | |
tfms = ifnone(tfms, [[],[]]) | |
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP | |
resize_method = ifnone(resize_method, default_rsz) | |
if resize_method <= 2: tfms = _maybe_add_crop_pad(tfms) | |
return tfms, resize_method | |
class ImageTabularDataBunch(ImageDataBunch): | |
@classmethod | |
def from_df(cls, path:PathOrStr, df:pd.DataFrame, valid_idx, size, ds_tfms:Optional[TfmList]=None, dl_tfms:Optional[Collection[Callable]]=None, procs=None, | |
cat_names:OptStrList=None, cont_names:OptStrList=None, test_df=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus, label_delim:str=None, | |
fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **kwargs:Any)->'ImageTabularDataBunch': | |
cat_names = ifnone(cat_names, []).copy() | |
cont_names = ifnone(cont_names, list(set(df)-set(cat_names)-{dep_var})) | |
procs = listify(procs) | |
src = (ImageTabularList.from_df(df, path=path, cols=fn_col, cat_names=cat_names, cont_names=cont_names, procs=procs) | |
.split_by_idx(valid_idx) | |
.label_from_df(label_delim=label_delim, cols=label_col)) | |
# if test_df is not None: src.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names, | |
# processor = src.train.x.processor)) | |
ds_tfms, resize_method = _prep_tfm_kwargs(ds_tfms, size) | |
src = src.transform(tfms=ds_tfms, size=size, resize_method=resize_method) | |
return src.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn, | |
device=device, no_check=no_check) | |
class ImageTabular(ItemBase): | |
def __init__(self, img, tabular): | |
self.img = img | |
self.tabular = tabular | |
# self.data = [img.data,tabular.data] | |
self.data = img.data | |
def apply_tfms(self, tfms, **kwargs): | |
self.img = self.img.apply_tfms(tfms, **kwargs) | |
# self.data = [self.img.data,self.tabular.data] | |
self.data = self.img.data | |
return self | |
class ImageTabularList(ItemList): | |
_item_cls=ImageTabular | |
_bunch=ImageTabularDataBunch | |
def __init__(self, items, path, cols, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, *args, **kwargs): | |
super().__init__(items, path, *args, **kwargs) | |
self.tabularList = TabularList.from_df(df=self.xtra, cat_names=cat_names, cont_names=cont_names, procs=procs) | |
self.imageList = ImageItemList.from_df(df=self.xtra, path=path, cols=cols) | |
self.cols = cols | |
self.copy_new += ['cols'] | |
def get(self, i): | |
img = self.imageList.get(i) | |
line = self.tabularList.get(i) | |
return self._item_cls(img, line) | |
@classmethod | |
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList': | |
"Get the list of inputs in the `col` of `path/csv_name`." | |
return cls(items=range(len(df)), path=path, cols=cols, cat_names=cat_names, cont_names=cont_names, procs=procs, xtra=df.copy(), **kwargs) | |
#%% | |
dep_var = 'Passed' | |
cat_names = ['StoreID', 'AnsweredBy'] | |
cat_names += [col for col in pictures if col.startswith('Performed')] | |
cat_names.remove('PerformedElapsed') | |
cont_names = ['MaximumScore', 'PerformedElapsed', 'MinimumPassingScore'] | |
procs = [FillMissing, Categorify, Normalize] | |
np.random.seed(42) | |
valid_idx = np.random.choice(len(pictures), round(len(pictures)*.2)) | |
valid_idx | |
#%% | |
bs = 64 | |
np.random.seed(42) | |
# data = ImageDataBunch.from_df(path=path, df=pictures, fn_col='PicturePath', label_col='Passed', ds_tfms=get_transforms(), size=224, bs=bs).normalize(imagenet_stats) | |
data = ImageTabularDataBunch.from_df(path=path, df=pictures, valid_idx=valid_idx, fn_col='PicturePath', label_col='Passed', ds_tfms=get_transforms(), size=224, bs=bs, cat_names=cat_names, cont_names=cont_names, procs=procs).normalize(imagenet_stats) | |
learn = create_cnn(data, models.resnet34, metrics=accuracy) | |
#%% | |
learn.lr_find() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment