Skip to content

Instantly share code, notes, and snippets.

@georgepar
Created May 29, 2019 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save georgepar/7a7370bc3ccb399444c9d21fe07a6d80 to your computer and use it in GitHub Desktop.
Save georgepar/7a7370bc3ccb399444c9d21fe07a6d80 to your computer and use it in GitHub Desktop.
import mlflow
import mlflow.pytorch
class MlFlowLogger(object):
def __init__(self,
uri=None,
experiment_name=None,
model_path='models',
**params):
self.params = params
self.experiment_name = experiment_name
self.run = None
self.uri = uri
self.model_path = model_path
self.start()
def get_or_set_experiment(self):
print(mlflow.get_tracking_uri())
if self.experiment_name is None:
return
try:
mlflow.create_experiment(self.experiment_name)
except Exception:
print('Experiment {} already exists'
.format(self.experiment_name))
mlflow.set_experiment(self.experiment_name)
@staticmethod
def log_param(k, v):
mlflow.log_param(k, v)
def log_params(self, params=None):
if params is None:
params = self.params
for k, v in params.items():
self.log_param(k, v)
@staticmethod
def log_metric(k, v):
mlflow.log_metric(k, v)
def log_metrics(self, metrics):
for k, v in metrics.items():
self.log_metric(k, v)
def log_model(self, model):
""" for local saving of models """
mlflow.pytorch.save_model(model, self.model_path)
def start(self):
mlflow.set_tracking_uri(self.uri)
self.get_or_set_experiment()
self.run = mlflow.start_run()
self.log_params()
def end(self):
mlflow.end_run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment