Skip to content

Instantly share code, notes, and snippets.

@phillies
Last active January 16, 2023 15:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phillies/35cadb1abad057c70982fb529852de79 to your computer and use it in GitHub Desktop.
Save phillies/35cadb1abad057c70982fb529852de79 to your computer and use it in GitHub Desktop.
Example for MLFlow creating warnings when storing pytorch models
from torch import nn
import logging
import mlflow
import torch
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(633, 1024),
nn.RReLU(),
nn.Linear(1024, 128),
nn.RReLU(),
nn.Linear(128, 2)
)
def forward(self, x):
x = self.layers(x)
return x
mlflow.set_tracking_uri("http://host.docker.internal:5000")
experiment = mlflow.set_experiment("Test")
with mlflow.start_run():
logging.getLogger("mlflow").setLevel(logging.DEBUG)
torch.manual_seed(1337)
mlp = MLP()
mlflow.log_param("torch_seed", 1337)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
mlflow.log_param("lr", 1e-4)
inputs = torch.rand((1000, 633)).float()
targets = torch.rand((1000, 2)).float()
mlflow.log_param("batch_size", 1000)
model_signature = mlflow.models.signature.infer_signature(inputs.detach().numpy(), targets.detach().numpy())
mlflow.pytorch.log_model(mlp, 'model', signature=model_signature)
for epoch in range(0, 20):
current_loss = 0.0
mlp.train()
optimizer.zero_grad()
outputs = mlp(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
mlflow.pytorch.log_model(mlp, 'model', signature=model_signature, )
print(f"epoch {epoch}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment