Skip to content

Instantly share code, notes, and snippets.

@alexkon
Created June 13, 2018 16:41
Show Gist options
  • Save alexkon/90d602b17404db2c2857497754bb8d6d to your computer and use it in GitHub Desktop.
Save alexkon/90d602b17404db2c2857497754bb8d6d to your computer and use it in GitHub Desktop.
DataFrameSuite allows you to check if two DataFrames are equal. You can assert the DataFrames equality using method assertDataFrameEquals. When DataFrames contains doubles or Spark Mllib Vector, you can assert that the DataFrames approximately equal using method assertDataFrameApproximateEquals
import breeze.numerics.abs
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Column, DataFrame, Row}
/**
* Originally created by Umberto on 06/02/2017 (https://gist.github.com/umbertogriffo/112a02848d8be269f23757c9656df908). Added minor fix by alexkon.
*/
object DataFrameSuite {
/**
* Compares if two [[DataFrame]]s are equal.
* This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.
* 1. Check two schemas are equal
* 2. Check the number of rows are equal
* 3. Check there is no unequal rows
*
* @param a DataFrame
* @param b DataFrame
* @param isRelaxed Boolean
* @return
*/
def assertDataFrameEquals(a: DataFrame, b: DataFrame, isRelaxed: Boolean): Boolean = {
try {
a.rdd.cache
b.rdd.cache
// 1. Check the equality of two schemas
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) {
return false
}
// 2. Check the number of rows in two dfs
if (a.count() != b.count()) {
return false
}
// 3. Check there is no unequal rows
val aColumns: Array[String] = a.columns
val bColumns: Array[String] = b.columns
// To correctly handles cases where the DataFrames may have columns in different orders
scala.util.Sorting.quickSort(aColumns)
scala.util.Sorting.quickSort(bColumns)
val aSeq: Seq[Column] = aColumns.map(col(_))
val bSeq: Seq[Column] = bColumns.map(col(_))
var a_prime: DataFrame = null
var b_prime: DataFrame = null
if (isRelaxed) {
a_prime = a
// a_prime.show()
b_prime = b
// a_prime.show()
}
else {
// To correctly handles cases where the DataFrames may have duplicate rows and/or rows in different orders
a_prime = a.sort(aSeq: _*).groupBy(aSeq: _*).count()
// a_prime.show()
b_prime = b.sort(aSeq: _*).groupBy(bSeq: _*).count()
// a_prime.show()
}
val c1: Long = a_prime.except(b_prime).count()
val c2: Long = b_prime.except(a_prime).count()
if (c1 != c2 || c1 != 0 || c2 != 0) {
return false
}
} finally {
a.rdd.unpersist()
b.rdd.unpersist()
}
true
}
/**
* Compares if two [[DataFrame]]s containing double are equal.
* 1. Check two schemas are equal
* 2. Check the number of rows are equal
* 3. Check there is no unequal rows
*
* @param tol max acceptable tolerance, should be less than 1.
*/
def assertDataFrameApproximateEquals(a: DataFrame, b: DataFrame, tol: Double): Boolean = {
try {
a.rdd.cache
b.rdd.cache
// 1. Check the equality of two schemas
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) {
return false
}
// 2. Check the number of rows in two dfs
if (a.count() != b.count()) {
return false
}
// 3. Check there is no unequal rows
val aIndexValue = zipWithIndex(a.rdd)
val bIndexValue = zipWithIndex(b.rdd)
val unequalRDD = aIndexValue.join(bIndexValue).filter { case (idx, (r1, r2)) =>
!DataFrameSuite.approxEquals(r1, r2, tol)
}
if (unequalRDD.take(1).length != 0) {
return false;
}
} finally {
a.rdd.unpersist()
b.rdd.unpersist()
}
true
}
def zipWithIndex[U](rdd: RDD[U]) = rdd.zipWithIndex().map { case (row, idx) => (idx, row) }
/**
* Approximate equality, based on equals from [[Row]]
*
* @param r1
* @param r2
* @param tol
* @return
*/
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
if (r1.length != r2.length) {
return false
} else {
var idx = 0
val length = r1.length
while (idx < length) {
if (r1.isNullAt(idx) != r2.isNullAt(idx))
return false
if (!r1.isNullAt(idx)) {
val o1 = r1.get(idx)
val o2 = r2.get(idx)
o1 match {
case b1: Array[Byte] =>
if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) return false
case f1: Float =>
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(o2.asInstanceOf[Float])) return false
if (abs(f1 - o2.asInstanceOf[Float]) > tol) return false
case d1: Double =>
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(o2.asInstanceOf[Double])) return false
if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false
case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false
case d1: org.apache.spark.ml.linalg.Vector =>
val arr1: Array[Double] = d1.toArray
val arr2: Array[Double] = o2.asInstanceOf[org.apache.spark.ml.linalg.Vector].toArray
if (arr1.length != arr2.length) return false
for (i <- 0 to (arr1.length - 1)) {
if (abs(arr1(i) - arr2(i)) > tol) return false
}
case _ =>
if (o1 != o2) return false
}
}
idx += 1
}
}
true
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment