Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaravind/158418d61bbfc71c117e8aaf1dca28ea to your computer and use it in GitHub Desktop.
Save yaravind/158418d61bbfc71c117e8aaf1dca28ea to your computer and use it in GitHub Desktop.
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
class AbsTransformer extends Transformer {
@Override
Dataset<Row> transform(Dataset<?> dataset) {
(dataset as Dataset<GenericRowWithSchema>).map({ value ->
def values = (value.values().first() as DenseVector).values()
for (int i = 0; i < values.length; i++) {
values[i] = Math.abs(values[i])
}
value as Row
}, RowEncoder.apply(dataset.schema()))
}
@Override
StructType transformSchema(StructType structType) {
structType
}
@Override
Transformer copy(ParamMap paramMap) {
this
}
@Override
String uid() {
UUID.randomUUID().toString()
}
public static void main(String[] args) {
def spark = SparkSession.builder().master('local[*]').getOrCreate()
def trainingData = spark.createDataFrame([
new LabeledPoint(1.0d, Vectors.dense([18.0d, -25.0d] as double[])),
new LabeledPoint(1.0d, Vectors.dense([-15.0d, 20.0d] as double[])),
new LabeledPoint(1.0d, Vectors.dense([10.0d, 27.0d] as double[])),
new LabeledPoint(0.0d, Vectors.dense([0.0d, 5.0d] as double[])),
new LabeledPoint(0.0d, Vectors.dense([0.0d, -6.0d] as double[])),
new LabeledPoint(0.0d, Vectors.dense([0.0d, 3.0d] as double[]))
], LabeledPoint)
def stages = [new AbsTransformer(), new LogisticRegression()] as PipelineStage[]
def pipeLine = new Pipeline().setStages(stages)
def model = pipeLine.fit(trainingData)
def data = spark.createDataFrame([new LabeledPoint(1.0d, Vectors.dense([-20.0d, -20.0d] as double[]))], LabeledPoint)
def result = model.transform(data)
result.show()
def confidence = result.collectAsList().first().get(3)
println confidence
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment