Skip to content

Instantly share code, notes, and snippets.

@danielchalef
Created August 7, 2018 20:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielchalef/cd9f9cbf4548cafdb7591acbf839e79a to your computer and use it in GitHub Desktop.
Save danielchalef/cd9f9cbf4548cafdb7591acbf839e79a to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.types._
import com.amazonaws.services.sagemaker.sparksdk.IAMRole
import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator
import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer
val iam_role = "arn:aws:iam::XXXXX"
val container = "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1"
val spark = SparkSession.builder.getOrCreate
val corpus = spark.read.parquet("s3a://xxxxx.parquet")
val NTMResponseSchema = StructType(Array(StructField("topic_weights", VectorType)))
val NTMProtobufKeys = Some(Seq("topic_weights"))
val ntmEstimator = new SageMakerEstimator(
trainingImage = container,
modelImage = container,
requestRowSerializer = new ProtobufRequestRowSerializer(),
responseRowDeserializer = new ProtobufResponseRowDeserializer(schema = NTMResponseSchema,
protobufKeys = NTMProtobufKeys),
hyperParameters = Map("num_topics" -> "15", "feature_dim" -> "208194", "epochs" -> "50"),
sagemakerRole = IAMRole(iam_role),
trainingInstanceType = "ml.p3.8xlarge",
trainingInstanceCount = 2,
endpointInstanceType = "ml.c4.xlarge",
endpointInitialInstanceCount = 1,
trainingSparkDataFormat = "sagemaker")
val ntmModel = ntmEstimator.fit(corpus)
val results = ntmModel.transform(corpus)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment