from argparse import ArgumentParser
import time
import numpy as np
import torch
from torch import nn
from import DataLoader, TensorDataset
from torch.optim import Adam
from torch.nn import functional as F
parser = ArgumentParser()
parser.add_argument("--pytorch_2", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--tensorcores", action="store_true")
args = parser.parse_args()
pytorch_2 = args.pytorch_2
compile = args.compile
tensorcores = args.tensorcores
# For tensorcore speedup:
if tensorcores:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
if pytorch_2 and tensorcores:
class MLP(nn.Module):
def __init__(self, n_in, n_out, n_hidden=128, n_layers=2, activation=F.relu):
self.activation = activation
self.layers = nn.ModuleList([nn.Linear(n_in, n_hidden)])
self.layers.extend([nn.Linear(n_hidden, n_hidden) for _ in range(n_layers - 1)])
self.layers.append(nn.Linear(n_hidden, n_out))
def forward(self, x):
for layer in self.layers[:-1]:
x = self.activation(layer(x))
return self.layers[-1](x)
class ResidualConnection(nn.Module):
def __init__(self, module):
self.module = module
def forward(self, x):
return x + self.module(x)
class DeepMLP(nn.Module):
def __init__(self, n_in, n_out, n_hidden=128, blocks=2, activation=F.relu):
self.activation = activation = nn.Sequential(
MLP(n_in, n_hidden, n_hidden, n_layers=2, activation=activation),
MLP(n_hidden, n_hidden, n_hidden, n_layers=2, activation=activation)
for _ in range(blocks)
nn.Linear(n_hidden, n_out),
def forward(self, x):
device = torch.device("cuda")
# Dataset
N = 1_000_000
m = 100
X = torch.rand(N, m, device=device) * 20 - 10
y = torch.cos(X)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
# Model
n_hidden = 256
model = DeepMLP(m, m, n_hidden=n_hidden, blocks=2, activation=F.relu)
model =
# Optimizer
opt = Adam(model.parameters(), lr=1e-3)
def train(model, X_batch, y_batch):
y_pred = model(X_batch)
# assert y_pred.shape == y_batch.shape
loss = F.mse_loss(y_pred, y_batch)
return loss.item()
if compile:
train = torch.compile(train, mode="reduce-overhead")
losses = []
times = []
# Training:
for epoch in range(10):
for X_batch, y_batch in loader:
start = time.time()
loss = train(model, X_batch, y_batch)
end = time.time()
times.append(end - start)
f"Epoch {epoch}: loss={np.median(losses[-100:]):.3f}, timing={np.median(times[-100:]):.3e}"
# a100, with tensorcores, with compilation: 2.47e-3
# a100, with tensorcores, without compilation: 4.60e-3
# a100, without tensorcores, with compilation: 2.47e-3
# a100, without tensorcores, without compilation: 4.51e-3
# h100, with tensorcores, with compilation: 2.37e-3
# h100, with tensorcores, without compilation: 4.42e-3
# h100, without tensorcores, with compilation: 2.37e-3
# h100, without tensorcores, without compilation: 5.08e-3
