Skip to content

Commit

Permalink
Remove max scaling limit from local and remote exchange
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav8297 authored and sopel39 committed Dec 6, 2023
1 parent c95b362 commit 77cb0d6
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,13 @@
import static io.trino.SystemSessionProperties.getSkewedPartitionMinDataProcessedRebalanceThreshold;
import static io.trino.operator.InterpretedHashGenerator.createChannelsHashGenerator;
import static io.trino.operator.exchange.LocalExchangeSink.finishedLocalExchangeSink;
import static io.trino.operator.output.SkewedPartitionRebalancer.getScaleWritersMaxSkewedPartitions;
import static io.trino.sql.planner.PartitioningHandle.isScaledWriterHashDistribution;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

Expand Down Expand Up @@ -147,12 +145,7 @@ else if (isScaledWriterHashDistribution(partitioning)) {
bufferCount,
1,
writerScalingMinDataProcessed.toBytes(),
getSkewedPartitionMinDataProcessedRebalanceThreshold(session).toBytes(),
// Keep the maxPartitionsToRebalance to atleast writer count such that single partition writes do
// not suffer from skewness and can scale uniformly across all writers. Additionally, note that
// maxWriterCount is calculated considering memory into account. So, it is safe to set the
// maxPartitionsToRebalance to maximum number of writers.
max(getScaleWritersMaxSkewedPartitions(session), bufferCount));
getSkewedPartitionMinDataProcessedRebalanceThreshold(session).toBytes());
LocalExchangeMemoryManager memoryManager = new LocalExchangeMemoryManager(maxBufferedBytes.toBytes());
sources = IntStream.range(0, bufferCount)
.mapToObj(i -> new LocalExchangeSource(memoryManager, source -> checkAllSourcesFinished()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -86,12 +85,10 @@ public class SkewedPartitionRebalancer
private final int taskBucketCount;
private final long minPartitionDataProcessedRebalanceThreshold;
private final long minDataProcessedRebalanceThreshold;
private final int maxPartitionsToRebalance;

private final AtomicLongArray partitionRowCount;
private final AtomicLong dataProcessed;
private final AtomicLong dataProcessedAtLastRebalance;
private final AtomicInteger numOfRebalancedPartitions;

@GuardedBy("this")
private final long[] partitionDataSize;
Expand Down Expand Up @@ -158,12 +155,6 @@ public static int getMaxWritersBasedOnMemory(Session session)
return (int) ceil((double) getQueryMaxMemoryPerNode(session).toBytes() / getMaxMemoryPerPartitionWriter(session).toBytes());
}

public static int getScaleWritersMaxSkewedPartitions(Session session)
{
// Set the value of maxSkewedPartitions to scale to 60% of maximum number of writers possible per node.
return (int) (getMaxWritersBasedOnMemory(session) * 0.60);
}

public static int getTaskCount(PartitioningScheme partitioningScheme)
{
// Todo: Handle skewness if there are more nodes/tasks than the buckets coming from connector
Expand All @@ -179,20 +170,17 @@ public SkewedPartitionRebalancer(
int taskCount,
int taskBucketCount,
long minPartitionDataProcessedRebalanceThreshold,
long maxDataProcessedRebalanceThreshold,
int maxPartitionsToRebalance)
long maxDataProcessedRebalanceThreshold)
{
this.partitionCount = partitionCount;
this.taskCount = taskCount;
this.taskBucketCount = taskBucketCount;
this.minPartitionDataProcessedRebalanceThreshold = minPartitionDataProcessedRebalanceThreshold;
this.minDataProcessedRebalanceThreshold = max(minPartitionDataProcessedRebalanceThreshold, maxDataProcessedRebalanceThreshold);
this.maxPartitionsToRebalance = maxPartitionsToRebalance;

this.partitionRowCount = new AtomicLongArray(partitionCount);
this.dataProcessed = new AtomicLong();
this.dataProcessedAtLastRebalance = new AtomicLong();
this.numOfRebalancedPartitions = new AtomicInteger();

this.partitionDataSize = new long[partitionCount];
this.partitionDataSizeAtLastRebalance = new long[partitionCount];
Expand Down Expand Up @@ -254,9 +242,7 @@ public void rebalance()
private boolean shouldRebalance(long dataProcessed)
{
// Rebalance only when total bytes processed since last rebalance is greater than rebalance threshold.
// Check if the number of rebalanced partitions is less than maxPartitionsToRebalance.
return (dataProcessed - dataProcessedAtLastRebalance.get()) >= minDataProcessedRebalanceThreshold
&& numOfRebalancedPartitions.get() < maxPartitionsToRebalance;
return (dataProcessed - dataProcessedAtLastRebalance.get()) >= minDataProcessedRebalanceThreshold;
}

private synchronized void rebalancePartitions(long dataProcessed)
Expand Down Expand Up @@ -412,12 +398,6 @@ private boolean rebalancePartition(
return false;
}

// If the number of rebalanced partitions is less than maxPartitionsToRebalance then assign
// the partition to the task.
if (numOfRebalancedPartitions.get() >= maxPartitionsToRebalance) {
return false;
}

assignments.add(toTaskBucket);

int newTaskCount = assignments.size();
Expand All @@ -438,8 +418,6 @@ private boolean rebalancePartition(
minTasks.addOrUpdate(taskBucket, Long.MAX_VALUE - estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
}

// Increment the number of rebalanced partitions.
numOfRebalancedPartitions.incrementAndGet();
log.debug("Rebalanced partition %s to task %s with taskCount %s", partitionId, toTaskBucket.taskId, assignments.size());
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@
import static io.trino.operator.output.SkewedPartitionRebalancer.checkCanScalePartitionsRemotely;
import static io.trino.operator.output.SkewedPartitionRebalancer.createPartitionFunction;
import static io.trino.operator.output.SkewedPartitionRebalancer.getMaxWritersBasedOnMemory;
import static io.trino.operator.output.SkewedPartitionRebalancer.getScaleWritersMaxSkewedPartitions;
import static io.trino.operator.output.SkewedPartitionRebalancer.getTaskCount;
import static io.trino.operator.window.pattern.PhysicalValuePointer.CLASSIFIER;
import static io.trino.operator.window.pattern.PhysicalValuePointer.MATCH_NUMBER;
Expand Down Expand Up @@ -382,7 +381,6 @@
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.lang.Math.ceil;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
Expand Down Expand Up @@ -593,10 +591,7 @@ public LocalExecutionPlan plan(
taskCount,
taskBucketCount,
getWriterScalingMinDataProcessed(taskContext.getSession()).toBytes(),
getSkewedPartitionMinDataProcessedRebalanceThreshold(taskContext.getSession()).toBytes(),
// Keep the maxPartitionsToRebalance to atleast task count such that single partition writes do
// not suffer from skewness and can scale uniformly across all tasks.
max(getScaleWritersMaxSkewedPartitions(taskContext.getSession()), taskCount)));
getSkewedPartitionMinDataProcessedRebalanceThreshold(taskContext.getSession()).toBytes()));
}
else {
partitionFunction = nodePartitioningManager.getPartitionFunction(taskContext.getSession(), partitioningScheme, partitionChannelTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,115 +818,6 @@ public void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandle
});
}

@Test(dataProvider = "scalingPartitionHandles")
public void testNoScalingWhenMaxScaledPartitionsPerTaskIsSmall(PartitioningHandle partitioningHandle)
{
LocalExchange localExchange = new LocalExchange(
nodePartitioningManager,
testSessionBuilder()
.setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB")
.setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "256MB")
.build(),
4,
partitioningHandle,
ImmutableList.of(0),
TYPES,
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(2)),
TYPE_OPERATORS,
DataSize.of(10, KILOBYTE),
TOTAL_MEMORY_USED);

run(localExchange, exchange -> {
assertThat(exchange.getBufferCount()).isEqualTo(4);
assertExchangeTotalBufferedBytes(exchange, 0);

LocalExchangeSinkFactory sinkFactory = exchange.createSinkFactory();
sinkFactory.noMoreSinkFactories();
LocalExchangeSink sink = sinkFactory.createSink();
assertSinkCanWrite(sink);
sinkFactory.close();

LocalExchangeSource sourceA = exchange.getNextSource();
assertSource(sourceA, 0);

LocalExchangeSource sourceB = exchange.getNextSource();
assertSource(sourceB, 0);

LocalExchangeSource sourceC = exchange.getNextSource();
assertSource(sourceC, 0);

LocalExchangeSource sourceD = exchange.getNextSource();
assertSource(sourceD, 0);

sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(1, 2));
sink.addPage(createSingleValuePage(1, 2));

// Two partitions are assigned to two different writers
assertSource(sourceA, 2);
assertSource(sourceB, 0);
assertSource(sourceC, 0);
assertSource(sourceD, 2);

sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));

// partition 0 is assigned to writer B after scaling.
assertSource(sourceA, 2);
assertSource(sourceB, 2);
assertSource(sourceC, 0);
assertSource(sourceD, 4);

sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));

// partition 0 is assigned to writer A after scaling.
assertSource(sourceA, 3);
assertSource(sourceB, 4);
assertSource(sourceC, 0);
assertSource(sourceD, 5);

sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));
sink.addPage(createSingleValuePage(0, 1000));

// partition 0 is assigned to writer C after scaling.
assertSource(sourceA, 4);
assertSource(sourceB, 5);
assertSource(sourceC, 1);
assertSource(sourceD, 6);

sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));

// partition 1 is assigned to writer B after scaling.
assertSource(sourceA, 6);
assertSource(sourceB, 7);
assertSource(sourceC, 1);
assertSource(sourceD, 6);

sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));
sink.addPage(createSingleValuePage(1, 10000));

// no scaling will happen since we have scaled to maximum limit which is the number of writer count.
assertSource(sourceA, 8);
assertSource(sourceB, 9);
assertSource(sourceC, 1);
assertSource(sourceD, 6);
});
}

@Test
public void testNoScalingWhenNoWriterSkewness()
{
Expand Down
Loading

0 comments on commit 77cb0d6

Please sign in to comment.