Skip to content

Commit

Permalink
typelevel#787 - attempt covar_pop and kurtosis through tolerances
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Apr 11, 2024
1 parent 66b31e9 commit 80de4f2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,10 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
TypedColumn[X3[Int, A, B], A],
TypedColumn[X3[Int, A, B], B]
) => TypedAggregate[X3[Int, A, B], Option[Double]],
sparkFun: (Column, Column) => Column
sparkFun: (Column, Column) => Column,
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
BigDecimal
], Option[BigDecimal]] = identity
)(implicit
encEv: Encoder[(Int, A, B)],
encEv2: Encoder[(Int, Option[Double])],
Expand Down Expand Up @@ -496,7 +499,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
})

// Should be the same
tdBivar.toMap ?= compBivar.collect().toMap
// tdBivar.toMap ?= compBivar.collect().toMap
DoubleBehaviourUtils.compareMaps(
tdBivar.toMap,
compBivar.collect().toMap,
fudger
)
}

def univariatePropTemplate[A: TypedEncoder](
Expand All @@ -505,7 +513,10 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
X2[Int, A],
Option[Double]
],
sparkFun: (Column) => Column
sparkFun: (Column) => Column,
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
BigDecimal
], Option[BigDecimal]] = identity
)(implicit
encEv: Encoder[(Int, A)],
encEv2: Encoder[(Int, Option[Double])],
Expand Down Expand Up @@ -534,7 +545,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
})

// Should be the same
tdUnivar.toMap ?= compUnivar.collect().toMap
// tdUnivar.toMap ?= compUnivar.collect().toMap
DoubleBehaviourUtils.compareMaps(
tdUnivar.toMap,
compUnivar.collect().toMap,
fudger
)
}

test("corr") {
Expand Down Expand Up @@ -571,7 +587,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
evCanBeDoubleB: CatalystCast[B, Double]
): Prop = bivariatePropTemplate(xs)(
covarPop[A, B, X3[Int, A, B]],
org.apache.spark.sql.functions.covar_pop
org.apache.spark.sql.functions.covar_pop,
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("100"))
)

check(forAll(prop[Double, Double] _))
Expand Down Expand Up @@ -614,7 +631,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
evCanBeDoubleA: CatalystCast[A, Double]
): Prop = univariatePropTemplate(xs)(
kurtosis[A, X2[Int, A]],
org.apache.spark.sql.functions.kurtosis
org.apache.spark.sql.functions.kurtosis,
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("0.1"))
)

check(forAll(prop[Double] _))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package frameless
package functions

import org.scalacheck.Prop
import org.scalacheck.util.Pretty

/**
* Some statistical functions in Spark can result in Double, Double.NaN or Null.
* This tends to break ?= of the property based testing. Use the nanNullHandler function
Expand Down Expand Up @@ -37,6 +40,63 @@ object DoubleBehaviourUtils {
else
BigDecimal.RoundingMode.CEILING
)

def compareMaps[K](
m1: Map[K, Option[BigDecimal]],
m2: Map[K, Option[BigDecimal]],
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
BigDecimal
], Option[BigDecimal]]
): Prop = {
def compareKey(k: K): Prop = {
val m1v = m1.get(k)
val m2v = m2.get(k)
if (!m2v.isDefined)
Prop.falsified :| {
val expKey = Pretty.pretty[K](k, Pretty.Params(0))
"Expected key of " + expKey + " in right side map"
}
else {
val (v1, v2) = fudger((m1v.get, m2v.get))
if (v1 == v2)
Prop.proved
else
Prop.falsified :| {
val expKey = Pretty.pretty[K](k, Pretty.Params(0))
val leftVal =
Pretty.pretty[Option[BigDecimal]](v1, Pretty.Params(0))
val rightVal =
Pretty.pretty[Option[BigDecimal]](v2, Pretty.Params(0))
"For key of " + expKey + " expected " + leftVal + " got " + rightVal
}
}
}

if (m1.size != m2.size)
Prop.falsified :| {
"Expected map of size " + m1.size + " but got " + m2.size
}
else
m1.keys.foldLeft(Prop.passed) { (curr, elem) => curr && compareKey(elem) }
}

/** running covar_pop and kurtosis multiple times is giving slightly different results */
def tolerance(
p: Tuple2[Option[BigDecimal], Option[BigDecimal]],
of: BigDecimal
): Tuple2[Option[BigDecimal], Option[BigDecimal]] = {
val comb = p._1.flatMap(a => p._2.map(b => (a, b)))
if (comb.isEmpty)
p
else {
val (l, r) = comb.get
if ((l.max(r) - l.min(r)).abs < of)
// tolerate it
(Some(l), Some(l))
else
p
}
}
}

/** drop in conversion for doubles to handle serialization on cluster */
Expand Down

0 comments on commit 80de4f2

Please sign in to comment.