Skip to content

Instantly share code, notes, and snippets.

@alaiacano
Created February 8, 2014 21:36
Show Gist options
  • Save alaiacano/8890718 to your computer and use it in GitHub Desktop.
Save alaiacano/8890718 to your computer and use it in GitHub Desktop.
class NBTestJob(args: Args) extends Job(args) {
val output = args("output")
val iris = Tsv("iris.tsv", ('id, 'class, 'sepalLength, 'sepalWidth, 'petalLength, 'petalWidth))
.read
// The model expects data in "melted" form.
// One Field for the variable names, one for the coresponding values.
val irisMelted = iris
.unpivot(('sepalLength, 'sepalWidth, 'petalLength, 'petalWidth) -> ('feature, 'score))
// split into train/test sets
val irisTrain = irisMelted.filter('id){id: Int => (id % 3) != 0}.discard('id)
val irisTest = irisMelted
.filter('id){id: Int => (id % 3) ==0}
.discard('class)
// Build the model from the training set.
val model = GaussianNB.train(irisTrain)
// Classify the remaining values and write to output.
val predictions = GaussianNB.classify(irisTest, model)
.write(Tsv(output))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment