diff --git a/lib/trino-array/TrueraTrinoPlugin.java b/lib/trino-array/TrueraTrinoPlugin.java new file mode 100644 index 000000000000..372ff013b4c0 --- /dev/null +++ b/lib/trino-array/TrueraTrinoPlugin.java @@ -0,0 +1,2 @@ +public class TrueraTrinoPlugin { +} diff --git a/plugin/trino-truera/README.md b/plugin/trino-truera/README.md new file mode 100644 index 000000000000..20a746adafcd --- /dev/null +++ b/plugin/trino-truera/README.md @@ -0,0 +1,19 @@ +# TruEra Trino Extensions +This module is a custom Trino plugin for TruEra. Currently it just contains one new function (ROC-AUC). + +## How do I add to the plugin? +To get started, read the Trino [dev guide](https://trino.io/docs/current/develop/spi-overview.html#) + +## How do I test the package? +1. Compile using `mvn clean install -Dair.check.skip-all=true -DskipTests`. +2. Look at the `target` folder which should now have a packaged ZIP like: `trino-truera-389.zip` +3. Unzip the package into the trino plugins directory: `unzip trino-truera-389.zip -d ~/external_dependencies/trino-server/plugin` +4. Restart Trino (`./service.sh stop trino && ./service.sh start trino`) + +You can test the "roc_auc" function with a command like: +```sql +SELECT roc_auc(__truera_prediction__, CAST(__truera_label__ as boolean)) +FROM "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_predictions_probits_score" as pred +JOIN "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_label" +AS label ON pred.__truera_id__ = label.__truera_id__; +``` diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml new file mode 100644 index 000000000000..597de120b385 --- /dev/null +++ b/plugin/trino-truera/pom.xml @@ -0,0 +1,106 @@ + + + 4.0.0 + + trino-root + io.trino + 406 + ../../pom.xml + + + trino-truera + Trino Truera Extensions + trino-plugin + + + ${project.parent.basedir} + + + + + io.trino + trino-array + + + io.trino + trino-spi + provided + + + io.airlift + slice + provided + + + io.airlift + log + + + org.openjdk.jol + jol-core + provided + + + org.testng + testng + test + + + io.trino + trino-main + test + + + io.trino + trino-main + test-jar + test + + + io.trino + trino-main + test + + + io.trino + trino-testing + test + + + io.trino + trino-hive + test + + + io.trino + trino-iceberg + test + + + io.trino + trino-hive-hadoop2 + test + + + io.trino + trino-main + + + + + + + + com.mycila + license-maven-plugin + + + src/main/** + + + + + + + + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java new file mode 100644 index 000000000000..5ffbe9ec13a0 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java @@ -0,0 +1,46 @@ +package io.trino.plugin.truera; + +import java.util.Comparator; +import io.airlift.log.Logger; +import java.util.stream.IntStream; + +public class AreaUnderRocCurveAlgorithm { + private static final Logger log = Logger.get(AreaUnderRocCurveAlgorithm.class); + public static double computeRocAuc(boolean[] labels, double[] scores) { + log.info("cow"); + log.info("compute", labels.toString(), scores.toString()); + int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( + Comparator.comparing(i -> scores[i], Comparator.reverseOrder()) + ).mapToInt(i->i).toArray(); + + int currTruePositives = 0, currFalsePositives = 0; + double auc = 0.; + + int i = 0; + while (i < sortedIndices.length) { + int prevTruePositives = currTruePositives; + int prevFalsePositives = currFalsePositives; + double currentScore = scores[sortedIndices[i]]; + while (i < sortedIndices.length && currentScore == scores[sortedIndices[i]]) { + if (labels[sortedIndices[i]]) { + currTruePositives++; + } else { + currFalsePositives++; + } + ++i; + } + auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); + } + + // If labels only contain one class, AUC is undefined + if (currTruePositives == 0 || currFalsePositives == 0) { + return Double.NaN; + } + + return auc / (currTruePositives * currFalsePositives); + } + + private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) { + return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2 + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java new file mode 100644 index 000000000000..8bb99034098d --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera; + +import io.trino.spi.Plugin; +import io.trino.plugin.truera.aggregation.ROCAUCAggregation; + +import java.util.Collections; +import java.util.Set; + +public class TrueraTrinoPlugin + implements Plugin +{ + @Override + public Set> getFunctions() + { + return Collections.singleton(ROCAUCAggregation.class); + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java new file mode 100644 index 000000000000..66e8761f8d4e --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java @@ -0,0 +1,5 @@ +package io.trino.plugin.truera.aggregation; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.PageBuilderStatus; diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java new file mode 100644 index 000000000000..93572e544854 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java @@ -0,0 +1,150 @@ +package io.trino.plugin.truera.aggregation; + +import io.airlift.log.Logger; +import java.util.ArrayList; +import java.util.List; + + +import io.trino.array.BooleanBigArray; +import io.trino.array.IntBigArray; +import io.trino.array.LongBigArray; +import io.trino.array.DoubleBigArray; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import org.openjdk.jol.info.ClassLayout; + +import static io.trino.plugin.truera.AreaUnderRocCurveAlgorithm.computeRocAuc; +import static java.util.Objects.requireNonNull; + +public class GroupedRocAucCurve { + private static final Logger log = Logger.get(GroupedRocAucCurve.class); + + private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedRocAucCurve.class).instanceSize(); + private static final int NULL = -1; + + // one entry per group + // each entry is the index of the first elements of the group in the labels/scores/nextLinks arrays + private final LongBigArray headIndices; + + // one entry per double/boolean pair + private final BooleanBigArray labels; + private final DoubleBigArray scores; + + // the value in nextLinks contains the index of the next value in the chain + // a value NULL (-1) indicates it is the last value for the group + private final IntBigArray nextLinks; + + // the index of the next free element in the labels/scores/nextLinks arrays + // this is needed to be able to know where to continue adding elements when after the arrays are resized + private int nextFreeIndex; + + private long currentGroupId = -1; + + public GroupedRocAucCurve() { + this.headIndices = new LongBigArray(NULL); + this.labels = new BooleanBigArray(); + this.scores = new DoubleBigArray(); + this.nextLinks = new IntBigArray(NULL); + this.nextFreeIndex = 0; + } + + public GroupedRocAucCurve(long groupId, Block serialized) { + this(); + this.currentGroupId = groupId; + + requireNonNull(serialized, "serialized block is null"); + for (int i = 0; i < serialized.getPositionCount(); i++) { + Block entryBlock = serialized.getObject(i, Block.class); + add(entryBlock, entryBlock, 0, 1); + } + } + + public void serialize(BlockBuilder out) { + if (isCurrentGroupEmpty()) { + out.appendNull(); + return; + } + + // retrieve scores + labels + List labelList = new ArrayList<>(); + List scoreList = new ArrayList<>(); + + int currentIndex = (int) headIndices.get(currentGroupId); + while (currentIndex != NULL) { + labelList.add(labels.get(currentIndex)); + scoreList.add(scores.get(currentIndex)); + currentIndex = nextLinks.get(currentIndex); + } + + // convert lists to primitive arrays + boolean[] labels = new boolean[labelList.size()]; + for (int i = 0; i < labels.length; i++) { + labels[i] = labelList.get(i); + } + double[] scores = scoreList.stream().mapToDouble(Double::doubleValue).toArray(); + log.info("cow2"); + log.info("compute", labels.toString(), scores.toString()); + + // compute + return + double auc = computeRocAuc(labels, scores); + if (Double.isNaN(auc)) { + out.appendNull(); + } else { + DoubleType.DOUBLE.writeDouble(out, auc); + } + } + + public long estimatedInMemorySize() { + return INSTANCE_SIZE + labels.sizeOf() + scores.sizeOf() + nextLinks.sizeOf() + headIndices.sizeOf(); + } + + public GroupedRocAucCurve setGroupId(long groupId) { + this.currentGroupId = groupId; + return this; + } + + public long getGroupId() { + return this.currentGroupId; + } + + public void add(Block labelsBlock, Block scoresBlock, int labelPosition, int scorePosition) { + ensureCapacity(currentGroupId + 1); + + labels.set(nextFreeIndex, BooleanType.BOOLEAN.getBoolean(labelsBlock, labelPosition)); + scores.set(nextFreeIndex, DoubleType.DOUBLE.getDouble(scoresBlock, scorePosition)); + nextLinks.set(nextFreeIndex, (int) headIndices.get(currentGroupId)); + nextFreeIndex++; + } + + public void ensureCapacity(long numberOfGroups) { + headIndices.ensureCapacity(numberOfGroups); + int numberOfValues = nextFreeIndex + 1; + labels.ensureCapacity(numberOfValues); + scores.ensureCapacity(numberOfValues); + nextLinks.ensureCapacity(numberOfValues); + } + + public void addAll(GroupedRocAucCurve other) { + other.readAll(this); + } + + public void readAll(GroupedRocAucCurve to) { + int currentIndex = (int) headIndices.get(currentGroupId); + while (currentIndex != NULL) { + BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0); + BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, labels.get(currentIndex)); + BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0); + DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, scores.get(currentIndex)); + + to.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0); + currentIndex = nextLinks.get(currentIndex); + } + } + + public boolean isCurrentGroupEmpty() { + return headIndices.get(currentGroupId) == NULL; + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java new file mode 100644 index 000000000000..781821eed918 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera.aggregation; + +import io.trino.plugin.truera.state.AreaUnderRocCurveState; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.*; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.StandardTypes; + +@AggregationFunction("roc_auc") +public class ROCAUCAggregation { + @InputFunction + public static void input(AreaUnderRocCurveState state, @SqlType(StandardTypes.DOUBLE) double score, @SqlType(StandardTypes.BOOLEAN) boolean label) { + GroupedRocAucCurve auc = state.get(); + BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0); + BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, label); + BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0); + DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, score); + + long startSize = auc.estimatedInMemorySize(); + auc.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0); + state.addMemoryUsage(auc.estimatedInMemorySize() - startSize); + } + + @CombineFunction + public static void combine(AreaUnderRocCurveState state, AreaUnderRocCurveState otherState) { + if (!state.get().isCurrentGroupEmpty() && !otherState.get().isCurrentGroupEmpty()) { + GroupedRocAucCurve auc = state.get(); + long startSize = auc.estimatedInMemorySize(); + auc.addAll(otherState.get()); + state.addMemoryUsage(auc.estimatedInMemorySize() - startSize); + } + else if (state.get().isCurrentGroupEmpty()) { + state.set(otherState.get()); + } + } + + @OutputFunction(StandardTypes.DOUBLE) + public static void output(AreaUnderRocCurveState state, BlockBuilder out) { + GroupedRocAucCurve auc = state.get(); + auc.serialize(out); + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java new file mode 100644 index 000000000000..8af8d0ebc8e7 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java @@ -0,0 +1,2 @@ +package io.trino.plugin.truera.metrics;public class SingleAUCROC { +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java new file mode 100644 index 000000000000..8a28a48b4831 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java @@ -0,0 +1,18 @@ +package io.trino.plugin.truera.state; + +import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata(stateFactoryClass = AreaUnderRocCurveStateFactory.class, stateSerializerClass = AreaUnderRocCurveStateSerializer.class) +public interface AreaUnderRocCurveState extends AccumulatorState +{ + GroupedRocAucCurve get(); + + void set(GroupedRocAucCurve value); + + void addMemoryUsage(long memory); + + void deserialize(Block serialized); +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java new file mode 100644 index 000000000000..38fa9837cc64 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java @@ -0,0 +1,130 @@ +package io.trino.plugin.truera.state; + +import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class AreaUnderRocCurveStateFactory implements AccumulatorStateFactory +{ + @Override + public AreaUnderRocCurveState createSingleState() { + return new SingleState(); + } + + @Override + public AreaUnderRocCurveState createGroupedState() { + return new GroupedState(); + } + + public static class GroupedState implements AreaUnderRocCurveState, GroupedAccumulatorState { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); + private GroupedRocAucCurve auc; + private long size; + + public GroupedState() { + auc = new GroupedRocAucCurve(); + } + + @Override + public void setGroupId(long groupId) + { + auc.setGroupId(groupId); + } + + @Override + public void ensureCapacity(long size) { + auc.ensureCapacity(size); + } + + @Override + public GroupedRocAucCurve get() { + return auc; + } + + @Override + public void set(GroupedRocAucCurve value) + { + requireNonNull(value, "value is null"); + + GroupedRocAucCurve previous = get(); + if (previous != null) { + size -= previous.estimatedInMemorySize(); + } + + auc = value; + size += value.estimatedInMemorySize(); + } + + @Override + public void addMemoryUsage(long memory) + { + size += memory; + } + + @Override + public void deserialize(Block serialized) + { + this.auc = new GroupedRocAucCurve(0, serialized); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + size + auc.estimatedInMemorySize(); + } + } + + public static class SingleState + implements AreaUnderRocCurveState + { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); + private GroupedRocAucCurve auc; + + public SingleState() + { + auc = new GroupedRocAucCurve(); + + // set synthetic, unique group id to use GroupAreaUnderRocCurve from the single state + auc.setGroupId(0); + } + + @Override + public GroupedRocAucCurve get() + { + return auc; + } + + @Override + public void deserialize(Block serialized) + { + this.auc = new GroupedRocAucCurve(0, serialized); + } + + @Override + public void set(GroupedRocAucCurve value) + { + auc = value; + } + + @Override + public void addMemoryUsage(long memory) + { + } + + @Override + public long getEstimatedSize() + { + long estimatedSize = INSTANCE_SIZE; + if (auc != null) { + estimatedSize += auc.estimatedInMemorySize(); + } + return estimatedSize; + } + } + + +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java new file mode 100644 index 000000000000..5b11d0e0e6d7 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java @@ -0,0 +1,32 @@ +package io.trino.plugin.truera.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Type; + +import static io.trino.spi.type.RowType.anonymousRow; + +public class AreaUnderRocCurveStateSerializer implements AccumulatorStateSerializer +{ + static final ArrayType SERIALIZED_TYPE = new ArrayType(anonymousRow(BooleanType.BOOLEAN, DoubleType.DOUBLE)); + + @Override + public Type getSerializedType() { + return SERIALIZED_TYPE; + } + + @Override + public void serialize(AreaUnderRocCurveState state, BlockBuilder out) { + state.get().serialize(out); + } + + @Override + public void deserialize(Block block, int index, AreaUnderRocCurveState state) { + state.deserialize(SERIALIZED_TYPE.getObject(block, index)); + } + +} diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java new file mode 100644 index 000000000000..cc1b5c756495 --- /dev/null +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java @@ -0,0 +1,35 @@ +package io.trino.plugin.truera; + +import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; +public class TestAreaUnderRocCurveAlgorithm { + + @Test + public void testComputeAucRocConstantYs() { + boolean[] ys = new boolean[3]; + double[] ysPred = new double[]{-1, 0, 1}; + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), Double.NaN); + } + + @Test + public void testComputeAucRocConstantYsPred() { + // Check if first element is only true. + boolean[] ys = new boolean[10]; + ys[0] = true; + double[] ysPred = new double[10]; + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5); + // Check if last element is only true. + ys = new boolean[10]; + ys[ys.length - 1] = true; + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5); + } + + @Test + public void testComputeAucRocRandomCase() { + boolean[] ys = {false, false, false, true, false, false, true, true, false, true, true, true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, false, true, false, false, true, true, false, true, false, false, false, false, false, true, false, false, false, true, false, true, false, true, true, false, false, false, true, true, true, false, true, false, false, true, true, false, true, false, true, true, false, false, true, false, false, true, true, true, false, false, true, true, false, true, false, true, false, true, true, true, false, true, true, false, true, true, true, true, false, false, false, true, true, false}; + double[] ysPred = {-0.2286298738575081, -0.26869038039536985, -0.312319527551126, 0.2734721402453152, -0.24223018509401162, 0.31503992286551685, 0.4046447140511742, 0.28792712595231496, 0.23384628640264638, 0.4428378129948476, 0.4543878651258779, -0.01374360133220165, 0.0007349034879994276, 0.09636778467067175, 0.13655154956520688, 0.17454951949942543, 0.4931408863331128, 0.23547728150236757, 0.37226953309841515, -0.15728692842143432, 0.15924354023313247, 0.4286270233907359, -0.4685474933247794, -0.25577981860345134, 0.4681668166617913, 0.11320750870358587, -0.2736801671178942, -0.3864589476069452, -0.399822172488029, 0.015685730721504587, 0.3321140926944811, 0.2530935731647217, -0.3959046404148484, -0.15179547616306643, -0.36416614908162126, -0.29774936545934994, -0.12315200126789438, -0.2644045710260078, -0.44906116392418005, -0.020580493381027964, -0.12273927226482872, -0.056684762439960235, -0.08061360953066143, -0.29603641015837856, -0.18856132377074108, 0.4235267740783706, -0.4606040899799605, 0.42625848143208045, -0.0011082145951951672, -0.10729132271105546, 0.05327228838675735, 0.23256669296146493, 0.44780010851745966, 0.3568537456277838, 0.1805343926823827, -0.21846170998705394, -0.18590956046362084, 0.2509080583778559, -0.22290770310271601, 0.09721402039749849, 0.12656860105942058, -0.4855359797468307, -0.39689924561298107, 0.06657447973146235, 0.2578990291023152, 0.4742641426354459, -0.18982849914795363, 0.2227804823488645, 0.35372747647450187, -0.4543332887713202, -0.2564141959803503, -0.3998731379722754, 0.4302966112027795, 0.20920354786320405, -0.3493332019987293, 0.22594844508851064, 0.3506301294801156, 0.49704339475860027, -0.06293993755501126, 0.3918169521501762, -0.09311136809088782, 0.1600693305289207, -0.15765165599671038, 0.45406593535839646, 0.25178211933304273, -0.10253079227341522, -0.07396783373899074, -0.2779385207339454, 0.39166810505415206, -0.21097405310184147, 0.49742358104577056, 0.2503826842835034, 0.06827961053326548, -0.14270071261942086, 0.2718756995005246, -0.40524883036019543, 0.017045033508361507, 0.38197771242907064, 0.26845558760844856, -0.0965167478978054}; + double auc = AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred); + assertEquals(auc, 0.8187274909963985); + } + +} diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java new file mode 100644 index 000000000000..95aece0bcab6 --- /dev/null +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java @@ -0,0 +1,89 @@ +package io.trino.plugin.truera; + +import static io.trino.spi.security.SelectedRole.Type.ROLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import io.trino.plugin.hive.HivePlugin; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.iceberg.IcebergPlugin; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SelectedRole; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import java.io.File; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(singleThreaded = true) +public class TestAucRocAggregation extends AbstractTestQueryFramework { + private ImmutableMap.Builder icebergProperties = ImmutableMap.builder(); + private Optional metastoreDirectory = Optional.empty(); + public static final String ICEBERG_CATALOG = "iceberg"; + @Override + protected DistributedQueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setIdentity(Identity.forUser("hive") + .withConnectorRole("hive", new SelectedRole(ROLE, Optional.of("admin"))) + .build()) + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); + + queryRunner.installPlugin(new IcebergPlugin()); + Map icebergProperties = new HashMap<>(this.icebergProperties.buildOrThrow()); + String catalogType = icebergProperties.get("iceberg.catalog.type"); + Path dataDir = metastoreDirectory.map(File::toPath).orElseGet(() -> queryRunner.getCoordinator().getBaseDataDir().resolve("iceberg_data")); + if (catalogType == null) { + icebergProperties.put("iceberg.catalog.type", "TESTING_FILE_METASTORE"); + icebergProperties.put("hive.metastore.catalog.dir", dataDir.toString()); + } + + queryRunner.createCatalog(ICEBERG_CATALOG, "iceberg", icebergProperties); + queryRunner.installPlugin(new HivePlugin()); + queryRunner.createCatalog("hive", "hive", ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", dataDir.toString()) + .put("hive.security", "sql-standard") + .buildOrThrow()); + queryRunner.installPlugin(new TrueraTrinoPlugin()); + + return queryRunner; + } + + @BeforeClass + public void setUp() + { + assertQuerySucceeds("CREATE SCHEMA hive.test_schema"); + assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_probits_table (__id__ VARCHAR, __preds__ double)"); + assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_labels_table (__id__ VARCHAR, __label__ int)"); + assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_probits_table VALUES ('A', 0.5), ('B', 0.2), ('C', 0.9)"); + assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_labels_table VALUES ('A', 1), ('B', 0), ('C', 1)"); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_labels_table"); + assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_probits_table"); + assertQuerySucceeds("DROP SCHEMA IF EXISTS hive.test_schema"); + } + + @Test + public void testAucRoc() + { + assertQuery("SELECT roc_auc(__preds__, CAST(__label__ as boolean)) " + + "FROM iceberg.test_schema.iceberg_probits_table as pred " + + "JOIN iceberg.test_schema.iceberg_labels_table AS label " + + "ON pred.__id__ = label.__id__", + "VALUES (NULL)"); + } + +} diff --git a/pom.xml b/pom.xml index 9864b6b6b4f2..147f5764eed2 100644 --- a/pom.xml +++ b/pom.xml @@ -182,6 +182,7 @@ plugin/trino-thrift-testing-server plugin/trino-tpcds plugin/trino-tpch + plugin/trino-truera service/trino-proxy service/trino-verifier testing/trino-benchmark @@ -739,6 +740,12 @@ ${project.version} + + io.trino + trino-truera + ${project.version} + + io.trino.benchto