Created
May 29, 2024 21:29
-
-
Save ManfeiBai/ad2134d9e6aabf9b6b6f43dba9e0f725 to your computer and use it in GitHub Desktop.
modified
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_while_loop_tpu_MNIST_inside_loop_with_mutation_in_batchnorm2d(self): | |
xm.mark_step() | |
device = xm.xla_device() | |
torch.set_grad_enabled(False) | |
n_epochs = 3 | |
batch_size_train = 8 | |
batch_size_test = 10 | |
learning_rate = 0.01 | |
momentum = 0.5 | |
log_interval = 10 | |
random_seed = 1 | |
torch.backends.cudnn.enabled = False | |
torch.manual_seed(random_seed) | |
class MNIST(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2) | |
self.bn1 = torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False) | |
# self.bn1 = torch.nn.BatchNorm2d(10) | |
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5) | |
self.bn2 = torch.nn.BatchNorm2d(20, affine=False, track_running_stats=False) | |
# self.bn2 = torch.nn.BatchNorm2d(20) | |
self.fc1 = torch.nn.Linear(500, 50) | |
self.fc2 = torch.nn.Linear(50, 10) | |
self.bnLayersWeights = [] | |
def forward(self, iteri, x, y): | |
def cond_fn(iteri, x, y): | |
return iteri > 0 | |
def body_fn(iteri, x, y): | |
y = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
y = self.bn1(y) | |
y = F.relu(F.max_pool2d(self.conv2(y), 2)) | |
y = self.bn2(y) | |
y = torch.flatten(y, 1) | |
y = F.relu(self.fc1(y)) | |
y = self.fc2(y) | |
return iteri - 1, x.clone(), F.log_softmax(y, dim=1) | |
return while_loop(cond_fn, body_fn, (iteri, x, y)) | |
def forward_compare(self, iteri, x, y): | |
y = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
y = self.bn1(y) | |
y = F.relu(F.max_pool2d(self.conv2(y), 2)) | |
y = self.bn2(y) | |
y = torch.flatten(y, 1) | |
y = F.relu(self.fc1(y)) | |
y = self.fc2(y) | |
return iteri - 1, x.clone(), F.log_softmax(y, dim=1) | |
mnist = MNIST() | |
mnist.to(device) | |
bs=16 | |
l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) | |
l_out = torch.randn(bs, 10, dtype=torch.float32, device=device) | |
iteri = torch.tensor(3, dtype=torch.int64, device=device) | |
print("print and check behavior by exporting the model") | |
ep = torch.export.export(mnist, (iteri, l_in_0, l_out)) | |
ep.module().print_readable() | |
print("after print and check behavior by exporting the model") | |
_, _, res = mnist(iteri, l_in_0, l_out) | |
# _, _, res = mnist(iteri, l_in_0) | |
print("res: ", res) | |
# === expected result for one iteration to be compared since body_fn defined use the same input in each iteration === | |
_, _, expected_res = mnist.forward_compare(iteri, l_in_0, l_out) | |
# _, _, expected_res = mnist.forward_compare(iteri, l_in_0) | |
print("expected_res: ", expected_res) | |
self.assertTrue(torch.all(torch.eq(res, expected_res))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment