Skip to content

Instantly share code, notes, and snippets.

@marcovivero
Created August 5, 2015 20:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcovivero/e17188228a92f52f54fc to your computer and use it in GitHub Desktop.
Save marcovivero/e17188228a92f52f54fc to your computer and use it in GitHub Desktop.
class StratifiedSplits (data : DataFrame, labelCol : String, numSplits : Int) extends Serializable {
private val labels : Seq[Double] = data.select(labelCol).distinct.map(row => row.getDouble(0)).collect
private val dataFrames : Seq[Array[DataFrame]] = labels.map(label => {
val newData = data.filter(data(labelCol) === label)
val splits : Array[Double] = (0 until numSplits).map(k => 1 / numSplits.toDouble).toArray
newData.randomSplit(splits)
})
dataFrames.foreach(e => e.foreach(df => df.persist))
val dataSplits : Seq[DataFrame] = (0 until numSplits).map(k => {
val datas : Seq[DataFrame] = dataFrames.map(e => e(k))
datas.reduce((a, b) => a.unionAll(b))
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment