Skip to content

Instantly share code, notes, and snippets.

@hhbyyh
Last active February 24, 2017 18:19
Show Gist options
  • Save hhbyyh/889b88ae2176d1263fdc9dd3e29d1c2d to your computer and use it in GitHub Desktop.
Save hhbyyh/889b88ae2176d1263fdc9dd3e29d1c2d to your computer and use it in GitHub Desktop.
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.fpm
import javassist.bytecode.stackmap.TypeTag
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
* Common params for FPGrowth and FPGrowthModel
*/
private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
/**
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
* more than (minSupport * size-of-the-dataset) times will be output
* Default: 0.3
* @group param
*/
@Since("2.2.0")
val minSupport: DoubleParam = new DoubleParam(this, "minSupport",
"the minimal support level of a frequent pattern",
ParamValidators.inRange(0.0, 1.0))
setDefault(minSupport -> 0.3)
/** @group getParam */
@Since("2.2.0")
def getMinSupport: Double = $(minSupport)
/**
* Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and
* partition number of the input dataset is used.
* @group expertParam
*/
@Since("2.2.0")
val numPartitions: IntParam = new IntParam(this, "numPartitions",
"Number of partitions used by parallel FP-growth", ParamValidators.gtEq[Int](1))
/** @group expertGetParam */
@Since("2.2.0")
def getNumPartitions: Int = $(numPartitions)
/**
* Minimal confidence for generating Association Rule.
* Note that minConfidence has no effect during fitting.
* Default: 0.8
* @group param
*/
@Since("2.2.0")
val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence",
"minimal confidence for generating Association Rule",
ParamValidators.inRange(0.0, 1.0))
setDefault(minConfidence -> 0.8)
/** @group getParam */
@Since("2.2.0")
def getMinConfidence: Double = $(minConfidence)
/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
@Since("2.2.0")
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputType = schema($(featuresCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
}
}
/**
* :: Experimental ::
* A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in
* <a href="http://dx.doi.org/10.1145/1454008.1454027">Li et al., PFP: Parallel FP-Growth for Query
* Recommendation</a>. PFP distributes computation in such a way that each worker executes an
* independent group of mining tasks. The FP-Growth algorithm is described in
* <a href="http://dx.doi.org/10.1145/335191.335372">Han et al., Mining frequent patterns without
* candidate generation</a>. Note null values in the feature column are ignored during fit().
*
* @see <a href="http://en.wikipedia.org/wiki/Association_rule_learning">
* Association rule learning (Wikipedia)</a>
*/
@Since("2.2.0")
@Experimental
class FPGrowth @Since("2.2.0") (
@Since("2.2.0") override val uid: String)
extends Estimator[FPGrowthModel] with FPGrowthParams with DefaultParamsWritable {
@Since("2.2.0")
def this() = this(Identifiable.randomUID("fpgrowth"))
/** @group setParam */
@Since("2.2.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)
/** @group expertSetParam */
@Since("2.2.0")
def setNumPartitions(value: Int): this.type = set(numPartitions, value)
/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
@Since("2.2.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@Since("2.2.0")
override def fit(dataset: Dataset[_]): FPGrowthModel = {
transformSchema(dataset.schema, logging = true)
genericFit(dataset)
}
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
val data = dataset.select($(featuresCol))
val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
if (isSet(numPartitions)) {
mllibFP.setNumPartitions($(numPartitions))
}
val parentModel = mllibFP.run(items)
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
val schema = StructType(Seq(
StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
}
@Since("2.2.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
@Since("2.2.0")
override def copy(extra: ParamMap): FPGrowth = defaultCopy(extra)
}
@Since("2.2.0")
object FPGrowth extends DefaultParamsReadable[FPGrowth] {
@Since("2.2.0")
override def load(path: String): FPGrowth = super.load(path)
}
/**
* :: Experimental ::
* Model fitted by FPGrowth.
*
* @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long])
*/
@Since("2.2.0")
@Experimental
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
@transient val freqItemsets: DataFrame)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
@Since("2.2.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/**
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
*/
@Since("2.2.0")
@transient lazy val associationRules: DataFrame = {
val freqItems = freqItemsets
AssociationRules.getAssociationRulesFromFP(freqItems, "items", "freq", $(minConfidence))
}
/**
* The transform method first generates the association rules according to the frequent itemsets.
* Then for each association rule, it will examine the input items against antecedents and
* summarize the consequents as prediction. The prediction column has the same data type as the
* input column. (Array[T])
* Note that internally it uses Cartesian join and may exhaust memory for large datasets. null
* values in the feature columns are treated as empty sets.
*/
@Since("2.2.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
genericTransform(dataset)
}
private def genericTransform[T](dataset: Dataset[_]): DataFrame = {
// use index to perform the join and aggregateByKey, and keep the original order after join.
val indexToItems = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[T](0))
.zipWithIndex().map(_.swap)
val rulesRDD = associationRules.select("antecedent", "consequent").rdd
.map(r => (r.getSeq[T](0), r.getSeq[T](1)))
val indexToConsequents = indexToItems.cartesian(rulesRDD).map {
case ((id, items), (antecedent, consequent)) =>
val consequents = if (items != null) {
val itemSet = items.toSet
if (antecedent.forall(itemSet.contains)) {
consequent.filterNot(itemSet.contains)
} else {
Seq.empty
}
} else {
Seq.empty
}
// println(id)
(id, consequents)
}.aggregateByKey(new ArrayBuffer[T])((ar, seq) => ar ++= seq, (ar, seq) => ar ++= seq)
.map { case (index, cons) => (index, cons.distinct) }
println(indexToConsequents.count())
val rowAndConsequents = dataset.toDF().rdd.zipWithUniqueId().map(_.swap)
.join(indexToConsequents)//.sortByKey(ascending = true, dataset.rdd.getNumPartitions)
.map(_._2).map(t => Row.merge(t._1, Row(t._2)))
val mergedSchema = dataset.schema.add(StructField($(predictionCol),
dataset.schema($(featuresCol)).dataType, dataset.schema($(featuresCol)).nullable))
dataset.sparkSession.createDataFrame(rowAndConsequents, mergedSchema)
}
private def genericTransform2[T](dataset: Dataset[_]): DataFrame = {
val itemsRDD = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[T](0)).distinct()
val rulesRDD = associationRules.rdd.map(r => (r.getSeq[T](0), r.getSeq[T](1)))
val itemsWithConsequents = itemsRDD.cartesian(rulesRDD).map {
case ((items), (antecedent, consequent)) =>
val itemSet = items.toSet
val consequents = if (antecedent.forall(itemSet.contains(_))) consequent else Seq.empty
(items, consequents)
}.aggregateByKey(new ArrayBuffer[T])(
(ar, seq) => ar ++= seq, (ar, seq) => ar ++= seq)
.map (cols => Row(cols._1, cols._2))
val dt = dataset.schema($(featuresCol)).dataType
val fields = Array($(featuresCol), $(predictionCol))
.map(fieldName => StructField(fieldName, dt, nullable = true))
val schema = StructType(fields)
val mapping = dataset.sparkSession.createDataFrame(itemsWithConsequents, schema)
dataset.join(mapping, $(featuresCol))
}
private def genericTransform3[T](dataset: Dataset[_]): DataFrame = {
// use unique id to perform the join and aggregateByKey
val itemsRDD = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[T](0))
.distinct().zipWithUniqueId().map(_.swap).cache()
val rulesRDD = associationRules.rdd.map(r => (r.getSeq[T](0), r.getSeq[T](1)))
val itemsWithConsequents = itemsRDD.cartesian(rulesRDD).map {
case ((id, items), (antecedent, consequent)) =>
val itemSet = items.toSet
val consequents = if (antecedent.forall(itemSet.contains(_))) consequent else Seq.empty
(id, consequents)
}.aggregateByKey(new ArrayBuffer[T])(
(ar, seq) => ar ++= seq, (ar, seq) => ar ++= seq)
val mappingRDD = itemsRDD.join(itemsWithConsequents)
.map { case (id, (items, consequent)) => (items, consequent) }
.map (cols => Row(cols._1, cols._2))
val dt = dataset.schema($(featuresCol)).dataType
val fields = Array($(featuresCol), $(predictionCol))
.map(fieldName => StructField(fieldName, dt, nullable = true))
val schema = StructType(fields)
val mapping = dataset.sparkSession.createDataFrame(mappingRDD, schema)
dataset.join(mapping, $(featuresCol))
}
private def genericTransform4[T: Manifest](dataset: Dataset[_]): DataFrame = {
val rules = associationRules.rdd.map(r =>
(r.getSeq[Int](0), r.getSeq[Int](1))
).collect()
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[Int]) => brRules.value.flatMap( r =>
if (r._1.forall(items.contains(_))) r._2 else Seq.empty[Int]
).distinct)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@Since("2.2.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
@Since("2.2.0")
override def copy(extra: ParamMap): FPGrowthModel = {
val copied = new FPGrowthModel(uid, freqItemsets)
copyValues(copied, extra).setParent(this.parent)
}
@Since("2.2.0")
override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this)
}
@Since("2.2.0")
object FPGrowthModel extends MLReadable[FPGrowthModel] {
@Since("2.2.0")
override def read: MLReader[FPGrowthModel] = new FPGrowthModelReader
@Since("2.2.0")
override def load(path: String): FPGrowthModel = super.load(path)
/** [[MLWriter]] instance for [[FPGrowthModel]] */
private[FPGrowthModel]
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
}
}
private class FPGrowthModelReader extends MLReader[FPGrowthModel] {
/** Checked against metadata when loading model */
private val className = classOf[FPGrowthModel].getName
override def load(path: String): FPGrowthModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val model = new FPGrowthModel(metadata.uid, frequentItems)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
private[fpm] object AssociationRules {
/**
* Computes the association rules with confidence above minConfidence.
* @param dataset DataFrame("items", "freq") containing frequent itemset obtained from
* algorithms like [[FPGrowth]].
* @param itemsCol column name for frequent itemsets
* @param freqCol column name for frequent itemsets count
* @param minConfidence minimum confidence for the result association rules
* @return a DataFrame("antecedent", "consequent", "confidence") containing the association
* rules.
*/
def getAssociationRulesFromFP[T: ClassTag](
dataset: Dataset[_],
itemsCol: String,
freqCol: String,
minConfidence: Double): DataFrame = {
val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
val rows = new MLlibAssociationRules()
.setMinConfidence(minConfidence)
.run(freqItemSetRdd)
.map(r => Row(r.antecedent, r.consequent, r.confidence))
val dt = dataset.schema(itemsCol).dataType
val schema = StructType(Seq(
StructField("antecedent", dt, nullable = false),
StructField("consequent", dt, nullable = false),
StructField("confidence", DoubleType, nullable = false)))
val rules = dataset.sparkSession.createDataFrame(rows, schema)
rules
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment