Created
July 27, 2021 04:29
-
-
Save nkthiebaut/a8fe27d11b041bc1e2f1dbe14888d8af to your computer and use it in GitHub Desktop.
Move and unpack sagemaker models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tarfile | |
import subprocess | |
import os | |
import s3fs | |
REQUIRED_AWS_ENV = {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"} | |
def setup_aws_fs() -> s3fs.S3FileSystem: | |
"""Check the presence of AWS required environment variables | |
and return an S3 filesystem object. | |
""" | |
if not REQUIRED_AWS_ENV.issubset(os.environ): | |
raise EnvironmentError( | |
f"{REQUIRED_AWS_ENV} environment variables should be defined and are not." | |
) | |
elif os.environ["AWS_DEFAULT_REGION"] != "us-east-1": | |
raise EnvironmentError( | |
"The AWS_DEFAULT_REGION environment variable is " | |
f"{os.environ['AWS_DEFAULT_REGION']} but it should 'us-east-1'." | |
) | |
return s3fs.S3FileSystem(use_ssl=True) | |
filesystem = setup_aws_fs() | |
BASE_PATH = ( | |
f"s3://my-bucket/folder/folder" | |
) | |
print(f"Base path: {BASE_PATH}") | |
for model in models: | |
latest_model = sorted(filesystem.ls(f"{BASE_PATH}/{model}"))[-1] | |
print(f"Latest model: {latest_model}") | |
local_dir = f"./models/{model}" | |
os.makedirs(local_dir, exist_ok=True) | |
filesystem.download( | |
f"{latest_model}/output/model.tar.gz", f"{local_dir}/model.tar.gz" | |
) | |
tar = tarfile.open(f"{local_dir}/model.tar.gz") | |
tar.extractall(local_dir) | |
tar.close() | |
output_path = f"s3://my-bucket/folder/output" | |
bash_command = f"aws s3 sync ./models {output_path}" | |
process = subprocess.Popen(bash_command.split(), stdout=subprocess.PIPE) | |
output, error = process.communicate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment