Last active
November 20, 2019 14:11
-
-
Save tamanobi/6d39a73db8e62975694ff6e7485a2f63 to your computer and use it in GitHub Desktop.
AWS SageMakerをローカルで使うときに便利なスクリプト
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sagemaker.estimator import Estimator | |
from sagemaker.predictor import RealTimePredictor | |
def get_predictor(training_job_name: str) -> RealTimePredictor: | |
""" | |
ジョブ名から RealTimePredictor を取得して返す | |
すでに訓練ジョブ名と同名のエンドポイントがある場合には、エンドポイントから RealTimePredictor を返す | |
Args: | |
training_job_name (str): ジョブ名(あるいはエンドポイント名) | |
Returns: | |
RealtimePredictor: 予測エンドポイント | |
""" | |
try: | |
estimator = Estimator.attach(training_job_name=training_job_name) | |
predictor = estimator.deploy( | |
initial_instance_count=1, instance_type="ml.c4.xlarge" | |
) | |
except botocore.exceptions.ClientError as e: | |
error = e.response["Error"] | |
if error["Code"] == "ValidationException" and error["Message"].startswith( | |
"Cannot create already existing endpoint" | |
): | |
predictor = RealTimePredictor(training_job_name) | |
else: | |
raise e | |
return predictor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment