diff --git a/src/main/java/net/snowflake/client/PrivateKeySigner.java b/src/main/java/net/snowflake/client/PrivateKeySigner.java new file mode 100644 index 000000000..4628127fa --- /dev/null +++ b/src/main/java/net/snowflake/client/PrivateKeySigner.java @@ -0,0 +1,17 @@ +package net.snowflake.client; + +import java.security.PublicKey; + +/** Interface for customer signer implementations for key pair authentication. */ +public interface PrivateKeySigner { + /** + * Returns a signature for the given input. + * + *

The signature must be compatible with the "RS256" JWT signing algorithm, a.k.a. + * "RSASSA-PKCS1-v1_5 using SHA-256" + */ + byte[] sign(byte[] input); + + /** Returns the public key associated with the private key used by the sign() method. */ + PublicKey publicKey(); +} diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index f448d086a..dccd53f22 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -47,6 +47,7 @@ public class SFLoginInput { private HttpClientSettingsKey httpClientKey; private String privateKeyFile; private String privateKeyFilePwd; + private String privateKeySignerClass; SFLoginInput() {} @@ -294,6 +295,11 @@ SFLoginInput setPrivateKeyFilePwd(String privateKeyFilePwd) { return this; } + SFLoginInput setPrivateKeySignerClass(String privateKeySignerClass) { + this.privateKeySignerClass = privateKeySignerClass; + return this; + } + String getPrivateKeyFile() { return privateKeyFile; } @@ -302,6 +308,10 @@ String getPrivateKeyFilePwd() { return privateKeyFilePwd; } + String getPrivateKeySignerClass() { + return privateKeySignerClass; + } + public String getApplication() { return application; } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index c831abebf..ae84b660d 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -61,6 +61,7 @@ public class SFSession extends SFBaseSession { private String mfaToken; private String privateKeyFileLocation; private String privateKeyPassword; + private String privateKeySignerClass; private PrivateKey privateKey; /** @@ -334,6 +335,12 @@ public void addSFSessionProperty(String propertyName, Object propertyValue) thro } break; + case PRIVATE_KEY_SIGNER_CLASS: + if (propertyValue != null) { + privateKeySignerClass = (String) propertyValue; + } + break; + default: break; } @@ -450,6 +457,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE)) .setPrivateKeyFilePwd( (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD)) + .setPrivateKeySignerClass( + (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_SIGNER_CLASS)) .setApplication((String) connectionPropertiesMap.get(SFSessionProperty.APPLICATION)) .setServiceName(getServiceName()) .setOCSPMode(getOCSPMode()) @@ -532,7 +541,10 @@ private boolean isSnowflakeAuthenticator() { Map connectionPropertiesMap = getConnectionPropertiesMap(); String authenticator = (String) connectionPropertiesMap.get(SFSessionProperty.AUTHENTICATOR); PrivateKey privateKey = (PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY); - return (authenticator == null && privateKey == null && privateKeyFileLocation == null) + return (authenticator == null + && privateKey == null + && privateKeyFileLocation == null + && privateKeySignerClass == null) || ClientAuthnDTO.AuthenticatorType.SNOWFLAKE.name().equalsIgnoreCase(authenticator); } diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java index 755d07c36..ffd1bf3e9 100644 --- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java +++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java @@ -55,6 +55,7 @@ public enum SFSessionProperty { INJECT_WAIT_IN_PUT("inject_wait_in_put", false, Integer.class), PRIVATE_KEY_FILE("private_key_file", false, String.class), PRIVATE_KEY_FILE_PWD("private_key_file_pwd", false, String.class), + PRIVATE_KEY_SIGNER_CLASS("private_key_signer_class", false, String.class), CLIENT_INFO("snowflakeClientInfo", false, String.class), ALLOW_UNDERSCORES_IN_HOST("allowUnderscoresInHost", false, Boolean.class); diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 94dd19fc3..68aae2ecd 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -221,7 +221,9 @@ private static ClientAuthnDTO.AuthenticatorType getAuthenticator(SFLoginInput lo // authenticator is null, then jdbc will decide authenticator depends on // if privateKey is specified or not. If yes, authenticator type will be // SNOWFLAKE_JWT, otherwise it will use SNOWFLAKE. - return (loginInput.getPrivateKey() != null || loginInput.getPrivateKeyFile() != null) + return (loginInput.getPrivateKey() != null + || loginInput.getPrivateKeyFile() != null + || loginInput.getPrivateKeySignerClass() != null) ? ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT : ClientAuthnDTO.AuthenticatorType.SNOWFLAKE; } @@ -384,6 +386,7 @@ private static SFLoginOutput newSession( loginInput.getPrivateKey(), loginInput.getPrivateKeyFile(), loginInput.getPrivateKeyFilePwd(), + loginInput.getPrivateKeySignerClass(), loginInput.getAccountName(), loginInput.getUserName()); @@ -608,6 +611,7 @@ private static SFLoginOutput newSession( loginInput.getPrivateKey(), loginInput.getPrivateKeyFile(), loginInput.getPrivateKeyFilePwd(), + loginInput.getPrivateKeySignerClass(), loginInput.getAccountName(), loginInput.getUserName()); @@ -1545,12 +1549,18 @@ public static String generateJWTToken( PrivateKey privateKey, String privateKeyFile, String privateKeyFilePwd, + String privateKeySignerClass, String accountName, String userName) throws SFException { SessionUtilKeyPair s = new SessionUtilKeyPair( - privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName); + privateKey, + privateKeyFile, + privateKeyFilePwd, + privateKeySignerClass, + accountName, + userName); return s.issueJwtToken(); } } diff --git a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java index 8c27d2c44..845647c39 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java +++ b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java @@ -11,10 +11,13 @@ import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.crypto.impl.RSASSAProvider; +import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import java.io.IOException; import java.io.StringReader; +import java.lang.reflect.InvocationTargetException; import java.nio.file.Files; import java.nio.file.Paths; import java.security.*; @@ -26,6 +29,7 @@ import javax.crypto.EncryptedPrivateKeyInfo; import javax.crypto.SecretKeyFactory; import javax.crypto.spec.PBEKeySpec; +import net.snowflake.client.PrivateKeySigner; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; @@ -35,6 +39,19 @@ /** Class used to compute jwt token for key pair authentication Created by hyu on 1/16/18. */ class SessionUtilKeyPair { + private static class DelegatingJWSSigner extends RSASSAProvider implements JWSSigner { + private final PrivateKeySigner privateKeySigner; + + DelegatingJWSSigner(PrivateKeySigner privateKeySigner) { + this.privateKeySigner = privateKeySigner; + } + + @Override + public Base64URL sign(JWSHeader jwsHeader, byte[] bytes) { + return Base64URL.encode(privateKeySigner.sign(bytes)); + } + } + static final SFLogger logger = SFLoggerFactory.getLogger(SessionUtilKeyPair.class); // user name in upper case @@ -43,9 +60,9 @@ class SessionUtilKeyPair { // account name in upper case private final String accountName; - private final PrivateKey privateKey; + private final JWSSigner signer; - private PublicKey publicKey = null; + private final PublicKey publicKey; private boolean isFipsMode = false; @@ -63,6 +80,7 @@ class SessionUtilKeyPair { PrivateKey privateKey, String privateKeyFile, String privateKeyFilePwd, + String privateKeySignerClass, String accountName, String userName) throws SFException { @@ -78,33 +96,59 @@ class SessionUtilKeyPair { } } - // if there is both a file and a private key, there is a problem - if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) { - throw new SFException( - ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, - "Cannot have both private key value and private key file."); - } else { - // if privateKeyFile has a value and privateKey is null - this.privateKey = - Strings.isNullOrEmpty(privateKeyFile) - ? privateKey - : extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd); - } - // construct public key from raw bytes - if (this.privateKey instanceof RSAPrivateCrtKey) { - RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) this.privateKey; - RSAPublicKeySpec rsaPublicKeySpec = - new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent()); - + if (!Strings.isNullOrEmpty(privateKeySignerClass)) { try { - this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec); - } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { - throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key"); + PrivateKeySigner privateKeySigner = + (PrivateKeySigner) + Class.forName(privateKeySignerClass).getDeclaredConstructor().newInstance(); + this.signer = new DelegatingJWSSigner(privateKeySigner); + this.publicKey = privateKeySigner.publicKey(); + } catch (ClassNotFoundException e) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + String.format("Could not load class %s.", privateKeySignerClass)); + } catch (InvocationTargetException + | InstantiationException + | IllegalAccessException + | NoSuchMethodException e) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + String.format("Failed to instantiate class %s.", privateKeySignerClass)); + } catch (ClassCastException e) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + String.format( + "%s is not an instance of %s", + privateKeySignerClass, PrivateKeySigner.class.getName())); } } else { - throw new SFException( - ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, - "Use java.security.interfaces.RSAPrivateCrtKey.class for the private key"); + // if there is both a file and a private key, there is a problem + if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + "Cannot have both private key value and private key file."); + } + if (!Strings.isNullOrEmpty(privateKeyFile)) { + privateKey = extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd); + } + // construct public key from raw bytes + if (privateKey instanceof RSAPrivateCrtKey) { + this.signer = new RSASSASigner(privateKey); + RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey; + RSAPublicKeySpec rsaPublicKeySpec = + new RSAPublicKeySpec( + rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent()); + + try { + this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec); + } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { + throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key"); + } + } else { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + "Use java.security.interfaces.RSAPrivateCrtKey.class for the private key"); + } } } @@ -180,10 +224,9 @@ public String issueJwtToken() throws SFException { builder.issuer(iss).subject(sub).issueTime(iat).expirationTime(exp).build(); SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet); - JWSSigner signer = new RSASSASigner(this.privateKey); try { - signedJWT.sign(signer); + signedJWT.sign(this.signer); } catch (JOSEException e) { throw new SFException(e, ErrorCode.FAILED_TO_GENERATE_JWT); } diff --git a/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java b/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java index 608c475fa..cfd3f6f91 100644 --- a/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java +++ b/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java @@ -179,7 +179,12 @@ private TelemetryClient createSessionlessTelemetry() Map parameters = getConnectionParameters(); String jwtToken = SessionUtil.generateJWTToken( - null, privateKeyLocation, null, parameters.get("account"), parameters.get("user")); + null, + privateKeyLocation, + null, + null, + parameters.get("account"), + parameters.get("user")); CloseableHttpClient httpClient = HttpUtil.buildHttpClient(null, null, false); TelemetryClient telemetry =