From 773200a8216c587e88b5d81ce3fb646a0a5dd2ad Mon Sep 17 00:00:00 2001 From: Divya Gopinath Date: Wed, 25 Jan 2023 14:34:03 -0800 Subject: [PATCH 01/10] Initial custom func --- lib/trino-array/TrueraTrinoPlugin.java | 2 + plugin/trino-truera/pom.xml | 30 +++++++++++++ .../plugin/truera/TrueraTrinoPlugin.java | 30 +++++++++++++ .../truera/aggregation/DivAggregation.java | 45 +++++++++++++++++++ .../aggregation/LongAndDoubleState.java | 28 ++++++++++++ pom.xml | 6 +++ 6 files changed, 141 insertions(+) create mode 100644 lib/trino-array/TrueraTrinoPlugin.java create mode 100644 plugin/trino-truera/pom.xml create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java diff --git a/lib/trino-array/TrueraTrinoPlugin.java b/lib/trino-array/TrueraTrinoPlugin.java new file mode 100644 index 000000000000..372ff013b4c0 --- /dev/null +++ b/lib/trino-array/TrueraTrinoPlugin.java @@ -0,0 +1,2 @@ +public class TrueraTrinoPlugin { +} diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml new file mode 100644 index 000000000000..bf79bddca383 --- /dev/null +++ b/plugin/trino-truera/pom.xml @@ -0,0 +1,30 @@ + + + 4.0.0 + + trino-root + io.trino + 389 + ../../pom.xml + + + trino-truera + + Trino Truera Extensions + trino-plugin + + + ${project.parent.basedir} + + + + + io.trino + trino-spi + provided + + + + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java new file mode 100644 index 000000000000..5ba344bafd07 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera; + +import io.trino.spi.Plugin; +import io.trino.plugin.truera.aggregation.DivAggregation; + +import java.util.Collections; +import java.util.Set; + +public class TrueraTrinoPlugin + implements Plugin +{ + @Override + public Set> getFunctions() + { + return Collections.singleton(DivAggregation.class); + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java new file mode 100644 index 000000000000..f90a32768860 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera.aggregation; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.*; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.StandardTypes; + +@AggregationFunction("roc_auc") +public class DivAggregation { + @InputFunction + public static void input(LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double prediction, @SqlType(StandardTypes.DOUBLE) double label) { + state.setLong(state.getLong() + 1); + state.setDouble(state.getDouble() + (prediction * label)); + } + + @CombineFunction + public static void combine(LongAndDoubleState state, LongAndDoubleState otherState) { + state.setLong(state.getLong() + otherState.getLong()); + state.setDouble(state.getDouble() + otherState.getDouble()); + } + + @OutputFunction(StandardTypes.DOUBLE) + public static void output(LongAndDoubleState state, BlockBuilder out) { + long count = state.getLong(); + if (count == 0) { + out.appendNull(); + } else { + double value = state.getDouble(); + DoubleType.DOUBLE.writeDouble(out, value / count); + } + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java new file mode 100644 index 000000000000..485ed05f7556 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera.aggregation; + +import io.trino.spi.function.AccumulatorState; + +public interface LongAndDoubleState + extends AccumulatorState +{ + long getLong(); + + void setLong(long value); + + double getDouble(); + + void setDouble(double value); +} diff --git a/pom.xml b/pom.xml index 580ba5065144..271e8a1c4873 100644 --- a/pom.xml +++ b/pom.xml @@ -722,6 +722,12 @@ ${project.version} + + io.trino + trino-truera + ${project.version} + + io.trino.benchto From 21ae628a1daffcaae8032a2decb8f5e76cb9aba2 Mon Sep 17 00:00:00 2001 From: Divya Gopinath Date: Wed, 1 Feb 2023 23:22:42 -0800 Subject: [PATCH 02/10] Try to refactor for AUC --- plugin/trino-truera/README.md | 19 ++ plugin/trino-truera/pom.xml | 14 ++ .../plugin/truera/TrueraTrinoPlugin.java | 4 +- .../aggregation/AUCBlockBuilderStatus.java | 5 + .../truera/aggregation/DivAggregation.java | 45 ----- .../aggregation/GroupedRocAucCurve.java | 173 ++++++++++++++++++ .../aggregation/LongAndDoubleState.java | 28 --- .../truera/aggregation/ROCAUCAggregation.java | 57 ++++++ .../truera/state/AreaUnderRocCurveState.java | 18 ++ .../state/AreaUnderRocCurveStateFactory.java | 130 +++++++++++++ .../AreaUnderRocCurveStateSerializer.java | 32 ++++ 11 files changed, 450 insertions(+), 75 deletions(-) create mode 100644 plugin/trino-truera/README.md create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java delete mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java delete mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java diff --git a/plugin/trino-truera/README.md b/plugin/trino-truera/README.md new file mode 100644 index 000000000000..20a746adafcd --- /dev/null +++ b/plugin/trino-truera/README.md @@ -0,0 +1,19 @@ +# TruEra Trino Extensions +This module is a custom Trino plugin for TruEra. Currently it just contains one new function (ROC-AUC). + +## How do I add to the plugin? +To get started, read the Trino [dev guide](https://trino.io/docs/current/develop/spi-overview.html#) + +## How do I test the package? +1. Compile using `mvn clean install -Dair.check.skip-all=true -DskipTests`. +2. Look at the `target` folder which should now have a packaged ZIP like: `trino-truera-389.zip` +3. Unzip the package into the trino plugins directory: `unzip trino-truera-389.zip -d ~/external_dependencies/trino-server/plugin` +4. Restart Trino (`./service.sh stop trino && ./service.sh start trino`) + +You can test the "roc_auc" function with a command like: +```sql +SELECT roc_auc(__truera_prediction__, CAST(__truera_label__ as boolean)) +FROM "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_predictions_probits_score" as pred +JOIN "iceberg"."tablestore"."a83db5d335ab494590cd7ada132707ad_label" +AS label ON pred.__truera_id__ = label.__truera_id__; +``` diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml index bf79bddca383..72f5f8e6c212 100644 --- a/plugin/trino-truera/pom.xml +++ b/plugin/trino-truera/pom.xml @@ -25,6 +25,20 @@ trino-spi provided + + io.trino + trino-array + + + io.airlift + slice + provided + + + org.openjdk.jol + jol-core + provided + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java index 5ba344bafd07..8bb99034098d 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/TrueraTrinoPlugin.java @@ -14,7 +14,7 @@ package io.trino.plugin.truera; import io.trino.spi.Plugin; -import io.trino.plugin.truera.aggregation.DivAggregation; +import io.trino.plugin.truera.aggregation.ROCAUCAggregation; import java.util.Collections; import java.util.Set; @@ -25,6 +25,6 @@ public class TrueraTrinoPlugin @Override public Set> getFunctions() { - return Collections.singleton(DivAggregation.class); + return Collections.singleton(ROCAUCAggregation.class); } } diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java new file mode 100644 index 000000000000..66e8761f8d4e --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/AUCBlockBuilderStatus.java @@ -0,0 +1,5 @@ +package io.trino.plugin.truera.aggregation; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.PageBuilderStatus; diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java deleted file mode 100644 index f90a32768860..000000000000 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/DivAggregation.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.truera.aggregation; - -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.*; -import io.trino.spi.type.DoubleType; -import io.trino.spi.type.StandardTypes; - -@AggregationFunction("roc_auc") -public class DivAggregation { - @InputFunction - public static void input(LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double prediction, @SqlType(StandardTypes.DOUBLE) double label) { - state.setLong(state.getLong() + 1); - state.setDouble(state.getDouble() + (prediction * label)); - } - - @CombineFunction - public static void combine(LongAndDoubleState state, LongAndDoubleState otherState) { - state.setLong(state.getLong() + otherState.getLong()); - state.setDouble(state.getDouble() + otherState.getDouble()); - } - - @OutputFunction(StandardTypes.DOUBLE) - public static void output(LongAndDoubleState state, BlockBuilder out) { - long count = state.getLong(); - if (count == 0) { - out.appendNull(); - } else { - double value = state.getDouble(); - DoubleType.DOUBLE.writeDouble(out, value / count); - } - } -} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java new file mode 100644 index 000000000000..de3840c90d2a --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java @@ -0,0 +1,173 @@ +package io.trino.plugin.truera.aggregation; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.IntStream; + +import io.trino.array.BooleanBigArray; +import io.trino.array.IntBigArray; +import io.trino.array.LongBigArray; +import io.trino.array.DoubleBigArray; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class GroupedRocAucCurve { + + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedRocAucCurve.class).instanceSize(); + private static final int NULL = -1; + + // one entry per group + // each entry is the index of the first elements of the group in the labels/scores/nextLinks arrays + private final LongBigArray headIndices; + + // one entry per double/boolean pair + private final BooleanBigArray labels; + private final DoubleBigArray scores; + + // the value in nextLinks contains the index of the next value in the chain + // a value NULL (-1) indicates it is the last value for the group + private final IntBigArray nextLinks; + + // the index of the next free element in the labels/scores/nextLinks arrays + // this is needed to be able to know where to continue adding elements when after the arrays are resized + private int nextFreeIndex; + + private long currentGroupId = -1; + + public GroupedRocAucCurve() { + this.headIndices = new LongBigArray(NULL); + this.labels = new BooleanBigArray(); + this.scores = new DoubleBigArray(); + this.nextLinks = new IntBigArray(NULL); + this.nextFreeIndex = 0; + } + + public GroupedRocAucCurve(long groupId, Block serialized) { + this(); + this.currentGroupId = groupId; + + requireNonNull(serialized, "serialized block is null"); + for (int i = 0; i < serialized.getPositionCount(); i++) { + Block entryBlock = serialized.getObject(i, Block.class); + add(entryBlock, entryBlock, 0, 1); + } + } + + public void serialize(BlockBuilder out) { + if (isCurrentGroupEmpty()) { + out.appendNull(); + return; + } + + // retrieve scores + labels + List labelList = new ArrayList<>(); + List scoreList = new ArrayList<>(); + + int currentIndex = (int) headIndices.get(currentGroupId); + while (currentIndex != NULL) { + labelList.add(labels.get(currentIndex)); + scoreList.add(scores.get(currentIndex)); + currentIndex = nextLinks.get(currentIndex); + } + + // convert lists to primitive arrays + boolean[] labels = new boolean[labelList.size()]; + for (int i = 0; i < labels.length; i++) { + labels[i] = labelList.get(i); + } + double[] scores = scoreList.stream().mapToDouble(Double::doubleValue).toArray(); + + // compute + return + double auc = computeRocAuc(labels, scores); + if (Double.isFinite(auc)) { + DoubleType.DOUBLE.writeDouble(out, auc); + } else { + out.appendNull(); + } + } + + public long estimatedInMemorySize() { + return INSTANCE_SIZE + labels.sizeOf() + scores.sizeOf() + nextLinks.sizeOf() + headIndices.sizeOf(); + } + + public GroupedRocAucCurve setGroupId(long groupId) { + this.currentGroupId = groupId; + return this; + } + + public void add(Block labelsBlock, Block scoresBlock, int labelPosition, int scorePosition) { + ensureCapacity(currentGroupId + 1); + + labels.set(nextFreeIndex, BooleanType.BOOLEAN.getBoolean(labelsBlock, labelPosition)); + scores.set(nextFreeIndex, DoubleType.DOUBLE.getDouble(scoresBlock, scorePosition)); + nextLinks.set(nextFreeIndex, (int) headIndices.get(currentGroupId)); + nextFreeIndex++; + } + + public void ensureCapacity(long numberOfGroups) { + headIndices.ensureCapacity(numberOfGroups); + int numberOfValues = nextFreeIndex + 1; + labels.ensureCapacity(numberOfValues); + scores.ensureCapacity(numberOfValues); + nextLinks.ensureCapacity(numberOfValues); + } + + public void addAll(GroupedRocAucCurve other) { + other.readAll(this); + } + + public void readAll(GroupedRocAucCurve to) { + int currentIndex = (int) headIndices.get(currentGroupId); + while (currentIndex != NULL) { + BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0); + BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, labels.get(currentIndex)); + BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0); + DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, scores.get(currentIndex)); + + to.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0); + currentIndex = nextLinks.get(currentIndex); + } + } + + public boolean isCurrentGroupEmpty() { + return headIndices.get(currentGroupId) == NULL; + } + + private static double computeRocAuc(boolean[] labels, double[] scores) { + int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( + Comparator.comparing(i -> scores[i]) + ).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray(); + + int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0; + double auc = 0.; + + for (int i : sortedIndices) { + if (labels[i]) { currTruePositives++; } else { currFalsePositives++; }; + prevTruePositives = currTruePositives; + prevFalsePositives = currFalsePositives; + auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); + } + + // If labels only contain one class, AUC is undefined + if (currTruePositives == 0 || currFalsePositives == 0) { + return Double.POSITIVE_INFINITY; + } + + return auc / (currTruePositives * currFalsePositives); + } + + private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) { + return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2 + } + +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java deleted file mode 100644 index 485ed05f7556..000000000000 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/LongAndDoubleState.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.truera.aggregation; - -import io.trino.spi.function.AccumulatorState; - -public interface LongAndDoubleState - extends AccumulatorState -{ - long getLong(); - - void setLong(long value); - - double getDouble(); - - void setDouble(double value); -} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java new file mode 100644 index 000000000000..81e94c13d0a2 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.truera.aggregation; + +import io.trino.plugin.truera.state.AreaUnderRocCurveState; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.function.*; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.StandardTypes; + +@AggregationFunction("roc_auc") +public class ROCAUCAggregation { + @InputFunction + public static void input(AreaUnderRocCurveState state, @SqlType(StandardTypes.DOUBLE) double score, @SqlType(StandardTypes.BOOLEAN) boolean label) { + GroupedRocAucCurve auc = state.get(); + BlockBuilder labelBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, 0); + BooleanType.BOOLEAN.writeBoolean(labelBlockBuilder, label); + BlockBuilder scoreBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, 0); + DoubleType.DOUBLE.writeDouble(scoreBlockBuilder, score); + + long startSize = auc.estimatedInMemorySize(); + auc.add(labelBlockBuilder.build(), scoreBlockBuilder.build(), 0, 0); + state.addMemoryUsage(auc.estimatedInMemorySize() - startSize); + } + + @CombineFunction + public static void combine(AreaUnderRocCurveState state, AreaUnderRocCurveState otherState) { + if (!state.get().isCurrentGroupEmpty() && !otherState.get().isCurrentGroupEmpty()) { + GroupedRocAucCurve auc = state.get(); + long startSize = auc.estimatedInMemorySize(); + auc.addAll(otherState.get()); + state.addMemoryUsage(auc.estimatedInMemorySize() - startSize); + } + else if (state.get().isCurrentGroupEmpty()) { + state.set(otherState.get()); + } + } + + @OutputFunction(StandardTypes.DOUBLE) + public static void output(AreaUnderRocCurveState state, BlockBuilder out) { + GroupedRocAucCurve auc = state.get(); + auc.serialize(out); + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java new file mode 100644 index 000000000000..8a28a48b4831 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveState.java @@ -0,0 +1,18 @@ +package io.trino.plugin.truera.state; + +import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata(stateFactoryClass = AreaUnderRocCurveStateFactory.class, stateSerializerClass = AreaUnderRocCurveStateSerializer.class) +public interface AreaUnderRocCurveState extends AccumulatorState +{ + GroupedRocAucCurve get(); + + void set(GroupedRocAucCurve value); + + void addMemoryUsage(long memory); + + void deserialize(Block serialized); +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java new file mode 100644 index 000000000000..9d0b911944f7 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java @@ -0,0 +1,130 @@ +package io.trino.plugin.truera.state; + +import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class AreaUnderRocCurveStateFactory implements AccumulatorStateFactory +{ + @Override + public AreaUnderRocCurveState createSingleState() { + return new SingleState(); + } + + @Override + public AreaUnderRocCurveState createGroupedState() { + return new GroupedState(); + } + + public static class GroupedState implements AreaUnderRocCurveState, GroupedAccumulatorState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); + private GroupedRocAucCurve auc; + private long size; + + public GroupedState() { + auc = new GroupedRocAucCurve(); + } + + @Override + public void setGroupId(long groupId) + { + auc.setGroupId(groupId); + } + + @Override + public void ensureCapacity(long size) { + auc.ensureCapacity(size); + } + + @Override + public GroupedRocAucCurve get() { + return auc; + } + + @Override + public void set(GroupedRocAucCurve value) + { + requireNonNull(value, "value is null"); + + GroupedRocAucCurve previous = get(); + if (previous != null) { + size -= previous.estimatedInMemorySize(); + } + + auc = value; + size += value.estimatedInMemorySize(); + } + + @Override + public void addMemoryUsage(long memory) + { + size += memory; + } + + @Override + public void deserialize(Block serialized) + { + this.auc = new GroupedRocAucCurve(0, serialized); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + size + auc.estimatedInMemorySize(); + } + } + + public static class SingleState + implements AreaUnderRocCurveState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); + private GroupedRocAucCurve auc; + + public SingleState() + { + auc = new GroupedRocAucCurve(); + + // set synthetic, unique group id to use GroupAreaUnderRocCurve from the single state + auc.setGroupId(0); + } + + @Override + public GroupedRocAucCurve get() + { + return auc; + } + + @Override + public void deserialize(Block serialized) + { + this.auc = new GroupedRocAucCurve(0, serialized); + } + + @Override + public void set(GroupedRocAucCurve value) + { + auc = value; + } + + @Override + public void addMemoryUsage(long memory) + { + } + + @Override + public long getEstimatedSize() + { + long estimatedSize = INSTANCE_SIZE; + if (auc != null) { + estimatedSize += auc.estimatedInMemorySize(); + } + return estimatedSize; + } + } + + +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java new file mode 100644 index 000000000000..078096fe119f --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java @@ -0,0 +1,32 @@ +package io.trino.plugin.truera.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Type; + +import static io.trino.spi.type.RowType.anonymousRow; + +public class AreaUnderRocCurveStateSerializer implements AccumulatorStateSerializer +{ + private static final ArrayType SERIALIZED_TYPE = new ArrayType(anonymousRow(BooleanType.BOOLEAN, DoubleType.DOUBLE)); + + @Override + public Type getSerializedType() { + return SERIALIZED_TYPE; + } + + @Override + public void serialize(AreaUnderRocCurveState state, BlockBuilder out) { + state.get().serialize(out); + } + + @Override + public void deserialize(Block block, int index, AreaUnderRocCurveState state) { + state.deserialize(SERIALIZED_TYPE.getObject(block, index)); + } + +} From 85ecbed5b27b250a5003ca34bb1d50f6a00096b4 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Fri, 10 Feb 2023 15:25:59 -0800 Subject: [PATCH 03/10] fixes to run stuff --- plugin/trino-truera/pom.xml | 27 ++++++++++++++----- .../aggregation/GroupedRocAucCurve.java | 2 -- .../truera/aggregation/ROCAUCAggregation.java | 1 - .../state/AreaUnderRocCurveStateFactory.java | 1 + .../AreaUnderRocCurveStateSerializer.java | 1 + pom.xml | 1 + 6 files changed, 23 insertions(+), 10 deletions(-) diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml index 72f5f8e6c212..6d50dceaafcf 100644 --- a/plugin/trino-truera/pom.xml +++ b/plugin/trino-truera/pom.xml @@ -1,7 +1,5 @@ - + 4.0.0 trino-root @@ -11,7 +9,6 @@ trino-truera - Trino Truera Extensions trino-plugin @@ -22,12 +19,12 @@ io.trino - trino-spi - provided + trino-array io.trino - trino-array + trino-spi + provided io.airlift @@ -41,4 +38,20 @@ + + + + + com.mycila + license-maven-plugin + + + src/main/** + + + + + + + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java index de3840c90d2a..fb2dfb42544c 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java @@ -13,8 +13,6 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.BlockBuilderStatus; -import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.type.BooleanType; import io.trino.spi.type.DoubleType; import org.openjdk.jol.info.ClassLayout; diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java index 81e94c13d0a2..781821eed918 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/ROCAUCAggregation.java @@ -15,7 +15,6 @@ import io.trino.plugin.truera.state.AreaUnderRocCurveState; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.function.*; import io.trino.spi.type.BooleanType; import io.trino.spi.type.DoubleType; diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java index 9d0b911944f7..1e85c146f81c 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java @@ -1,5 +1,6 @@ package io.trino.plugin.truera.state; +import io.trino.plugin.truera.state.AreaUnderRocCurveState; import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; import io.trino.spi.block.Block; import io.trino.spi.function.AccumulatorStateFactory; diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java index 078096fe119f..a3c002ce9ee6 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java @@ -1,5 +1,6 @@ package io.trino.plugin.truera.state; +import io.trino.plugin.truera.state.AreaUnderRocCurveState; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; diff --git a/pom.xml b/pom.xml index a8025e6ae78f..0b2896b61b6a 100644 --- a/pom.xml +++ b/pom.xml @@ -178,6 +178,7 @@ plugin/trino-thrift-testing-server plugin/trino-tpcds plugin/trino-tpch + plugin/trino-truera service/trino-proxy service/trino-verifier testing/trino-benchmark From 1a776590811599ed976fb3537ca7f506706f99a3 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Mon, 13 Feb 2023 11:38:41 -0800 Subject: [PATCH 04/10] add testing --- plugin/trino-truera/pom.xml | 5 +++ .../trino/plugin/truera/ROCAUCFunction.java | 34 +++++++++++++++++++ .../aggregation/GroupedRocAucCurve.java | 33 ++---------------- .../plugin/truera/TestROCAUCFunction.java | 33 ++++++++++++++++++ 4 files changed, 74 insertions(+), 31 deletions(-) create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java create mode 100644 plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml index 6d50dceaafcf..225616853fdc 100644 --- a/plugin/trino-truera/pom.xml +++ b/plugin/trino-truera/pom.xml @@ -36,6 +36,11 @@ jol-core provided + + org.testng + testng + test + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java new file mode 100644 index 000000000000..7db6c412b1dd --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java @@ -0,0 +1,34 @@ +package io.trino.plugin.truera; + +import java.util.Collections; +import java.util.Comparator; +import java.util.stream.IntStream; + +public class ROCAUCFunction { + public static double computeRocAuc(boolean[] labels, double[] scores) { + int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( + Comparator.comparing(i -> scores[i]) + ).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray(); + + int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0; + double auc = 0.; + + for (int i : sortedIndices) { + if (labels[i]) { currTruePositives++; } else { currFalsePositives++; }; + prevTruePositives = currTruePositives; + prevFalsePositives = currFalsePositives; + auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); + } + + // If labels only contain one class, AUC is undefined + if (currTruePositives == 0 || currFalsePositives == 0) { + return Double.POSITIVE_INFINITY; + } + + return auc / (currTruePositives * currFalsePositives); + } + + private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) { + return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2 + } +} diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java index fb2dfb42544c..116ba436cb98 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java @@ -1,10 +1,8 @@ package io.trino.plugin.truera.aggregation; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.List; -import java.util.stream.IntStream; + import io.trino.array.BooleanBigArray; import io.trino.array.IntBigArray; @@ -17,6 +15,7 @@ import io.trino.spi.type.DoubleType; import org.openjdk.jol.info.ClassLayout; +import static io.trino.plugin.truera.ROCAUCFunction.computeRocAuc; import static java.util.Objects.requireNonNull; public class GroupedRocAucCurve { @@ -140,32 +139,4 @@ public void readAll(GroupedRocAucCurve to) { public boolean isCurrentGroupEmpty() { return headIndices.get(currentGroupId) == NULL; } - - private static double computeRocAuc(boolean[] labels, double[] scores) { - int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( - Comparator.comparing(i -> scores[i]) - ).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray(); - - int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0; - double auc = 0.; - - for (int i : sortedIndices) { - if (labels[i]) { currTruePositives++; } else { currFalsePositives++; }; - prevTruePositives = currTruePositives; - prevFalsePositives = currFalsePositives; - auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); - } - - // If labels only contain one class, AUC is undefined - if (currTruePositives == 0 || currFalsePositives == 0) { - return Double.POSITIVE_INFINITY; - } - - return auc / (currTruePositives * currFalsePositives); - } - - private static double trapezoidIntegrate(double x1, double x2, double y1, double y2) { - return (y1 + y2) * Math.abs(x2 - x1) / 2; // (base1 + base2) * height / 2 - } - } diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java new file mode 100644 index 000000000000..379bb9d6a04a --- /dev/null +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java @@ -0,0 +1,33 @@ +package io.trino.plugin.truera; + +import java.util.Arrays; +import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; +public class TestROCAUCFunction { + @Test + public void testComputeAucRocWhenUndefined() { + + boolean[] testLabels = new boolean[10]; + double[] testProbabilities = new double[10]; + assertEquals(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities), Double.POSITIVE_INFINITY); + } + + @Test + public void testComputeAucRoc() { + + boolean[] testLabels = new boolean[10]; + testLabels[9] = true; + double[] testProbabilities = new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617, + 0.62190361, 0.34917899, 0.88604834, 0.09936481, 0.65903197}; +// Arrays.fill(testProbabilities, 1.0); + for (boolean element: testLabels) { + System.out.println(element); + } + for (double element: testProbabilities) { + System.out.println(element); + } + System.out.println("score"); + System.out.println(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities)); + } + +} From 266e1f0f1d3e4429b35c58c298453533086b0c39 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Mon, 13 Feb 2023 13:35:56 -0800 Subject: [PATCH 05/10] fix sorting --- .../java/io/trino/plugin/truera/ROCAUCFunction.java | 4 ++-- .../io/trino/plugin/truera/TestROCAUCFunction.java | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java index 7db6c412b1dd..1e7dc728956f 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java @@ -7,8 +7,8 @@ public class ROCAUCFunction { public static double computeRocAuc(boolean[] labels, double[] scores) { int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( - Comparator.comparing(i -> scores[i]) - ).sorted(Collections.reverseOrder()).mapToInt(i->i).toArray(); + Comparator.comparing(i -> scores[i], Comparator.reverseOrder()) + ).mapToInt(i->i).toArray(); int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0; double auc = 0.; diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java index 379bb9d6a04a..d08f0fdb7490 100644 --- a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java @@ -20,12 +20,12 @@ public void testComputeAucRoc() { double[] testProbabilities = new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617, 0.62190361, 0.34917899, 0.88604834, 0.09936481, 0.65903197}; // Arrays.fill(testProbabilities, 1.0); - for (boolean element: testLabels) { - System.out.println(element); - } - for (double element: testProbabilities) { - System.out.println(element); - } +// for (boolean element: testLabels) { +// System.out.println(element); +// } +// for (double element: testProbabilities) { +// System.out.println(element); +// } System.out.println("score"); System.out.println(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities)); } From 61899d105ffca54aab2816f40daddc58381ce520 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Mon, 13 Feb 2023 14:56:45 -0800 Subject: [PATCH 06/10] debugging --- .../trino/plugin/truera/ROCAUCFunction.java | 21 ++++++++++++---- .../plugin/truera/TestROCAUCFunction.java | 24 ++++++++----------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java index 1e7dc728956f..29df393020a2 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java @@ -1,6 +1,5 @@ package io.trino.plugin.truera; -import java.util.Collections; import java.util.Comparator; import java.util.stream.IntStream; @@ -10,14 +9,26 @@ public static double computeRocAuc(boolean[] labels, double[] scores) { Comparator.comparing(i -> scores[i], Comparator.reverseOrder()) ).mapToInt(i->i).toArray(); - int currTruePositives = 0, currFalsePositives = 0, prevTruePositives =0, prevFalsePositives = 0; +// for (int element: sortedIndices) { +// System.out.println(element); +// } + + int currTruePositives = 0, currFalsePositives = 0; double auc = 0.; for (int i : sortedIndices) { - if (labels[i]) { currTruePositives++; } else { currFalsePositives++; }; - prevTruePositives = currTruePositives; - prevFalsePositives = currFalsePositives; + int prevTruePositives = currTruePositives; + int prevFalsePositives = currFalsePositives; + if (labels[i]) { currTruePositives++; } else { currFalsePositives++; } +// System.out.println("FP"); +// System.out.println(prevFalsePositives); +// System.out.println(currFalsePositives); +// System.out.println("TP"); +// System.out.println(prevTruePositives); +// System.out.println(currTruePositives); auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); +// System.out.println("auc"); +// System.out.println(auc); } // If labels only contain one class, AUC is undefined diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java index d08f0fdb7490..bef3d56e51ac 100644 --- a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java @@ -6,7 +6,6 @@ public class TestROCAUCFunction { @Test public void testComputeAucRocWhenUndefined() { - boolean[] testLabels = new boolean[10]; double[] testProbabilities = new double[10]; assertEquals(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities), Double.POSITIVE_INFINITY); @@ -14,20 +13,17 @@ public void testComputeAucRocWhenUndefined() { @Test public void testComputeAucRoc() { - - boolean[] testLabels = new boolean[10]; - testLabels[9] = true; - double[] testProbabilities = new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617, + boolean[] testLabels1 = new boolean[10]; + testLabels1[9] = true; + double[] testProbabilities1= new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617, 0.62190361, 0.34917899, 0.88604834, 0.09936481, 0.65903197}; -// Arrays.fill(testProbabilities, 1.0); -// for (boolean element: testLabels) { -// System.out.println(element); -// } -// for (double element: testProbabilities) { -// System.out.println(element); -// } - System.out.println("score"); - System.out.println(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities)); + assertEquals(String.format("%.2f",ROCAUCFunction.computeRocAuc(testLabels1, testProbabilities1)), "0.67"); + + boolean[] testLabels2 = new boolean[10]; + testLabels2[9] = true; + double[] testProbabilities2 = new double[10]; + Arrays.fill(testProbabilities2, 1.0); + assertEquals(ROCAUCFunction.computeRocAuc(testLabels2, testProbabilities2), "0.5"); } } From bb719a9c75316816217630411933a6a4d0d400a2 Mon Sep 17 00:00:00 2001 From: David Kurokawa Date: Thu, 16 Feb 2023 09:58:21 -0800 Subject: [PATCH 07/10] Fix AUC calculation. --- .../trino/plugin/truera/ROCAUCFunction.java | 27 ++++++-------- .../plugin/truera/TestROCAUCFunction.java | 37 +++++++++++-------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java index 29df393020a2..8a8a15362562 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java @@ -9,31 +9,28 @@ public static double computeRocAuc(boolean[] labels, double[] scores) { Comparator.comparing(i -> scores[i], Comparator.reverseOrder()) ).mapToInt(i->i).toArray(); -// for (int element: sortedIndices) { -// System.out.println(element); -// } - int currTruePositives = 0, currFalsePositives = 0; double auc = 0.; - for (int i : sortedIndices) { + int i = 0; + while (i < sortedIndices.length) { int prevTruePositives = currTruePositives; int prevFalsePositives = currFalsePositives; - if (labels[i]) { currTruePositives++; } else { currFalsePositives++; } -// System.out.println("FP"); -// System.out.println(prevFalsePositives); -// System.out.println(currFalsePositives); -// System.out.println("TP"); -// System.out.println(prevTruePositives); -// System.out.println(currTruePositives); + double currentScore = scores[sortedIndices[i]]; + while (i < sortedIndices.length && currentScore == scores[sortedIndices[i]]) { + if (labels[sortedIndices[i]]) { + currTruePositives++; + } else { + currFalsePositives++; + } + ++i; + } auc += trapezoidIntegrate(prevFalsePositives, currFalsePositives, prevTruePositives, currTruePositives); -// System.out.println("auc"); -// System.out.println(auc); } // If labels only contain one class, AUC is undefined if (currTruePositives == 0 || currFalsePositives == 0) { - return Double.POSITIVE_INFINITY; + return Double.NaN; } return auc / (currTruePositives * currFalsePositives); diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java index bef3d56e51ac..3518dfccb4e4 100644 --- a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java @@ -4,26 +4,33 @@ import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; public class TestROCAUCFunction { + @Test - public void testComputeAucRocWhenUndefined() { - boolean[] testLabels = new boolean[10]; - double[] testProbabilities = new double[10]; - assertEquals(ROCAUCFunction.computeRocAuc(testLabels, testProbabilities), Double.POSITIVE_INFINITY); + public void testComputeAucRocConstantYs() { + boolean[] ys = new boolean[3]; + double[] ysPred = new double[]{-1, 0, 1}; + assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), Double.NaN); } @Test - public void testComputeAucRoc() { - boolean[] testLabels1 = new boolean[10]; - testLabels1[9] = true; - double[] testProbabilities1= new double[]{0.21206135, 0.97905249, 0.6460657 , 0.83698787, 0.40314617, - 0.62190361, 0.34917899, 0.88604834, 0.09936481, 0.65903197}; - assertEquals(String.format("%.2f",ROCAUCFunction.computeRocAuc(testLabels1, testProbabilities1)), "0.67"); + public void testComputeAucRocConstantYsPred() { + // Check if first element is only true. + boolean[] ys = new boolean[10]; + ys[0] = true; + double[] ysPred = new double[10]; + assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), 0.5); + // Check if last element is only true. + ys = new boolean[10]; + ys[ys.length - 1] = true; + assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), 0.5); + } - boolean[] testLabels2 = new boolean[10]; - testLabels2[9] = true; - double[] testProbabilities2 = new double[10]; - Arrays.fill(testProbabilities2, 1.0); - assertEquals(ROCAUCFunction.computeRocAuc(testLabels2, testProbabilities2), "0.5"); + @Test + public void testComputeAucRocRandomCase() { + boolean[] ys = {false, false, false, true, false, false, true, true, false, true, true, true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, false, true, false, false, true, true, false, true, false, false, false, false, false, true, false, false, false, true, false, true, false, true, true, false, false, false, true, true, true, false, true, false, false, true, true, false, true, false, true, true, false, false, true, false, false, true, true, true, false, false, true, true, false, true, false, true, false, true, true, true, false, true, true, false, true, true, true, true, false, false, false, true, true, false}; + double[] ysPred = {-0.2286298738575081, -0.26869038039536985, -0.312319527551126, 0.2734721402453152, -0.24223018509401162, 0.31503992286551685, 0.4046447140511742, 0.28792712595231496, 0.23384628640264638, 0.4428378129948476, 0.4543878651258779, -0.01374360133220165, 0.0007349034879994276, 0.09636778467067175, 0.13655154956520688, 0.17454951949942543, 0.4931408863331128, 0.23547728150236757, 0.37226953309841515, -0.15728692842143432, 0.15924354023313247, 0.4286270233907359, -0.4685474933247794, -0.25577981860345134, 0.4681668166617913, 0.11320750870358587, -0.2736801671178942, -0.3864589476069452, -0.399822172488029, 0.015685730721504587, 0.3321140926944811, 0.2530935731647217, -0.3959046404148484, -0.15179547616306643, -0.36416614908162126, -0.29774936545934994, -0.12315200126789438, -0.2644045710260078, -0.44906116392418005, -0.020580493381027964, -0.12273927226482872, -0.056684762439960235, -0.08061360953066143, -0.29603641015837856, -0.18856132377074108, 0.4235267740783706, -0.4606040899799605, 0.42625848143208045, -0.0011082145951951672, -0.10729132271105546, 0.05327228838675735, 0.23256669296146493, 0.44780010851745966, 0.3568537456277838, 0.1805343926823827, -0.21846170998705394, -0.18590956046362084, 0.2509080583778559, -0.22290770310271601, 0.09721402039749849, 0.12656860105942058, -0.4855359797468307, -0.39689924561298107, 0.06657447973146235, 0.2578990291023152, 0.4742641426354459, -0.18982849914795363, 0.2227804823488645, 0.35372747647450187, -0.4543332887713202, -0.2564141959803503, -0.3998731379722754, 0.4302966112027795, 0.20920354786320405, -0.3493332019987293, 0.22594844508851064, 0.3506301294801156, 0.49704339475860027, -0.06293993755501126, 0.3918169521501762, -0.09311136809088782, 0.1600693305289207, -0.15765165599671038, 0.45406593535839646, 0.25178211933304273, -0.10253079227341522, -0.07396783373899074, -0.2779385207339454, 0.39166810505415206, -0.21097405310184147, 0.49742358104577056, 0.2503826842835034, 0.06827961053326548, -0.14270071261942086, 0.2718756995005246, -0.40524883036019543, 0.017045033508361507, 0.38197771242907064, 0.26845558760844856, -0.0965167478978054}; + double auc = ROCAUCFunction.computeRocAuc(ys, ysPred); + assertEquals(auc, 0.8187274909963985); } } From 98536a604c41b2e4773a8a14321bd7a7c1a9efb8 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Thu, 6 Apr 2023 16:01:12 -0700 Subject: [PATCH 08/10] get test up for plugin --- .../plugin/geospatial/TestKdbTreeCasts.java | 1 + plugin/trino-truera/pom.xml | 42 ++++++++++++++++++- ...n.java => AreaUnderRocCurveAlgorithm.java} | 6 ++- .../aggregation/GroupedRocAucCurve.java | 18 +++++--- .../state/AreaUnderRocCurveStateFactory.java | 5 +-- .../AreaUnderRocCurveStateSerializer.java | 3 +- ...va => TestAreaUnderRocCurveAlgorithm.java} | 11 +++-- 7 files changed, 68 insertions(+), 18 deletions(-) rename plugin/trino-truera/src/main/java/io/trino/plugin/truera/{ROCAUCFunction.java => AreaUnderRocCurveAlgorithm.java} (86%) rename plugin/trino-truera/src/test/java/io/trino/plugin/truera/{TestROCAUCFunction.java => TestAreaUnderRocCurveAlgorithm.java} (90%) diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java index 8c5303b830bb..77a6a7f00cb8 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java @@ -37,6 +37,7 @@ public void registerFunctions() public void test() { String kdbTreeJson = makeKdbTreeJson(); + System.out.println(kdbTreeJson); assertFunction(format("typeof(cast('%s' AS KdbTree))", kdbTreeJson), VARCHAR, "KdbTree"); assertFunction(format("typeof(cast('%s' AS KDBTree))", kdbTreeJson), VARCHAR, "KdbTree"); assertFunction(format("typeof(cast('%s' AS kdbTree))", kdbTreeJson), VARCHAR, "KdbTree"); diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml index 225616853fdc..a221b1cbf141 100644 --- a/plugin/trino-truera/pom.xml +++ b/plugin/trino-truera/pom.xml @@ -4,7 +4,7 @@ trino-root io.trino - 389 + 406 ../../pom.xml @@ -31,6 +31,10 @@ slice provided + + io.airlift + log + org.openjdk.jol jol-core @@ -41,6 +45,42 @@ testng test + + io.trino + trino-main + test + + + io.trino + trino-main + test-jar + test + + + io.trino + trino-main + test + + + io.trino + trino-testing + test + + + io.trino + trino-hive + test + + + io.trino + trino-iceberg + test + + + io.trino + trino-hive-hadoop2 + test + diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java similarity index 86% rename from plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java rename to plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java index 8a8a15362562..5ffbe9ec13a0 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/ROCAUCFunction.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/AreaUnderRocCurveAlgorithm.java @@ -1,10 +1,14 @@ package io.trino.plugin.truera; import java.util.Comparator; +import io.airlift.log.Logger; import java.util.stream.IntStream; -public class ROCAUCFunction { +public class AreaUnderRocCurveAlgorithm { + private static final Logger log = Logger.get(AreaUnderRocCurveAlgorithm.class); public static double computeRocAuc(boolean[] labels, double[] scores) { + log.info("cow"); + log.info("compute", labels.toString(), scores.toString()); int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted( Comparator.comparing(i -> scores[i], Comparator.reverseOrder()) ).mapToInt(i->i).toArray(); diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java index 116ba436cb98..93572e544854 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/aggregation/GroupedRocAucCurve.java @@ -1,5 +1,6 @@ package io.trino.plugin.truera.aggregation; +import io.airlift.log.Logger; import java.util.ArrayList; import java.util.List; @@ -15,12 +16,13 @@ import io.trino.spi.type.DoubleType; import org.openjdk.jol.info.ClassLayout; -import static io.trino.plugin.truera.ROCAUCFunction.computeRocAuc; +import static io.trino.plugin.truera.AreaUnderRocCurveAlgorithm.computeRocAuc; import static java.util.Objects.requireNonNull; public class GroupedRocAucCurve { + private static final Logger log = Logger.get(GroupedRocAucCurve.class); - private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedRocAucCurve.class).instanceSize(); + private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedRocAucCurve.class).instanceSize(); private static final int NULL = -1; // one entry per group @@ -83,13 +85,15 @@ public void serialize(BlockBuilder out) { labels[i] = labelList.get(i); } double[] scores = scoreList.stream().mapToDouble(Double::doubleValue).toArray(); + log.info("cow2"); + log.info("compute", labels.toString(), scores.toString()); // compute + return double auc = computeRocAuc(labels, scores); - if (Double.isFinite(auc)) { - DoubleType.DOUBLE.writeDouble(out, auc); - } else { + if (Double.isNaN(auc)) { out.appendNull(); + } else { + DoubleType.DOUBLE.writeDouble(out, auc); } } @@ -102,6 +106,10 @@ public GroupedRocAucCurve setGroupId(long groupId) { return this; } + public long getGroupId() { + return this.currentGroupId; + } + public void add(Block labelsBlock, Block scoresBlock, int labelPosition, int scorePosition) { ensureCapacity(currentGroupId + 1); diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java index 1e85c146f81c..38fa9837cc64 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateFactory.java @@ -1,6 +1,5 @@ package io.trino.plugin.truera.state; -import io.trino.plugin.truera.state.AreaUnderRocCurveState; import io.trino.plugin.truera.aggregation.GroupedRocAucCurve; import io.trino.spi.block.Block; import io.trino.spi.function.AccumulatorStateFactory; @@ -22,7 +21,7 @@ public AreaUnderRocCurveState createGroupedState() { } public static class GroupedState implements AreaUnderRocCurveState, GroupedAccumulatorState { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); + private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); private GroupedRocAucCurve auc; private long size; @@ -82,7 +81,7 @@ public long getEstimatedSize() public static class SingleState implements AreaUnderRocCurveState { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); + private static final long INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); private GroupedRocAucCurve auc; public SingleState() diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java index a3c002ce9ee6..5b11d0e0e6d7 100644 --- a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/state/AreaUnderRocCurveStateSerializer.java @@ -1,6 +1,5 @@ package io.trino.plugin.truera.state; -import io.trino.plugin.truera.state.AreaUnderRocCurveState; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; @@ -13,7 +12,7 @@ public class AreaUnderRocCurveStateSerializer implements AccumulatorStateSerializer { - private static final ArrayType SERIALIZED_TYPE = new ArrayType(anonymousRow(BooleanType.BOOLEAN, DoubleType.DOUBLE)); + static final ArrayType SERIALIZED_TYPE = new ArrayType(anonymousRow(BooleanType.BOOLEAN, DoubleType.DOUBLE)); @Override public Type getSerializedType() { diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java similarity index 90% rename from plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java rename to plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java index 3518dfccb4e4..cc1b5c756495 100644 --- a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestROCAUCFunction.java +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAreaUnderRocCurveAlgorithm.java @@ -1,15 +1,14 @@ package io.trino.plugin.truera; -import java.util.Arrays; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; -public class TestROCAUCFunction { +public class TestAreaUnderRocCurveAlgorithm { @Test public void testComputeAucRocConstantYs() { boolean[] ys = new boolean[3]; double[] ysPred = new double[]{-1, 0, 1}; - assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), Double.NaN); + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), Double.NaN); } @Test @@ -18,18 +17,18 @@ public void testComputeAucRocConstantYsPred() { boolean[] ys = new boolean[10]; ys[0] = true; double[] ysPred = new double[10]; - assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), 0.5); + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5); // Check if last element is only true. ys = new boolean[10]; ys[ys.length - 1] = true; - assertEquals(ROCAUCFunction.computeRocAuc(ys, ysPred), 0.5); + assertEquals(AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred), 0.5); } @Test public void testComputeAucRocRandomCase() { boolean[] ys = {false, false, false, true, false, false, true, true, false, true, true, true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, false, true, false, false, true, true, false, true, false, false, false, false, false, true, false, false, false, true, false, true, false, true, true, false, false, false, true, true, true, false, true, false, false, true, true, false, true, false, true, true, false, false, true, false, false, true, true, true, false, false, true, true, false, true, false, true, false, true, true, true, false, true, true, false, true, true, true, true, false, false, false, true, true, false}; double[] ysPred = {-0.2286298738575081, -0.26869038039536985, -0.312319527551126, 0.2734721402453152, -0.24223018509401162, 0.31503992286551685, 0.4046447140511742, 0.28792712595231496, 0.23384628640264638, 0.4428378129948476, 0.4543878651258779, -0.01374360133220165, 0.0007349034879994276, 0.09636778467067175, 0.13655154956520688, 0.17454951949942543, 0.4931408863331128, 0.23547728150236757, 0.37226953309841515, -0.15728692842143432, 0.15924354023313247, 0.4286270233907359, -0.4685474933247794, -0.25577981860345134, 0.4681668166617913, 0.11320750870358587, -0.2736801671178942, -0.3864589476069452, -0.399822172488029, 0.015685730721504587, 0.3321140926944811, 0.2530935731647217, -0.3959046404148484, -0.15179547616306643, -0.36416614908162126, -0.29774936545934994, -0.12315200126789438, -0.2644045710260078, -0.44906116392418005, -0.020580493381027964, -0.12273927226482872, -0.056684762439960235, -0.08061360953066143, -0.29603641015837856, -0.18856132377074108, 0.4235267740783706, -0.4606040899799605, 0.42625848143208045, -0.0011082145951951672, -0.10729132271105546, 0.05327228838675735, 0.23256669296146493, 0.44780010851745966, 0.3568537456277838, 0.1805343926823827, -0.21846170998705394, -0.18590956046362084, 0.2509080583778559, -0.22290770310271601, 0.09721402039749849, 0.12656860105942058, -0.4855359797468307, -0.39689924561298107, 0.06657447973146235, 0.2578990291023152, 0.4742641426354459, -0.18982849914795363, 0.2227804823488645, 0.35372747647450187, -0.4543332887713202, -0.2564141959803503, -0.3998731379722754, 0.4302966112027795, 0.20920354786320405, -0.3493332019987293, 0.22594844508851064, 0.3506301294801156, 0.49704339475860027, -0.06293993755501126, 0.3918169521501762, -0.09311136809088782, 0.1600693305289207, -0.15765165599671038, 0.45406593535839646, 0.25178211933304273, -0.10253079227341522, -0.07396783373899074, -0.2779385207339454, 0.39166810505415206, -0.21097405310184147, 0.49742358104577056, 0.2503826842835034, 0.06827961053326548, -0.14270071261942086, 0.2718756995005246, -0.40524883036019543, 0.017045033508361507, 0.38197771242907064, 0.26845558760844856, -0.0965167478978054}; - double auc = ROCAUCFunction.computeRocAuc(ys, ysPred); + double auc = AreaUnderRocCurveAlgorithm.computeRocAuc(ys, ysPred); assertEquals(auc, 0.8187274909963985); } From 736031d1d677bfef0d387ff5d80acebc8dc3a78f Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Wed, 17 May 2023 14:28:52 -0700 Subject: [PATCH 09/10] test --- .../plugin/geospatial/TestKdbTreeCasts.java | 1 - .../plugin/truera/metrics/SingleAUCROC.java | 2 + .../plugin/truera/TestAucRocAggregation.java | 89 +++++++++++++++++++ 3 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java create mode 100644 plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java index 77a6a7f00cb8..8c5303b830bb 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeCasts.java @@ -37,7 +37,6 @@ public void registerFunctions() public void test() { String kdbTreeJson = makeKdbTreeJson(); - System.out.println(kdbTreeJson); assertFunction(format("typeof(cast('%s' AS KdbTree))", kdbTreeJson), VARCHAR, "KdbTree"); assertFunction(format("typeof(cast('%s' AS KDBTree))", kdbTreeJson), VARCHAR, "KdbTree"); assertFunction(format("typeof(cast('%s' AS kdbTree))", kdbTreeJson), VARCHAR, "KdbTree"); diff --git a/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java new file mode 100644 index 000000000000..8af8d0ebc8e7 --- /dev/null +++ b/plugin/trino-truera/src/main/java/io/trino/plugin/truera/metrics/SingleAUCROC.java @@ -0,0 +1,2 @@ +package io.trino.plugin.truera.metrics;public class SingleAUCROC { +} diff --git a/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java new file mode 100644 index 000000000000..95aece0bcab6 --- /dev/null +++ b/plugin/trino-truera/src/test/java/io/trino/plugin/truera/TestAucRocAggregation.java @@ -0,0 +1,89 @@ +package io.trino.plugin.truera; + +import static io.trino.spi.security.SelectedRole.Type.ROLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import io.trino.plugin.hive.HivePlugin; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.iceberg.IcebergPlugin; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SelectedRole; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import java.io.File; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(singleThreaded = true) +public class TestAucRocAggregation extends AbstractTestQueryFramework { + private ImmutableMap.Builder icebergProperties = ImmutableMap.builder(); + private Optional metastoreDirectory = Optional.empty(); + public static final String ICEBERG_CATALOG = "iceberg"; + @Override + protected DistributedQueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setIdentity(Identity.forUser("hive") + .withConnectorRole("hive", new SelectedRole(ROLE, Optional.of("admin"))) + .build()) + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); + + queryRunner.installPlugin(new IcebergPlugin()); + Map icebergProperties = new HashMap<>(this.icebergProperties.buildOrThrow()); + String catalogType = icebergProperties.get("iceberg.catalog.type"); + Path dataDir = metastoreDirectory.map(File::toPath).orElseGet(() -> queryRunner.getCoordinator().getBaseDataDir().resolve("iceberg_data")); + if (catalogType == null) { + icebergProperties.put("iceberg.catalog.type", "TESTING_FILE_METASTORE"); + icebergProperties.put("hive.metastore.catalog.dir", dataDir.toString()); + } + + queryRunner.createCatalog(ICEBERG_CATALOG, "iceberg", icebergProperties); + queryRunner.installPlugin(new HivePlugin()); + queryRunner.createCatalog("hive", "hive", ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", dataDir.toString()) + .put("hive.security", "sql-standard") + .buildOrThrow()); + queryRunner.installPlugin(new TrueraTrinoPlugin()); + + return queryRunner; + } + + @BeforeClass + public void setUp() + { + assertQuerySucceeds("CREATE SCHEMA hive.test_schema"); + assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_probits_table (__id__ VARCHAR, __preds__ double)"); + assertQuerySucceeds("CREATE TABLE iceberg.test_schema.iceberg_labels_table (__id__ VARCHAR, __label__ int)"); + assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_probits_table VALUES ('A', 0.5), ('B', 0.2), ('C', 0.9)"); + assertQuerySucceeds("INSERT INTO iceberg.test_schema.iceberg_labels_table VALUES ('A', 1), ('B', 0), ('C', 1)"); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_labels_table"); + assertQuerySucceeds("DROP TABLE IF EXISTS iceberg.test_schema.iceberg_probits_table"); + assertQuerySucceeds("DROP SCHEMA IF EXISTS hive.test_schema"); + } + + @Test + public void testAucRoc() + { + assertQuery("SELECT roc_auc(__preds__, CAST(__label__ as boolean)) " + + "FROM iceberg.test_schema.iceberg_probits_table as pred " + + "JOIN iceberg.test_schema.iceberg_labels_table AS label " + + "ON pred.__id__ = label.__id__", + "VALUES (NULL)"); + } + +} From ada43f3e8b1246f5c0138b4e660eadefe375a448 Mon Sep 17 00:00:00 2001 From: Reetika Roy Date: Wed, 17 May 2023 14:29:20 -0700 Subject: [PATCH 10/10] pom --- plugin/trino-truera/pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/plugin/trino-truera/pom.xml b/plugin/trino-truera/pom.xml index a221b1cbf141..597de120b385 100644 --- a/plugin/trino-truera/pom.xml +++ b/plugin/trino-truera/pom.xml @@ -81,6 +81,10 @@ trino-hive-hadoop2 test + + io.trino + trino-main +