From 1c690ddafa8376c55cbc5b7a7a750200abfbe2a6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 23 Jan 2016 00:34:55 -0800 Subject: [PATCH] [SPARK-12933][SQL] Initial implementation of Count-Min sketch This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1]. As required by the [design doc][2], spark-sketch should have no external dependency. Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation. The following features will be added in future follow-up PRs: - Serialization support - DataFrame API integration [1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java [2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf Author: Cheng Lian Closes #10851 from liancheng/count-min-sketch. --- common/sketch/pom.xml | 42 +++ .../spark/util/sketch/CountMinSketch.java | 132 +++++++++ .../spark/util/sketch/CountMinSketchImpl.java | 268 ++++++++++++++++++ .../spark/util/sketch/Murmur3_x86_32.java | 126 ++++++++ .../apache/spark/util/sketch/Platform.java | 172 +++++++++++ .../util/sketch/CountMinSketchSuite.scala | 112 ++++++++ dev/sparktestsupport/modules.py | 12 + pom.xml | 1 + project/SparkBuild.scala | 39 ++- 9 files changed, 892 insertions(+), 12 deletions(-) create mode 100644 common/sketch/pom.xml create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml new file mode 100644 index 0000000000000..67723fa421ab1 --- /dev/null +++ b/common/sketch/pom.xml @@ -0,0 +1,42 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sketch_2.10 + jar + Spark Project Sketch + http://spark.apache.org/ + + sketch + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java new file mode 100644 index 0000000000000..21b161bc74ae0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in + * sub-linear space. Currently, supported data types include: + * + * Each {@link CountMinSketch} is initialized with a random seed, and a pair + * of parameters: + *
    + *
  1. relative error (or {@code eps}), and + *
  2. confidence (or {@code delta}) + *
+ * Suppose you want to estimate the number of times an element {@code x} has appeared in a data + * stream so far. With probability {@code delta}, the estimate of this frequency is within the + * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the + * total count of items have appeared the the data stream so far. + * + * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array + * with depth {@code d} and width {@code w}, where + * + * + * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details, + * including proofs of the estimates and error bounds used in this implementation. + * + * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. + */ +abstract public class CountMinSketch { + /** + * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. + */ + public abstract double relativeError(); + + /** + * Returns the confidence (or {@code delta}) of this {@link CountMinSketch}. + */ + public abstract double confidence(); + + /** + * Depth of this {@link CountMinSketch}. + */ + public abstract int depth(); + + /** + * Width of this {@link CountMinSketch}. + */ + public abstract int width(); + + /** + * Total count of items added to this {@link CountMinSketch} so far. + */ + public abstract long totalCount(); + + /** + * Adds 1 to {@code item}. + */ + public abstract void add(Object item); + + /** + * Adds {@code count} to {@code item}. + */ + public abstract void add(Object item, long count); + + /** + * Returns the estimated frequency of {@code item}. + */ + public abstract long estimateCount(Object item); + + /** + * Merges another {@link CountMinSketch} with this one in place. + * + * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed + * can be merged. + */ + public abstract CountMinSketch mergeInPlace(CountMinSketch other); + + /** + * Writes out this {@link CountMinSketch} to an output stream in binary format. + */ + public abstract void writeTo(OutputStream out); + + /** + * Reads in a {@link CountMinSketch} from an input stream. + */ + public static CountMinSketch readFrom(InputStream in) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** + * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random + * {@code seed}. + */ + public static CountMinSketch create(int depth, int width, int seed) { + return new CountMinSketchImpl(depth, width, seed); + } + + /** + * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, + * and random {@code seed}. + */ + public static CountMinSketch create(double eps, double confidence, int seed) { + return new CountMinSketchImpl(eps, confidence, seed); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java new file mode 100644 index 0000000000000..e9fdbe3a86862 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import java.util.Random; + +class CountMinSketchImpl extends CountMinSketch { + public static final long PRIME_MODULUS = (1L << 31) - 1; + + private int depth; + private int width; + private long[][] table; + private long[] hashA; + private long totalCount; + private double eps; + private double confidence; + + public CountMinSketchImpl(int depth, int width, int seed) { + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + initTablesWith(depth, width, seed); + } + + public CountMinSketchImpl(double eps, double confidence, int seed) { + // 2/w = eps ; w = 2/eps + // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) + this.eps = eps; + this.confidence = confidence; + this.width = (int) Math.ceil(2 / eps); + this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2)); + initTablesWith(depth, width, seed); + } + + private void initTablesWith(int depth, int width, int seed) { + this.table = new long[depth][width]; + this.hashA = new long[depth]; + Random r = new Random(seed); + // We're using a linear hash functions + // of the form (a*x+b) mod p. + // a,b are chosen independently for each hash function. + // However we can set b = 0 as all it does is shift the results + // without compromising their uniformity or independence with + // the other hashes. + for (int i = 0; i < depth; ++i) { + hashA[i] = r.nextInt(Integer.MAX_VALUE); + } + } + + @Override + public double relativeError() { + return eps; + } + + @Override + public double confidence() { + return confidence; + } + + @Override + public int depth() { + return depth; + } + + @Override + public int width() { + return width; + } + + @Override + public long totalCount() { + return totalCount; + } + + @Override + public void add(Object item) { + add(item, 1); + } + + @Override + public void add(Object item, long count) { + if (item instanceof String) { + addString((String) item, count); + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + addLong(longValue, count); + } + } + + private void addString(String item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + int[] buckets = getHashBuckets(item, depth, width); + + for (int i = 0; i < depth; ++i) { + table[i][buckets[i]] += count; + } + + totalCount += count; + } + + private void addLong(long item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + for (int i = 0; i < depth; ++i) { + table[i][hash(item, i)] += count; + } + + totalCount += count; + } + + private int hash(long item, int count) { + long hash = hashA[count] * item; + // A super fast way of computing x mod 2^p-1 + // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf + // page 149, right after Proposition 7. + hash += hash >> 32; + hash &= PRIME_MODULUS; + // Doing "%" after (int) conversion is ~2x faster than %'ing longs. + return ((int) hash) % width; + } + + private static int[] getHashBuckets(String key, int hashCount, int max) { + byte[] b; + try { + b = key.getBytes("UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + return getHashBuckets(b, hashCount, max); + } + + private static int[] getHashBuckets(byte[] b, int hashCount, int max) { + int[] result = new int[hashCount]; + int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); + int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1); + for (int i = 0; i < hashCount; i++) { + result[i] = Math.abs((hash1 + i * hash2) % max); + } + return result; + } + + @Override + public long estimateCount(Object item) { + if (item instanceof String) { + return estimateCountForStringItem((String) item); + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + return estimateCountForLongItem(longValue); + } + } + + private long estimateCountForLongItem(long item) { + long res = Long.MAX_VALUE; + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][hash(item, i)]); + } + return res; + } + + private long estimateCountForStringItem(String item) { + long res = Long.MAX_VALUE; + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][buckets[i]]); + } + return res; + } + + @Override + public CountMinSketch mergeInPlace(CountMinSketch other) { + if (other == null) { + throw new CMSMergeException("Cannot merge null estimator"); + } + + if (!(other instanceof CountMinSketchImpl)) { + throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + if (this.depth != that.depth) { + throw new CMSMergeException("Cannot merge estimators of different depth"); + } + + if (this.width != that.width) { + throw new CMSMergeException("Cannot merge estimators of different width"); + } + + if (!Arrays.equals(this.hashA, that.hashA)) { + throw new CMSMergeException("Cannot merge estimators of different seed"); + } + + for (int i = 0; i < this.table.length; ++i) { + for (int j = 0; j < this.table[i].length; ++j) { + this.table[i][j] = this.table[i][j] + that.table[i][j]; + } + } + + this.totalCount += that.totalCount; + + return this; + } + + @Override + public void writeTo(OutputStream out) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + protected static class CMSMergeException extends RuntimeException { + public CMSMergeException(String message) { + super(message); + } + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java new file mode 100644 index 0000000000000..3d1f28bcb911e --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure +// spark-sketch has no external dependencies. +final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + return hashInt(input, seed); + } + + public static int hashInt(int input, int seed) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return h1; + } + + public int hashLong(long input) { + return hashLong(input, seed); + } + + public static int hashLong(long input, int seed) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java new file mode 100644 index 0000000000000..75d6a6beec408 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no +// external dependencies. +final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala new file mode 100644 index 0000000000000..ec5b4eddeca0d --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite + private val epsOfTotalCount = 0.0001 + + private val confidence = 0.99 + + private val seed = 42 + + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"accuracy - $typeName") { + val r = new Random() + + val numAllItems = 1000000 + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + + val numSamples = numAllItems / 10 + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(identity).mapValues(_.length.toLong) + } + + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + + val probCorrect = { + val numErrors = allItems.map { item => + val count = exactFreq.getOrElse(item, 0L) + val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems + if (ratio > epsOfTotalCount) 1 else 0 + }.sum + + 1D - numErrors.toDouble / numAllItems + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + val r = new Random() + val numToMerge = 5 + val numItemsPerSketch = 100000 + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { + itemGenerator(r) + } + + val sketches = perSketchItems.map { items => + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + items.foreach(sketch.add) + sketch + } + + val mergedSketch = sketches.reduce(_ mergeInPlace _) + + val expectedSketch = { + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(sketch.add)) + sketch + } + + perSketchItems.foreach { + _.foreach { item => + assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item)) + } + } + } + } + + def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + testAccuracy[T](typeName)(itemGenerator) + testMergeInPlace[T](typeName)(itemGenerator) + } + + testItemType[Byte]("Byte") { _.nextInt().toByte } + + testItemType[Short]("Short") { _.nextInt().toShort } + + testItemType[Int]("Int") { _.nextInt() } + + testItemType[Long]("Long") { _.nextLong() } + + testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } +} diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe58ea2e0e78..032c0616edb1e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -113,6 +113,18 @@ def contains_file(self, filename): ) +sketch = Module( + name="sketch", + dependencies=[], + source_file_regexes=[ + "common/sketch/", + ], + sbt_test_goals=[ + "sketch/test" + ] +) + + graphx = Module( name="graphx", dependencies=[], diff --git a/pom.xml b/pom.xml index f08642f606788..fb7750602c425 100644 --- a/pom.xml +++ b/pom.xml @@ -86,6 +86,7 @@ + common/sketch tags core graphx diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3927b88fb0bf6..4224a65a822b8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,13 +34,24 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = - Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", - "streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver" + ).map(ProjectRef(buildLocation, _)) + + val streamingProjects@Seq( + streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt, + streamingTwitter, streamingZeromq + ) = Seq( + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka", + "streaming-mqtt", "streaming-twitter", "streaming-zeromq" + ).map(ProjectRef(buildLocation, _)) + + val allProjects@Seq( + core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + ) = Seq( + "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "test-tags", "sketch" + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, streamingKinesisAsl, dockerIntegrationTests) = @@ -232,11 +243,15 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove streamingAkka from this list after 2.0.0 - allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach { - x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) - } + // TODO: remove streamingAkka and sketch from this list after 2.0.0 + allProjects.filterNot { x => + Seq( + spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, + unsafe, streamingAkka, testTags, sketch + ).contains(x) + }.foreach { x => + enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Unsafe settings */ enable(Unsafe.settings)(unsafe)