Skip to content

Instantly share code, notes, and snippets.

@umbertogriffo
Last active February 12, 2020 06:13
Show Gist options
  • Save umbertogriffo/112a02848d8be269f23757c9656df908 to your computer and use it in GitHub Desktop.
Save umbertogriffo/112a02848d8be269f23757c9656df908 to your computer and use it in GitHub Desktop.
DataFrameSuite allows you to check if two DataFrames are equal. You can assert the DataFrames equality using method assertDataFrameEquals. When DataFrames contains doubles or Spark Mllib Vector, you can assert that the DataFrames approximately equal using method assertDataFrameApproximateEquals
package test.com.idlike.junit.df
import breeze.numerics.abs
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Column, DataFrame, Row}
/**
* Created by Umberto on 06/02/2017.
*/
object DataFrameSuite {
/**
* Compares if two [[DataFrame]]s are equal.
* This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.
* 1. Check two schemas are equal
* 2. Check the number of rows are equal
* 3. Check there is no unequal rows
*
* @param a DataFrame
* @param b DataFrame
* @param isRelaxed Boolean
* @return
*/
def assertDataFrameEquals(a: DataFrame, b: DataFrame, isRelaxed: Boolean): Boolean = {
try {
a.rdd.cache
b.rdd.cache
// 1. Check the equality of two schemas
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) {
return false
}
// 2. Check the number of rows in two dfs
if (a.count() != b.count()) {
return false
}
// 3. Check there is no unequal rows
val aColumns: Array[String] = a.columns
val bColumns: Array[String] = b.columns
// To correctly handles cases where the DataFrames may have columns in different orders
scala.util.Sorting.quickSort(aColumns)
scala.util.Sorting.quickSort(bColumns)
val aSeq: Seq[Column] = aColumns.map(col(_))
val bSeq: Seq[Column] = bColumns.map(col(_))
var a_prime: DataFrame = null
var b_prime: DataFrame = null
if (isRelaxed) {
a_prime = a
// a_prime.show()
b_prime = b
// a_prime.show()
}
else {
// To correctly handles cases where the DataFrames may have duplicate rows and/or rows in different orders
a_prime = a.sort(aSeq: _*).groupBy(aSeq: _*).count()
// a_prime.show()
b_prime = b.sort(aSeq: _*).groupBy(bSeq: _*).count()
// a_prime.show()
}
val c1: Long = a_prime.except(b_prime).count()
val c2: Long = b_prime.except(a_prime).count()
if (c1 != c2 || c1 != 0 || c2 != 0) {
return false
}
} finally {
a.rdd.unpersist()
b.rdd.unpersist()
}
true
}
/**
* Compares if two [[DataFrame]]s containing double are equal.
* 1. Check two schemas are equal
* 2. Check the number of rows are equal
* 3. Check there is no unequal rows
*
* @param tol max acceptable tolerance, should be less than 1.
*/
def assertDataFrameApproximateEquals(a: DataFrame, b: DataFrame, tol: Double): Boolean = {
try {
a.rdd.cache
b.rdd.cache
// 1. Check the equality of two schemas
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) {
return false
}
// 2. Check the number of rows in two dfs
if (a.count() != b.count()) {
return false
}
// 3. Check there is no unequal rows
val aIndexValue = zipWithIndex(a.rdd)
val bIndexValue = zipWithIndex(b.rdd)
val unequalRDD = aIndexValue.join(bIndexValue).filter { case (idx, (r1, r2)) =>
!DataFrameSuite.approxEquals(r1, r2, tol)
}
if (unequalRDD.take(1).length != 0) {
return false;
}
} finally {
a.rdd.unpersist()
b.rdd.unpersist()
}
true
}
def zipWithIndex[U](rdd: RDD[U]) = rdd.zipWithIndex().map { case (row, idx) => (idx, row) }
/**
* Approximate equality, based on equals from [[Row]]
*
* @param r1
* @param r2
* @param tol
* @return
*/
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
if (r1.length != r2.length) {
return false
} else {
var idx = 0
val length = r1.length
while (idx < length) {
if (r1.isNullAt(idx) != r2.isNullAt(idx))
return false
if (!r1.isNullAt(idx)) {
val o1 = r1.get(idx)
val o2 = r2.get(idx)
o1 match {
case b1: Array[Byte] =>
if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) return false
case f1: Float =>
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(o2.asInstanceOf[Float])) return false
if (abs(f1 - o2.asInstanceOf[Float]) > tol) return false
case d1: Double =>
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(o2.asInstanceOf[Double])) return false
if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false
case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false
case d1: org.apache.spark.ml.linalg.Vector =>
val arr1: Array[Double] = d1.toArray
val arr2: Array[Double] = o2.asInstanceOf[org.apache.spark.ml.linalg.Vector].toArray
if (arr1.length != arr2.length) return false
for (i <- 0 to (arr1.length - 1)) {
if (abs(arr1(i) - arr2(i)) > tol) return false
}
case _ =>
if (o1 != o2) return false
}
}
idx += 1
}
}
true
}
}
@Arnold1
Copy link

Arnold1 commented Dec 14, 2017

what is "isRelaxed"?

@alexkon
Copy link

alexkon commented Jun 13, 2018

Line 49: "scala.util.Sorting.quickSort(aColumns)" -> "scala.util.Sorting.quickSort(bColumns)"

@giftig
Copy link

giftig commented Sep 24, 2018

Here's a really good book to help you write effective scala: https://www.amazon.co.uk/Scala-Impatient-Cay-S-Horstmann/dp/0134540565/

@umbertogriffo
Copy link
Author

Line 49: "scala.util.Sorting.quickSort(aColumns)" -> "scala.util.Sorting.quickSort(bColumns)"

@alexkon Thank you very much! I've update It

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment