Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created December 6, 2020 20:03
Show Gist options
  • Save tchaton/01bf334f65d98f38d45dad45d38db11a to your computer and use it in GitHub Desktop.
Save tchaton/01bf334f65d98f38d45dad45d38db11a to your computer and use it in GitHub Desktop.
def instantiate_datamodule(args):
datamodule = CoraDataset(
num_workers=args.num_workers,
batch_size=args.batch_size,
drop_last=args.drop_last,
pin_memory=args.pin_memory,
num_layers=args.num_layers,
)
return datamodule
def instantiate_model(args, datamodule):
model = DNAConvNet(
num_layers=args.num_layers,
hidden_channels=args.hidden_channels,
heads=args.heads,
groups=args.groups,
dropout=args.dropout,
# provide dataset specific arguments
**datamodule.hyper_parameters,
)
# Attached datamodule function to model
model.gather_data = datamodule.gather_data
return model
def run(args):
datamodule: LightningDataModule = instantiate_datamodule(args)
model: LightningModule = instantiate_model(args, datamodule)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, datamodule)
trainer.test()
if __name__ == "__main__":
parser = ArgumentParser(description="Pytorch Geometric Example")
parser = Trainer.add_argparse_args(parser)
parser = CoraDataset.add_argparse_args(parser)
parser = DNAConvNet.add_argparse_args(parser)
cmd_line = '--max_epochs 1'.split(' ')
run(parser.parse_args(cmd_line))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment