Skip to content

Commit

Permalink
add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
reetika-roy committed Feb 13, 2023
1 parent 85ecbed commit 1a77659
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 31 deletions.
5 changes: 5 additions & 0 deletions plugin/trino-truera/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
<artifactId>jol-core</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.trino.plugin.truera;

import java.util.Collections;
import java.util.Comparator;
import java.util.stream.IntStream;

public class ROCAUCFunction {
public static double computeRocAuc(boolean[] labels, double[] scores) {
int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted(
Comparator.comparing(i -> scores[i])
).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray();

int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0;
double auc = 0.;

for (int i : sortedIndices) {
if (labels[i]) { currTruePositives++; } else { currFalsePositives++; };
prevTruePositives = currTruePositives;
prevFalsePositives = currFalsePositives;
auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives);
}

// If labels only contain one class, AUC is undefined
if (currTruePositives == 0 || currFalsePositives == 0) {
return Double.POSITIVE_INFINITY;
}

return auc / (currTruePositives * currFalsePositives);
}

private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) {
return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package io.trino.plugin.truera.aggregation;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.IntStream;


import io.trino.array.BooleanBigArray;
import io.trino.array.IntBigArray;
Expand All @@ -17,6 +15,7 @@
import io.trino.spi.type.DoubleType;
import org.openjdk.jol.info.ClassLayout;

import static io.trino.plugin.truera.ROCAUCFunction.computeRocAuc;
import static java.util.Objects.requireNonNull;

public class GroupedRocAucCurve {
Expand Down Expand Up @@ -140,32 +139,4 @@ public void readAll(GroupedRocAucCurve to) {
public boolean isCurrentGroupEmpty() {
return headIndices.get(currentGroupId) == NULL;
}

private static double computeRocAuc(boolean[] labels, double[] scores) {
int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted(
Comparator.comparing(i -> scores[i])
).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray();

int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0;
double auc = 0.;

for (int i : sortedIndices) {
if (labels[i]) { currTruePositives++; } else { currFalsePositives++; };
prevTruePositives = currTruePositives;
prevFalsePositives = currFalsePositives;
auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives);
}

// If labels only contain one class, AUC is undefined
if (currTruePositives == 0 || currFalsePositives == 0) {
return Double.POSITIVE_INFINITY;
}

return auc / (currTruePositives * currFalsePositives);
}

private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) {
return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.trino.plugin.truera;

import java.util.Arrays;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
public class TestROCAUCFunction {
@Test
public void testComputeAucRocWhenUndefined() {

boolean[] testLabels = new boolean[10];
double[] testProbabilities = new double[10];
assertEquals(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities), Double.POSITIVE_INFINITY);
}

@Test
public void testComputeAucRoc() {

boolean[] testLabels = new boolean[10];
testLabels[9] = true;
double[] testProbabilities = new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617,
0.62190361, 0.34917899, 0.88604834, 0.09936481, 0.65903197};
// Arrays.fill(testProbabilities, 1.0);
for (boolean element: testLabels) {
System.out.println(element);
}
for (double element: testProbabilities) {
System.out.println(element);
}
System.out.println("score");
System.out.println(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities));
}

}

0 comments on commit 1a77659

Please sign in to comment.