Skip to content

Instantly share code, notes, and snippets.

@tuliocasagrande
Created September 6, 2020 13:36
Show Gist options
  • Save tuliocasagrande/55611d1deb278cfbf09354b1d85d411c to your computer and use it in GitHub Desktop.
Save tuliocasagrande/55611d1deb278cfbf09354b1d85d411c to your computer and use it in GitHub Desktop.
import json
import os
import boto3
CLIENT = boto3.client('sagemaker')
SAGEMAKER_ROLE_ARN = os.environ['SAGEMAKER_ROLE_ARN']
class ResourcePending(Exception):
pass
class ResourceFailed(Exception):
pass
def _check_job_status(response):
# Valid Values: InProgress | Completed | Failed | Stopping | Stopped
if response['ProcessingJobStatus'] in {'InProgress', 'Stopping'}:
raise ResourcePending
elif response['ProcessingJobStatus'] in {'Failed', 'Stopped'}:
raise ResourceFailed(response.get('FailureReason', ''))
def lambda_handler(event, context):
print('New event:', event)
job_name = event['ProcessingJobName']
try:
response = CLIENT.describe_processing_job(ProcessingJobName=job_name)
except CLIENT.exceptions.ClientError:
print('Creating new processing job:', job_name)
_create_processing_job(event)
response = CLIENT.describe_processing_job(ProcessingJobName=job_name)
_check_job_status(response)
return json.dumps(response, default=str)
def _create_processing_job(event):
job_name = event['ProcessingJobName']
image_uri = event['ImageUri']
entrypoint = event['Entrypoint']
inputs_config = event['InputsConfig']
outputs_config = event['OutputsConfig']
arguments = event.get('Arguments')
instance_type = event.get('InstanceType', 'ml.m5.2xlarge')
instance_count = event.get('InstanceCount', 1)
volume_size_in_gb = event.get('VolumeSizeInGB', 30)
# Creating inputs
inputs = [_make_input(input_) for input_ in inputs_config.values()]
# Creating outputs
outputs = [_make_output(output) for output in outputs_config.values()]
# Create Processing Job
process_request = {
'ProcessingJobName': job_name,
'ProcessingResources': {
'ClusterConfig': {
'InstanceType': instance_type,
'InstanceCount': instance_count,
'VolumeSizeInGB': volume_size_in_gb
}
},
'AppSpecification': {
'ImageUri': image_uri,
'ContainerEntrypoint': entrypoint
},
'RoleArn': SAGEMAKER_ROLE_ARN,
'ProcessingInputs': inputs,
'ProcessingOutputConfig': {
'Outputs': outputs
}
}
if arguments:
process_request['AppSpecification']['ContainerArguments'] = arguments
CLIENT.create_processing_job(**process_request)
def _make_input(input_config):
return {
'InputName': input_config['InputName'],
'S3Input': {
'S3Uri': input_config['S3Uri'],
'LocalPath': '/opt/ml/processing/' + input_config['InputName'],
'S3DataType': 'S3Prefix',
'S3InputMode': 'File'
}
}
def _make_output(output_config):
return {
'OutputName': output_config['OutputName'],
'S3Output': {
'S3Uri': output_config['S3Uri'],
'LocalPath': '/opt/ml/processing/' + output_config['OutputName'],
'S3UploadMode': 'EndOfJob'
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment