diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java index 19bb7339a..1df2b0eae 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java @@ -298,7 +298,7 @@ public void closeAll() { partitionsToChannel.clear(); StreamingClientProvider.getStreamingClientProviderInstance() - .closeClient(this.streamingIngestClient); + .closeClient(this.connectorConfig, this.streamingIngestClient); } /** diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandler.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandler.java index aa5c4da2c..367f99ce3 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandler.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandler.java @@ -17,15 +17,8 @@ package com.snowflake.kafka.connector.internal.streaming; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION; -import static net.snowflake.ingest.utils.ParameterProvider.BLOB_FORMAT_VERSION; - import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.internal.KCLogger; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.Properties; import java.util.concurrent.atomic.AtomicInteger; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; @@ -35,9 +28,6 @@ /** This class handles all calls to manage the streaming ingestion client */ public class StreamingClientHandler { private static final KCLogger LOGGER = new KCLogger(StreamingClientHandler.class.getName()); - private static final String STREAMING_CLIENT_PREFIX_NAME = "KC_CLIENT_"; - private static final String TEST_CLIENT_NAME = "TEST_CLIENT"; - private AtomicInteger createdClientId = new AtomicInteger(0); /** @@ -51,40 +41,27 @@ public static boolean isClientValid(SnowflakeStreamingIngestClient client) { } /** - * Creates a streaming client from the given config + * Creates a streaming client from the given properties * - * @param connectorConfig The config to create the client + * @param streamingClientProperties The properties to create the client * @return A newly created client */ - public SnowflakeStreamingIngestClient createClient(Map connectorConfig) { + public SnowflakeStreamingIngestClient createClient( + StreamingClientProperties streamingClientProperties) { LOGGER.info("Initializing Streaming Client..."); - // get streaming properties from config - Properties streamingClientProps = new Properties(); - streamingClientProps.putAll( - StreamingUtils.convertConfigForStreamingClient(new HashMap<>(connectorConfig))); - try { - // Override only if bdec version is explicitly set in config, default to the version set - // inside Ingest SDK - Map parameterOverrides = new HashMap<>(); - Optional snowpipeStreamingBdecVersion = - Optional.ofNullable(connectorConfig.get(SNOWPIPE_STREAMING_FILE_VERSION)); - snowpipeStreamingBdecVersion.ifPresent( - overriddenValue -> { - LOGGER.info("Config is overridden for {} ", SNOWPIPE_STREAMING_FILE_VERSION); - parameterOverrides.put(BLOB_FORMAT_VERSION, overriddenValue); - }); - - String clientName = this.getNewClientName(connectorConfig); - SnowflakeStreamingIngestClient createdClient = - SnowflakeStreamingIngestClientFactory.builder(clientName) - .setProperties(streamingClientProps) - .setParameterOverrides(parameterOverrides) + SnowflakeStreamingIngestClientFactory.builder( + streamingClientProperties.clientName + "_" + createdClientId.getAndIncrement()) + .setProperties(streamingClientProperties.clientProperties) + .setParameterOverrides(streamingClientProperties.parameterOverrides) .build(); - LOGGER.info("Successfully initialized Streaming Client:{}", clientName); + LOGGER.info( + "Successfully initialized Streaming Client:{} with properties {}", + streamingClientProperties.clientName, + streamingClientProperties.getLoggableClientProperties()); return createdClient; } catch (SFException ex) { @@ -115,11 +92,4 @@ public void closeClient(SnowflakeStreamingIngestClient client) { LOGGER.error(Utils.getExceptionMessage("Failure closing Streaming client", e)); } } - - private String getNewClientName(Map connectorConfig) { - return STREAMING_CLIENT_PREFIX_NAME - + connectorConfig.getOrDefault(Utils.NAME, TEST_CLIENT_NAME) - + "_" - + createdClientId.getAndIncrement(); - } } diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProperties.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProperties.java new file mode 100644 index 000000000..0e19e6005 --- /dev/null +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProperties.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2023 Snowflake Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.snowflake.kafka.connector.internal.streaming; + +import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION; +import static net.snowflake.ingest.utils.ParameterProvider.BLOB_FORMAT_VERSION; + +import com.snowflake.kafka.connector.Utils; +import com.snowflake.kafka.connector.internal.KCLogger; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import net.snowflake.ingest.utils.Constants; + +/** + * Object to convert and store properties for {@link + * net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient}. This object is used to compare + * equality between clients in {@link StreamingClientProvider}. + */ +public class StreamingClientProperties { + public static final String STREAMING_CLIENT_PREFIX_NAME = "KC_CLIENT_"; + public static final String DEFAULT_CLIENT_NAME = "DEFAULT_CLIENT"; + + private static final KCLogger LOGGER = new KCLogger(StreamingClientProperties.class.getName()); + + // contains converted config properties that are loggable (not PII data) + public static final List LOGGABLE_STREAMING_CONFIG_PROPERTIES = + Stream.of( + Constants.ACCOUNT_URL, + Constants.ROLE, + Constants.USER, + StreamingUtils.STREAMING_CONSTANT_AUTHORIZATION_TYPE) + .collect(Collectors.toList()); + + public final Properties clientProperties; + public final String clientName; + public final Map parameterOverrides; + + /** + * Creates non-null properties, client name and parameter overrides for the {@link + * net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient} from the given connectorConfig + * Properties are created by {@link StreamingUtils#convertConfigForStreamingClient(Map)} and are a + * subset of the given connector configuration + * + * @param connectorConfig Given connector configuration. Null configs are treated as an empty map + */ + public StreamingClientProperties(Map connectorConfig) { + // treat null config as empty config + if (connectorConfig == null) { + LOGGER.warn( + "Creating empty streaming client properties because given connector config was empty"); + connectorConfig = new HashMap<>(); + } + + this.clientProperties = StreamingUtils.convertConfigForStreamingClient(connectorConfig); + + this.clientName = + STREAMING_CLIENT_PREFIX_NAME + + connectorConfig.getOrDefault(Utils.NAME, DEFAULT_CLIENT_NAME); + + // Override only if bdec version is explicitly set in config, default to the version set + // inside Ingest SDK + this.parameterOverrides = new HashMap<>(); + Optional snowpipeStreamingBdecVersion = + Optional.ofNullable(connectorConfig.get(SNOWPIPE_STREAMING_FILE_VERSION)); + snowpipeStreamingBdecVersion.ifPresent( + overriddenValue -> { + LOGGER.info("Config is overridden for {} ", SNOWPIPE_STREAMING_FILE_VERSION); + parameterOverrides.put(BLOB_FORMAT_VERSION, overriddenValue); + }); + } + + /** + * Gets the loggable properties, see {@link + * StreamingClientProperties#LOGGABLE_STREAMING_CONFIG_PROPERTIES} + * + * @return A formatted string with the loggable properties + */ + public String getLoggableClientProperties() { + return this.clientProperties == null | this.clientProperties.isEmpty() + ? "" + : this.clientProperties.entrySet().stream() + .filter( + propKvp -> + LOGGABLE_STREAMING_CONFIG_PROPERTIES.stream() + .anyMatch(propKvp.getKey().toString()::equalsIgnoreCase)) + .collect(Collectors.toList()) + .toString(); + } + + /** + * Determines equality between StreamingClientProperties by only looking at the parsed + * clientProperties. This is used in {@link StreamingClientProvider} to determine equality in + * registered clients + * + * @param other other object to determine equality + * @return if the given object's clientProperties exists and is equal + */ + @Override + public boolean equals(Object other) { + return other.getClass().equals(StreamingClientProperties.class) + & ((StreamingClientProperties) other).clientProperties.equals(this.clientProperties); + } + + /** + * Creates the hashcode for this object from the clientProperties. This is used in {@link + * StreamingClientProvider} to determine equality in registered clients + * + * @return the clientProperties' hashcode + */ + @Override + public int hashCode() { + return this.clientProperties.hashCode(); + } +} diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java index 05ef856e0..f88c60fe1 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java @@ -23,14 +23,19 @@ import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; import com.snowflake.kafka.connector.internal.KCLogger; import java.util.Map; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; +import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.Caffeine; +import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; +import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.RemovalCause; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; /** - * Factory that provides the streaming client(s). There should only be one provider, but it may - * provide multiple clients if optimizations are disabled - see - * ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG in the {@link SnowflakeSinkConnectorConfig } + * Static factory that provides streaming client(s). If {@link + * SnowflakeSinkConnectorConfig#ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG} is disabled then the + * provider will always create a new client. If the optimization is enabled, then the provider will + * reuse clients when possible by registering clients internally. Since this is a static factory, + * clients will be reused on a per Kafka worker node and based on it's {@link + * StreamingClientProperties}. This means that multiple connectors/tasks on the same Kafka worker + * node with equal {@link StreamingClientProperties} will use the same client */ public class StreamingClientProvider { private static class StreamingClientProviderSingleton { @@ -47,72 +52,141 @@ public static StreamingClientProvider getStreamingClientProviderInstance() { return StreamingClientProviderSingleton.streamingClientProvider; } - /** ONLY FOR TESTING - to get a provider with injected properties */ - @VisibleForTesting - public static StreamingClientProvider getStreamingClientProviderForTests( - SnowflakeStreamingIngestClient parameterEnabledClient, - StreamingClientHandler streamingClientHandler) { - return new StreamingClientProvider(parameterEnabledClient, streamingClientHandler); - } - - /** ONLY FOR TESTING - private constructor to inject properties for testing */ - private StreamingClientProvider( - SnowflakeStreamingIngestClient parameterEnabledClient, - StreamingClientHandler streamingClientHandler) { - this(); - this.parameterEnabledClient = parameterEnabledClient; - this.streamingClientHandler = streamingClientHandler; + /** + * Builds a threadsafe loading cache to register at max 10,000 streaming clients. It maps each + * {@link StreamingClientProperties} to it's corresponding {@link SnowflakeStreamingIngestClient} + * + * @param streamingClientHandler The handler to create clients with + * @return A loading cache to register clients + */ + public static LoadingCache + buildLoadingCache(StreamingClientHandler streamingClientHandler) { + return Caffeine.newBuilder() + .maximumSize(10000) // limit 10,000 clients + .evictionListener( + (StreamingClientProperties key, + SnowflakeStreamingIngestClient client, + RemovalCause removalCause) -> { + streamingClientHandler.closeClient(client); + LOGGER.info( + "Removed registered client {} due to {}", + client.getName(), + removalCause.toString()); + }) + .build(streamingClientHandler::createClient); } + /***************************** BEGIN SINGLETON CODE *****************************/ private static final KCLogger LOGGER = new KCLogger(StreamingClientProvider.class.getName()); - private SnowflakeStreamingIngestClient parameterEnabledClient; + private StreamingClientHandler streamingClientHandler; - private Lock providerLock; + private LoadingCache registeredClients; - // private constructor for singleton + /** + * Private constructor to retain singleton + * + *

If the one client optimization is enabled, this creates a threadsafe {@link LoadingCache} to + * register created clients based on the corresponding {@link StreamingClientProperties} built + * from the given connector configuration. The cache calls streamingClientHandler to create the + * client if the requested streaming client properties has not already been loaded into the cache. + * When a client is evicted, the cache will try closing the client, however it is best to still + * call close client manually as eviction is executed lazily + */ private StreamingClientProvider() { this.streamingClientHandler = new StreamingClientHandler(); - providerLock = new ReentrantLock(true); + this.registeredClients = buildLoadingCache(this.streamingClientHandler); } /** * Gets the current client or creates a new one from the given connector config. If client * optimization is not enabled, it will create a new streaming client and the caller is - * responsible for closing it + * responsible for closing it. If the optimization is enabled and the registered client is + * invalid, we will try recreating and reregistering the client * * @param connectorConfig The connector config * @return A streaming client */ public SnowflakeStreamingIngestClient getClient(Map connectorConfig) { - if (Boolean.parseBoolean( - connectorConfig.getOrDefault( - SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, - Boolean.toString(ENABLE_STREAMING_CLIENT_OPTIMIZATION_DEFAULT)))) { - LOGGER.info( - "Streaming client optimization is enabled, returning the existing streaming client if" - + " valid"); - this.providerLock.lock(); - // recreate streaming client if needed - if (!StreamingClientHandler.isClientValid(this.parameterEnabledClient)) { - LOGGER.error("Current streaming client is invalid, recreating client"); - this.parameterEnabledClient = this.streamingClientHandler.createClient(connectorConfig); + SnowflakeStreamingIngestClient resultClient; + StreamingClientProperties clientProperties = new StreamingClientProperties(connectorConfig); + final boolean isOptimizationEnabled = + Boolean.parseBoolean( + connectorConfig.getOrDefault( + SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, + Boolean.toString(ENABLE_STREAMING_CLIENT_OPTIMIZATION_DEFAULT))); + + if (isOptimizationEnabled) { + resultClient = this.registeredClients.get(clientProperties); + + // refresh if registered client is invalid + if (!StreamingClientHandler.isClientValid(resultClient)) { + LOGGER.warn( + "Registered streaming client is not valid, recreating and registering new client"); + resultClient = this.streamingClientHandler.createClient(clientProperties); + this.registeredClients.put(clientProperties, resultClient); } - this.providerLock.unlock(); - return this.parameterEnabledClient; } else { - LOGGER.info("Streaming client optimization is disabled, creating a new streaming client"); - return this.streamingClientHandler.createClient(connectorConfig); + resultClient = this.streamingClientHandler.createClient(clientProperties); } + + LOGGER.info( + "Streaming client optimization is {}. Returning client with name: {}", + isOptimizationEnabled + ? "enabled per worker node, KC will reuse valid clients when possible" + : "disabled, KC will create new clients", + resultClient.getName()); + + return resultClient; } /** - * Closes the given client + * Closes the given client and deregisters it from the cache if necessary. It will also call close + * on the registered client if exists, which should be the same as the given client so the call + * will no-op. * + * @param connectorConfig The configuration to deregister from the cache * @param client The client to be closed */ - public void closeClient(SnowflakeStreamingIngestClient client) { - this.providerLock.lock(); + public void closeClient( + Map connectorConfig, SnowflakeStreamingIngestClient client) { + StreamingClientProperties clientProperties = new StreamingClientProperties(connectorConfig); + + // invalidate cache + SnowflakeStreamingIngestClient registeredClient = + this.registeredClients.getIfPresent(clientProperties); + if (registeredClient != null) { + // invalidations are processed on the next get or in the background, so we still need to close + // the client here + this.registeredClients.invalidate(clientProperties); + this.streamingClientHandler.closeClient(registeredClient); + } + + // also close given client in case it is different from registered client. this should no-op if + // it is already closed this.streamingClientHandler.closeClient(client); - this.providerLock.unlock(); + } + + // TEST ONLY - to get a provider with injected properties + @VisibleForTesting + public static StreamingClientProvider getStreamingClientProviderForTests( + StreamingClientHandler streamingClientHandler, + LoadingCache registeredClients) { + return new StreamingClientProvider(streamingClientHandler, registeredClients); + } + + // TEST ONLY - private constructor to inject properties for testing + @VisibleForTesting + private StreamingClientProvider( + StreamingClientHandler streamingClientHandler, + LoadingCache registeredClients) { + this(); + this.streamingClientHandler = streamingClientHandler; + this.registeredClients = registeredClients; + } + + // TEST ONLY - return the current state of the registered clients + @VisibleForTesting + public Map getRegisteredClients() { + return this.registeredClients.asMap(); } } diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingUtils.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingUtils.java index e5cfd46cb..c9edc1686 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingUtils.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingUtils.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.Properties; import java.util.Set; import net.snowflake.ingest.utils.Constants; import org.apache.kafka.common.config.ConfigException; @@ -66,34 +67,34 @@ public class StreamingUtils { public static final int MAX_RECORD_OVERHEAD_BYTES = DefaultRecord.MAX_RECORD_OVERHEAD; /* Maps streaming client's property keys to what we got from snowflake KC config file. */ - public static Map convertConfigForStreamingClient( + public static Properties convertConfigForStreamingClient( Map connectorConfig) { - Map streamingPropertiesMap = new HashMap<>(); + Properties streamingProperties = new Properties(); connectorConfig.computeIfPresent( Utils.SF_URL, (key, value) -> { - streamingPropertiesMap.put(Constants.ACCOUNT_URL, value); + streamingProperties.put(Constants.ACCOUNT_URL, value); return value; }); connectorConfig.computeIfPresent( Utils.SF_ROLE, (key, value) -> { - streamingPropertiesMap.put(Constants.ROLE, value); + streamingProperties.put(Constants.ROLE, value); return value; }); connectorConfig.computeIfPresent( Utils.SF_USER, (key, value) -> { - streamingPropertiesMap.put(Constants.USER, value); + streamingProperties.put(Constants.USER, value); return value; }); connectorConfig.computeIfPresent( Utils.SF_PRIVATE_KEY, (key, value) -> { - streamingPropertiesMap.put(Constants.PRIVATE_KEY, value); + streamingProperties.put(Constants.PRIVATE_KEY, value); return value; }); @@ -101,11 +102,12 @@ public static Map convertConfigForStreamingClient( Utils.PRIVATE_KEY_PASSPHRASE, (key, value) -> { if (!value.isEmpty()) { - streamingPropertiesMap.put(Constants.PRIVATE_KEY_PASSPHRASE, value); + streamingProperties.put(Constants.PRIVATE_KEY_PASSPHRASE, value); } return value; }); - return streamingPropertiesMap; + + return streamingProperties; } /* Returns true if sf connector config has error.tolerance = ALL */ diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java index 71acb037e..9a1ca79e3 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java @@ -3,9 +3,11 @@ import static com.snowflake.kafka.connector.internal.streaming.TopicPartitionChannel.NO_OFFSET_TOKEN_REGISTERED_IN_SNOWFLAKE; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.dlq.InMemoryKafkaRecordErrorReporter; import com.snowflake.kafka.connector.internal.SchematizationTestUtils; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; +import com.snowflake.kafka.connector.internal.SnowflakeConnectionServiceFactory; import com.snowflake.kafka.connector.internal.SnowflakeErrors; import com.snowflake.kafka.connector.internal.SnowflakeSinkService; import com.snowflake.kafka.connector.internal.SnowflakeSinkServiceFactory; @@ -1166,4 +1168,92 @@ private void createNonNullableColumn(String tableName, String colName) { throw SnowflakeErrors.ERROR_2007.getException(e); } } + + // note this test relies on testrole_kafka and testrole_kafka_1 roles being granted to test_kafka + // user + @Test + public void testStreamingIngest_multipleChannel_distinctClients() throws Exception { + // create cat and dog configs and partitions + // one client is enabled but two clients should be created because different roles in config + String catTopic = "catTopic_" + TestUtils.randomTableName(); + Map catConfig = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(catConfig); + catConfig.put(SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, "true"); + catConfig.put(Utils.SF_OAUTH_CLIENT_ID, "1"); + catConfig.put(Utils.NAME, catTopic); + + String dogTopic = "dogTopic_" + TestUtils.randomTableName(); + Map dogConfig = TestUtils.getConfForStreaming(); + SnowflakeSinkConnectorConfig.setDefaultValues(dogConfig); + dogConfig.put(SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, "true"); + dogConfig.put(Utils.SF_OAUTH_CLIENT_ID, "2"); + dogConfig.put(Utils.NAME, dogTopic); + + // setup connection and create tables + TopicPartition catTp = new TopicPartition(catTopic, 0); + SnowflakeConnectionService catConn = + SnowflakeConnectionServiceFactory.builder().setProperties(catConfig).build(); + catConn.createTable(catTopic); + + TopicPartition dogTp = new TopicPartition(dogTopic, 1); + SnowflakeConnectionService dogconn = + SnowflakeConnectionServiceFactory.builder().setProperties(dogConfig).build(); + dogconn.createTable(dogTopic); + + // create the sink services + SnowflakeSinkService catService = + SnowflakeSinkServiceFactory.builder( + catConn, IngestionMethodConfig.SNOWPIPE_STREAMING, catConfig) + .setRecordNumber(1) + .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) + .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(catTp))) + .addTask(catTopic, catTp) // Internally calls startTask + .build(); + + SnowflakeSinkService dogService = + SnowflakeSinkServiceFactory.builder( + dogconn, IngestionMethodConfig.SNOWPIPE_STREAMING, dogConfig) + .setRecordNumber(1) + .setErrorReporter(new InMemoryKafkaRecordErrorReporter()) + .setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(dogTp))) + .addTask(dogTopic, dogTp) // Internally calls startTask + .build(); + + // create records + final int catRecordCount = 9; + final int dogRecordCount = 3; + + List catRecords = + TestUtils.createJsonStringSinkRecords(0, catRecordCount, catTp.topic(), catTp.partition()); + List dogRecords = + TestUtils.createJsonStringSinkRecords(0, dogRecordCount, dogTp.topic(), dogTp.partition()); + + // insert records + catService.insert(catRecords); + dogService.insert(dogRecords); + + // check data was ingested + TestUtils.assertWithRetry(() -> catService.getOffset(catTp) == catRecordCount, 20, 20); + TestUtils.assertWithRetry(() -> dogService.getOffset(dogTp) == dogRecordCount, 20, 20); + + // verify two clients were created + assert StreamingClientProvider.getStreamingClientProviderInstance() + .getRegisteredClients() + .containsKey(new StreamingClientProperties(catConfig)); + assert StreamingClientProvider.getStreamingClientProviderInstance() + .getRegisteredClients() + .containsKey(new StreamingClientProperties(dogConfig)); + + // close services + catService.closeAll(); + dogService.closeAll(); + + // verify both clients were closed + assert !StreamingClientProvider.getStreamingClientProviderInstance() + .getRegisteredClients() + .containsKey(new StreamingClientProperties(catConfig)); + assert !StreamingClientProvider.getStreamingClientProviderInstance() + .getRegisteredClients() + .containsKey(new StreamingClientProperties(dogConfig)); + } } diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientConcurrencyTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientConcurrencyTest.java index 567298725..03a3e35e8 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientConcurrencyTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientConcurrencyTest.java @@ -70,7 +70,8 @@ public void setup() { this.streamingClientHandler = Mockito.spy(StreamingClientHandler.class); this.streamingClientProvider = StreamingClientProvider.getStreamingClientProviderForTests( - null, this.streamingClientHandler); + this.streamingClientHandler, + StreamingClientProvider.buildLoadingCache(this.streamingClientHandler)); this.getClientFuturesTeardown = new ArrayList<>(); this.closeClientFuturesTeardown = new ArrayList<>(); @@ -117,12 +118,18 @@ public void testMultipleGetAndClose() throws Exception { getClient1Futures.add(this.callGetClientThread(task1Executor, task1Latch, clientConfig1)); closeClient1Futures.add( this.callCloseClientThread( - task1Executor, task1Latch, getClient1Futures.get(getClient1Futures.size() - 1).get())); + task1Executor, + task1Latch, + clientConfig1, + getClient1Futures.get(getClient1Futures.size() - 1).get())); getClient1Futures.add(this.callGetClientThread(task1Executor, task1Latch, clientConfig1)); createClientCount++; closeClient1Futures.add( this.callCloseClientThread( - task1Executor, task1Latch, getClient1Futures.get(getClient1Futures.size() - 1).get())); + task1Executor, + task1Latch, + clientConfig1, + getClient1Futures.get(getClient1Futures.size() - 1).get())); // task2: get client, close client x3, get client, close client CountDownLatch task2Latch = new CountDownLatch(7); @@ -134,18 +141,30 @@ public void testMultipleGetAndClose() throws Exception { getClient2Futures.add(this.callGetClientThread(task2Executor, task2Latch, clientConfig1)); closeClient2Futures.add( this.callCloseClientThread( - task2Executor, task2Latch, getClient2Futures.get(getClient2Futures.size() - 1).get())); + task2Executor, + task2Latch, + clientConfig2, + getClient2Futures.get(getClient2Futures.size() - 1).get())); closeClient2Futures.add( this.callCloseClientThread( - task2Executor, task2Latch, getClient2Futures.get(getClient2Futures.size() - 1).get())); + task2Executor, + task2Latch, + clientConfig2, + getClient2Futures.get(getClient2Futures.size() - 1).get())); closeClient2Futures.add( this.callCloseClientThread( - task2Executor, task2Latch, getClient2Futures.get(getClient2Futures.size() - 1).get())); + task2Executor, + task2Latch, + clientConfig2, + getClient2Futures.get(getClient2Futures.size() - 1).get())); getClient2Futures.add(this.callGetClientThread(task2Executor, task2Latch, clientConfig1)); createClientCount++; closeClient2Futures.add( this.callCloseClientThread( - task2Executor, task2Latch, getClient2Futures.get(getClient2Futures.size() - 1).get())); + task2Executor, + task2Latch, + clientConfig2, + getClient2Futures.get(getClient2Futures.size() - 1).get())); // task3: get client, close client CountDownLatch task3Latch = new CountDownLatch(3); @@ -157,18 +176,30 @@ public void testMultipleGetAndClose() throws Exception { getClient3Futures.add(this.callGetClientThread(task3Executor, task3Latch, clientConfig1)); closeClient3Futures.add( this.callCloseClientThread( - task3Executor, task3Latch, getClient3Futures.get(getClient3Futures.size() - 1).get())); + task3Executor, + task3Latch, + clientConfig3, + getClient3Futures.get(getClient3Futures.size() - 1).get())); // add final close to each task, kicking the threads off closeClient1Futures.add( this.callCloseClientThread( - task1Executor, task1Latch, getClient1Futures.get(getClient1Futures.size() - 1).get())); + task1Executor, + task1Latch, + clientConfig1, + getClient1Futures.get(getClient1Futures.size() - 1).get())); closeClient2Futures.add( this.callCloseClientThread( - task2Executor, task2Latch, getClient2Futures.get(getClient2Futures.size() - 1).get())); + task2Executor, + task2Latch, + clientConfig2, + getClient2Futures.get(getClient2Futures.size() - 1).get())); closeClient3Futures.add( this.callCloseClientThread( - task3Executor, task3Latch, getClient3Futures.get(getClient3Futures.size() - 1).get())); + task3Executor, + task3Latch, + clientConfig3, + getClient3Futures.get(getClient3Futures.size() - 1).get())); task1Latch.await(); task2Latch.await(); @@ -183,7 +214,7 @@ public void testMultipleGetAndClose() throws Exception { Mockito.verify( this.streamingClientHandler, Mockito.times(this.enableClientOptimization ? createClientCount : totalGetCount)) - .createClient(Mockito.anyMap()); + .createClient(Mockito.any()); Mockito.verify(this.streamingClientHandler, Mockito.times(totalCloseCount)) .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); } @@ -213,7 +244,7 @@ public void testGetClientConcurrency() throws Exception { Mockito.verify( this.streamingClientHandler, Mockito.times(this.enableClientOptimization ? 1 : numGetClientCalls)) - .createClient(Mockito.anyMap()); + .createClient(Mockito.any()); } @Test @@ -229,7 +260,7 @@ public void testCloseClientConcurrency() throws Exception { // start closeClient threads List> futures = new ArrayList<>(); for (int i = 0; i < numCloseClientCalls; i++) { - futures.add(this.callCloseClientThread(executorService, latch, client)); + futures.add(this.callCloseClientThread(executorService, latch, clientConfig, client)); } // wait for closeClient to complete @@ -240,8 +271,16 @@ public void testCloseClientConcurrency() throws Exception { Assert.assertFalse(StreamingClientHandler.isClientValid(future.get())); } - // Verify that closeClient() was called every time - Mockito.verify(this.streamingClientHandler, Mockito.times(numCloseClientCalls)) + // Verify that closeClient() at least once per thread + Mockito.verify(this.streamingClientHandler, Mockito.atLeast(numCloseClientCalls)) + .closeClient(client); + + // Verify that closeClient() was called at max twice per close thread. Because LoadingCache's + // invalidation happens async, we can't really expect an exact number of calls. The extra close + // client calls will no-op + Mockito.verify( + this.streamingClientHandler, + Mockito.atMost(numCloseClientCalls * (this.enableClientOptimization ? 2 : 1))) .closeClient(client); } @@ -261,11 +300,12 @@ private Future callGetClientThread( private Future callCloseClientThread( ExecutorService executorService, CountDownLatch countDownLatch, + Map config, SnowflakeStreamingIngestClient client) { Future future = executorService.submit( () -> { - streamingClientProvider.closeClient(client); + streamingClientProvider.closeClient(config, client); countDownLatch.countDown(); }); diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandlerTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandlerTest.java index 1a35c9589..07658ecf2 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandlerTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientHandlerTest.java @@ -39,24 +39,39 @@ public void setup() { } @Test - public void testCreateClient() { - SnowflakeStreamingIngestClient client = - this.streamingClientHandler.createClient(this.connectorConfig); + public void testCreateClient() throws Exception { + SnowflakeStreamingIngestClient client1 = + this.streamingClientHandler.createClient( + new StreamingClientProperties(this.connectorConfig)); // verify valid client against config - assert !client.isClosed(); - assert client.getName().contains(this.connectorConfig.get(Utils.NAME)); + assert !client1.isClosed(); + assert client1.getName().contains(this.connectorConfig.get(Utils.NAME) + "_0"); + + // create another client, confirm that the name was incremented + SnowflakeStreamingIngestClient client2 = + this.streamingClientHandler.createClient( + new StreamingClientProperties(this.connectorConfig)); + + // verify valid client against config + assert !client2.isClosed(); + assert client2.getName().contains(this.connectorConfig.get(Utils.NAME) + "_1"); + + // cleanup + client1.close(); + client2.close(); } @Test public void testCreateClientException() { // invalidate the config - this.connectorConfig.remove(Utils.SF_ROLE); + this.connectorConfig.remove(Utils.SF_PRIVATE_KEY); // private key is required try { - this.streamingClientHandler.createClient(this.connectorConfig); + this.streamingClientHandler.createClient(new StreamingClientProperties(this.connectorConfig)); } catch (ConnectException ex) { assert ex.getCause().getClass().equals(SFException.class); + throw ex; } } @@ -66,7 +81,7 @@ public void testCreateClientInvalidBdecVersion() { this.connectorConfig.put(SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION, "1"); // test create - this.streamingClientHandler.createClient(this.connectorConfig); + this.streamingClientHandler.createClient(new StreamingClientProperties(this.connectorConfig)); } @Test diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientPropertiesTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientPropertiesTest.java new file mode 100644 index 000000000..98c9cc175 --- /dev/null +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientPropertiesTest.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023 Snowflake Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.snowflake.kafka.connector.internal.streaming; + +import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWPIPE_STREAMING_FILE_VERSION; +import static com.snowflake.kafka.connector.internal.streaming.StreamingClientProperties.DEFAULT_CLIENT_NAME; +import static com.snowflake.kafka.connector.internal.streaming.StreamingClientProperties.LOGGABLE_STREAMING_CONFIG_PROPERTIES; +import static com.snowflake.kafka.connector.internal.streaming.StreamingClientProperties.STREAMING_CLIENT_PREFIX_NAME; +import static net.snowflake.ingest.utils.ParameterProvider.BLOB_FORMAT_VERSION; + +import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.Utils; +import com.snowflake.kafka.connector.internal.TestUtils; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import org.junit.Test; + +public class StreamingClientPropertiesTest { + + @Test + public void testGetValidProperties() { + String overrideValue = "overrideValue"; + + // setup config with all loggable properties and parameter override + Map connectorConfig = TestUtils.getConfForStreaming(); + connectorConfig.put(Utils.NAME, "testName"); + connectorConfig.put(Utils.SF_URL, "testUrl"); + connectorConfig.put(Utils.SF_ROLE, "testRole"); + connectorConfig.put(Utils.SF_USER, "testUser"); + connectorConfig.put(Utils.SF_AUTHENTICATOR, Utils.SNOWFLAKE_JWT); + connectorConfig.put(SNOWPIPE_STREAMING_FILE_VERSION, overrideValue); + + Properties expectedProps = StreamingUtils.convertConfigForStreamingClient(connectorConfig); + String expectedClientName = STREAMING_CLIENT_PREFIX_NAME + connectorConfig.get(Utils.NAME); + Map expectedParameterOverrides = new HashMap<>(); + expectedParameterOverrides.put(BLOB_FORMAT_VERSION, overrideValue); + + // test get properties + StreamingClientProperties resultProperties = new StreamingClientProperties(connectorConfig); + + // verify + assert resultProperties.clientProperties.equals(expectedProps); + assert resultProperties.clientName.equals(expectedClientName); + assert resultProperties.parameterOverrides.equals(expectedParameterOverrides); + + // verify only loggable props are returned + String loggableProps = resultProperties.getLoggableClientProperties(); + for (Object key : expectedProps.keySet()) { + Object value = expectedProps.get(key); + if (LOGGABLE_STREAMING_CONFIG_PROPERTIES.stream() + .anyMatch(key.toString()::equalsIgnoreCase)) { + assert loggableProps.contains( + Utils.formatString("{}={}", key.toString(), value.toString())); + } else { + assert !loggableProps.contains(key.toString()) && !loggableProps.contains(value.toString()); + } + } + } + + @Test + public void testGetInvalidProperties() { + StreamingClientProperties nullProperties = new StreamingClientProperties(null); + StreamingClientProperties emptyProperties = new StreamingClientProperties(new HashMap<>()); + + assert nullProperties.equals(emptyProperties); + assert nullProperties.clientName.equals(STREAMING_CLIENT_PREFIX_NAME + DEFAULT_CLIENT_NAME); + assert nullProperties.getLoggableClientProperties().equals(""); + } + + @Test + public void testStreamingClientPropertiesEquality() { + Map config1 = TestUtils.getConfForStreaming(); + config1.put(Utils.NAME, "catConnector"); + config1.put(SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS, "100"); + + Map config2 = TestUtils.getConfForStreaming(); + config1.put(Utils.NAME, "dogConnector"); + config1.put(SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS, "1000000"); + + // get properties + StreamingClientProperties prop1 = new StreamingClientProperties(config1); + StreamingClientProperties prop2 = new StreamingClientProperties(config2); + + assert prop1.equals(prop2); + assert prop1.hashCode() == prop2.hashCode(); + } +} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java new file mode 100644 index 000000000..037a6c5f7 --- /dev/null +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2023 Snowflake Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.snowflake.kafka.connector.internal.streaming; + +import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; +import com.snowflake.kafka.connector.Utils; +import com.snowflake.kafka.connector.internal.TestUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.Mockito; + +@RunWith(Parameterized.class) +public class StreamingClientProviderIT { + private final boolean enableClientOptimization; + private final Map clientConfig = TestUtils.getConfForStreaming(); + + @Parameterized.Parameters(name = "enableClientOptimization: {0}") + public static Collection input() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + public StreamingClientProviderIT(boolean enableClientOptimization) { + this.enableClientOptimization = enableClientOptimization; + this.clientConfig.put( + SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, + String.valueOf(this.enableClientOptimization)); + } + + @Test + public void testGetMultipleClients() throws Exception { + String validRegisteredClientName = "openRegisteredClient"; + String invalidRegisteredClientName = "closedRegisteredClient"; + String validUnregisteredClientName = "openUnregisteredClient"; + StreamingClientHandler clientCreator = new StreamingClientHandler(); + + // setup registered valid client + Map validRegisteredClientConfig = new HashMap<>(this.clientConfig); + validRegisteredClientConfig.put(Utils.NAME, validRegisteredClientName); + validRegisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "0"); + StreamingClientProperties validRegisteredClientProps = + new StreamingClientProperties(validRegisteredClientConfig); + SnowflakeStreamingIngestClient validRegisteredClient = + clientCreator.createClient(validRegisteredClientProps); + + // setup registered invalid client + Map invalidRegisteredClientConfig = new HashMap<>(this.clientConfig); + invalidRegisteredClientConfig.put(Utils.NAME, invalidRegisteredClientName); + invalidRegisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "1"); + StreamingClientProperties invalidRegisteredClientProps = + new StreamingClientProperties(invalidRegisteredClientConfig); + SnowflakeStreamingIngestClient invalidRegisteredClient = + clientCreator.createClient(invalidRegisteredClientProps); + invalidRegisteredClient.close(); + + // setup unregistered valid client + Map validUnregisteredClientConfig = new HashMap<>(this.clientConfig); + validUnregisteredClientConfig.put(Utils.NAME, validUnregisteredClientName); + validUnregisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "2"); + StreamingClientProperties validUnregisteredClientProps = + new StreamingClientProperties(validUnregisteredClientConfig); + SnowflakeStreamingIngestClient validUnregisteredClient = + clientCreator.createClient(validUnregisteredClientProps); + + // inject registered clients + StreamingClientHandler streamingClientHandlerSpy = + Mockito.spy(StreamingClientHandler.class); // use this to verify behavior + LoadingCache registeredClients = + StreamingClientProvider.buildLoadingCache(streamingClientHandlerSpy); + + registeredClients.put(validRegisteredClientProps, validRegisteredClient); + registeredClients.put(invalidRegisteredClientProps, invalidRegisteredClient); + + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + streamingClientHandlerSpy, registeredClients); + + assert streamingClientProvider.getRegisteredClients().size() == 2; + + // test 1: get registered valid client optimization returns existing client + SnowflakeStreamingIngestClient resultValidRegisteredClient = + streamingClientProvider.getClient(validRegisteredClientConfig); + + assert StreamingClientHandler.isClientValid(resultValidRegisteredClient); + assert resultValidRegisteredClient.getName().contains("_0"); + assert this.enableClientOptimization + == resultValidRegisteredClient.equals(validRegisteredClient); + Mockito.verify(streamingClientHandlerSpy, Mockito.times(this.enableClientOptimization ? 0 : 1)) + .createClient(validRegisteredClientProps); + assert streamingClientProvider.getRegisteredClients().size() == 2; + + // test 2: get registered invalid client creates new client regardless of optimization + SnowflakeStreamingIngestClient resultInvalidRegisteredClient = + streamingClientProvider.getClient(invalidRegisteredClientConfig); + + assert StreamingClientHandler.isClientValid(resultInvalidRegisteredClient); + assert resultInvalidRegisteredClient + .getName() + .contains("_" + (this.enableClientOptimization ? 0 : 1)); + assert !resultInvalidRegisteredClient.equals(invalidRegisteredClient); + Mockito.verify(streamingClientHandlerSpy, Mockito.times(1)) + .createClient(invalidRegisteredClientProps); + assert streamingClientProvider.getRegisteredClients().size() == 2; + + // test 3: get unregistered valid client creates and registers new client with optimization + SnowflakeStreamingIngestClient resultValidUnregisteredClient = + streamingClientProvider.getClient(validUnregisteredClientConfig); + + assert StreamingClientHandler.isClientValid(resultValidUnregisteredClient); + assert resultValidUnregisteredClient + .getName() + .contains("_" + (this.enableClientOptimization ? 1 : 2)); + assert !resultValidUnregisteredClient.equals(validUnregisteredClient); + Mockito.verify(streamingClientHandlerSpy, Mockito.times(1)) + .createClient(validUnregisteredClientProps); + assert streamingClientProvider.getRegisteredClients().size() + == (this.enableClientOptimization ? 3 : 2); + + // verify streamingClientHandler behavior + Mockito.verify(streamingClientHandlerSpy, Mockito.times(this.enableClientOptimization ? 2 : 3)) + .createClient(Mockito.any()); + + // test 4: get all clients multiple times and verify optimization doesn't create new clients + List gotClientList = new ArrayList<>(); + + for (int i = 0; i < 5; i++) { + gotClientList.add(streamingClientProvider.getClient(validRegisteredClientConfig)); + gotClientList.add(streamingClientProvider.getClient(invalidRegisteredClientConfig)); + gotClientList.add(streamingClientProvider.getClient(validUnregisteredClientConfig)); + } + + List distinctClients = + gotClientList.stream().distinct().collect(Collectors.toList()); + assert distinctClients.size() == (this.enableClientOptimization ? 3 : gotClientList.size()); + Mockito.verify( + streamingClientHandlerSpy, + Mockito.times(this.enableClientOptimization ? 2 : 3 + gotClientList.size())) + .createClient(Mockito.any()); + assert streamingClientProvider.getRegisteredClients().size() + == (this.enableClientOptimization ? 3 : 2); + + // close all clients + validRegisteredClient.close(); + invalidRegisteredClient.close(); + validUnregisteredClient.close(); + + resultValidRegisteredClient.close(); + resultInvalidRegisteredClient.close(); + resultValidUnregisteredClient.close(); + + distinctClients.stream() + .forEach( + client -> { + try { + client.close(); + } catch (Exception e) { + // do nothing + } + }); + } +} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java index 5a9d12f96..02d25f001 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java @@ -17,18 +17,14 @@ package com.snowflake.kafka.connector.internal.streaming; -import static com.snowflake.kafka.connector.internal.streaming.StreamingClientProvider.getStreamingClientProviderForTests; - import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.internal.TestUtils; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.Map; +import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -36,20 +32,8 @@ @RunWith(Parameterized.class) public class StreamingClientProviderTest { - // NOTE: use the following clients where possible so we don't leak clients - these will be closed - // after each test - private SnowflakeStreamingIngestClient client1; - private SnowflakeStreamingIngestClient client2; - private SnowflakeStreamingIngestClient client3; - private SnowflakeStreamingIngestClient validClient; - private SnowflakeStreamingIngestClient invalidClient; - - private Map clientConfig1; - private Map clientConfig2; - - private StreamingClientProvider streamingClientProvider; - private StreamingClientHandler streamingClientHandler; - private boolean enableClientOptimization; + private final boolean enableClientOptimization; + private final Map clientConfig = TestUtils.getConfForStreaming(); @Parameterized.Parameters(name = "enableClientOptimization: {0}") public static Collection input() { @@ -58,139 +42,235 @@ public static Collection input() { public StreamingClientProviderTest(boolean enableClientOptimization) { this.enableClientOptimization = enableClientOptimization; - } - - @Before - public void setup() { - // setup fresh configs - this.clientConfig1 = TestUtils.getConfForStreaming(); - this.clientConfig1.put( + this.clientConfig.put( SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, - this.enableClientOptimization + ""); - this.clientConfig2 = new HashMap<>(this.clientConfig1); - - this.clientConfig1.put(Utils.NAME, "client1"); - this.clientConfig2.put(Utils.NAME, "client2"); - - this.streamingClientHandler = Mockito.spy(StreamingClientHandler.class); - this.streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - null, this.streamingClientHandler); - } - - @After - public void tearDown() { - this.streamingClientHandler.closeClient(this.client1); - this.streamingClientHandler.closeClient(this.client2); - this.streamingClientHandler.closeClient(this.client3); - this.streamingClientHandler.closeClient(this.validClient); - this.streamingClientHandler.closeClient(this.invalidClient); + String.valueOf(this.enableClientOptimization)); } @Test public void testFirstGetClient() { - // test actual provider - this.client1 = this.streamingClientProvider.getClient(this.clientConfig1); + // setup mock client and handler + SnowflakeStreamingIngestClient clientMock = Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(clientMock.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) + .thenReturn(clientMock); + + // test provider gets new client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, StreamingClientProvider.buildLoadingCache(mockClientHandler)); + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); // verify - should create a client regardless of optimization - assert StreamingClientHandler.isClientValid(this.client1); - assert this.client1.getName().contains(this.clientConfig1.get(Utils.NAME)); - Mockito.verify(this.streamingClientHandler, Mockito.times(1)).createClient(this.clientConfig1); + assert client.equals(clientMock); + assert client.getName().contains(this.clientConfig.get(Utils.NAME)); + Mockito.verify(mockClientHandler, Mockito.times(1)) + .createClient(new StreamingClientProperties(this.clientConfig)); } @Test public void testGetInvalidClient() { - Map invalidClientConfig = new HashMap<>(this.clientConfig1); - invalidClientConfig.put(Utils.NAME, "invalid client"); - - Map validClientConfig = new HashMap<>(this.clientConfig1); - validClientConfig.put(Utils.NAME, "valid client"); - - // setup invalid client - this.invalidClient = Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(this.invalidClient.isClosed()).thenReturn(true); - StreamingClientProvider injectedProvider = - getStreamingClientProviderForTests(this.invalidClient, this.streamingClientHandler); - - // test: getting invalid client with valid config - this.validClient = injectedProvider.getClient(validClientConfig); - - // verify: created valid client - assert StreamingClientHandler.isClientValid(this.validClient); - assert this.validClient.getName().contains(validClientConfig.get(Utils.NAME)); - assert !this.validClient.getName().contains(invalidClientConfig.get(Utils.NAME)); - Mockito.verify(this.streamingClientHandler, Mockito.times(1)).createClient(validClientConfig); - - // verify: invalid client was closed, depending on optimization - Mockito.verify(this.invalidClient, Mockito.times(this.enableClientOptimization ? 1 : 0)) - .isClosed(); + // setup handler, invalid mock client and valid returned client + SnowflakeStreamingIngestClient mockInvalidClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockInvalidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockInvalidClient.isClosed()).thenReturn(true); + + SnowflakeStreamingIngestClient mockValidClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockValidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockValidClient.isClosed()).thenReturn(false); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockValidClient); + + // inject invalid client into provider + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockInvalidClient); + + // test provider gets new client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); + + // verify - returned client is valid even though we injected an invalid client + assert client.equals(mockValidClient); + assert !client.equals(mockInvalidClient); + assert client.getName().contains(this.clientConfig.get(Utils.NAME)); + assert !client.isClosed(); + Mockito.verify(mockClientHandler, Mockito.times(1)) + .createClient(new StreamingClientProperties(this.clientConfig)); } @Test public void testGetExistingClient() { - // test - this.client1 = this.streamingClientProvider.getClient(this.clientConfig1); - this.client2 = this.streamingClientProvider.getClient(this.clientConfig2); - this.client3 = this.streamingClientProvider.getClient(this.clientConfig1); - - // verify: clients are valid - assert StreamingClientHandler.isClientValid(client1); - assert StreamingClientHandler.isClientValid(client2); - assert StreamingClientHandler.isClientValid(client3); - - // verify: clients should be the same if optimization is enabled - if (this.enableClientOptimization) { - assert client1.getName().equals(client2.getName()); - assert client1.getName().equals(client3.getName()); - assert client1.getName().contains(this.clientConfig1.get(Utils.NAME)); - - Mockito.verify(this.streamingClientHandler, Mockito.times(1)) - .createClient(this.clientConfig1); - } else { - // client 1 and 3 are created from the same config, but will have different names - assert !client1.getName().equals(client2.getName()); - assert !client2.getName().equals(client3.getName()); - assert !client1.getName().equals(client3.getName()); - - assert client1.getName().contains(this.clientConfig1.get(Utils.NAME)); - assert client2.getName().contains(this.clientConfig2.get(Utils.NAME)); - assert client3.getName().contains(this.clientConfig1.get(Utils.NAME)); - - Mockito.verify(this.streamingClientHandler, Mockito.times(2)) - .createClient(this.clientConfig1); - Mockito.verify(this.streamingClientHandler, Mockito.times(1)) - .createClient(this.clientConfig2); + // setup existing client, handler and inject to registeredClients + SnowflakeStreamingIngestClient mockExistingClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockExistingClient.isClosed()).thenReturn(false); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockExistingClient); + + // if optimization is disabled, we will create new client regardless of registeredClientws + if (!this.enableClientOptimization) { + Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockExistingClient); } + + // test getting client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); + + // verify client and expected client creation + assert client.equals(mockExistingClient); + assert client.getName().equals(this.clientConfig.get(Utils.NAME)); + Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 0 : 1)) + .createClient(new StreamingClientProperties(this.clientConfig)); } @Test - public void testCloseClients() throws Exception { - this.client1 = Mockito.mock(SnowflakeStreamingIngestClient.class); + public void testGetClientMissingConfig() { + // remove one client opt from config + this.clientConfig.remove( + SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG); + + // setup existing client, handler and inject to registeredClients + SnowflakeStreamingIngestClient mockExistingClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockExistingClient.isClosed()).thenReturn(false); - // test closing all clients - StreamingClientProvider injectedProvider = - getStreamingClientProviderForTests(this.client1, this.streamingClientHandler); + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - injectedProvider.closeClient(this.client1); + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockExistingClient); - // verify: if optimized, there should only be one closeClient() call - Mockito.verify(this.streamingClientHandler, Mockito.times(1)).closeClient(this.client1); + // test getting client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); + + // verify returned existing client since removing the optimization should default to true + assert client.equals(mockExistingClient); + assert client.getName().equals(this.clientConfig.get(Utils.NAME)); + Mockito.verify(mockClientHandler, Mockito.times(0)) + .createClient(new StreamingClientProperties(this.clientConfig)); } @Test - public void testGetClientMissingConfig() { - this.clientConfig1.remove( - SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG); + public void testCloseClients() throws Exception { + // setup valid existing client and handler + SnowflakeStreamingIngestClient mockExistingClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockExistingClient.isClosed()).thenReturn(false); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + Mockito.doCallRealMethod() + .when(mockClientHandler) + .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); + + // inject existing client in for optimization + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + if (this.enableClientOptimization) { + Mockito.when( + mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockExistingClient); + } + + // test closing valid client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + streamingClientProvider.closeClient(this.clientConfig, mockExistingClient); + + // verify existing client was closed, optimization will call given client and registered client + Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 2 : 1)) + .closeClient(mockExistingClient); + Mockito.verify(mockExistingClient, Mockito.times(this.enableClientOptimization ? 2 : 1)) + .close(); + } + + @Test + public void testCloseInvalidClient() throws Exception { + // setup invalid existing client and handler + SnowflakeStreamingIngestClient mockInvalidClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockInvalidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockInvalidClient.isClosed()).thenReturn(true); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + Mockito.doCallRealMethod() + .when(mockClientHandler) + .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); + + // inject invalid existing client in for optimization + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + if (this.enableClientOptimization) { + Mockito.when( + mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) + .thenReturn(mockInvalidClient); + } - // test actual provider - this.client1 = this.streamingClientProvider.getClient(this.clientConfig1); - this.client2 = this.streamingClientProvider.getClient(this.clientConfig1); + // test closing valid client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + streamingClientProvider.closeClient(this.clientConfig, mockInvalidClient); + + // verify handler close client no-op and client did not need to call close + Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 2 : 1)) + .closeClient(mockInvalidClient); + Mockito.verify(mockInvalidClient, Mockito.times(0)).close(); + } - // Since it is enabled by default, we should only create one client. - assert this.client1.getName().equals(this.client2.getName()); + @Test + public void testCloseUnregisteredClient() throws Exception { + // setup valid existing client and handler + SnowflakeStreamingIngestClient mockUnregisteredClient = + Mockito.mock(SnowflakeStreamingIngestClient.class); + Mockito.when(mockUnregisteredClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); + Mockito.when(mockUnregisteredClient.isClosed()).thenReturn(false); + + StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); + Mockito.doCallRealMethod() + .when(mockClientHandler) + .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); + + // ensure no clients are registered + LoadingCache mockRegisteredClients = + Mockito.mock(LoadingCache.class); + Mockito.when( + mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) + .thenReturn(null); + + // test closing valid client + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderForTests( + mockClientHandler, mockRegisteredClients); + streamingClientProvider.closeClient(this.clientConfig, mockUnregisteredClient); - assert StreamingClientHandler.isClientValid(this.client1); - assert this.client1.getName().contains(this.clientConfig1.get(Utils.NAME)); - Mockito.verify(this.streamingClientHandler, Mockito.times(1)).createClient(this.clientConfig1); + // verify unregistered client was closed + Mockito.verify(mockClientHandler, Mockito.times(1)).closeClient(mockUnregisteredClient); + Mockito.verify(mockUnregisteredClient, Mockito.times(1)).close(); } }