diff --git a/lib/trino-array/TrueraTrinoPlugin.java b/lib/trino-array/TrueraTrinoPlugin.java
new file mode 100644
index 000000000000..372ff013b4c0
--- /dev/null
+++ b/lib/trino-array/TrueraTrinoPlugin.java
@@ -0,0 +1,2 @@
+public class TrueraTrinoPlugin {
+}
diff --git a/plugin/trino-truera/README.md b/plugin/trino-truera/README.md
new file mode 100644
index 000000000000..20a746adafcd
--- /dev/null
+++ b/plugin/trino-truera/README.md
@@ -0,0 +1,19 @@
+# TruEra Trino Extensions
+This module is a custom Trino plugin for TruEra. Currently it just contains one new function (ROC-AUC).
+
+## How do I add to the plugin?
+To get started, read the Trino [dev guide](https://trino.io/docs/current/develop/spi-overview.html#)
+
+## How do I test the package?
+1. Compile using `mvn clean install -Dair.check.skip-all=true -DskipTests`.
+2. Look at the `target` folder which should now have a packaged ZIP like: `trino-truera-389.zip`
+3. Unzip the package into the trino plugins directory: `unzip trino-truera-389.zip -d ~/external_dependencies/trino-server/plugin`
+4. Restart Trino (`./service.sh stop trino && ./service.sh start trino`)
+
+You can test the "roc_auc" function with a command like:
+```sql
+SELECT roc_auc(__truera_prediction__, CAST(__truera_label__ as boolean))
+FROM "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_predictions_probits_score" as pred
+JOIN "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_label"
+AS label ON pred.__truera_id__ = label.__truera_id__;
+```
diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml
new file mode 100644
index 000000000000..597de120b385
--- /dev/null
+++ b/plugin/trino-truera/pom.xml
@@ -0,0 +1,106 @@
+
+
+ 4.0.0
+
+ trino-root
+ io.trino
+ 406
+ ../../pom.xml
+
+
+ trino-truera
+ Trino Truera Extensions
+ trino-plugin
+
+
+ ${project.parent.basedir}
+
+
+
+
+ io.trino
+ trino-array
+
+
+ io.trino
+ trino-spi
+ provided
+
+
+ io.airlift
+ slice
+ provided
+
+
+ io.airlift
+ log
+
+
+ org.openjdk.jol
+ jol-core
+ provided
+
+
+ org.testng
+ testng
+ test
+
+
+ io.trino
+ trino-main
+ test
+
+
+ io.trino
+ trino-main
+ test-jar
+ test
+
+
+ io.trino
+ trino-main
+ test
+
+
+ io.trino
+ trino-testing
+ test
+
+
+ io.trino
+ trino-hive
+ test
+
+
+ io.trino
+ trino-iceberg
+ test
+
+
+ io.trino
+ trino-hive-hadoop2
+ test
+
+
+ io.trino
+ trino-main
+
+
+
+
+
+
+
+ com.mycila
+ license-maven-plugin
+
+
+ src/main/**
+
+
+
+
+
+
+
+
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java
new file mode 100644
index 000000000000..5ffbe9ec13a0
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java
@@ -0,0 +1,46 @@
+package io.trino.plugin.truera;
+
+import java.util.Comparator;
+import io.airlift.log.Logger;
+import java.util.stream.IntStream;
+
+public class AreaUnderRocCurveAlgorithm {
+ private static final Logger log = Logger.get(AreaUnderRocCurveAlgorithm.class);
+ public static double computeRocAuc(boolean[] labels, double[] scores) {
+ log.info("cow");
+ log.info("compute", labels.toString(), scores.toString());
+ int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted(
+ Comparator.comparing(i -> scores[i], Comparator.reverseOrder())
+ ).mapToInt(i->i).toArray();
+
+ int currTruePositives = 0, currFalsePositives = 0;
+ double auc = 0.;
+
+ int i = 0;
+ while (i < sortedIndices.length) {
+ int prevTruePositives = currTruePositives;
+ int prevFalsePositives = currFalsePositives;
+ double currentScore = scores[sortedIndices[i]];
+ while (i < sortedIndices.length && currentScore == scores[sortedIndices[i]]) {
+ if (labels[sortedIndices[i]]) {
+ currTruePositives++;
+ } else {
+ currFalsePositives++;
+ }
+ ++i;
+ }
+ auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives);
+ }
+
+ // If labels only contain one class, AUC is undefined
+ if (currTruePositives == 0 || currFalsePositives == 0) {
+ return Double.NaN;
+ }
+
+ return auc / (currTruePositives * currFalsePositives);
+ }
+
+ private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) {
+ return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2
+ }
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java
new file mode 100644
index 000000000000..8bb99034098d
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.truera;
+
+import io.trino.spi.Plugin;
+import io.trino.plugin.truera.aggregation.ROCAUCAggregation;
+
+import java.util.Collections;
+import java.util.Set;
+
+public class TrueraTrinoPlugin
+ implements Plugin
+{
+ @Override
+ public Set> getFunctions()
+ {
+ return Collections.singleton(ROCAUCAggregation.class);
+ }
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java
new file mode 100644
index 000000000000..66e8761f8d4e
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java
@@ -0,0 +1,5 @@
+package io.trino.plugin.truera.aggregation;
+
+import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.block.BlockBuilderStatus;
+import io.trino.spi.block.PageBuilderStatus;
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java
new file mode 100644
index 000000000000..93572e544854
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java
@@ -0,0 +1,150 @@
+package io.trino.plugin.truera.aggregation;
+
+import io.airlift.log.Logger;
+import java.util.ArrayList;
+import java.util.List;
+
+
+import io.trino.array.BooleanBigArray;
+import io.trino.array.IntBigArray;
+import io.trino.array.LongBigArray;
+import io.trino.array.DoubleBigArray;
+
+import io.trino.spi.block.Block;
+import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.type.BooleanType;
+import io.trino.spi.type.DoubleType;
+import org.openjdk.jol.info.ClassLayout;
+
+import static io.trino.plugin.truera.AreaUnderRocCurveAlgorithm.computeRocAuc;
+import static java.util.Objects.requireNonNull;
+
+public class GroupedRocAucCurve {
+ private static final Logger log = Logger.get(GroupedRocAucCurve.class);
+
+ private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedRocAucCurve.class).instanceSize();
+ private static final int NULL = -1;
+
+ // one entry per group
+ // each entry is the index of the first elements of the group in the labels/scores/nextLinks arrays
+ private final LongBigArray headIndices;
+
+ // one entry per double/boolean pair
+ private final BooleanBigArray labels;
+ private final DoubleBigArray scores;
+
+ // the value in nextLinks contains the index of the next value in the chain
+ // a value NULL (-1) indicates it is the last value for the group
+ private final IntBigArray nextLinks;
+
+ // the index of the next free element in the labels/scores/nextLinks arrays
+ // this is needed to be able to know where to continue adding elements when after the arrays are resized
+ private int nextFreeIndex;
+
+ private long currentGroupId = -1;
+
+ public GroupedRocAucCurve() {
+ this.headIndices = new LongBigArray(NULL);
+ this.labels = new BooleanBigArray();
+ this.scores = new DoubleBigArray();
+ this.nextLinks = new IntBigArray(NULL);
+ this.nextFreeIndex = 0;
+ }
+
+ public GroupedRocAucCurve(long groupId, Block serialized) {
+ this();
+ this.currentGroupId = groupId;
+
+ requireNonNull(serialized, "serialized block is null");
+ for (int i = 0; i < serialized.getPositionCount(); i++) {
+ Block entryBlock = serialized.getObject(i, Block.class);
+ add(entryBlock, entryBlock, 0, 1);
+ }
+ }
+
+ public void serialize(BlockBuilder out) {
+ if (isCurrentGroupEmpty()) {
+ out.appendNull();
+ return;
+ }
+
+ // retrieve scores + labels
+ List labelList = new ArrayList<>();
+ List scoreList = new ArrayList<>();
+
+ int currentIndex = (int) headIndices.get(currentGroupId);
+ while (currentIndex != NULL) {
+ labelList.add(labels.get(currentIndex));
+ scoreList.add(scores.get(currentIndex));
+ currentIndex = nextLinks.get(currentIndex);
+ }
+
+ // convert lists to primitive arrays
+ boolean[] labels = new boolean[labelList.size()];
+ for (int i = 0; i < labels.length; i++) {
+ labels[i] = labelList.get(i);
+ }
+ double[] scores = scoreList.stream().mapToDouble(Double::doubleValue).toArray();
+ log.info("cow2");
+ log.info("compute", labels.toString(), scores.toString());
+
+ // compute + return
+ double auc = computeRocAuc(labels, scores);
+ if (Double.isNaN(auc)) {
+ out.appendNull();
+ } else {
+ DoubleType.DOUBLE.writeDouble(out, auc);
+ }
+ }
+
+ public long estimatedInMemorySize() {
+ return INSTANCE_SIZE + labels.sizeOf() + scores.sizeOf() + nextLinks.sizeOf() + headIndices.sizeOf();
+ }
+
+ public GroupedRocAucCurve setGroupId(long groupId) {
+ this.currentGroupId = groupId;
+ return this;
+ }
+
+ public long getGroupId() {
+ return this.currentGroupId;
+ }
+
+ public void add(Block labelsBlock, Block scoresBlock, int labelPosition, int scorePosition) {
+ ensureCapacity(currentGroupId + 1);
+
+ labels.set(nextFreeIndex, BooleanType.BOOLEAN.getBoolean(labelsBlock, labelPosition));
+ scores.set(nextFreeIndex, DoubleType.DOUBLE.getDouble(scoresBlock, scorePosition));
+ nextLinks.set(nextFreeIndex, (int) headIndices.get(currentGroupId));
+ nextFreeIndex++;
+ }
+
+ public void ensureCapacity(long numberOfGroups) {
+ headIndices.ensureCapacity(numberOfGroups);
+ int numberOfValues = nextFreeIndex + 1;
+ labels.ensureCapacity(numberOfValues);
+ scores.ensureCapacity(numberOfValues);
+ nextLinks.ensureCapacity(numberOfValues);
+ }
+
+ public void addAll(GroupedRocAucCurve other) {
+ other.readAll(this);
+ }
+
+ public void readAll(GroupedRocAucCurve to) {
+ int currentIndex = (int) headIndices.get(currentGroupId);
+ while (currentIndex != NULL) {
+ BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0);
+ BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, labels.get(currentIndex));
+ BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0);
+ DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, scores.get(currentIndex));
+
+ to.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0);
+ currentIndex = nextLinks.get(currentIndex);
+ }
+ }
+
+ public boolean isCurrentGroupEmpty() {
+ return headIndices.get(currentGroupId) == NULL;
+ }
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java
new file mode 100644
index 000000000000..781821eed918
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.truera.aggregation;
+
+import io.trino.plugin.truera.state.AreaUnderRocCurveState;
+import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.function.*;
+import io.trino.spi.type.BooleanType;
+import io.trino.spi.type.DoubleType;
+import io.trino.spi.type.StandardTypes;
+
+@AggregationFunction("roc_auc")
+public class ROCAUCAggregation {
+ @InputFunction
+ public static void input(AreaUnderRocCurveState state, @SqlType(StandardTypes.DOUBLE) double score, @SqlType(StandardTypes.BOOLEAN) boolean label) {
+ GroupedRocAucCurve auc = state.get();
+ BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0);
+ BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, label);
+ BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0);
+ DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, score);
+
+ long startSize = auc.estimatedInMemorySize();
+ auc.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0);
+ state.addMemoryUsage(auc.estimatedInMemorySize() - startSize);
+ }
+
+ @CombineFunction
+ public static void combine(AreaUnderRocCurveState state, AreaUnderRocCurveState otherState) {
+ if (!state.get().isCurrentGroupEmpty() && !otherState.get().isCurrentGroupEmpty()) {
+ GroupedRocAucCurve auc = state.get();
+ long startSize = auc.estimatedInMemorySize();
+ auc.addAll(otherState.get());
+ state.addMemoryUsage(auc.estimatedInMemorySize() - startSize);
+ }
+ else if (state.get().isCurrentGroupEmpty()) {
+ state.set(otherState.get());
+ }
+ }
+
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void output(AreaUnderRocCurveState state, BlockBuilder out) {
+ GroupedRocAucCurve auc = state.get();
+ auc.serialize(out);
+ }
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java
new file mode 100644
index 000000000000..8af8d0ebc8e7
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java
@@ -0,0 +1,2 @@
+package io.trino.plugin.truera.metrics;public class SingleAUCROC {
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java
new file mode 100644
index 000000000000..8a28a48b4831
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java
@@ -0,0 +1,18 @@
+package io.trino.plugin.truera.state;
+
+import io.trino.plugin.truera.aggregation.GroupedRocAucCurve;
+import io.trino.spi.block.Block;
+import io.trino.spi.function.AccumulatorState;
+import io.trino.spi.function.AccumulatorStateMetadata;
+
+@AccumulatorStateMetadata(stateFactoryClass = AreaUnderRocCurveStateFactory.class, stateSerializerClass = AreaUnderRocCurveStateSerializer.class)
+public interface AreaUnderRocCurveState extends AccumulatorState
+{
+ GroupedRocAucCurve get();
+
+ void set(GroupedRocAucCurve value);
+
+ void addMemoryUsage(long memory);
+
+ void deserialize(Block serialized);
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java
new file mode 100644
index 000000000000..38fa9837cc64
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java
@@ -0,0 +1,130 @@
+package io.trino.plugin.truera.state;
+
+import io.trino.plugin.truera.aggregation.GroupedRocAucCurve;
+import io.trino.spi.block.Block;
+import io.trino.spi.function.AccumulatorStateFactory;
+import io.trino.spi.function.GroupedAccumulatorState;
+import org.openjdk.jol.info.ClassLayout;
+
+import static java.util.Objects.requireNonNull;
+
+public class AreaUnderRocCurveStateFactory implements AccumulatorStateFactory
+{
+ @Override
+ public AreaUnderRocCurveState createSingleState() {
+ return new SingleState();
+ }
+
+ @Override
+ public AreaUnderRocCurveState createGroupedState() {
+ return new GroupedState();
+ }
+
+ public static class GroupedState implements AreaUnderRocCurveState, GroupedAccumulatorState {
+ private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize();
+ private GroupedRocAucCurve auc;
+ private long size;
+
+ public GroupedState() {
+ auc = new GroupedRocAucCurve();
+ }
+
+ @Override
+ public void setGroupId(long groupId)
+ {
+ auc.setGroupId(groupId);
+ }
+
+ @Override
+ public void ensureCapacity(long size) {
+ auc.ensureCapacity(size);
+ }
+
+ @Override
+ public GroupedRocAucCurve get() {
+ return auc;
+ }
+
+ @Override
+ public void set(GroupedRocAucCurve value)
+ {
+ requireNonNull(value, "value is null");
+
+ GroupedRocAucCurve previous = get();
+ if (previous != null) {
+ size -= previous.estimatedInMemorySize();
+ }
+
+ auc = value;
+ size += value.estimatedInMemorySize();
+ }
+
+ @Override
+ public void addMemoryUsage(long memory)
+ {
+ size += memory;
+ }
+
+ @Override
+ public void deserialize(Block serialized)
+ {
+ this.auc = new GroupedRocAucCurve(0, serialized);
+ }
+
+ @Override
+ public long getEstimatedSize()
+ {
+ return INSTANCE_SIZE + size + auc.estimatedInMemorySize();
+ }
+ }
+
+ public static class SingleState
+ implements AreaUnderRocCurveState
+ {
+ private static final long INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize();
+ private GroupedRocAucCurve auc;
+
+ public SingleState()
+ {
+ auc = new GroupedRocAucCurve();
+
+ // set synthetic, unique group id to use GroupAreaUnderRocCurve from the single state
+ auc.setGroupId(0);
+ }
+
+ @Override
+ public GroupedRocAucCurve get()
+ {
+ return auc;
+ }
+
+ @Override
+ public void deserialize(Block serialized)
+ {
+ this.auc = new GroupedRocAucCurve(0, serialized);
+ }
+
+ @Override
+ public void set(GroupedRocAucCurve value)
+ {
+ auc = value;
+ }
+
+ @Override
+ public void addMemoryUsage(long memory)
+ {
+ }
+
+ @Override
+ public long getEstimatedSize()
+ {
+ long estimatedSize = INSTANCE_SIZE;
+ if (auc != null) {
+ estimatedSize += auc.estimatedInMemorySize();
+ }
+ return estimatedSize;
+ }
+ }
+
+
+}
diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java
new file mode 100644
index 000000000000..5b11d0e0e6d7
--- /dev/null
+++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java
@@ -0,0 +1,32 @@
+package io.trino.plugin.truera.state;
+
+import io.trino.spi.block.Block;
+import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.function.AccumulatorStateSerializer;
+import io.trino.spi.type.ArrayType;
+import io.trino.spi.type.BooleanType;
+import io.trino.spi.type.DoubleType;
+import io.trino.spi.type.Type;
+
+import static io.trino.spi.type.RowType.anonymousRow;
+
+public class AreaUnderRocCurveStateSerializer implements AccumulatorStateSerializer
+{
+ static final ArrayType SERIALIZED_TYPE = new ArrayType(anonymousRow(BooleanType.BOOLEAN, DoubleType.DOUBLE));
+
+ @Override
+ public Type getSerializedType() {
+ return SERIALIZED_TYPE;
+ }
+
+ @Override
+ public void serialize(AreaUnderRocCurveState state, BlockBuilder out) {
+ state.get().serialize(out);
+ }
+
+ @Override
+ public void deserialize(Block block, int index, AreaUnderRocCurveState state) {
+ state.deserialize(SERIALIZED_TYPE.getObject(block, index));
+ }
+
+}
diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java
new file mode 100644
index 000000000000..cc1b5c756495
--- /dev/null
+++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java
@@ -0,0 +1,35 @@
+package io.trino.plugin.truera;
+
+import org.testng.annotations.Test;
+import static org.testng.Assert.assertEquals;
+public class TestAreaUnderRocCurveAlgorithm {
+
+ @Test
+ public void testComputeAucRocConstantYs() {
+ boolean[] ys = new boolean[3];
+ double[] ysPred = new double[]{-1, 0, 1};
+ assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), Double.NaN);
+ }
+
+ @Test
+ public void testComputeAucRocConstantYsPred() {
+ // Check if first element is only true.
+ boolean[] ys = new boolean[10];
+ ys[0] = true;
+ double[] ysPred = new double[10];
+ assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5);
+ // Check if last element is only true.
+ ys = new boolean[10];
+ ys[ys.length - 1] = true;
+ assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5);
+ }
+
+ @Test
+ public void testComputeAucRocRandomCase() {
+ boolean[] ys = {false, false, false, true, false, false, true, true, false, true, true, true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, false, true, false, false, true, true, false, true, false, false, false, false, false, true, false, false, false, true, false, true, false, true, true, false, false, false, true, true, true, false, true, false, false, true, true, false, true, false, true, true, false, false, true, false, false, true, true, true, false, false, true, true, false, true, false, true, false, true, true, true, false, true, true, false, true, true, true, true, false, false, false, true, true, false};
+ double[] ysPred = {-0.2286298738575081, -0.26869038039536985, -0.312319527551126, 0.2734721402453152, -0.24223018509401162, 0.31503992286551685, 0.4046447140511742, 0.28792712595231496, 0.23384628640264638, 0.4428378129948476, 0.4543878651258779, -0.01374360133220165, 0.0007349034879994276, 0.09636778467067175, 0.13655154956520688, 0.17454951949942543, 0.4931408863331128, 0.23547728150236757, 0.37226953309841515, -0.15728692842143432, 0.15924354023313247, 0.4286270233907359, -0.4685474933247794, -0.25577981860345134, 0.4681668166617913, 0.11320750870358587, -0.2736801671178942, -0.3864589476069452, -0.399822172488029, 0.015685730721504587, 0.3321140926944811, 0.2530935731647217, -0.3959046404148484, -0.15179547616306643, -0.36416614908162126, -0.29774936545934994, -0.12315200126789438, -0.2644045710260078, -0.44906116392418005, -0.020580493381027964, -0.12273927226482872, -0.056684762439960235, -0.08061360953066143, -0.29603641015837856, -0.18856132377074108, 0.4235267740783706, -0.4606040899799605, 0.42625848143208045, -0.0011082145951951672, -0.10729132271105546, 0.05327228838675735, 0.23256669296146493, 0.44780010851745966, 0.3568537456277838, 0.1805343926823827, -0.21846170998705394, -0.18590956046362084, 0.2509080583778559, -0.22290770310271601, 0.09721402039749849, 0.12656860105942058, -0.4855359797468307, -0.39689924561298107, 0.06657447973146235, 0.2578990291023152, 0.4742641426354459, -0.18982849914795363, 0.2227804823488645, 0.35372747647450187, -0.4543332887713202, -0.2564141959803503, -0.3998731379722754, 0.4302966112027795, 0.20920354786320405, -0.3493332019987293, 0.22594844508851064, 0.3506301294801156, 0.49704339475860027, -0.06293993755501126, 0.3918169521501762, -0.09311136809088782, 0.1600693305289207, -0.15765165599671038, 0.45406593535839646, 0.25178211933304273, -0.10253079227341522, -0.07396783373899074, -0.2779385207339454, 0.39166810505415206, -0.21097405310184147, 0.49742358104577056, 0.2503826842835034, 0.06827961053326548, -0.14270071261942086, 0.2718756995005246, -0.40524883036019543, 0.017045033508361507, 0.38197771242907064, 0.26845558760844856, -0.0965167478978054};
+ double auc = AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred);
+ assertEquals(auc, 0.8187274909963985);
+ }
+
+}
diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java
new file mode 100644
index 000000000000..95aece0bcab6
--- /dev/null
+++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java
@@ -0,0 +1,89 @@
+package io.trino.plugin.truera;
+
+import static io.trino.spi.security.SelectedRole.Type.ROLE;
+import static io.trino.testing.TestingSession.testSessionBuilder;
+import io.trino.plugin.hive.HivePlugin;
+
+import com.google.common.collect.ImmutableMap;
+import io.trino.Session;
+import io.trino.plugin.iceberg.IcebergPlugin;
+import io.trino.spi.security.Identity;
+import io.trino.spi.security.SelectedRole;
+import io.trino.testing.AbstractTestQueryFramework;
+import io.trino.testing.DistributedQueryRunner;
+import java.io.File;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+@Test(singleThreaded = true)
+public class TestAucRocAggregation extends AbstractTestQueryFramework {
+ private ImmutableMap.Builder icebergProperties = ImmutableMap.builder();
+ private Optional metastoreDirectory = Optional.empty();
+ public static final String ICEBERG_CATALOG = "iceberg";
+ @Override
+ protected DistributedQueryRunner createQueryRunner()
+ throws Exception
+ {
+ Session session = testSessionBuilder()
+ .setIdentity(Identity.forUser("hive")
+ .withConnectorRole("hive", new SelectedRole(ROLE, Optional.of("admin")))
+ .build())
+ .build();
+
+ DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build();
+
+ queryRunner.installPlugin(new IcebergPlugin());
+ Map icebergProperties = new HashMap<>(this.icebergProperties.buildOrThrow());
+ String catalogType = icebergProperties.get("iceberg.catalog.type");
+ Path dataDir = metastoreDirectory.map(File::toPath).orElseGet(() -> queryRunner.getCoordinator().getBaseDataDir().resolve("iceberg_data"));
+ if (catalogType == null) {
+ icebergProperties.put("iceberg.catalog.type", "TESTING_FILE_METASTORE");
+ icebergProperties.put("hive.metastore.catalog.dir", dataDir.toString());
+ }
+
+ queryRunner.createCatalog(ICEBERG_CATALOG, "iceberg", icebergProperties);
+ queryRunner.installPlugin(new HivePlugin());
+ queryRunner.createCatalog("hive", "hive", ImmutableMap.builder()
+ .put("hive.metastore", "file")
+ .put("hive.metastore.catalog.dir", dataDir.toString())
+ .put("hive.security", "sql-standard")
+ .buildOrThrow());
+ queryRunner.installPlugin(new TrueraTrinoPlugin());
+
+ return queryRunner;
+ }
+
+ @BeforeClass
+ public void setUp()
+ {
+ assertQuerySucceeds("CREATE SCHEMA hive.test_schema");
+ assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_probits_table (__id__ VARCHAR, __preds__ double)");
+ assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_labels_table (__id__ VARCHAR, __label__ int)");
+ assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_probits_table VALUES ('A', 0.5), ('B', 0.2), ('C', 0.9)");
+ assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_labels_table VALUES ('A', 1), ('B', 0), ('C', 1)");
+ }
+
+ @AfterClass(alwaysRun = true)
+ public void tearDown()
+ {
+ assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_labels_table");
+ assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_probits_table");
+ assertQuerySucceeds("DROP SCHEMA IF EXISTS hive.test_schema");
+ }
+
+ @Test
+ public void testAucRoc()
+ {
+ assertQuery("SELECT roc_auc(__preds__, CAST(__label__ as boolean)) " +
+ "FROM iceberg.test_schema.iceberg_probits_table as pred " +
+ "JOIN iceberg.test_schema.iceberg_labels_table AS label " +
+ "ON pred.__id__ = label.__id__",
+ "VALUES (NULL)");
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index 9864b6b6b4f2..147f5764eed2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -182,6 +182,7 @@
plugin/trino-thrift-testing-server
plugin/trino-tpcds
plugin/trino-tpch
+ plugin/trino-truera
service/trino-proxy
service/trino-verifier
testing/trino-benchmark
@@ -739,6 +740,12 @@
${project.version}
+
+ io.trino
+ trino-truera
+ ${project.version}
+
+
io.trino.benchto