diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index 80c4481091fa..48ee31d60d3a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -82,6 +82,7 @@ import java.util.function.Predicate; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -656,6 +657,17 @@ public TaskInfo failTask(TaskId taskId, Throwable failure) return tasks.getUnchecked(taskId).failed(failure); } + public void cleanupTask(TaskId taskId) + { + requireNonNull(taskId, "taskId is null"); + SqlTask sqlTask = tasks.getIfPresent(taskId); + if (sqlTask == null) { + return; + } + checkState(sqlTask.getTaskState() == TaskState.FINISHED, "cleanup called for task %s which is in state %s", taskId, sqlTask.getTaskState()); + tasks.unsafeInvalidate(taskId); + } + @VisibleForTesting void removeOldTasks() { diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index 97ac4d980aa8..5f5e0832690a 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -328,6 +328,14 @@ public TaskInfo failTask( return taskManager.failTask(taskId, failTaskRequest.getFailureInfo().toException()); } + @POST + @Path("{taskId}/cleanup") + public void cleanupTask(@PathParam("taskId") TaskId taskId) + { + requireNonNull(taskId, "taskId is null"); + taskManager.cleanupTask(taskId); + } + @GET @Path("{taskId}/results/{bufferId}/{token}") @Produces(TRINO_PAGES) diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java index b643151d6288..5faccc2724b1 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java @@ -66,6 +66,7 @@ class ContinuousTaskStatusFetcher private final Supplier spanBuilderFactory; private final RequestErrorTracker errorTracker; private final RemoteTaskStats stats; + private final RemoteTaskCleaner remoteTaskCleaner; @GuardedBy("this") private boolean running; @@ -84,7 +85,8 @@ public ContinuousTaskStatusFetcher( Supplier spanBuilderFactory, Duration maxErrorDuration, ScheduledExecutorService errorScheduledExecutor, - RemoteTaskStats stats) + RemoteTaskStats stats, + RemoteTaskCleaner remoteTaskCleaner) { requireNonNull(initialTaskStatus, "initialTaskStatus is null"); @@ -102,6 +104,7 @@ public ContinuousTaskStatusFetcher( this.errorTracker = new RequestErrorTracker(taskId, initialTaskStatus.getSelf(), maxErrorDuration, errorScheduledExecutor, "getting task status"); this.stats = requireNonNull(stats, "stats is null"); + this.remoteTaskCleaner = requireNonNull(remoteTaskCleaner, "remoteTaskCleaner is null"); } public synchronized void start() @@ -121,6 +124,7 @@ public synchronized void stop() future.cancel(true); future = null; } + remoteTaskCleaner.markTaskStatusFetcherStopped(taskStatus.get().getState()); } private synchronized void scheduleNextRequest() diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java index a4413631209e..4ac1c18a2de6 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java @@ -61,6 +61,7 @@ class DynamicFiltersFetcher private final RequestErrorTracker errorTracker; private final RemoteTaskStats stats; private final DynamicFilterService dynamicFilterService; + private final RemoteTaskCleaner remoteTaskCleaner; @GuardedBy("this") private long dynamicFiltersVersion = INITIAL_DYNAMIC_FILTERS_VERSION; @@ -83,7 +84,8 @@ public DynamicFiltersFetcher( Duration maxErrorDuration, ScheduledExecutorService errorScheduledExecutor, RemoteTaskStats stats, - DynamicFilterService dynamicFilterService) + DynamicFilterService dynamicFilterService, + RemoteTaskCleaner remoteTaskCleaner) { this.taskId = requireNonNull(taskId, "taskId is null"); this.taskUri = requireNonNull(taskUri, "taskUri is null"); @@ -99,6 +101,8 @@ public DynamicFiltersFetcher( this.errorTracker = new RequestErrorTracker(taskId, taskUri, maxErrorDuration, errorScheduledExecutor, "getting dynamic filter domains"); this.stats = requireNonNull(stats, "stats is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + + this.remoteTaskCleaner = requireNonNull(remoteTaskCleaner, "remoteTaskCleaner is null"); } public synchronized void start() @@ -124,6 +128,7 @@ public synchronized void updateDynamicFiltersVersionAndFetchIfNecessary(long new private synchronized void stop() { running = false; + remoteTaskCleaner.markDynamicFilterFetcherStopped(); } @VisibleForTesting diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index 13eb1e7e63c0..76ed5c316fe0 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -322,6 +322,13 @@ public HttpRemoteTask( TaskInfo initialTask = createInitialTask(taskId, location, nodeId, this.speculative.get(), pipelinedBufferStates, new TaskStats(DateTime.now(), null)); + RemoteTaskCleaner remoteTaskCleaner = new RemoteTaskCleaner( + taskId, + location, + httpClient, + errorScheduledExecutor, + () -> createSpanBuilder("remote-task-cleaner", span)); + this.dynamicFiltersFetcher = new DynamicFiltersFetcher( this::fatalUnacknowledgedFailure, taskId, @@ -334,7 +341,8 @@ public HttpRemoteTask( maxErrorDuration, errorScheduledExecutor, stats, - dynamicFilterService); + dynamicFilterService, + remoteTaskCleaner); this.taskStatusFetcher = new ContinuousTaskStatusFetcher( this::fatalUnacknowledgedFailure, @@ -347,12 +355,14 @@ public HttpRemoteTask( () -> createSpanBuilder("task-status", span), maxErrorDuration, errorScheduledExecutor, - stats); + stats, + remoteTaskCleaner); RetryPolicy retryPolicy = getRetryPolicy(session); this.taskInfoFetcher = new TaskInfoFetcher( this::fatalUnacknowledgedFailure, taskStatusFetcher, + remoteTaskCleaner, initialTask, httpClient, () -> createSpanBuilder("task-info", span), diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/RemoteTaskCleaner.java b/core/trino-main/src/main/java/io/trino/server/remotetask/RemoteTaskCleaner.java new file mode 100644 index 000000000000..086a400f7dbc --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/RemoteTaskCleaner.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.remotetask; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.Request; +import io.airlift.http.client.StatusResponseHandler.StatusResponse; +import io.airlift.log.Logger; +import io.opentelemetry.api.trace.SpanBuilder; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; + +import java.net.URI; +import java.util.concurrent.Executor; +import java.util.function.Supplier; + +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.airlift.http.client.Request.Builder.preparePost; +import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; +import static java.util.Objects.requireNonNull; + +public class RemoteTaskCleaner +{ + private static final Logger log = Logger.get(RemoteTaskCleaner.class); + + private final TaskId taskId; + private final URI taskUri; + private final HttpClient httpClient; + private final Executor executor; + private final Supplier spanBuilderFactory; + + @GuardedBy("this") + private boolean taskStatusFetcherStopped; + + @GuardedBy("this") + private boolean taskInfoFetcherStopped; + + @GuardedBy("this") + private boolean dynamicFilterFetcherStopped; + + @GuardedBy("this") + private TaskState taskState; + + public RemoteTaskCleaner(TaskId taskId, URI taskUri, HttpClient httpClient, Executor executor, Supplier spanBuilderFactory) + { + this.taskId = requireNonNull(taskId, "taskId is null"); + this.taskUri = requireNonNull(taskUri, "taskUri is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.spanBuilderFactory = requireNonNull(spanBuilderFactory, "spanBuilderFactory is null"); + } + + public synchronized void markTaskStatusFetcherStopped(TaskState taskState) + { + if (taskStatusFetcherStopped) { + return; + } + taskStatusFetcherStopped = true; + this.taskState = taskState; + cleanupIfReady(); + } + + public synchronized void markTaskInfoFetcherStopped() + { + if (taskInfoFetcherStopped) { + return; + } + taskInfoFetcherStopped = true; + cleanupIfReady(); + } + + public synchronized void markDynamicFilterFetcherStopped() + { + if (dynamicFilterFetcherStopped) { + return; + } + dynamicFilterFetcherStopped = true; + cleanupIfReady(); + } + + @GuardedBy("this") + private void cleanupIfReady() + { + if (taskState != TaskState.FINISHED) { + // we do not perform early cleanup if task did not finish successfully. + // other workers may still reach out for the results; and we have no control over that. + return; + } + if (taskStatusFetcherStopped && taskInfoFetcherStopped && dynamicFilterFetcherStopped) { + scheduleCleanupRequest(); + } + } + + private void scheduleCleanupRequest() + { + executor.execute( + () -> { + Request request = preparePost() + .setUri(uriBuilderFrom(taskUri) + .appendPath("/cleanup") + .build()) + .setSpanBuilder(spanBuilderFactory.get()) + .build(); + + StatusResponse response = httpClient.execute(request, createStatusResponseHandler()); + if (response.getStatusCode() != 200) { + log.warn("Failed to cleanup task %s: %s", taskId, response.getStatusCode()); + } + }); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java index af3af8e8d202..8bb7170140b8 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java @@ -67,6 +67,7 @@ public class TaskInfoFetcher private final TaskId taskId; private final Consumer onFail; private final ContinuousTaskStatusFetcher taskStatusFetcher; + private final RemoteTaskCleaner remoteTaskCleaner; private final StateMachine taskInfo; private final StateMachine> finalTaskInfo; private final JsonCodec taskInfoCodec; @@ -100,6 +101,7 @@ public class TaskInfoFetcher public TaskInfoFetcher( Consumer onFail, ContinuousTaskStatusFetcher taskStatusFetcher, + RemoteTaskCleaner remoteTaskCleaner, TaskInfo initialTask, HttpClient httpClient, Supplier spanBuilderFactory, @@ -120,6 +122,7 @@ public TaskInfoFetcher( this.taskId = initialTask.taskStatus().getTaskId(); this.onFail = requireNonNull(onFail, "onFail is null"); this.taskStatusFetcher = requireNonNull(taskStatusFetcher, "taskStatusFetcher is null"); + this.remoteTaskCleaner = requireNonNull(remoteTaskCleaner, "remoteTaskCleaner is null"); this.taskInfo = new StateMachine<>("task " + taskId, executor, initialTask); this.finalTaskInfo = new StateMachine<>("task-" + taskId, executor, Optional.empty()); this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null"); @@ -163,6 +166,7 @@ private synchronized void stop() if (scheduledFuture != null) { scheduledFuture.cancel(true); } + remoteTaskCleaner.markTaskInfoFetcherStopped(); } /** diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 0b63590dffd3..a0df9a3205f7 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -217,6 +217,8 @@ public void testRegular() poll(() -> remoteTask.getTaskStatus().getState().isDone()); poll(() -> remoteTask.getTaskInfo().taskStatus().getState().isDone()); + assertEventually(new Duration(5, SECONDS), () -> assertThat(testingTaskResource.isCleanupCalled()).isTrue()); + httpRemoteTaskFactory.stop(); } @@ -590,6 +592,9 @@ private void runTest(FailureScenario failureScenario) .describedAs(format("TaskStatus is not in a done state: %s", remoteTask.getTaskStatus())) .isTrue(); + // explicit cleanup not done for failed tasks + assertThat(testingTaskResource.isCleanupCalled()).isFalse(); + ErrorCode actualErrorCode = getOnlyElement(remoteTask.getTaskStatus().getFailures()).getErrorCode(); switch (failureScenario) { case TASK_MISMATCH: @@ -787,6 +792,7 @@ public static class TestingTaskResource private long createOrUpdateCounter; private long dynamicFiltersFetchCounter; private long dynamicFiltersSentCounter; + private boolean cleanupCalled; private final List dynamicFiltersFetchRequests = new ArrayList<>(); public TestingTaskResource(AtomicLong lastActivityNanos, FailureScenario failureScenario) @@ -902,6 +908,18 @@ public synchronized TaskInfo deleteTask( return buildTaskInfo(); } + @POST + @Path("{taskId}/cleanup") + public void cleanupTask(@PathParam("taskId") TaskId taskId) + { + cleanupCalled = true; + } + + public boolean isCleanupCalled() + { + return cleanupCalled; + } + public void setInitialTaskInfo(TaskInfo initialTaskInfo) { this.initialTaskInfo = initialTaskInfo;