Skip to content

Instantly share code, notes, and snippets.

@alexbeach-bc
Created February 6, 2024 02:27
Show Gist options
  • Save alexbeach-bc/0ac8e97d1cee1de14588254617c9f7ab to your computer and use it in GitHub Desktop.
Save alexbeach-bc/0ac8e97d1cee1de14588254617c9f7ab to your computer and use it in GitHub Desktop.
toch_demo.py
import os
import typing
from dataclasses import dataclass
from typing import Tuple
import flytekit
from dataclasses_json import dataclass_json
from flytekit import ImageSpec, Resources, task, workflow
from flytekit.types.directory import TensorboardLogs
from flytekit.types.file import PNGImageFile, PythonPickledFile
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from flytekitplugins.kfpytorch import PyTorch, Worker
from tensorboardX import SummaryWriter
from torch import distributed as dist
from torch import nn, optim
from torchvision import datasets, transforms
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
custom_image = ImageSpec(
name="flyte-kfpytorch-plugin",
packages=["torch", "torchvision", "flytekitplugins-kfpytorch", "matplotlib", "tensorboardX"],
registry="xxxxxxxxxxxx",
cuda="11.2.2",
cudnn="8",
python_version="3.10"
)
cpu_request = "500m"
mem_request = "4Gi"
gpu_request = "1"
mem_limit = "8Gi"
gpu_limit = "1"
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(model, device, train_loader, optimizer, epoch, writer, log_interval):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
niter = epoch * len(train_loader) + batch_idx
writer.add_scalar("loss", loss.item(), niter)
def test(model, device, test_loader, writer, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("\naccuracy={:.4f}\n".format(float(correct) / len(test_loader.dataset)))
accuracy = float(correct) / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
return accuracy
def epoch_step(model, device, train_loader, test_loader, optimizer, epoch, writer, log_interval):
train(model, device, train_loader, optimizer, epoch, writer, log_interval)
return test(model, device, test_loader, writer, epoch)
def should_distribute():
return dist.is_available() and WORLD_SIZE > 1
def is_distributed():
return dist.is_available() and dist.is_initialized()
@dataclass_json
@dataclass
class Hyperparameters(object):
"""
Args:
backend: Distributed backend to use
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait for before logging training status
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
"""
backend: str = dist.Backend.GLOO
sgd_momentum: float = 0.5
seed: int = 1
log_interval: int = 10
batch_size: int = 64
test_batch_size: int = 1000
epochs: int = 10
learning_rate: float = 0.01
TrainingOutputs = typing.NamedTuple(
"TrainingOutputs",
epoch_accuracies=typing.List[float],
model_state=PythonPickledFile,
logs=TensorboardLogs,
)
@task(
task_config=PyTorch(worker=Worker(replicas=2)),
retries=2,
cache=True,
cache_version="0.1",
requests=Resources(cpu=cpu_request, mem=mem_request, gpu=gpu_request),
limits=Resources(mem=mem_limit, gpu=gpu_limit),
container_image=custom_image,
)
def mnist_pytorch_job(hp: Hyperparameters) -> TrainingOutputs:
log_dir = os.path.join(flytekit.current_context().working_directory, "logs")
writer = SummaryWriter(log_dir)
torch.manual_seed(hp.seed)
use_cuda = True
print(f"Use cuda {use_cuda}")
device = torch.device("cuda" if use_cuda else "cpu")
print("Using device: {}, world size: {}".format(device, WORLD_SIZE))
if should_distribute():
print("Using distributed PyTorch with {} backend".format(hp.backend))
dist.init_process_group(backend=hp.backend)
# Load data
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
os.path.join(flytekit.current_context().working_directory, "data"),
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=hp.batch_size,
shuffle=True,
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
os.path.join(flytekit.current_context().working_directory, "data"),
train=False,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=hp.test_batch_size,
shuffle=False,
**kwargs,
)
# Train the model
model = Net().to(device)
if is_distributed():
Distributor = nn.parallel.DistributedDataParallel if use_cuda else nn.parallel.DistributedDataParallelCPU
model = Distributor(model)
optimizer = optim.SGD(model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum)
accuracies = [
epoch_step(
model,
device,
train_loader,
test_loader,
optimizer,
epoch,
writer,
hp.log_interval,
)
for epoch in range(1, hp.epochs + 1)
]
# Save the model
model_file = os.path.join(flytekit.current_context().working_directory, "mnist_cnn.pt")
torch.save(model.state_dict(), model_file)
return TrainingOutputs(
epoch_accuracies=accuracies,
model_state=PythonPickledFile(model_file),
logs=TensorboardLogs(log_dir),
)
@task(container_image=custom_image)
def plot_accuracy(epoch_accuracies: typing.List[float]) -> PNGImageFile:
plt.plot(epoch_accuracies)
plt.title("Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
accuracy_plot = os.path.join(flytekit.current_context().working_directory, "accuracy.png")
plt.savefig(accuracy_plot)
return PNGImageFile(accuracy_plot)
@workflow
def pytorch_training_job(
hp: Hyperparameters = Hyperparameters(epochs=2, batch_size=128),
) -> Tuple[PythonPickledFile, PNGImageFile, TensorboardLogs]:
accuracies, model, logs = mnist_pytorch_job(hp=hp)
plot = plot_accuracy(epoch_accuracies=accuracies)
return model, plot, logs
if __name__ == "__main__":
model, plot, logs = pytorch_training_job()
print(f"Model: {model}, plot PNG: {plot}, Tensorboard Log Dir: {logs}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment