Skip to content

Instantly share code, notes, and snippets.

@devforfu
Last active November 13, 2020 19:40
Show Gist options
  • Save devforfu/40e24cf7fb3c16a0d6d542f4c2eb01af to your computer and use it in GitHub Desktop.
Save devforfu/40e24cf7fb3c16a0d6d542f4c2eb01af to your computer and use it in GitHub Desktop.
def main():
params = create_default_parser(__file__).parse_args()
super_seed(params.seed)
loaders, meta = create_data_loaders(dataset_name=params.dataset,
dataset_root=params.dataset_path,
num_workers=params.num_workers,
batch_size=params.batch_size,
sample_transformer=transforms.Compose([
transforms.ToTensor(),
ToThreeChannels(),
transforms.Normalize([0.5]*3, [0.5]*3)
]))
params.n_classes = len(meta['classes'])
trainer = pl.Trainer(**make_trainer_parameters(params))
trainer.fit(model=BasicCNNExperiment(params),
train_dataloader=loaders['train'],
val_dataloaders=loaders['valid'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment