Skip to content

Instantly share code, notes, and snippets.

@PaulEmmanuelSotir
Created June 8, 2020 20:47
Show Gist options
  • Save PaulEmmanuelSotir/10264acbea550776dd23582e9af734a3 to your computer and use it in GitHub Desktop.
Save PaulEmmanuelSotir/10264acbea550776dd23582e9af734a3 to your computer and use it in GitHub Desktop.
Logs some special tags about repository informations to MLFlow, which turns to be usefull when running experiments without `mlfow run` CLI (allows better runs display information in MLFLow Web UI)
import re
import logging
import git
import mlflow
import configparser
def mlflow_log_repo_tags(default_project_name: str, entry_point: str = __file__, path_in_repository: str = __file__):
""" This code creates special mlflow tags about current repository infos, which is not done by mlflow when starting an MLFlow run from code instead of from `mlflow run` command
Code similar to (mlflow.projects._create_run)[https://www.mlflow.org/docs/latest/_modules/mlflow/projects.html] which doesn't seems to be called by `mlflow.start_run`
"""
# Eventually add the following line depending on your needs:
#if mlflow.active_run() is not None:
tags = {mlflow.utils.mlflow_tags.MLFLOW_SOURCE_NAME: default_project_name,
mlflow.utils.mlflow_tags.MLFLOW_SOURCE_TYPE: mlflow.entities.SourceType.to_string(mlflow.entities.SourceType.PROJECT),
mlflow.utils.mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT: entry_point}
try:
repo = git.Repo(path_in_repository, search_parent_directories=True)
git_repo_url = repo.remote().url if 'origin' in repo.remotes else (repo.remotes[0].url if len(repo.remotes) > 0 else '')
git_repo_url = re.sub(r'git@([.\w]+):', r'https://\1/', git_repo_url).rstrip('.git') # Convert SSH git URL to http URL
mlflow.log_param('commit_url', git_repo_url + f'/commit/{repo.head.commit.hexsha}/')
# We also set MLFLOW_SOURCE_NAME to repo URL so that MLFlow web UI is able to parse it and render commit and source hyperlinks (MLFLow only supports github URLs for now)
tags.update({mlflow.utils.mlflow_tags.MLFLOW_SOURCE_NAME: git_repo_url if git_repo_url else default_project_name,
mlflow.utils.mlflow_tags.MLFLOW_GIT_BRANCH: repo.active_branch.name,
mlflow.utils.mlflow_tags.MLFLOW_GIT_REPO_URL: git_repo_url,
mlflow.utils.mlflow_tags.MLFLOW_GIT_COMMIT: repo.head.commit.hexsha})
# Change mlflow user to be git repository user instead of system user (if any git user is specified)
git_config_reader = repo.config_reader()
git_config_reader.read()
user = git_config_reader.get_value('user', 'name', default=None)
email = git_config_reader.get_value('user', 'email', default=None)
if user or email:
tags[mlflow.utils.mlflow_tags.MLFLOW_USER] = (str(user) + (f' <{email}>' if email else '')) if user else str(email)
except (ImportError, OSError, ValueError, IOError, KeyError, git.GitError, configparser.Error) as e:
logging.warning(f'Failed to import Git or to get repository informations. Error: {e}')
mlflow.set_tags(tags)
if __name__ == '__main__':
mlflow_log_repo_tags('TestProject')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment