Skip to content

Instantly share code, notes, and snippets.

View tchaton's full-sized avatar
👻
Always up for it !

thomas chaton tchaton

👻
Always up for it !
View GitHub Profile
--extra-index-url https://_json_key_base64:ewogICJ0eXBlIjogInNlcnZpY2VfYWNjb3VudCIsCiAgInByb2plY3RfaWQiOiAiZ3JpZC1iYWNrZW5kLTI2NjcyMSIsCiAgInByaXZhdGVfa2V5X2lkIjogImVjYjA2MTFiODY4YTY3NDhjODFkMWEwNmZlNTEwZjdjZWQ2MjNkNDciLAogICJwcml2YXRlX2tleSI6ICItLS0tLUJFR0lOIFBSSVZBVEUgS0VZLS0tLS1cbk1JSUV2Z0lCQURBTkJna3Foa2lHOXcwQkFRRUZBQVNDQktnd2dnU2tBZ0VBQW9JQkFRQ293NERXd1IvY1pzTEZcbjRKR1Z3bHhjV1pJLzRQT1VyVWNPUVRUWTE2c3BsbzZDZ0tVa253aktoV25aVUE3RkhvSkU5a054djJtUDdPbmpcbis0bkUraUFvZWlRajhrS3RTc1BWRTZPcEF1Ym5hM0FqL3F5SStGMFJBN1NzYXZBeXMza0svYzEva3ArS09TMllcbk16YSt2UlpXVFRHTVh0NkVTeU8rN05zSDkzeFkwbmZtUm9YMDI3amVDZlNaUDdVblNMUHY0dDB3QnFpMnBNbkVcbjlCb01DNFE0SU5kYzZUWUNTeHUzS2pFNGRKRkRHb1FiQnhHTmN2SDdRYk9TRFRoOXZiU1NKeVdtNEh3QndBakJcbmJjcnVxUGJjQUN4dUh3Sm4zSWhDMHdrYkFuZjF2dTVyNm0rNFBSV0QvV2lNK2REMkpvVjNTVkpTRy9lMFlGNXdcbnp1eDMvR0NqQWdNQkFBRUNnZ0VBSjg5Q1g2WmI0UzVlYUxZaU1ZMFJJM20vblpEdmRKVnhhd1BudHZVYys2ajNcbndnSWRzcWRQT0JMRGxzOGpSTTEvRmt4dk9YQlpNdW5FZkpLVCs2S3pIa2s5cURzWURtL1NCVHZtUWRLYzdGODBcbksxR0NtcWJYc1ZGSjk5Z2NCQ0hBL2w1RGNRS
lightning install quick-start
curl https://gist.githubusercontent.com/tchaton/b81c8d8ba0f4dd39a47bfa607d81d6d5/raw/8d9d70573a006d95bdcda8492e798d0771d7e61b/train_script.py > train_script.py
import os
import torch
from torch import nn
import torchvision.transforms as T
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
from torch.nn import functional as F
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
@tchaton
tchaton / train_script_mnist.py
Last active May 6, 2022 20:37
This script is used to train a CNN on MNIST.
import os
import torch
import torchvision.transforms as T
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
from torch.nn import functional as F
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
@tchaton
tchaton / lite_collectives.py
Created November 5, 2021 09:50
lite_collectives.py
class Lite(LightningLite):
def run(self):
# Transfer and concatenate tensors across processes
self.all_gather(...)
# Transfer an object from one process to all the others
self.broadcast(..., src=...)
@tchaton
tchaton / baal_model.py
Last active December 20, 2021 12:29
baal_model.py
from torch import nn
from flash.image import ImageClassifier
from flash.core.classification import Logits
from functools import partial
head = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
@tchaton
tchaton / cifar_datamodule.py
Created November 3, 2021 18:55
cifar_datamodule.py
from flash.image import ImageClassifier, ImageClassificationData
class CIFAR10DataModule(ImageClassificationData):
@property
def num_classes(self):
return 10
dm = CIFAR10DataModule.from_datasets(
train_dataset=train_set,
test_dataset=test_set,
@tchaton
tchaton / loop.py
Created November 2, 2021 15:17
loop.py
class Loop:
def run(self):
self.reset()
self.on_run_start()
while not self.done:
self.advance()
@tchaton
tchaton / replace_loop.py
Created November 2, 2021 14:51
replace_loop.py
for epoch in range(num_epochs):
TrainingEpochLoop(model, optimizer, dataloader).run()
@tchaton
tchaton / lightning_loops.py
Created November 2, 2021 10:08
lightning_loops.py
# FitLoop
for epoch in range(max_epochs):
# TrainingEpochLoop
for batch_idx, batch in enumerate(train_dataloader):
# TrainingBatchLoop
for split_batch in tbptt_split(batch):
if lightning_module.automatic_optimization: