Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save satylogin/559f7efa41ebb6728ea1da154fa10a81 to your computer and use it in GitHub Desktop.
Save satylogin/559f7efa41ebb6728ea1da154fa10a81 to your computer and use it in GitHub Desktop.
import com.amazonaws.services.sagemaker.sparksdk.SageMakerModel
import java.util.UUID
val REGION="us-east-1"
val TFS_VERSION="2.1.0"
val PROCESSOR_TYPE="gpu"
val IMAGE=s"763104351884.dkr.ecr.$REGION.amazonaws.com/tensorflow-inference:$TFS_VERSION-$PROCESSOR_TYPE"
val MODEL_PATH = "s3://bucket/path/model.tar.gz"
val id = UUID.randomUUID().toString
val model = SageMakerModel.fromModelS3Path(
endpointInstanceType = "ml.p3.2xlarge",
endpointInitialInstanceCount = 1,
modelImage = IMAGE,
modelPath = MODEL_PATH,
uid = s"saty-nn-test-${id}",
requestRowSerializer = new TFRequestRowSerializer(featureLength=10),
responseRowDeserializer = new TFResponseRowDeserializer(),
modelExecutionRoleARN = "aws role for sagemaker"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment