Skip to content

Instantly share code, notes, and snippets.

@marii-moe
Created January 12, 2021 01:48
Show Gist options
  • Save marii-moe/7374ee9d59400f85b17dc54b7a2c4af0 to your computer and use it in GitHub Desktop.
Save marii-moe/7374ee9d59400f85b17dc54b7a2c4af0 to your computer and use it in GitHub Desktop.
Dataloaders with val_ first.
class FilteredBase:
def dataloaders(self, bs=64, shuffle=True, n=None, path='.', dl_type=None, dl_kwargs=None,
device=None,drop_last=None, **kwargs):
if device is None: device=default_device()
if dl_kwargs is None: dl_kwargs = [{}] * self.n_subsets
if dl_type is None: dl_type = self._dl_type
if drop_last is None: drop_last = shuffle
val_kwargs={k[4:]:v for k,v in kwargs.items() if k.startswith('val_')}
def_kwargs = {'bs':bs,'shuffle':shuffle,'drop_last':drop_last,'n':n,'device':device}
dl = dl_type(self.subset(0), **merge(kwargs,def_kwargs, dl_kwargs[0]))
def_kwargs = {'bs':bs,'n':None}
dls = [dl] + [dl.new(self.subset(i), **merge(kwargs,def_kwargs,val_kwargs,dl_kwargs[i])) for i in range(1, self.n_subsets)]
return self._dbunch_type(*dls, path=path, device=device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment