Skip to content

Instantly share code, notes, and snippets.

@kerinin
Last active April 18, 2019 17:45
Show Gist options
  • Save kerinin/fe4aba38262efc9aebcee3a61a82a49c to your computer and use it in GitHub Desktop.
Save kerinin/fe4aba38262efc9aebcee3a61a82a49c to your computer and use it in GitHub Desktop.
// Source data, generated by some preprocessing pipeline or read from Kafka
var examples: DataStream[(Input,Target)] = null
// Algorithms to train against. An algorithm defines most of the values passed to SageMaker
// when creating training jobs and models. Algorithms can be defined statically or read from
// a live stream (ie Kafka). Algorithms have an associate "id" that can be used to train multiple
// algorithms against a single dataset.
var algorithms: DataStream[AlgorithmEvent] = null
// The `split_examples` uses a `Splitter` to partition an input stream into Training & Test datasets.
var (training, testing) = examples.split_examples(new DefaultSplitter[(Input,Target)]())
// Training examples are accumulated into `TrainingBatch` records
// Training batches describe a set of SageMaker training job channels and makes no assumptions about
// the details of batching, aggregation, sampling, etc. Batches have an associated "key" that can be
// used to train multiple datasets against a single model.
val training_batches: DataStream[TrainingBatch[KEY]] = training
.process(new CustomTrainingAggregation())
// Trains each batch against each algorithm.
// The most recent version of each algorithm is trained against the most recent training batch for
// each batch key. When new batches are received they are trained against the most recent version of each
// algorithm. When new algorithms are received they are trained against the most recent batch for
// each batch key.
val sagemaker_models: DataStream[SagemakerModel[KEY]] = training_batches.train(
algorithm_events = algorithms,
typeinfo = TypeInformation.of(new TypeHint[TrainingBatch[KEY]]() {}) // unfortunate boilerplate to handle generic keys
)
// Testing examples are accumulated into `TestingBatch` records
// Testing batches describe the inputs to a SageMaker transform job and a "target" value.
// As with training batches, no assumptions are made about batching, aggregation, sampling, etc.
// Testing batches use the same keying mechanism as training batches.
val testing_batches: DataStream[TestingBatch[KEY,TARGET]] = testing
.process(new CustomTestingAggregation())
// Evaluates the performance of each testing batch / algorithm pair.
// `validate` uses a `Validator` to evaluate the results of a SageMaker transform job and produce a performance
// element. The performance definition is generic and must be provided. The `Validator` interface is initialized
// with a description of the training job and it's target values (as defined in the testing batch), then passed
// each SageMaker transform result file before producing the performance value.
val performance: DataStream[PERF] = testing_batches.connect(sagemaker_models)
.validate(
validator = new CustomValidator(),
typeinfo = TypeInformation.of(new TypeHint[TARGET]() {}) // unfortunate boilerplate to handle generic targets
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment