Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created February 28, 2024 19:04
Show Gist options
  • Save msaroufim/15a4b97c3f45cead4b2feb90894ed8d3 to your computer and use it in GitHub Desktop.
Save msaroufim/15a4b97c3f45cead4b2feb90894ed8d3 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import os
torch.set_default_device("cpu")
torch.set_default_dtype(torch.float32)
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# Put these outside of main() otherwise torch.compile() craps out
net = SimpleNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)
# criterion = nn.MSELoss() # Not supported in fp16 on cpu
criterion = nn.L1Loss()
def main(input):
# Dummy input and target data
# input = torch.randn(1, 10)
for _ in range(128):
target = torch.randn(1, 1)
output = net(input)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
# Step 7: Single optimizer step
optimizer.step()
if __name__ == "__main__":
main = torch.compile(main, fullgraph=True)
main(torch.randn(1,10))
# main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment