diff --git a/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction3.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction3.java new file mode 100644 index 0000000000000..b634218bb4980 --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction3.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that takes three inputs and returns zero or more output records. + */ +@FunctionalInterface +public interface FlatMapFunction3 extends Serializable { + Iterator call(T1 t1, T2 t2, T3 t3) throws Exception; +} diff --git a/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction4.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction4.java new file mode 100644 index 0000000000000..2b76ce9c750f8 --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction4.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that takes three inputs and returns zero or more output records. + */ +@FunctionalInterface +public interface FlatMapFunction4 extends Serializable { + Iterator call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index f90ed8b9e9d8b..4b57fc8708cfd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,6 +21,7 @@ import java.{lang => jl} import java.lang.{Iterable => JIterable} import java.util.{Comparator, Iterator => JIterator, List => JList, Map => JMap} +import scala.annotation.varargs import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -308,20 +309,83 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag) } - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. + /* + * Zip this RDD's partitions with one other RDD and return a new RDD by applying a function to + * the zipped partitions. Assumes that both the RDDs have the *same number of partitions*, but + * does *not* require them to have the same number of elements in each partition. */ def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[JIterator[T], JIterator[U], V]): JavaRDD[V] = { - def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { (x: Iterator[T], y: Iterator[U]) => + f.call(x.asJava, y.asJava).asScala } + JavaRDD + .fromRDD(rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) + } + + /** + * Zip this RDD's partitions with two more RDDs and return a new RDD by applying a function to + * the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but + * does *not* require them to have the same number of elements in each partition. + */ + @Since("4.1.0") + def zipPartitions[U1, U2, V]( + other1: JavaRDDLike[U1, _], + other2: JavaRDDLike[U2, _], + f: FlatMapFunction3[JIterator[T], JIterator[U1], JIterator[U2], V]): JavaRDD[V] = { + def fn: (Iterator[T], Iterator[U1], Iterator[U2]) => Iterator[V] = + (t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2]) => + f.call(t.asJava, u1.asJava, u2.asJava).asScala + JavaRDD.fromRDD( + rdd.zipPartitions(other1.rdd, other2.rdd)(fn)( + other1.classTag, + other2.classTag, + fakeClassTag[V]))(fakeClassTag[V]) + } + + /** + * Zip this RDD's partitions with three more RDDs and return a new RDD by applying a function to + * the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but + * does *not* require them to have the same number of elements in each partition. + */ + @Since("4.1.0") + def zipPartitions[U1, U2, U3, V]( + other1: JavaRDDLike[U1, _], + other2: JavaRDDLike[U2, _], + other3: JavaRDDLike[U3, _], + f: FlatMapFunction4[JIterator[T], JIterator[U1], JIterator[U2], JIterator[U3], V]) + : JavaRDD[V] = { + def fn: (Iterator[T], Iterator[U1], Iterator[U2], Iterator[U3]) => Iterator[V] = + (t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2], u3: Iterator[U3]) => + f.call(t.asJava, u1.asJava, u2.asJava, u3.asJava).asScala JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) + rdd.zipPartitions(other1.rdd, other2.rdd, other3.rdd)(fn)( + other1.classTag, + other2.classTag, + other3.classTag, + fakeClassTag[V]))(fakeClassTag[V]) + } + + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by applying a + * function to the zipped partitions. Assumes that all the RDDs have the *same number of + * partitions*, but does *not* require them to have the same number of elements in each + * partition. + * + * @note + * A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe + * and other `zipPartitions` methods should be preferred. + */ + @Since("4.1.0") + @varargs + def zipPartitions[U, V]( + f: FlatMapFunction[JList[JIterator[U]], V], + others: JavaRDDLike[_, _]*): JavaRDD[V] = { + def fn: Seq[Iterator[_]] => Iterator[V] = + (i: Seq[Iterator[_]]) => f.call(i.map(_.asInstanceOf[Iterator[U]].asJava).asJava).asScala + JavaRDD + .fromRDD(rdd.zipPartitions(others.map(_.rdd): _*)(fn)(fakeClassTag[V]))(fakeClassTag[V]) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 80db818b77e42..911a151d93eea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1028,6 +1028,26 @@ abstract class RDD[T: ClassTag]( zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f) } + /** + * A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe + * and other `zipPartitions` methods should be preferred. + */ + @Since("4.1.0") + def zipPartitions[V: ClassTag](preservesPartitioning: Boolean, rdds: RDD[_]*)( + f: Seq[Iterator[_]] => Iterator[V]): RDD[V] = withScope { + new ZippedPartitionsRDDN(sc, sc.clean(f), this +: rdds, preservesPartitioning) + } + + /** + * A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe + * and other `zipPartitions` methods should be preferred. + */ + @Since("4.1.0") + def zipPartitions[V: ClassTag](rdds: RDD[_]*)(f: Seq[Iterator[_]] => Iterator[V]): RDD[V] = + withScope { + zipPartitions(preservesPartitioning = false, rdds: _*)(f) + } + // Actions (launch a job to return a value to the user program) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 678a48948a3c1..fbed891ea1115 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -28,7 +28,7 @@ private[spark] class ZippedPartitionsPartition( idx: Int, @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) - extends Partition { + extends Partition { override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) @@ -46,7 +46,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( sc: SparkContext, var rdds: Seq[RDD[_]], preservesPartitioning: Boolean = false) - extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { override val partitioner = if (preservesPartitioning) firstParent[Any].partitioner else None @@ -82,7 +82,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] var rdd1: RDD[A], var rdd2: RDD[B], preservesPartitioning: Boolean = false) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -97,19 +97,19 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] } } -private[spark] class ZippedPartitionsRDD3 - [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( +private[spark] class ZippedPartitionsRDD3[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C], preservesPartitioning: Boolean = false) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), + f( + rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context)) } @@ -123,8 +123,12 @@ private[spark] class ZippedPartitionsRDD3 } } -private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( +private[spark] class ZippedPartitionsRDD4[ + A: ClassTag, + B: ClassTag, + C: ClassTag, + D: ClassTag, + V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], @@ -132,11 +136,12 @@ private[spark] class ZippedPartitionsRDD4 var rdd3: RDD[C], var rdd4: RDD[D], preservesPartitioning: Boolean = false) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), + f( + rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context), rdd4.iterator(partitions(3), context)) @@ -151,3 +156,24 @@ private[spark] class ZippedPartitionsRDD4 f = null } } + +private[spark] class ZippedPartitionsRDDN[V: ClassTag]( + sc: SparkContext, + var f: Seq[Iterator[_]] => Iterator[V], + var rddN: Seq[RDD[_]], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, rddN, preservesPartitioning) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdds.zip(partitions).map { case (rdd, partition) => + rdd.iterator(partition, context) + }) + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + rddN = null + f = null + } +} diff --git a/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java index 276fc34db6218..1753c805606ab 100644 --- a/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java @@ -274,25 +274,45 @@ public void zip() { } @Test - public void zipPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); - JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); - FlatMapFunction2, Iterator, Integer> sizesFn = - (Iterator i, Iterator s) -> { - int sizeI = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - int sizeS = 0; - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS).iterator(); - }; - JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - Assertions.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + public void zipPartitions2() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, ZipPartitionsFunction.ZIP_PARTITIONS_2_FUNCTION); + Assertions.assertEquals(Arrays.asList("abef", "cdgh"), zipped.collect()); + } + + @Test + public void zipPartitions3() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, rdd3, ZipPartitionsFunction.ZIP_PARTITIONS_3_FUNCTION); + Assertions.assertEquals(Arrays.asList("abefij", "cdghkl"), zipped.collect()); + } + + @Test + public void zipPartitions4() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, rdd3, rdd4, ZipPartitionsFunction.ZIP_PARTITIONS_4_FUNCTION); + Assertions.assertEquals(Arrays.asList("abefijmn", "cdghklop"), zipped.collect()); + } + + @Test + public void zipPartitionsN() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2); + JavaRDD rdd5 = sc.parallelize(Arrays.asList("q", "r", "s", "t"), 2); + JavaRDD zipped = rdd1.zipPartitions( + ZipPartitionsFunction.ZIP_PARTITIONS_N_FUNCTION, rdd2, rdd3, rdd4, rdd5); + Assertions.assertEquals(Arrays.asList("abefijmnqr", "cdghklopst"), zipped.collect()); } @Test @@ -348,4 +368,39 @@ public void collectAsMapWithIntArrayValues() { pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } + + private static class ZipPartitionsFunction + implements FlatMapFunction>, String> { + + private static final ZipPartitionsFunction ZIP_PARTITIONS_N_FUNCTION = + new ZipPartitionsFunction(); + + private static final FlatMapFunction2, Iterator, String> + ZIP_PARTITIONS_2_FUNCTION = + (Iterator i1, Iterator i2) -> + ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2)); + + private static final FlatMapFunction3< + Iterator, Iterator, Iterator, String> + ZIP_PARTITIONS_3_FUNCTION = + (Iterator i1, Iterator i2, Iterator i3) -> + ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3)); + + private static final FlatMapFunction4< + Iterator, Iterator, Iterator, Iterator, String> + ZIP_PARTITIONS_4_FUNCTION = + (Iterator i1, Iterator i2, Iterator i3, Iterator i4) -> + ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3, i4)); + + @Override + public Iterator call(List> iterators) { + StringBuilder stringBuilder = new StringBuilder(); + for (Iterator iterator : iterators) { + while (iterator.hasNext()) { + stringBuilder.append(iterator.next()); + } + } + return Collections.singleton(stringBuilder.toString()).iterator(); + } + } } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 802cb2667cc88..97a01d305e037 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -46,7 +46,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; -import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; @@ -1207,16 +1206,46 @@ public void zip() { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + @Test + public void zipPartitions2() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, ZipPartitionsFunction.ZIP_PARTITIONS_2_FUNCTION); + assertEquals(Arrays.asList("abef", "cdgh"), zipped.collect()); + } @Test - public void zipPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); - JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); - FlatMapFunction2, Iterator, Integer> sizesFn = - (i, s) -> Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); + public void zipPartitions3() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, rdd3, ZipPartitionsFunction.ZIP_PARTITIONS_3_FUNCTION); + assertEquals(Arrays.asList("abefij", "cdghkl"), zipped.collect()); + } - JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + @Test + public void zipPartitions4() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2); + JavaRDD zipped = rdd1.zipPartitions( + rdd2, rdd3, rdd4, ZipPartitionsFunction.ZIP_PARTITIONS_4_FUNCTION); + assertEquals(Arrays.asList("abefijmn", "cdghklop"), zipped.collect()); + } + + @Test + public void zipPartitionsN() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2); + JavaRDD rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2); + JavaRDD rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2); + JavaRDD rdd5 = sc.parallelize(Arrays.asList("q", "r", "s", "t"), 2); + JavaRDD zipped = rdd1.zipPartitions( + ZipPartitionsFunction.ZIP_PARTITIONS_N_FUNCTION, rdd2, rdd3, rdd4, rdd5); + assertEquals(Arrays.asList("abefijmnqr", "cdghklopst"), zipped.collect()); } @Test @@ -1532,4 +1561,35 @@ public void testGetPersistentRDDs() { assertEquals("RDD2", cachedRddsMap.get(1).name()); } + private static class ZipPartitionsFunction + implements FlatMapFunction>, String> { + + private static final ZipPartitionsFunction ZIP_PARTITIONS_N_FUNCTION = + new ZipPartitionsFunction(); + + private static final FlatMapFunction2, Iterator, String> + ZIP_PARTITIONS_2_FUNCTION = + (i1, i2) -> ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2)); + + private static final FlatMapFunction3< + Iterator, Iterator, Iterator, String> + ZIP_PARTITIONS_3_FUNCTION = + (i1, i2, i3) -> ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3)); + + private static final FlatMapFunction4< + Iterator, Iterator, Iterator, Iterator, String> + ZIP_PARTITIONS_4_FUNCTION = + (i1, i2, i3, i4) -> ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3, i4)); + + @Override + public Iterator call(List> iterators) { + StringBuilder stringBuilder = new StringBuilder(); + for (Iterator iterator : iterators) { + while (iterator.hasNext()) { + stringBuilder.append(iterator.next()); + } + } + return Collections.singleton(stringBuilder.toString()).iterator(); + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index aecb8b99d0e31..7495dcb43d1ac 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -62,7 +62,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4).toImmutableArraySeq, 2) assert(dups.distinct().count() === 4) - assert(dups.distinct().count() === 4) // Can distinct and count be called without parentheses? + assert(dups.distinct().count() === 4) // Can distinct and count be called without parentheses? assert(dups.distinct().collect() === dups.distinct().collect()) assert(dups.distinct(2).collect() === dups.distinct().collect()) assert(nums.reduce(_ + _) === 10) @@ -73,20 +73,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) - assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert( + nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) assert(!nums.isEmpty()) assert(nums.max() === 4) assert(nums.min() === 1) val partitionSums = nums.mapPartitions(iter => Iterator(iter.sum)) assert(partitionSums.collect().toList === List(3, 7)) - val partitionSumsWithSplit = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.sum)) + val partitionSumsWithSplit = nums.mapPartitionsWithIndex { case (split, iter) => + Iterator((split, iter.sum)) } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) - val partitionSumsWithIndex = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.sum)) + val partitionSumsWithIndex = nums.mapPartitionsWithIndex { case (split, iter) => + Iterator((split, iter.sum)) } assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) @@ -313,7 +314,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { throw new Exception("injected failure") } }.cache() - val thrown = intercept[Exception]{ + val thrown = intercept[Exception] { rdd.collect() } assert(thrown.getMessage.contains("injected failure")) @@ -324,7 +325,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(empty.count() === 0) assert(empty.collect().length === 0) - val thrown = intercept[UnsupportedOperationException]{ + val thrown = intercept[UnsupportedOperationException] { empty.reduce(_ + _) } assert(thrown.getMessage.contains("empty")) @@ -377,7 +378,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(math.abs(partitions1(1).length - 500) < initialPartitions) assert(repartitioned1.collect() === input) - def testSplitPartitions(input: Seq[Int], initialPartitions: Int, finalPartitions: Int): Unit = { + def testSplitPartitions( + input: Seq[Int], + initialPartitions: Int, + finalPartitions: Int): Unit = { val data = sc.parallelize(input, initialPartitions) val repartitioned = data.repartition(finalPartitions) assert(repartitioned.partitions.length === finalPartitions) @@ -394,7 +398,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } testSplitPartitions(Array.fill(100)(1).toImmutableArraySeq, 10, 20) - testSplitPartitions((Array.fill(10000)(1) ++ Array.fill(10000)(2)).toImmutableArraySeq, 20, 100) + testSplitPartitions( + (Array.fill(10000)(1) ++ Array.fill(10000)(2)).toImmutableArraySeq, + 20, + 100) testSplitPartitions(Array.fill(1000)(1).toImmutableArraySeq, 250, 128) } @@ -407,36 +414,42 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val coalesced1 = data.coalesce(2) assert(coalesced1.collect().toList === (1 to 10).toList) - assert(coalesced1.glom().collect().map(_.toList).toList === - List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) + assert( + coalesced1.glom().collect().map(_.toList).toList === + List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === - List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === - List(5, 6, 7, 8, 9)) + assert( + coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert( + coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) val coalesced2 = data.coalesce(3) assert(coalesced2.collect().toList === (1 to 10).toList) - assert(coalesced2.glom().collect().map(_.toList).toList === - List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) + assert( + coalesced2.glom().collect().map(_.toList).toList === + List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) val coalesced3 = data.coalesce(10) assert(coalesced3.collect().toList === (1 to 10).toList) - assert(coalesced3.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) + assert( + coalesced3.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) // If we try to coalesce into more partitions than the original RDD, it should just // keep the original number of partitions. val coalesced4 = data.coalesce(20) assert(coalesced4.collect().toList === (1 to 10).toList) - assert(coalesced4.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) + assert( + coalesced4.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) - val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. - asInstanceOf[ShuffledRDD[_, _, _]] != null + val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd + .asInstanceOf[ShuffledRDD[_, _, _]] != null assert(isEquals) // when shuffling, we can increase the number of partitions @@ -452,9 +465,11 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(list3.sorted === Array("a", "b", "c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map { j => "m" + (j % 6) }))) val coalesced1 = data.coalesce(3) - assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + assert( + coalesced1.collect().toList.sorted === (1 to 9).toList, + "Data got *lost* in coalescing") val splits = coalesced1.glom().collect().map(_.toList).toList assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) @@ -465,9 +480,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { // keep the original number of partitions. val coalesced4 = data.coalesce(20) val listOfLists = coalesced4.glom().collect().map(_.toList).toList - val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } - assert(sortedList === (1 to 9). - map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + val sortedList = listOfLists.sortWith { (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert( + sortedList === (1 to 9).map { x => List(x) }.toList, + "Tried coalescing 9 partitions to 20 but didn't get 9 back") } test("coalesced RDDs with partial locality") { @@ -481,7 +497,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } })) val coalesced1 = data.coalesce(3) - assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + assert( + coalesced1.collect().toList.sorted === (1 to 9).toList, + "Data got *lost* in coalescing") val splits = coalesced1.glom().collect().map(_.toList).toList assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) @@ -492,9 +510,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { // keep the original number of partitions. val coalesced4 = data.coalesce(20) val listOfLists = coalesced4.glom().collect().map(_.toList).toList - val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } - assert(sortedList === (1 to 9). - map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + val sortedList = listOfLists.sortWith { (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert( + sortedList === (1 to 9).map { x => List(x) }.toList, + "Tried coalescing 9 partitions to 20 but didn't get 9 back") } test("coalesced RDDs with locality, large scale (10K partitions)") { @@ -519,8 +538,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val minLocality = coalesced2.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + - (minLocality * 100.0).toInt + "%") + assert( + minLocality >= 0.90, + "Expected 90% locality but got " + + (minLocality * 100.0).toInt + "%") // test that the groups are load balanced with 100 +/- 20 elements in each val maxImbalance = coalesced2.partitions @@ -533,8 +554,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val minLocality2 = coalesced3.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2 * 100.0).toInt + "%") + assert( + minLocality2 >= 0.90, + "Expected 90% locality for derived RDD but got " + + (minLocality2 * 100.0).toInt + "%") } } @@ -564,13 +587,17 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val coalesced2 = data2.coalesce(partitions) // test that we have 10000 partitions - assert(coalesced2.partitions.length == 10000, "Expected 10000 partitions, but got " + - coalesced2.partitions.length) + assert( + coalesced2.partitions.length == 10000, + "Expected 10000 partitions, but got " + + coalesced2.partitions.length) // test that we have 100 partitions val coalesced3 = data2.coalesce(numMachines * 2) - assert(coalesced3.partitions.length == 100, "Expected 100 partitions, but got " + - coalesced3.partitions.length) + assert( + coalesced3.partitions.length == 100, + "Expected 100 partitions, but got " + + coalesced3.partitions.length) // test that the groups are load balanced with 100 +/- 20 elements in each val maxImbalance3 = coalesced3.partitions @@ -584,7 +611,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("coalesced RDDs with locality, fail first pass") { val initialPartitions = 1000 val targetLen = 50 - val couponCount = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt // = 492 + val couponCount = 2 * (math.log(targetLen) * targetLen + targetLen + 0.5).toInt // = 492 val blocks = (1 to initialPartitions).map { i => (i, List(if (i > couponCount) "m2" else "m1")) @@ -597,8 +624,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("zipped RDDs") { val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val zipped = nums.zip(nums.map(_ + 1.0)) - assert(zipped.glom().map(_.toList).collect().toList === - List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + assert( + zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) intercept[IllegalArgumentException] { nums.zip(sc.parallelize(1 to 4, 1)).collect() @@ -609,6 +637,76 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } } + test("zipPartitions2") { + val data1 = sc.parallelize(1 to 8, 4) + val data2 = sc.parallelize(9 to 16, 4) + val zipped = data1 + .zipPartitions(data2) { case (i1, i2) => + Iterator(i1.mkString(",") + ";" + i2.mkString(",")) + } + .collect() + .toList + assert(zipped === List("1,2;9,10", "3,4;11,12", "5,6;13,14", "7,8;15,16")) + } + + test("zipPartitions3") { + val data1 = sc.parallelize(1 to 8, 4) + val data2 = sc.parallelize(9 to 16, 4) + val data3 = sc.parallelize(17 to 24, 4) + val zipped = data1 + .zipPartitions(data2, data3) { case (i1, i2, i3) => + Iterator(i1.mkString(",") + ";" + i2.mkString(",") + ";" + i3.mkString(",")) + } + .collect() + .toList + assert( + zipped === + List("1,2;9,10;17,18", "3,4;11,12;19,20", "5,6;13,14;21,22", "7,8;15,16;23,24")) + } + + test("zipPartitions4") { + val data1 = sc.parallelize(1 to 8, 4) + val data2 = sc.parallelize(9 to 16, 4) + val data3 = sc.parallelize(17 to 24, 4) + val data4 = sc.parallelize(25 to 32, 4) + val zipped = data1 + .zipPartitions(data2, data3, data4) { case (i1, i2, i3, i4) => + Iterator( + i1.mkString(",") + ";" + i2.mkString(",") + ";" + i3.mkString(",") + ";" + + i4.mkString(",")) + } + .collect() + .toList + assert( + zipped === + List( + "1,2;9,10;17,18;25,26", + "3,4;11,12;19,20;27,28", + "5,6;13,14;21,22;29,30", + "7,8;15,16;23,24;31,32")) + } + + test("zipPartitionsN") { + val data1 = sc.parallelize(1 to 8, 4) + val data2 = sc.parallelize(9 to 16, 4) + val data3 = sc.parallelize(17 to 24, 4) + val data4 = sc.parallelize(25 to 32, 4) + val data5 = sc.parallelize(33 to 40, 4) + val zipped = data1 + .zipPartitions(data2, data3, data4, data5) { is => + Iterator(is.map(_.mkString(",")).mkString(";")) + } + .collect() + .toList + assert( + zipped === + List( + "1,2;9,10;17,18;25,26;33,34", + "3,4;11,12;19,20;27,28;35,36", + "5,6;13,14;21,22;29,30;37,38", + "7,8;15,16;23,24;31,32;39,40")) + } + test("partition pruning") { val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. @@ -736,48 +834,48 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement = false, num = num) - assert(sample.length === num) // Got exactly num elements - assert(sample.toSet.size === num) // Elements are distinct + assert(sample.length === num) // Got exactly num elements + assert(sample.toSet.size === num) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = false, 20, seed) - assert(sample.length === 20) // Got exactly 20 elements - assert(sample.toSet.size === 20) // Elements are distinct + assert(sample.length === 20) // Got exactly 20 elements + assert(sample.toSet.size === 20) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = false, 100, seed) - assert(sample.length === 100) // Got only 100 elements - assert(sample.toSet.size === 100) // Elements are distinct + assert(sample.length === 100) // Got only 100 elements + assert(sample.toSet.size === 100) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = true, 20, seed) - assert(sample.length === 20) // Got exactly 20 elements + assert(sample.length === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement = true, num = 20) - assert(sample.length === 20) // Got exactly 20 elements + assert(sample.length === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement = true, num = n) - assert(sample.length === n) // Got exactly n elements + assert(sample.length === n) // Got exactly n elements // Chance of getting all distinct elements is astronomically low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = true, n, seed) - assert(sample.length === n) // Got exactly n elements + assert(sample.length === n) // Got exactly n elements // Chance of getting all distinct elements is astronomically low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = true, 2 * n, seed) - assert(sample.length === 2 * n) // Got exactly 2 * n elements + assert(sample.length === 2 * n) // Got exactly 2 * n elements // Chance of getting all distinct elements is still quite low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } @@ -795,7 +893,8 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { for (seed <- 1 to 5) { val splits = data.randomSplit(Array(1.0, 2.0, 3.0), seed) assert(splits.length == 3, "wrong number of splits") - assert(splits.flatMap(_.collect()).sorted.toList == data.collect().toList, + assert( + splits.flatMap(_.collect()).sorted.toList == data.collect().toList, "incomplete or wrong split") val s = splits.map(_.count()) assert(math.abs(s(0) - 100) < 50) // std = 9.13 @@ -806,7 +905,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) + sc.runJob(sc.parallelize(1 to 10, 2), { iter: Iterator[Int] => iter.size }, Seq(0, 1, 2)) } } @@ -838,21 +937,15 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("sortByKey with explicit ordering") { - val data = sc.parallelize(Seq("Bob|Smith|50", - "Jane|Smith|40", - "Thomas|Williams|30", - "Karen|Williams|60")) + val data = sc.parallelize( + Seq("Bob|Smith|50", "Jane|Smith|40", "Thomas|Williams|30", "Karen|Williams|60")) - val ageOrdered = Array("Thomas|Williams|30", - "Jane|Smith|40", - "Bob|Smith|50", - "Karen|Williams|60") + val ageOrdered = + Array("Thomas|Williams|30", "Jane|Smith|40", "Bob|Smith|50", "Karen|Williams|60") // last name, then first name - val nameOrdered = Array("Bob|Smith|50", - "Jane|Smith|40", - "Karen|Williams|60", - "Thomas|Williams|30") + val nameOrdered = + Array("Bob|Smith|50", "Jane|Smith|40", "Karen|Williams|60", "Thomas|Williams|30") val parse = (s: String) => { val split = s.split("\\|") @@ -888,8 +981,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val sorted = agged.repartitionAndSortWithinPartitions(partitioner) assert(sorted.partitioner == Some(partitioner)) - assert(sorted.dependencies.nonEmpty && - sorted.dependencies.forall(_.isInstanceOf[OneToOneDependency[_]])) + assert( + sorted.dependencies.nonEmpty && + sorted.dependencies.forall(_.isInstanceOf[OneToOneDependency[_]])) } test("cartesian on empty RDD") { @@ -926,8 +1020,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { sqlState = "54000", parameters = Map( "numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString, - "maxRoundedArrayLength" -> Int.MaxValue.toString) - ) + "maxRoundedArrayLength" -> Int.MaxValue.toString)) } test("intersection") { @@ -1067,8 +1160,8 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { /** * This tests for the pathological condition in which the RDD dependency graph is cyclical. * - * Since RDD is part of the public API, applications may actually implement RDDs that allow - * such graphs to be constructed. In such cases, getNarrowAncestor should not simply hang. + * Since RDD is part of the public API, applications may actually implement RDDs that allow such + * graphs to be constructed. In such cases, getNarrowAncestor should not simply hang. */ test("getNarrowAncestors with cycles") { val rdd1 = new CyclicalDependencyRDD[Int] @@ -1199,11 +1292,15 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { var totalPartitionCount = 0L coalescedHadoopRDD.partitions.foreach(partition => { var splitSizeSum = 0L - partition.asInstanceOf[CoalescedRDDPartition].parents.foreach(partition => { - val split = partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit] - splitSizeSum += split.getLength - totalPartitionCount += 1 - }) + partition + .asInstanceOf[CoalescedRDDPartition] + .parents + .foreach(partition => { + val split = + partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit] + splitSizeSum += split.getLength + totalPartitionCount += 1 + }) assert(splitSizeSum <= maxSplitSize) }) assert(totalPartitionCount == 10) @@ -1213,16 +1310,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val rdd = sc.parallelize(1 to 1000, 10) rdd.cache() - rdd.mapPartitions { iter => - ThreadUtils.runInNewThread("TestThread") { - // Iterate to the end of the input iterator, to cause the CompletionIterator completion to - // fire outside of the task's main thread. - while (iter.hasNext) { - iter.next() + rdd + .mapPartitions { iter => + ThreadUtils.runInNewThread("TestThread") { + // Iterate to the end of the input iterator, to cause the CompletionIterator completion to + // fire outside of the task's main thread. + while (iter.hasNext) { + iter.next() + } + iter } - iter } - }.collect() + .collect() } test("SPARK-27666: Do not release lock while TaskContext already completed") { @@ -1230,20 +1329,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val tid = sc.longAccumulator("threadId") // validate cache rdd.collect() - rdd.mapPartitions { iter => - val t = new Thread(() => { - while (iter.hasNext) { - iter.next() - Thread.sleep(100) - } - }) - t.setDaemon(false) - t.start() - tid.add(t.getId) - Iterator(0) - }.collect() + rdd + .mapPartitions { iter => + val t = new Thread(() => { + while (iter.hasNext) { + iter.next() + Thread.sleep(100) + } + }) + t.setDaemon(false) + t.start() + tid.add(t.getId) + Iterator(0) + } + .collect() val tmx = ManagementFactory.getThreadMXBean - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds)) { // getThreadInfo() will return null after child thread `t` died val t = tmx.getThreadInfo(tid.value) assert(t == null || t.getThreadState == Thread.State.TERMINATED) @@ -1258,17 +1359,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val inputRDD = sc.makeRDD(Range(0, numInputPartitions), numInputPartitions) assert(inputRDD.getNumPartitions == numInputPartitions) - val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => - if (p.index < numCoalescedPartitions) { - Seq(locations(0)) - } else { - Seq(locations(1)) - } - }) + val locationPrefRDD = new LocationPrefRDD( + inputRDD, + { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) - val numPartsPerLocation = coalescedRDD - .getPartitions + val numPartsPerLocation = coalescedRDD.getPartitions .map(coalescedRDD.getPreferredLocations(_).head) .groupBy(identity) .transform((_, v) => v.length) @@ -1343,8 +1445,8 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } /** - * Coalesces partitions based on their size assuming that the parent RDD is a [[HadoopRDD]]. - * Took this class out of the test suite to prevent "Task not serializable" exceptions. + * Coalesces partitions based on their size assuming that the parent RDD is a [[HadoopRDD]]. Took + * this class out of the test suite to prevent "Task not serializable" exceptions. */ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Serializable { override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { @@ -1397,7 +1499,8 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria /** Alters the preferred locations of the parent RDD using provided function. */ class LocationPrefRDD[T: ClassTag]( @transient var prev: RDD[T], - val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + val locationPicker: Partition => Seq[String]) + extends RDD[T](prev) { override protected def getPartitions: Array[Partition] = prev.partitions override def compute(partition: Partition, context: TaskContext): Iterator[T] =