Last active
March 19, 2021 14:52
-
-
Save awaiskaleem/a52dee3a569102085d145d8571c55b8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import mlflow | |
logger = logging.getLogger(__name__) | |
class BaseMLLogger: | |
""" | |
Base class for tracking experiments. | |
This class can be extended to implement custom logging backends like MLFlow, Tensorboard, or Sacred. | |
""" | |
disable_logging = False | |
def __init__(self, tracking_uri, **kwargs): | |
self.tracking_uri = tracking_uri | |
print(WELCOME_BARN) | |
def init_experiment(self, tracking_uri): | |
raise NotImplementedError() | |
@classmethod | |
def log_metrics(cls, metrics, step): | |
raise NotImplementedError() | |
@classmethod | |
def log_artifacts(cls, self): | |
raise NotImplementedError() | |
@classmethod | |
def log_params(cls, params): | |
raise NotImplementedError() | |
class MLFlowLogger(BaseMLLogger): | |
""" | |
Logger for MLFlow experiment tracking. | |
""" | |
def init_experiment(self, experiment_name, run_name=None, nested=True): | |
if not self.disable_logging: | |
try: | |
mlflow.set_tracking_uri(self.tracking_uri) | |
mlflow.set_experiment(experiment_name) | |
mlflow.start_run(run_name=run_name, nested=nested) | |
except ConnectionError: | |
raise Exception( | |
f"MLFlow cannot connect to the remote server at {self.tracking_uri}.\n" | |
f"MLFlow also supports logging runs locally to files. Set the MLFlowLogger " | |
f"tracking_uri to an empty string to use that." | |
) | |
@classmethod | |
def log_metrics(cls, metrics, step): | |
if not cls.disable_logging: | |
try: | |
mlflow.log_metrics(metrics, step=step) | |
except ConnectionError: | |
logger.warning(f"ConnectionError in logging metrics to MLFlow.") | |
except Exception as e: | |
logger.warning(f"Failed to log metrics: {e}") | |
@classmethod | |
def log_params(cls, params): | |
if not cls.disable_logging: | |
try: | |
mlflow.log_params(params) | |
except ConnectionError: | |
logger.warning("ConnectionError in logging params to MLFlow") | |
except Exception as e: | |
logger.warning(f"Failed to log params: {e}") | |
@classmethod | |
def log_artifacts(cls, dir_path, artifact_path=None): | |
if not cls.disable_logging: | |
try: | |
mlflow.log_artifacts(dir_path, artifact_path) | |
except ConnectionError: | |
logger.warning(f"ConnectionError in logging artifacts to MLFlow") | |
except Exception as e: | |
logger.warning(f"Failed to log artifacts: {e}") | |
@classmethod | |
def end_run(cls): | |
if not cls.disable_logging: | |
mlflow.end_run() | |
@classmethod | |
def disable(cls): | |
logger.warning("ML Logging is turned off. No parameters, metrics or artifacts will be logged to MLFlow.") | |
cls.disable_logging = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment