This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--extra-index-url https://_json_key_base64:ewogICJ0eXBlIjogInNlcnZpY2VfYWNjb3VudCIsCiAgInByb2plY3RfaWQiOiAiZ3JpZC1iYWNrZW5kLTI2NjcyMSIsCiAgInByaXZhdGVfa2V5X2lkIjogImVjYjA2MTFiODY4YTY3NDhjODFkMWEwNmZlNTEwZjdjZWQ2MjNkNDciLAogICJwcml2YXRlX2tleSI6ICItLS0tLUJFR0lOIFBSSVZBVEUgS0VZLS0tLS1cbk1JSUV2Z0lCQURBTkJna3Foa2lHOXcwQkFRRUZBQVNDQktnd2dnU2tBZ0VBQW9JQkFRQ293NERXd1IvY1pzTEZcbjRKR1Z3bHhjV1pJLzRQT1VyVWNPUVRUWTE2c3BsbzZDZ0tVa253aktoV25aVUE3RkhvSkU5a054djJtUDdPbmpcbis0bkUraUFvZWlRajhrS3RTc1BWRTZPcEF1Ym5hM0FqL3F5SStGMFJBN1NzYXZBeXMza0svYzEva3ArS09TMllcbk16YSt2UlpXVFRHTVh0NkVTeU8rN05zSDkzeFkwbmZtUm9YMDI3amVDZlNaUDdVblNMUHY0dDB3QnFpMnBNbkVcbjlCb01DNFE0SU5kYzZUWUNTeHUzS2pFNGRKRkRHb1FiQnhHTmN2SDdRYk9TRFRoOXZiU1NKeVdtNEh3QndBakJcbmJjcnVxUGJjQUN4dUh3Sm4zSWhDMHdrYkFuZjF2dTVyNm0rNFBSV0QvV2lNK2REMkpvVjNTVkpTRy9lMFlGNXdcbnp1eDMvR0NqQWdNQkFBRUNnZ0VBSjg5Q1g2WmI0UzVlYUxZaU1ZMFJJM20vblpEdmRKVnhhd1BudHZVYys2ajNcbndnSWRzcWRQT0JMRGxzOGpSTTEvRmt4dk9YQlpNdW5FZkpLVCs2S3pIa2s5cURzWURtL1NCVHZtUWRLYzdGODBcbksxR0NtcWJYc1ZGSjk5Z2NCQ0hBL2w1RGNRS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lightning install quick-start | |
curl https://gist.githubusercontent.com/tchaton/b81c8d8ba0f4dd39a47bfa607d81d6d5/raw/8d9d70573a006d95bdcda8492e798d0771d7e61b/train_script.py > train_script.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
from pytorch_lightning import LightningDataModule, LightningModule | |
from pytorch_lightning.utilities.cli import LightningCLI | |
from torch.nn import functional as F | |
from torchmetrics import Accuracy | |
from torchvision.datasets import MNIST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torchvision.transforms as T | |
from pytorch_lightning import LightningDataModule, LightningModule | |
from pytorch_lightning.utilities.cli import LightningCLI | |
from torch.nn import functional as F | |
import torch.nn as nn | |
from torchmetrics import Accuracy | |
from torchvision.datasets import MNIST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Lite(LightningLite): | |
def run(self): | |
# Transfer and concatenate tensors across processes | |
self.all_gather(...) | |
# Transfer an object from one process to all the others | |
self.broadcast(..., src=...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from flash.image import ImageClassifier | |
from flash.core.classification import Logits | |
from functools import partial | |
head = nn.Sequential( | |
nn.Linear(512, 512), | |
nn.ReLU(True), | |
nn.Dropout(), | |
nn.Linear(512, 512), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flash.image import ImageClassifier, ImageClassificationData | |
class CIFAR10DataModule(ImageClassificationData): | |
@property | |
def num_classes(self): | |
return 10 | |
dm = CIFAR10DataModule.from_datasets( | |
train_dataset=train_set, | |
test_dataset=test_set, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Loop: | |
def run(self): | |
self.reset() | |
self.on_run_start() | |
while not self.done: | |
self.advance() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(num_epochs): | |
TrainingEpochLoop(model, optimizer, dataloader).run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FitLoop | |
for epoch in range(max_epochs): | |
# TrainingEpochLoop | |
for batch_idx, batch in enumerate(train_dataloader): | |
# TrainingBatchLoop | |
for split_batch in tbptt_split(batch): | |
if lightning_module.automatic_optimization: |