You should replace the arguments in pl.Trainer
with Ray Train's implementations.
import pytorch_lightning as pl
+ from ray.train.lightning import (
+ get_devices,
+ prepare_trainer,
import os | |
import tempfile | |
import torch | |
from torch import nn | |
from torch.nn.parallel import DistributedDataParallel | |
import ray | |
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig | |
from ray.train.torch import TorchTrainer |
accelerate==0.19.0 | |
adal==1.2.7 | |
aiofiles==22.1.0 | |
aiohttp==3.8.5 | |
aiohttp-cors==0.7.0 | |
aiorwlock==1.3.0 | |
aiosignal==1.3.1 | |
aiosqlite==0.19.0 | |
alabaster==0.7.13 | |
anyio==3.7.1 |
accelerate==0.19.0 | |
adal==1.2.7 | |
aiofiles==22.1.0 | |
aiohttp==3.8.5 | |
aiohttp-cors==0.7.0 | |
aiorwlock==1.3.0 | |
aiosignal==1.3.1 | |
aiosqlite==0.19.0 | |
alabaster==0.7.13 | |
anyio==3.7.1 |
about-time==4.2.1 | |
absl-py==1.4.0 | |
accelerate==0.19.0 | |
adal==1.2.7 | |
aim==3.17.5 | |
aim-ui==3.17.5 | |
aimrecords==0.0.7 | |
aimrocks==0.4.0 | |
aioboto3==11.2.0 | |
aiobotocore==2.5.0 |
# Minimal Example adapted from https://huggingface.co/docs/transformers/training | |
import deepspeed | |
import evaluate | |
import torch | |
from datasets import load_dataset | |
from deepspeed.accelerator import get_accelerator | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import ( | |
AutoModelForSequenceClassification, |
import evaluate | |
import torch | |
from datasets import load_dataset | |
from torch.optim import AdamW | |
from torch.utils.data import DataLoader | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
get_linear_schedule_with_warmup, | |
set_seed, |
import os | |
import evaluate | |
import numpy as np | |
from datasets import load_dataset | |
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, Checkpoint | |
from ray.train.torch import TorchTrainer | |
from transformers import AutoTokenizer | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
DataCollatorWithPadding, |
import os | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from filelock import FileLock | |
from torch.utils.data import DataLoader, random_split | |
from torchmetrics import Accuracy | |
from torchvision.datasets import MNIST | |
from torchvision import transforms |
class MNISTClassifier(pl.LightningModule): | |
def __init__(self, config): | |
super(MNISTClassifier, self).__init__() | |
self.accuracy = Accuracy() | |
# [!] Determine your data augmentation strategy here | |
self.batch_size = config["batch_size"] | |
self.aug_strategy = config["augmentation_strategy"] | |