Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created December 6, 2020 20:20
Show Gist options
  • Save tchaton/3e13193c6483aab1a3a637c4a525a86c to your computer and use it in GitHub Desktop.
Save tchaton/3e13193c6483aab1a3a637c4a525a86c to your computer and use it in GitHub Desktop.
def get_single_batch(datamodule):
for batch in datamodule.test_dataloader():
return datamodule.gather_data(batch, 0)
def run(args):
datamodule: LightningDataModule = instantiate_datamodule(args)
model: LightningModule = instantiate_model(args, datamodule)
print(model)
model.jittable()
print(model)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, datamodule)
trainer.test()
batch = get_single_batch(datamodule)
model.to_torchscript(file_path="model_trace.pt",
method='script',
example_inputs=batch)
print(torch.jit.load("model_trace.pt"))
"""
Out:
DNAConvNet(
(val_acc): Accuracy()
(test_acc): Accuracy()
(lin1): Linear(in_features=1433, out_features=128, bias=True)
(convs): ModuleList(
(0): DNAConv(128, heads=8, groups=16)
(1): DNAConv(128, heads=8, groups=16)
)
(lin2): Linear(in_features=128, out_features=7, bias=False)
)
DNAConvNet(
(val_acc): Accuracy()
(test_acc): Accuracy()
(lin1): Linear(in_features=1433, out_features=128, bias=True)
(convs): ModuleList(
(0): DNAConvJittable_5d8941(128, heads=8, groups=16)
(1): DNAConvJittable_5d913d(128, heads=8, groups=16)
)
(lin2): Linear(in_features=128, out_features=7, bias=False)
)
RecursiveScriptModule(
original_name=DNAConvNet
(val_acc): RecursiveScriptModule(original_name=Accuracy)
(test_acc): RecursiveScriptModule(original_name=Accuracy)
(lin1): RecursiveScriptModule(original_name=Linear)
(convs): RecursiveScriptModule(
original_name=ModuleList
(0): RecursiveScriptModule(
original_name=DNAConvJittable_b01eeb
(multi_head): RecursiveScriptModule(
original_name=MultiHead
(lin_q): RecursiveScriptModule(original_name=Linear)
(lin_k): RecursiveScriptModule(original_name=Linear)
(lin_v): RecursiveScriptModule(original_name=Linear)
)
)
(1): RecursiveScriptModule(
original_name=DNAConvJittable_b026ab
(multi_head): RecursiveScriptModule(
original_name=MultiHead
(lin_q): RecursiveScriptModule(original_name=Linear)
(lin_k): RecursiveScriptModule(original_name=Linear)
(lin_v): RecursiveScriptModule(original_name=Linear)
)
)
)
(lin2): RecursiveScriptModule(original_name=Linear)
)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment