diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java index 6d41c662350e..aa6b081375e5 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java @@ -18,11 +18,11 @@ import io.trino.plugin.adb.connector.encode.RowEncoder; import io.trino.plugin.adb.connector.protocol.gpfdist.Context; import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; +import io.trino.plugin.adb.connector.protocol.gpfdist.load.process.GpfdistPageProcessorProvider; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.ContextId; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistLoadMetadata; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -32,31 +32,33 @@ public class WriteContext implements Context { private static final Logger log = Logger.get(WriteContext.class); - private final ContextId id; - private final GpfdistLoadMetadata metadata; - private final ConcurrentLinkedQueue pageProcessors = new ConcurrentLinkedQueue<>(); private final AtomicReference adbQueryException = new AtomicReference<>(); private final AtomicLong completedBytes = new AtomicLong(); private final AtomicLong memoryUsage = new AtomicLong(); - private final AtomicBoolean isReadyForWrite = new AtomicBoolean(false); + private final AtomicReference error = new AtomicReference<>(); + private final ContextId id; + private final GpfdistLoadMetadata metadata; private final RowEncoder rowEncoder; private final DataSize writeBufferSize; - private final AtomicReference error = new AtomicReference<>(); + private final GpfdistPageProcessorProvider pageProcessorProvider; - public WriteContext(GpfdistLoadMetadata metadata, RowEncoder rowEncoder, DataSize writeBufferSize) + public WriteContext(GpfdistLoadMetadata metadata, RowEncoder rowEncoder, DataSize writeBufferSize, + GpfdistPageProcessorProvider pageProcessorProvider) { - this(new ContextId(metadata.getSourceTable()), metadata, rowEncoder, writeBufferSize); + this(new ContextId(metadata.getSourceTable()), metadata, rowEncoder, writeBufferSize, pageProcessorProvider); } public WriteContext(ContextId id, GpfdistLoadMetadata metadata, RowEncoder rowEncoder, - DataSize writeBufferSize) + DataSize writeBufferSize, + GpfdistPageProcessorProvider pageProcessorProvider) { this.id = id; this.metadata = metadata; this.rowEncoder = rowEncoder; this.writeBufferSize = writeBufferSize; + this.pageProcessorProvider = pageProcessorProvider; } @Override @@ -80,11 +82,6 @@ public AtomicReference getAdbQueryException() return adbQueryException; } - public ConcurrentLinkedQueue getPageProcessors() - { - return pageProcessors; - } - public AtomicLong getCompletedBytes() { return completedBytes; @@ -95,11 +92,6 @@ public AtomicLong getMemoryUsage() return memoryUsage; } - public AtomicBoolean getIsReadyForWrite() - { - return isReadyForWrite; - } - public DataSize getWriteBufferSize() { return writeBufferSize; @@ -110,9 +102,15 @@ public AtomicReference getError() return error; } + public GpfdistPageProcessorProvider getPageProcessorProvider() + { + return pageProcessorProvider; + } + @Override public void close() { + ConcurrentLinkedQueue pageProcessors = pageProcessorProvider.getAll(); StringBuilder sb = new StringBuilder(); pageProcessors.forEach(processor -> { try { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java new file mode 100644 index 000000000000..0f1d2da9330b --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; + +import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; + +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import static java.lang.String.format; + +public class GpfdistPageProcessorProvider + implements PageProcessorProvider +{ + private static final long ADB_SEGMENT_WAIT_TIMEOUT = 60000L; + private final ConcurrentLinkedQueue pageProcessors = new ConcurrentLinkedQueue<>(); + private final AtomicBoolean isReadyForProcessing = new AtomicBoolean(false); + private final ReentrantLock lock = new ReentrantLock(); + private final Condition isReadyForProcessingCondition = lock.newCondition(); + + public GpfdistPageProcessorProvider() + { + } + + @Override + public void add(PageProcessor processor) + { + lock.lock(); + try { + pageProcessors.add(processor); + isReadyForProcessing.set(true); + isReadyForProcessingCondition.signalAll(); + } + finally { + lock.unlock(); + } + } + + @Override + public PageProcessor take() + { + lock.lock(); + try { + if (!isReadyForProcessing.get()) { + long startTime = System.currentTimeMillis(); + while (pageProcessors.isEmpty()) { + try { + if (System.currentTimeMillis() - startTime > ADB_SEGMENT_WAIT_TIMEOUT) { + throw new RuntimeException( + format("Timeout :%d ms waiting for segments responses is exceeded", + ADB_SEGMENT_WAIT_TIMEOUT)); + } + isReadyForProcessingCondition.await(); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + PageProcessor pageProcessor = pageProcessors.poll(); + pageProcessors.offer(pageProcessor); + return pageProcessor; + } + finally { + lock.unlock(); + } + } + + @Override + public ConcurrentLinkedQueue getAll() + { + return pageProcessors; + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java index 4de726c4acdf..9719b6ca4744 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java @@ -35,7 +35,6 @@ public class GpfdistPageSink implements ConnectorPageSink { private static final Logger log = Logger.get(GpfdistPageSink.class); - private static final long ADB_SEGMENT_WAIT_TIMEOUT = 60000L; private final ContextManager writeContextManager; private final WriteContext writeContext; private final CompletableFuture queryLoadFuture; @@ -57,10 +56,8 @@ public GpfdistPageSink(ContextManager writeContextManager, public CompletableFuture appendPage(Page page) { pageProcessingFuture = CompletableFuture.runAsync(() -> { - waitForProcessors(); if (writeContext.getAdbQueryException().get() == null) { - PageProcessor pageProcessor = writeContext.getPageProcessors().poll(); - writeContext.getPageProcessors().offer(pageProcessor); + PageProcessor pageProcessor = writeContext.getPageProcessorProvider().take(); pageProcessor.process(page); } else { @@ -70,26 +67,6 @@ public CompletableFuture appendPage(Page page) return pageProcessingFuture; } - private void waitForProcessors() - { - try { - if (!writeContext.getIsReadyForWrite().get()) { - long startTime = System.currentTimeMillis(); - while (writeContext.getPageProcessors().isEmpty()) { - if (System.currentTimeMillis() - startTime > ADB_SEGMENT_WAIT_TIMEOUT) { - throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, - "Timed out after waiting for ${ADB_SEGMENT_WAIT_TIMEOUT} ms for segments"); - } - Thread.sleep(100L); - } - writeContext.getIsReadyForWrite().set(true); - } - } - catch (InterruptedException e) { - throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, e); - } - } - @Override public CompletableFuture> finish() { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java index 793bf35ef927..5bb7d8743aea 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java @@ -78,8 +78,10 @@ public GpfdistPageSinkProvider(@ForBaseJdbc JdbcClient client, this.rowEncoderFactory = rowEncoderFactory; this.externalTableFormatConfigFactory = externalTableFormatConfigFactory; this.loadQueryThreadExecutor = ExecutorServiceProvider.LOAD_DATA_QUERY_EXECUTOR_SERVICE; - Map externalTableQueryFactoryMap = createExternalTableQueryFactories.stream() - .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, Function.identity())); + Map externalTableQueryFactoryMap = + createExternalTableQueryFactories.stream() + .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, + Function.identity())); externalTableCreateQueryFactory = externalTableQueryFactoryMap.get(EXTERNAL_TABLE_TYPE); checkArgument(externalTableCreateQueryFactory != null, "failed to get writable table query factory by externalTableType %s", @@ -107,7 +109,8 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) { - return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, pageSinkId); + return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, + pageSinkId); } private ConnectorPageSink createPageSinkInternal(ConnectorTransactionHandle transactionHandle, @@ -122,7 +125,8 @@ private ConnectorPageSink createPageSinkInternal(ConnectorTransactionHandle tran WriteContext writeContext = new WriteContext( loadMetadata, rowEncoderFactory.create(session, loadMetadata.getDataTypes()), - pluginConfig.getWriteBufferSize()); + pluginConfig.getWriteBufferSize(), + new GpfdistPageProcessorProvider()); DataTransferQueryExecutor loadDataExecutor = new GpfdistLoadDataTransferQueryExecutor(client, session, loadQueryThreadExecutor, diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java new file mode 100644 index 000000000000..cca2f093e9c0 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; + +import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; + +import java.util.concurrent.ConcurrentLinkedQueue; + +public interface PageProcessorProvider +{ + void add(PageProcessor processor); + + PageProcessor take(); + + ConcurrentLinkedQueue getAll(); +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java index 72e4a2a2ec0f..8ed1e85e984f 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java @@ -81,7 +81,8 @@ public GpfdistResource(ContextManager writeContextManager, @GET @Produces("text/plain") @Path("/read/{tableName}") - public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { GpfdistReadableRequest request = GpfdistReadableRequest.create(tableName, headers.getRequestHeaders()); checkArgument(request.getGpProtocol() == GPFDIST_FOR_READ_PROTOCOL_VERSION, @@ -101,7 +102,8 @@ public void get(@PathParam("tableName") String tableName, @Context HttpHeaders h } } - private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, WriteContext writeContext) + private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, + WriteContext writeContext) { int bufferSizeInBytes = Long.valueOf(pluginConfig.getWriteBufferSize().toBytes()).intValue(); try (PipedOutputStream outputStream = new PipedOutputStream(); @@ -110,7 +112,7 @@ private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableReque request, writeContext, new GpfdistPageSerializer(writeContext.getMetadata().getDataTypes(), writeContext.getRowEncoder())); - writeContext.getPageProcessors().add(gpfdistPageProcessor); + writeContext.getPageProcessorProvider().add(gpfdistPageProcessor); asyncResponse.resume(createOkGetResponseBuilder(request) .entity(inputStream) .build()); @@ -139,13 +141,15 @@ private Response.ResponseBuilder createOkGetResponseBuilder(GpfdistReadableReque @POST @Consumes("*/*") @Path("/write/{tableName}") - public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { try { GpfdistWritableRequest request = GpfdistWritableRequest.create(tableName, headers.getRequestHeaders()); log.debug("Received POST request: %s", request); checkArgument(request.getGpProtocol() == GPFDIST_FOR_WRITE_PROTOCOL_VERSION, - format("Gpfdist protocol version %s for write operation is supported", GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); + format("Gpfdist protocol version %s for write operation is supported", + GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); Optional readContextOptional = readContextManager.get(new ContextId(tableName)); if (readContextOptional.isEmpty()) { processNotFoundQueryRequest(tableName, asyncResponse, request); @@ -170,7 +174,8 @@ public void post(@PathParam("tableName") String tableName, InputStream data, @Co } } - private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, GpfdistWritableRequest request) + private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, + GpfdistWritableRequest request) { String errorMessage = "No active query for writeable table: " + tableName; asyncResponse.resume(Response.status(Response.Status.BAD_REQUEST.getStatusCode(), errorMessage) @@ -179,7 +184,8 @@ private static void processNotFoundQueryRequest(String tableName, AsyncResponse log.error("Failed to processed request: %s. " + errorMessage, request); } - private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { InputDataProcessor dataProcessor = inputDataProcessorFactory.create(readContext.getRowDecoder(), readContext.getRowProcessingService()); @@ -191,7 +197,8 @@ private void processInitialRequest(AsyncResponse asyncResponse, ReadContext read log.debug("Request for initial data transferring completed successfully: %s", request); } - private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { executorService.submit(() -> { try { @@ -208,7 +215,8 @@ private void processDataRequest(InputStream data, AsyncResponse asyncResponse, R }); } - private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { GpfdistSegmentRequestProcessor processor = getSegmentProcessor(readContext, request.getSegmentId()); processor.stop(); @@ -221,7 +229,8 @@ private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext rea private GpfdistSegmentRequestProcessor getSegmentProcessor(ReadContext readContext, Integer segmentId) { return Optional.ofNullable(readContext.getSegmentDataProcessors().get(segmentId)) - .orElseThrow(() -> new IllegalStateException("Failed to get segment request processor by segmentId: " + segmentId)); + .orElseThrow(() -> new IllegalStateException( + "Failed to get segment request processor by segmentId: " + segmentId)); } private void failWriteResponse(AsyncResponse asyncResponse, Exception e)