-
-
Save stellaraccident/83f91c7316ea668d59e0718e179e2cfd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copied from: https://github.com/pytorch/examples/blob/main/mnist/main.py | |
from __future__ import print_function | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.export import export | |
from torchvision import datasets, transforms | |
from torch.optim.lr_scheduler import StepLR | |
from shark_turbine import aot | |
from iree.compiler.ir import Context | |
from iree import runtime as ireert | |
from turbine_models.custom_models.sd_inference import utils | |
from turbine_models.model_runner import vmfbRunner | |
from torch._functorch.aot_autograd import aot_export_module | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
self.dropout1 = nn.Dropout(0.25) | |
self.dropout2 = nn.Dropout(0.5) | |
self.fc1 = nn.Linear(9216, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x: torch.Tensor): | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = F.max_pool2d(x, 2) | |
x = self.dropout1(x) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
output = F.log_softmax(x, dim=1) | |
return output | |
def train(args, model, device, train_loader, optimizer, epoch): | |
model.train() | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % args.log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
if args.dry_run: | |
break | |
def train_func(data, target): | |
global args, model, device, optimizer | |
data, target = data.to(device), target.to(device) | |
#optimizer.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
return loss, output | |
#optimizer.step() | |
def test(model, device, test_loader): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss | |
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
test_loss /= len(test_loader.dataset) | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
test_loss, correct, len(test_loader.dataset), | |
100. * correct / len(test_loader.dataset))) | |
def main(): | |
global args, model, device, optimizer | |
# Training settings | |
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | |
help='input batch size for testing (default: 1000)') | |
parser.add_argument('--epochs', type=int, default=14, metavar='N', | |
help='number of epochs to train (default: 14)') | |
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | |
help='learning rate (default: 1.0)') | |
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | |
help='Learning rate step gamma (default: 0.7)') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
parser.add_argument('--no-mps', action='store_true', default=False, | |
help='disables macOS GPU training') | |
parser.add_argument('--dry-run', action='store_true', default=False, | |
help='quickly check a single pass') | |
parser.add_argument('--seed', type=int, default=1, metavar='S', | |
help='random seed (default: 1)') | |
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | |
help='how many batches to wait before logging training status') | |
parser.add_argument('--save-model', action='store_true', default=False, | |
help='For Saving the current Model') | |
args = parser.parse_args() | |
use_cuda = not args.no_cuda and torch.cuda.is_available() | |
use_mps = not args.no_mps and torch.backends.mps.is_available() | |
torch.manual_seed(args.seed) | |
if use_cuda: | |
device = torch.device("cuda") | |
elif use_mps: | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
train_kwargs = {'batch_size': args.batch_size} | |
test_kwargs = {'batch_size': args.test_batch_size} | |
if use_cuda: | |
cuda_kwargs = {'num_workers': 1, | |
'pin_memory': True, | |
'shuffle': True} | |
train_kwargs.update(cuda_kwargs) | |
test_kwargs.update(cuda_kwargs) | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
dataset1 = datasets.MNIST('../data', train=True, download=True, | |
transform=transform) | |
dataset2 = datasets.MNIST('../data', train=False, | |
transform=transform) | |
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) | |
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) | |
model = Net().to(device) | |
# optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | |
optimizer = optim.SGD(model.parameters(), lr=5e-5) | |
# optimizer = torch.optim.Adam(model.parameters()) | |
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | |
for batch_idx, (data, target) in enumerate(train_loader): | |
print(data.shape, data.dtype) | |
print(target.shape, target.dtype) | |
break | |
print("DEBUG T0") | |
class TrainModel(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, y): | |
return train_func(x, y) | |
train_model = TrainModel() | |
"""class CompiledTraining(CompiledModule): | |
params = export_parameters(train_model, mutable=True) | |
def main( | |
self, | |
data=AbstractTensor(*[64,1,28,28], dtype=torch.float32), | |
target=AbstractTensor(*[64], dtype=torch.int64), | |
): | |
return jittable(train_func)( | |
data, target | |
) | |
print("DEBUG T1") | |
#import_to = "INPUT" if args.compile_to == "linalg" else "IMPORT" | |
inst = CompiledTraining(context=Context(), import_to="IMPORT") | |
print(str(CompiledModule.get_mlir_module(inst)))""" | |
example_args = (torch.randn(64,1,28,28),torch.ones([64], dtype=torch.int64)) | |
ea = torch.randn(64,1,28,28) | |
extended_args = ( | |
torch.randn(32,1,3,3), | |
torch.randn(32), | |
torch.randn(64,32,3,3), | |
torch.randn(64), | |
torch.randn(128,9216), | |
torch.randn(128), | |
torch.randn(10,128), | |
torch.randn(10), | |
ea, | |
torch.ones([64],dtype=torch.int64) | |
) | |
class TrainNet(nn.Module): | |
def __init__(self): | |
super(TrainNet, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
self.dropout1 = nn.Dropout(0.25) | |
self.dropout2 = nn.Dropout(0.5) | |
self.fc1 = nn.Linear(9216, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x: torch.Tensor, target: torch.Tensor): | |
optimizer.zero_grad() | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = F.max_pool2d(x, 2) | |
x = self.dropout1(x) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
output = F.log_softmax(x, dim=1) | |
loss = F.nll_loss(output, target) | |
optimizer.step() | |
return (loss, output.detach()) | |
train_net = TrainNet() | |
train_net.train() | |
original_example_args = (ea, torch.ones([64], dtype=torch.int64)) | |
fx_g, signature = aot_export_module(train_net, original_example_args, trace_joint=True, output_loss_index=0) | |
print(fx_g) | |
print(signature) | |
import punktorch | |
train_net.requires_grad_(False) | |
bwd_mod = punktorch.TraningInferenceModule(train_net, fx_g, signature) | |
example_outputs = bwd_mod.forward(*original_example_args) | |
print(example_outputs) | |
#print("EXTENDED ARGS:", extended_args) | |
a = torch.export.export(bwd_mod, original_example_args) | |
print(type(a), a) | |
#import IPython; IPython.embed() | |
eo = aot.export(a) | |
print(eo.mlir_module()) | |
#torch.export.save(a, "exported_program.pt2") | |
print("DEBUG T2") | |
#f = open("fixed.mlir") | |
#mlir = f.read() | |
#f.close() | |
vmfb_path = utils.compile_to_vmfb( | |
#str(CompiledModule.get_mlir_module(inst)), | |
eo.mlir_module(), | |
"cpu", | |
"x86_64-linux-gnu", # target_triple, | |
"", | |
"safe_name", | |
return_path=True, | |
const_expr_hoisting=True | |
) | |
print("DEBUG T3") | |
#runner = vmfbRunner("local-task", vmfb_path) | |
print("DEBUG T3") | |
#model.train() | |
#for epoch in range(1, args.epochs + 1): | |
# for batch_idx, (data, target) in enumerate(train_loader): | |
# pass | |
#loss = train_func(data, target) | |
#inputs = [ireert.asdevicearray(runner.config.device, data), ireert.asdevicearray(runner.config.device, target)] | |
#results = runner.ctx.modules.compiled_training["main"](*inputs) | |
#optimizer.step() | |
#train(args, model, device, train_loader, optimizer, epoch) | |
#test(model, device, test_loader) | |
#scheduler.step() | |
#if args.save_model: | |
# torch.save(model.state_dict(), "mnist_cnn.pt") | |
print("DONE!") | |
if __name__ == '__main__': | |
main() | |
# Notes | |
# fails to inline train_func, always | |
# optimizer.step complains about mutable tensors | |
# can't return gradients | |
# can get loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class TraningInferenceModule(nn.Module): | |
"""Transforms a functorch traced backwards trace into an inference module. | |
This wraps the original training module and each call to forward() will | |
read current parameters and apply gradient updates to them. | |
""" | |
def __init__( | |
self, | |
main_module: nn.Module, | |
backwards_trace: nn.Module, | |
graph_signature: torch._functorch._aot_autograd.schemas.GraphSignature, | |
): | |
super().__init__() | |
# Every original module parameter gets promoted to a buffer and the | |
# mapping stashed in the parameter_buffers dict (so we can access | |
# by original name). | |
# TODO: For best ergonomics, we should create a module structure that | |
# mirrors the original so that buffer names have the same name as the | |
# original parameters (vs having "." replaced by "_"). | |
self.param_buffers = nn.Module() | |
self.parameter_buffers = {} | |
for param_name, param in main_module.named_parameters(): | |
buffer_name = param_name.replace(".", "_") | |
buffer_value = torch.tensor(param) | |
self.parameter_buffers[param_name] = buffer_value | |
self.param_buffers.register_buffer( | |
buffer_name, buffer_value, persistent=False) | |
self.backwards_trace = backwards_trace | |
self.graph_signature = graph_signature | |
def forward(self, *user_inputs): | |
gs = self.graph_signature | |
bs = gs.backward_signature | |
all_params = self.parameter_buffers | |
assert all(name in all_params for name in gs.parameters) | |
assert len(gs.buffers) == 0, "NYI: Buffer lifting" | |
assert len(gs.inputs_to_buffers) == 0, "NYI: Input buffer lifting" | |
assert len(gs.buffers_to_mutate) == 0, "NYI: Output buffer lifting" | |
assert len(gs.user_inputs_to_mutate) == 0, "NYI: Mutated user inputs" | |
# When invoking the functionalized backwards graph, the argument order | |
# is: | |
# gs.inputs_to_parameters | |
# ... probably gs.inputs_to_buffers ... | |
# user_inputs | |
# The graph signature as an in_spec pytree which specifies some | |
# structure but it is unclear exactly how this flows yet. | |
parameter_inputs = [all_params[name] for name in gs.inputs_to_parameters.values()] | |
trace_inputs = (*parameter_inputs, *user_inputs) | |
trace_outputs = list(self.backwards_trace.forward(*trace_inputs)) | |
# Now process the outputs by making mutations as directed by the | |
# graph signature. | |
# The order here seems to be the reverse of the input order convention. | |
# Because why the ____ not. | |
# user_outputs | |
# ... probably buffers_to_mutate, user_inputs_to_mutate ... | |
# gradients_to_parameters | |
# ... probably gradients_to_user_inputs ... | |
def shift_outputs(n: int) -> list: | |
shifted = trace_outputs[0:n] | |
rest = trace_outputs[n:] | |
return shifted, rest | |
user_outputs, trace_outputs = shift_outputs(len(gs.user_outputs)) | |
parameter_updates, trace_outputs = shift_outputs(len(bs.gradients_to_parameters)) | |
for update, param_name in zip(parameter_updates, bs.gradients_to_parameters.values()): | |
param = all_params[param_name] | |
param.copy_(update[:]) | |
assert len(trace_outputs) == 0, "Backward trace outputs remaining (likely unimplemented lifting)" | |
return user_outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment