Skip to content

Commit

Permalink
typelevel#787 - tolerance on map members and on vectors for cluster runs
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Apr 11, 2024
1 parent 80de4f2 commit a89542e
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
evCanBeDoubleB: CatalystCast[B, Double]
): Prop = bivariatePropTemplate(xs)(
covarSamp[A, B, X3[Int, A, B]],
org.apache.spark.sql.functions.covar_samp
org.apache.spark.sql.functions.covar_samp,
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("10"))
)

check(forAll(prop[Double, Double] _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package frameless
package functions

import org.scalacheck.Prop
import org.scalacheck.Prop.AnyOperators
import org.scalacheck.util.Pretty
import shapeless.{ Lens, OpticDefns }

/**
* Some statistical functions in Spark can result in Double, Double.NaN or Null.
Expand All @@ -14,6 +16,8 @@ import org.scalacheck.util.Pretty
*/
object DoubleBehaviourUtils {

val dp5 = BigDecimal(0.00001)

// Mapping with this function is needed because spark uses Double.NaN for some semantics in the
// correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN
private val nanHandler: Double => Option[Double] = value =>
Expand Down Expand Up @@ -41,6 +45,45 @@ object DoubleBehaviourUtils {
BigDecimal.RoundingMode.CEILING
)

import shapeless._

def tolerantCompareVectors[K, CC[X] <: Seq[X]](
v1: CC[K],
v2: CC[K],
of: BigDecimal
)(fudgers: Seq[OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]]]
): Prop = compareVectors(v1, v2)(fudgers.map(f => (f, tolerance(_, of))))

def compareVectors[K, CC[X] <: Seq[X]](
v1: CC[K],
v2: CC[K]
)(fudgers: Seq[
(OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]],
Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
BigDecimal
], Option[BigDecimal]]
)
]
): Prop =
if (v1.size != v2.size)
Prop.falsified :| {
"Expected Seq of size " + v1.size + " but got " + v2.size
}
else {
val together = v1.zip(v2)
val m =
together.map { p =>
fudgers.foldLeft(p) { (curr, nf) =>
val theLens = nf._1(lens[K])
val p = (theLens.get(curr._1), theLens.get(curr._2))
val (nl, nr) = nf._2(p)
(theLens.set(curr._1)(nl), theLens.set(curr._2)(nr))
}
}.toMap

m.keys.toVector ?= m.values.toVector
}

def compareMaps[K](
m1: Map[K, Option[BigDecimal]],
m2: Map[K, Option[BigDecimal]],
Expand Down Expand Up @@ -97,11 +140,25 @@ object DoubleBehaviourUtils {
p
}
}

import shapeless._

def tl[X](
lensf: OpticDefns.RootLens[X] => Lens[X, Option[BigDecimal]],
of: BigDecimal
): (X, X) => (X, X) =
(l: X, r: X) => {
val theLens = lensf(lens[X])
val (nl, rl) = tolerance((theLens.get(l), theLens.get(r)), of)
(theLens.set(l)(nl), theLens.set(r)(rl))
}

}

/** drop in conversion for doubles to handle serialization on cluster */
trait ToDecimal[A] {
def truncate(a: A): Option[BigDecimal]

}

object ToDecimal {
Expand Down
50 changes: 38 additions & 12 deletions dataset/src/test/scala/frameless/ops/CubeTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless
package ops

import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors }
import frameless.functions.ToDecimal
import frameless.functions.aggregate._
import org.scalacheck.Prop
Expand Down Expand Up @@ -249,10 +250,22 @@ class CubeTests extends TypedDatasetSuite {
)
.sortBy(t => (t._1, t._2, t._3))

(framelessSumBC ?= sparkSumBC)
.&&(framelessSumBCB ?= sparkSumBCB)
.&&(framelessSumBCBC ?= sparkSumBCBC)
.&&(framelessSumBCBCB ?= sparkSumBCBCB)
(tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3)))
.&&(
tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)(
Seq(l => l._3)
)
)
.&&(
tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)(
Seq(l => l._3, l => l._5)
)
)
.&&(
tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)(
Seq(l => l._3, l => l._5)
)
)
}

check(forAll(prop[String, Long, Double, Long, Double] _))
Expand All @@ -265,7 +278,7 @@ class CubeTests extends TypedDatasetSuite {
C: TypedEncoder,
D: TypedEncoder,
OutC: TypedEncoder: Numeric,
OutD: TypedEncoder: Numeric
OutD: TypedEncoder: Numeric: ToDecimal
](data: List[X4[A, B, C, D]]
)(implicit
summableC: CatalystSummable[C, OutC],
Expand All @@ -277,12 +290,15 @@ class CubeTests extends TypedDatasetSuite {
val C = dataset.col[C]('c)
val D = dataset.col[D]('d)

val toDecOpt = implicitly[ToDecimal[OutD]].truncate _

val framelessSumByAB = dataset
.cube(A, B)
.agg(sum(C), sum(D))
.collect()
.run()
.toVector
.map(row => row.copy(_4 = toDecOpt(row._4)))
.sortBy(x => (x._1, x._2))

val sparkSumByAB = dataset.dataset
Expand All @@ -295,12 +311,14 @@ class CubeTests extends TypedDatasetSuite {
Option(row.getAs[A](0)),
Option(row.getAs[B](1)),
row.getAs[OutC](2),
row.getAs[OutD](3)
toDecOpt(row.getAs[OutD](3))
)
)
.sortBy(x => (x._1, x._2))

framelessSumByAB ?= sparkSumByAB
tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)(
Seq(l => l._4)
)
}

check(forAll(prop[Byte, Int, Long, Double, Long, Double] _))
Expand Down Expand Up @@ -470,11 +488,19 @@ class CubeTests extends TypedDatasetSuite {
)
.sortBy(t => (t._2, t._1, t._3))

(framelessSumC ?= sparkSumC) &&
(framelessSumCC ?= sparkSumCC) &&
(framelessSumCCC ?= sparkSumCCC) &&
(framelessSumCCCC ?= sparkSumCCCC) &&
(framelessSumCCCCC ?= sparkSumCCCCC)
(tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) &&
(tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)(
Seq(l => l._3, l => l._4)
)) &&
(tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5)
)) &&
(tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5, l => l._6)
)) &&
(tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7)
))
}

check(forAll(prop[String, Long, Double, Double] _))
Expand Down
44 changes: 34 additions & 10 deletions dataset/src/test/scala/frameless/ops/RollupTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless
package ops

import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors }
import frameless.functions.ToDecimal
import frameless.functions.aggregate._
import org.scalacheck.Prop
Expand Down Expand Up @@ -239,10 +240,23 @@ class RollupTests extends TypedDatasetSuite {
)
.sortBy(identity)

(framelessSumBC ?= sparkSumBC)
.&&(framelessSumBCB ?= sparkSumBCB)
.&&(framelessSumBCBC ?= sparkSumBCBC)
.&&(framelessSumBCBCB ?= sparkSumBCBCB)
(tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3)))
.&&(
tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)(
Seq(l => l._3)
)
)
.&&(
tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)(
Seq(l => l._3, l => l._5)
)
)
.&&(
tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)(
Seq(l => l._3, l => l._5)
)
)

}

check(forAll(prop[String, Long, Double, Long, Double] _))
Expand Down Expand Up @@ -293,7 +307,9 @@ class RollupTests extends TypedDatasetSuite {
)
.sortBy(t => (t._2, t._1, t._3, t._4))

framelessSumByAB ?= sparkSumByAB
tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)(
Seq(l => l._4)
)
}

check(forAll(prop[Byte, Int, Long, Double, Long, Double] _))
Expand Down Expand Up @@ -462,11 +478,19 @@ class RollupTests extends TypedDatasetSuite {
)
.sortBy(t => (t._2, t._1, t._3))

(framelessSumC ?= sparkSumC) &&
(framelessSumCC ?= sparkSumCC) &&
(framelessSumCCC ?= sparkSumCCC) &&
(framelessSumCCCC ?= sparkSumCCCC) &&
(framelessSumCCCCC ?= sparkSumCCCCC)
(tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) &&
(tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)(
Seq(l => l._3, l => l._4)
)) &&
(tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5)
)) &&
(tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5, l => l._6)
)) &&
(tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)(
Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7)
))
}

check(forAll(prop[String, Long, Double, Double] _))
Expand Down

0 comments on commit a89542e

Please sign in to comment.