Created
March 22, 2022 13:18
-
-
Save sofianhamiti/0855ed3d4e525472be5ce1a754cbf22d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example workflow pipeline script for abalone pipeline. | |
. -RegisterModel | |
. | |
Process-> Train -> Evaluate -> Condition . | |
. | |
. -(stop) | |
Implements a get_pipeline(**kwargs) method. | |
""" | |
import os | |
import boto3 | |
import sagemaker | |
import sagemaker.session | |
from sagemaker.estimator import Estimator | |
from sagemaker.inputs import TrainingInput | |
from sagemaker.model_metrics import ( | |
MetricsSource, | |
ModelMetrics, | |
) | |
from sagemaker.processing import ( | |
ProcessingInput, | |
ProcessingOutput, | |
ScriptProcessor, | |
) | |
from sagemaker.sklearn.processing import SKLearnProcessor | |
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo | |
from sagemaker.workflow.condition_step import ( | |
ConditionStep, | |
JsonGet, | |
) | |
from sagemaker.workflow.parameters import ( | |
ParameterInteger, | |
ParameterString, | |
) | |
from sagemaker.workflow.pipeline import Pipeline | |
from sagemaker.workflow.properties import PropertyFile | |
from sagemaker.workflow.steps import ( | |
ProcessingStep, | |
TrainingStep, | |
) | |
from sagemaker.workflow.step_collections import RegisterModel | |
from sagemaker.workflow.retry import ( | |
StepRetryPolicy, | |
StepExceptionTypeEnum, | |
SageMakerJobStepRetryPolicy, | |
SageMakerJobExceptionTypeEnum | |
) | |
step_retry_policy = StepRetryPolicy( | |
exception_types=[ | |
StepExceptionTypeEnum.SERVICE_FAULT, | |
StepExceptionTypeEnum.THROTTLING, | |
], | |
backoff_rate=2.0, | |
interval_seconds=30, | |
expire_after_mins=240 # keep trying for for 4 hours max | |
) | |
job_retry_policy = SageMakerJobStepRetryPolicy( | |
exception_types=[SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT], | |
failure_reason_types=[ | |
SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, | |
SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, | |
], | |
backoff_rate=2.0, | |
interval_seconds=30, | |
expire_after_mins=240 # keep trying for for 4 hours max | |
) | |
BASE_DIR = os.path.dirname(os.path.realpath(__file__)) | |
def get_sagemaker_client(region): | |
"""Gets the sagemaker client. | |
Args: | |
region: the aws region to start the session | |
default_bucket: the bucket to use for storing the artifacts | |
Returns: | |
`sagemaker.session.Session instance | |
""" | |
boto_session = boto3.Session(region_name=region) | |
sagemaker_client = boto_session.client("sagemaker") | |
return sagemaker_client | |
def get_session(region, default_bucket): | |
"""Gets the sagemaker session based on the region. | |
Args: | |
region: the aws region to start the session | |
default_bucket: the bucket to use for storing the artifacts | |
Returns: | |
`sagemaker.session.Session instance | |
""" | |
boto_session = boto3.Session(region_name=region) | |
sagemaker_client = boto_session.client("sagemaker") | |
runtime_client = boto_session.client("sagemaker-runtime") | |
return sagemaker.session.Session( | |
boto_session=boto_session, | |
sagemaker_client=sagemaker_client, | |
sagemaker_runtime_client=runtime_client, | |
default_bucket=default_bucket, | |
) | |
def get_pipeline_custom_tags(new_tags, region, sagemaker_project_arn=None): | |
try: | |
sm_client = get_sagemaker_client(region) | |
response = sm_client.list_tags( | |
ResourceArn=sagemaker_project_arn) | |
project_tags = response["Tags"] | |
for project_tag in project_tags: | |
new_tags.append(project_tag) | |
except Exception as e: | |
print(f"Error getting project tags: {e}") | |
return new_tags | |
def get_pipeline( | |
region, | |
sagemaker_project_arn=None, | |
role=None, | |
default_bucket=None, | |
model_package_group_name="AbalonePackageGroup", | |
pipeline_name="AbalonePipeline", | |
base_job_prefix="Abalone", | |
): | |
"""Gets a SageMaker ML Pipeline instance working with on abalone data. | |
Args: | |
region: AWS region to create and run the pipeline. | |
role: IAM role to create and run steps and pipeline. | |
default_bucket: the bucket to use for storing the artifacts | |
Returns: | |
an instance of a pipeline | |
""" | |
sagemaker_session = get_session(region, default_bucket) | |
if role is None: | |
role = sagemaker.session.get_execution_role(sagemaker_session) | |
# parameters for pipeline execution | |
processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1) | |
processing_instance_type = ParameterString( | |
name="ProcessingInstanceType", default_value="ml.m5.xlarge" | |
) | |
training_instance_type = ParameterString( | |
name="TrainingInstanceType", default_value="ml.c5n.2xlarge" | |
) | |
model_approval_status = ParameterString( | |
name="ModelApprovalStatus", default_value="PendingManualApproval" | |
) | |
input_data = ParameterString( | |
name="InputDataUrl", | |
default_value=f"s3://sagemaker-servicecatalog-seedcode-{region}/dataset/abalone-dataset.csv", | |
) | |
# processing step for feature engineering | |
sklearn_processor = SKLearnProcessor( | |
framework_version="0.23-1", | |
instance_type=processing_instance_type, | |
instance_count=processing_instance_count, | |
base_job_name=f"{base_job_prefix}/sklearn-abalone-preprocess", | |
sagemaker_session=sagemaker_session, | |
role=role, | |
) | |
step_process = ProcessingStep( | |
name="PreprocessAbaloneData", | |
processor=sklearn_processor, | |
outputs=[ | |
ProcessingOutput(output_name="train", source="/opt/ml/processing/train"), | |
ProcessingOutput(output_name="validation", source="/opt/ml/processing/validation"), | |
ProcessingOutput(output_name="test", source="/opt/ml/processing/test"), | |
], | |
code=os.path.join(BASE_DIR, "preprocess.py"), | |
job_arguments=["--input-data", input_data], | |
retry_policies=[ | |
step_retry_policy, | |
job_retry_policy | |
] | |
) | |
# training step for generating model artifacts | |
model_path = f"s3://{sagemaker_session.default_bucket()}/{base_job_prefix}/AbaloneTrain" | |
image_uri = sagemaker.image_uris.retrieve( | |
framework="xgboost", | |
region=region, | |
version="1.0-1", | |
py_version="py3", | |
instance_type=training_instance_type, | |
) | |
xgb_train = Estimator( | |
image_uri=image_uri, | |
instance_type=training_instance_type, | |
instance_count=1, | |
output_path=model_path, | |
base_job_name=f"{base_job_prefix}/abalone-train", | |
sagemaker_session=sagemaker_session, | |
role=role, | |
) | |
xgb_train.set_hyperparameters( | |
objective="reg:linear", | |
num_round=50, | |
max_depth=5, | |
eta=0.2, | |
gamma=4, | |
min_child_weight=6, | |
subsample=0.7, | |
silent=0, | |
) | |
step_train = TrainingStep( | |
name="TrainAbaloneModel", | |
estimator=xgb_train, | |
inputs={ | |
"train": TrainingInput( | |
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ | |
"train" | |
].S3Output.S3Uri, | |
content_type="text/csv", | |
), | |
"validation": TrainingInput( | |
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ | |
"validation" | |
].S3Output.S3Uri, | |
content_type="text/csv", | |
), | |
}, | |
retry_policies=[ | |
step_retry_policy, | |
job_retry_policy | |
] | |
) | |
# processing step for evaluation | |
script_eval = ScriptProcessor( | |
image_uri=image_uri, | |
command=["python3"], | |
instance_type=processing_instance_type, | |
instance_count=1, | |
base_job_name=f"{base_job_prefix}/script-abalone-eval", | |
sagemaker_session=sagemaker_session, | |
role=role, | |
) | |
evaluation_report = PropertyFile( | |
name="AbaloneEvaluationReport", | |
output_name="evaluation", | |
path="evaluation.json", | |
) | |
step_eval = ProcessingStep( | |
name="EvaluateAbaloneModel", | |
processor=script_eval, | |
inputs=[ | |
ProcessingInput( | |
source=step_train.properties.ModelArtifacts.S3ModelArtifacts, | |
destination="/opt/ml/processing/model", | |
), | |
ProcessingInput( | |
source=step_process.properties.ProcessingOutputConfig.Outputs[ | |
"test" | |
].S3Output.S3Uri, | |
destination="/opt/ml/processing/test", | |
), | |
], | |
outputs=[ | |
ProcessingOutput(output_name="evaluation", source="/opt/ml/processing/evaluation"), | |
], | |
code=os.path.join(BASE_DIR, "evaluate.py"), | |
property_files=[evaluation_report], | |
) | |
# register model step that will be conditionally executed | |
model_metrics = ModelMetrics( | |
model_statistics=MetricsSource( | |
s3_uri="{}/evaluation.json".format( | |
step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] | |
), | |
content_type="application/json" | |
) | |
) | |
step_register = RegisterModel( | |
name="RegisterAbaloneModel", | |
estimator=xgb_train, | |
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, | |
content_types=["text/csv"], | |
response_types=["text/csv"], | |
inference_instances=["ml.t2.medium", "ml.m5.large"], | |
transform_instances=["ml.m5.large"], | |
model_package_group_name=model_package_group_name, | |
approval_status=model_approval_status, | |
model_metrics=model_metrics, | |
) | |
# condition step for evaluating model quality and branching execution | |
cond_lte = ConditionLessThanOrEqualTo( | |
left=JsonGet( | |
step=step_eval, | |
property_file=evaluation_report, | |
json_path="regression_metrics.mse.value" | |
), | |
right=6.0, | |
) | |
step_cond = ConditionStep( | |
name="CheckMSEAbaloneEvaluation", | |
conditions=[cond_lte], | |
if_steps=[step_register], | |
else_steps=[], | |
) | |
# pipeline instance | |
pipeline = Pipeline( | |
name=pipeline_name, | |
parameters=[ | |
processing_instance_type, | |
processing_instance_count, | |
training_instance_type, | |
model_approval_status, | |
input_data, | |
], | |
steps=[step_process, step_train, step_eval, step_cond], | |
sagemaker_session=sagemaker_session, | |
) | |
return pipeline |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment