Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Last active March 28, 2024 03:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stellaraccident/83f91c7316ea668d59e0718e179e2cfd to your computer and use it in GitHub Desktop.
Save stellaraccident/83f91c7316ea668d59e0718e179e2cfd to your computer and use it in GitHub Desktop.
# Copied from: https://github.com/pytorch/examples/blob/main/mnist/main.py
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.export import export
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from shark_turbine import aot
from iree.compiler.ir import Context
from iree import runtime as ireert
from turbine_models.custom_models.sd_inference import utils
from turbine_models.model_runner import vmfbRunner
from torch._functorch.aot_autograd import aot_export_module
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x: torch.Tensor):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def train_func(data, target):
global args, model, device, optimizer
data, target = data.to(device), target.to(device)
#optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
return loss, output
#optimizer.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
global args, model, device, optimizer
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
torch.manual_seed(args.seed)
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
# optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
optimizer = optim.SGD(model.parameters(), lr=5e-5)
# optimizer = torch.optim.Adam(model.parameters())
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for batch_idx, (data, target) in enumerate(train_loader):
print(data.shape, data.dtype)
print(target.shape, target.dtype)
break
print("DEBUG T0")
class TrainModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return train_func(x, y)
train_model = TrainModel()
"""class CompiledTraining(CompiledModule):
params = export_parameters(train_model, mutable=True)
def main(
self,
data=AbstractTensor(*[64,1,28,28], dtype=torch.float32),
target=AbstractTensor(*[64], dtype=torch.int64),
):
return jittable(train_func)(
data, target
)
print("DEBUG T1")
#import_to = "INPUT" if args.compile_to == "linalg" else "IMPORT"
inst = CompiledTraining(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(inst)))"""
example_args = (torch.randn(64,1,28,28),torch.ones([64], dtype=torch.int64))
ea = torch.randn(64,1,28,28)
extended_args = (
torch.randn(32,1,3,3),
torch.randn(32),
torch.randn(64,32,3,3),
torch.randn(64),
torch.randn(128,9216),
torch.randn(128),
torch.randn(10,128),
torch.randn(10),
ea,
torch.ones([64],dtype=torch.int64)
)
class TrainNet(nn.Module):
def __init__(self):
super(TrainNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x: torch.Tensor, target: torch.Tensor):
optimizer.zero_grad()
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
loss = F.nll_loss(output, target)
optimizer.step()
return (loss, output.detach())
train_net = TrainNet()
train_net.train()
original_example_args = (ea, torch.ones([64], dtype=torch.int64))
fx_g, signature = aot_export_module(train_net, original_example_args, trace_joint=True, output_loss_index=0)
print(fx_g)
print(signature)
import punktorch
train_net.requires_grad_(False)
bwd_mod = punktorch.TraningInferenceModule(train_net, fx_g, signature)
example_outputs = bwd_mod.forward(*original_example_args)
print(example_outputs)
#print("EXTENDED ARGS:", extended_args)
a = torch.export.export(bwd_mod, original_example_args)
print(type(a), a)
#import IPython; IPython.embed()
eo = aot.export(a)
print(eo.mlir_module())
#torch.export.save(a, "exported_program.pt2")
print("DEBUG T2")
#f = open("fixed.mlir")
#mlir = f.read()
#f.close()
vmfb_path = utils.compile_to_vmfb(
#str(CompiledModule.get_mlir_module(inst)),
eo.mlir_module(),
"cpu",
"x86_64-linux-gnu", # target_triple,
"",
"safe_name",
return_path=True,
const_expr_hoisting=True
)
print("DEBUG T3")
#runner = vmfbRunner("local-task", vmfb_path)
print("DEBUG T3")
#model.train()
#for epoch in range(1, args.epochs + 1):
# for batch_idx, (data, target) in enumerate(train_loader):
# pass
#loss = train_func(data, target)
#inputs = [ireert.asdevicearray(runner.config.device, data), ireert.asdevicearray(runner.config.device, target)]
#results = runner.ctx.modules.compiled_training["main"](*inputs)
#optimizer.step()
#train(args, model, device, train_loader, optimizer, epoch)
#test(model, device, test_loader)
#scheduler.step()
#if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
print("DONE!")
if __name__ == '__main__':
main()
# Notes
# fails to inline train_func, always
# optimizer.step complains about mutable tensors
# can't return gradients
# can get loss
import torch
import torch.nn as nn
class TraningInferenceModule(nn.Module):
"""Transforms a functorch traced backwards trace into an inference module.
This wraps the original training module and each call to forward() will
read current parameters and apply gradient updates to them.
"""
def __init__(
self,
main_module: nn.Module,
backwards_trace: nn.Module,
graph_signature: torch._functorch._aot_autograd.schemas.GraphSignature,
):
super().__init__()
# Every original module parameter gets promoted to a buffer and the
# mapping stashed in the parameter_buffers dict (so we can access
# by original name).
# TODO: For best ergonomics, we should create a module structure that
# mirrors the original so that buffer names have the same name as the
# original parameters (vs having "." replaced by "_").
self.param_buffers = nn.Module()
self.parameter_buffers = {}
for param_name, param in main_module.named_parameters():
buffer_name = param_name.replace(".", "_")
buffer_value = torch.tensor(param)
self.parameter_buffers[param_name] = buffer_value
self.param_buffers.register_buffer(
buffer_name, buffer_value, persistent=False)
self.backwards_trace = backwards_trace
self.graph_signature = graph_signature
def forward(self, *user_inputs):
gs = self.graph_signature
bs = gs.backward_signature
all_params = self.parameter_buffers
assert all(name in all_params for name in gs.parameters)
assert len(gs.buffers) == 0, "NYI: Buffer lifting"
assert len(gs.inputs_to_buffers) == 0, "NYI: Input buffer lifting"
assert len(gs.buffers_to_mutate) == 0, "NYI: Output buffer lifting"
assert len(gs.user_inputs_to_mutate) == 0, "NYI: Mutated user inputs"
# When invoking the functionalized backwards graph, the argument order
# is:
# gs.inputs_to_parameters
# ... probably gs.inputs_to_buffers ...
# user_inputs
# The graph signature as an in_spec pytree which specifies some
# structure but it is unclear exactly how this flows yet.
parameter_inputs = [all_params[name] for name in gs.inputs_to_parameters.values()]
trace_inputs = (*parameter_inputs, *user_inputs)
trace_outputs = list(self.backwards_trace.forward(*trace_inputs))
# Now process the outputs by making mutations as directed by the
# graph signature.
# The order here seems to be the reverse of the input order convention.
# Because why the ____ not.
# user_outputs
# ... probably buffers_to_mutate, user_inputs_to_mutate ...
# gradients_to_parameters
# ... probably gradients_to_user_inputs ...
def shift_outputs(n: int) -> list:
shifted = trace_outputs[0:n]
rest = trace_outputs[n:]
return shifted, rest
user_outputs, trace_outputs = shift_outputs(len(gs.user_outputs))
parameter_updates, trace_outputs = shift_outputs(len(bs.gradients_to_parameters))
for update, param_name in zip(parameter_updates, bs.gradients_to_parameters.values()):
param = all_params[param_name]
param.copy_(update[:])
assert len(trace_outputs) == 0, "Backward trace outputs remaining (likely unimplemented lifting)"
return user_outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment