Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to set a custom key pair auth signing implementation #910

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/main/java/net/snowflake/client/PrivateKeySigner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package net.snowflake.client;

import java.security.PublicKey;

/** Interface for customer signer implementations for key pair authentication. */
public interface PrivateKeySigner {
/**
* Returns a signature for the given input.
*
* <p>The signature must be compatible with the "RS256" JWT signing algorithm, a.k.a.
* "RSASSA-PKCS1-v1_5 using SHA-256"
*/
byte[] sign(byte[] input);

/** Returns the public key associated with the private key used by the sign() method. */
PublicKey publicKey();
}
10 changes: 10 additions & 0 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class SFLoginInput {
private HttpClientSettingsKey httpClientKey;
private String privateKeyFile;
private String privateKeyFilePwd;
private String privateKeySignerClass;

SFLoginInput() {}

Expand Down Expand Up @@ -294,6 +295,11 @@ SFLoginInput setPrivateKeyFilePwd(String privateKeyFilePwd) {
return this;
}

SFLoginInput setPrivateKeySignerClass(String privateKeySignerClass) {
this.privateKeySignerClass = privateKeySignerClass;
return this;
}

String getPrivateKeyFile() {
return privateKeyFile;
}
Expand All @@ -302,6 +308,10 @@ String getPrivateKeyFilePwd() {
return privateKeyFilePwd;
}

String getPrivateKeySignerClass() {
return privateKeySignerClass;
}

public String getApplication() {
return application;
}
Expand Down
14 changes: 13 additions & 1 deletion src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class SFSession extends SFBaseSession {
private String mfaToken;
private String privateKeyFileLocation;
private String privateKeyPassword;
private String privateKeySignerClass;
private PrivateKey privateKey;

/**
Expand Down Expand Up @@ -334,6 +335,12 @@ public void addSFSessionProperty(String propertyName, Object propertyValue) thro
}
break;

case PRIVATE_KEY_SIGNER_CLASS:
if (propertyValue != null) {
privateKeySignerClass = (String) propertyValue;
}
break;

default:
break;
}
Expand Down Expand Up @@ -450,6 +457,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE))
.setPrivateKeyFilePwd(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD))
.setPrivateKeySignerClass(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_SIGNER_CLASS))
.setApplication((String) connectionPropertiesMap.get(SFSessionProperty.APPLICATION))
.setServiceName(getServiceName())
.setOCSPMode(getOCSPMode())
Expand Down Expand Up @@ -532,7 +541,10 @@ private boolean isSnowflakeAuthenticator() {
Map<SFSessionProperty, Object> connectionPropertiesMap = getConnectionPropertiesMap();
String authenticator = (String) connectionPropertiesMap.get(SFSessionProperty.AUTHENTICATOR);
PrivateKey privateKey = (PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY);
return (authenticator == null && privateKey == null && privateKeyFileLocation == null)
return (authenticator == null
&& privateKey == null
&& privateKeyFileLocation == null
&& privateKeySignerClass == null)
|| ClientAuthnDTO.AuthenticatorType.SNOWFLAKE.name().equalsIgnoreCase(authenticator);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public enum SFSessionProperty {
INJECT_WAIT_IN_PUT("inject_wait_in_put", false, Integer.class),
PRIVATE_KEY_FILE("private_key_file", false, String.class),
PRIVATE_KEY_FILE_PWD("private_key_file_pwd", false, String.class),
PRIVATE_KEY_SIGNER_CLASS("private_key_signer_class", false, String.class),
CLIENT_INFO("snowflakeClientInfo", false, String.class),
ALLOW_UNDERSCORES_IN_HOST("allowUnderscoresInHost", false, Boolean.class);

Expand Down
14 changes: 12 additions & 2 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ private static ClientAuthnDTO.AuthenticatorType getAuthenticator(SFLoginInput lo
// authenticator is null, then jdbc will decide authenticator depends on
// if privateKey is specified or not. If yes, authenticator type will be
// SNOWFLAKE_JWT, otherwise it will use SNOWFLAKE.
return (loginInput.getPrivateKey() != null || loginInput.getPrivateKeyFile() != null)
return (loginInput.getPrivateKey() != null
|| loginInput.getPrivateKeyFile() != null
|| loginInput.getPrivateKeySignerClass() != null)
? ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT
: ClientAuthnDTO.AuthenticatorType.SNOWFLAKE;
}
Expand Down Expand Up @@ -384,6 +386,7 @@ private static SFLoginOutput newSession(
loginInput.getPrivateKey(),
loginInput.getPrivateKeyFile(),
loginInput.getPrivateKeyFilePwd(),
loginInput.getPrivateKeySignerClass(),
loginInput.getAccountName(),
loginInput.getUserName());

Expand Down Expand Up @@ -608,6 +611,7 @@ private static SFLoginOutput newSession(
loginInput.getPrivateKey(),
loginInput.getPrivateKeyFile(),
loginInput.getPrivateKeyFilePwd(),
loginInput.getPrivateKeySignerClass(),
loginInput.getAccountName(),
loginInput.getUserName());

Expand Down Expand Up @@ -1545,12 +1549,18 @@ public static String generateJWTToken(
PrivateKey privateKey,
String privateKeyFile,
String privateKeyFilePwd,
String privateKeySignerClass,
String accountName,
String userName)
throws SFException {
SessionUtilKeyPair s =
new SessionUtilKeyPair(
privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName);
privateKey,
privateKeyFile,
privateKeyFilePwd,
privateKeySignerClass,
accountName,
userName);
return s.issueJwtToken();
}
}
99 changes: 71 additions & 28 deletions src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.crypto.impl.RSASSAProvider;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.io.StringReader;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.*;
Expand All @@ -26,6 +29,7 @@
import javax.crypto.EncryptedPrivateKeyInfo;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import net.snowflake.client.PrivateKeySigner;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
Expand All @@ -35,6 +39,19 @@
/** Class used to compute jwt token for key pair authentication Created by hyu on 1/16/18. */
class SessionUtilKeyPair {

private static class DelegatingJWSSigner extends RSASSAProvider implements JWSSigner {
private final PrivateKeySigner privateKeySigner;

DelegatingJWSSigner(PrivateKeySigner privateKeySigner) {
this.privateKeySigner = privateKeySigner;
}

@Override
public Base64URL sign(JWSHeader jwsHeader, byte[] bytes) {
return Base64URL.encode(privateKeySigner.sign(bytes));
}
}

static final SFLogger logger = SFLoggerFactory.getLogger(SessionUtilKeyPair.class);

// user name in upper case
Expand All @@ -43,9 +60,9 @@ class SessionUtilKeyPair {
// account name in upper case
private final String accountName;

private final PrivateKey privateKey;
private final JWSSigner signer;

private PublicKey publicKey = null;
private final PublicKey publicKey;

private boolean isFipsMode = false;

Expand All @@ -63,6 +80,7 @@ class SessionUtilKeyPair {
PrivateKey privateKey,
String privateKeyFile,
String privateKeyFilePwd,
String privateKeySignerClass,
String accountName,
String userName)
throws SFException {
Expand All @@ -78,33 +96,59 @@ class SessionUtilKeyPair {
}
}

// if there is both a file and a private key, there is a problem
if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
"Cannot have both private key value and private key file.");
} else {
// if privateKeyFile has a value and privateKey is null
this.privateKey =
Strings.isNullOrEmpty(privateKeyFile)
? privateKey
: extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd);
}
// construct public key from raw bytes
if (this.privateKey instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) this.privateKey;
RSAPublicKeySpec rsaPublicKeySpec =
new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());

if (!Strings.isNullOrEmpty(privateKeySignerClass)) {
try {
this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec);
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key");
PrivateKeySigner privateKeySigner =
(PrivateKeySigner)
Class.forName(privateKeySignerClass).getDeclaredConstructor().newInstance();
this.signer = new DelegatingJWSSigner(privateKeySigner);
this.publicKey = privateKeySigner.publicKey();
} catch (ClassNotFoundException e) {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
String.format("Could not load class %s.", privateKeySignerClass));
} catch (InvocationTargetException
| InstantiationException
| IllegalAccessException
| NoSuchMethodException e) {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
String.format("Failed to instantiate class %s.", privateKeySignerClass));
} catch (ClassCastException e) {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
String.format(
"%s is not an instance of %s",
privateKeySignerClass, PrivateKeySigner.class.getName()));
}
} else {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
"Use java.security.interfaces.RSAPrivateCrtKey.class for the private key");
// if there is both a file and a private key, there is a problem
if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
"Cannot have both private key value and private key file.");
}
if (!Strings.isNullOrEmpty(privateKeyFile)) {
privateKey = extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd);
}
// construct public key from raw bytes
if (privateKey instanceof RSAPrivateCrtKey) {
this.signer = new RSASSASigner(privateKey);
RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey;
RSAPublicKeySpec rsaPublicKeySpec =
new RSAPublicKeySpec(
rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());

try {
this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec);
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key");
}
} else {
throw new SFException(
ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
"Use java.security.interfaces.RSAPrivateCrtKey.class for the private key");
}
}
}

Expand Down Expand Up @@ -180,10 +224,9 @@ public String issueJwtToken() throws SFException {
builder.issuer(iss).subject(sub).issueTime(iat).expirationTime(exp).build();

SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet);
JWSSigner signer = new RSASSASigner(this.privateKey);

try {
signedJWT.sign(signer);
signedJWT.sign(this.signer);
} catch (JOSEException e) {
throw new SFException(e, ErrorCode.FAILED_TO_GENERATE_JWT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ private TelemetryClient createSessionlessTelemetry()
Map<String, String> parameters = getConnectionParameters();
String jwtToken =
SessionUtil.generateJWTToken(
null, privateKeyLocation, null, parameters.get("account"), parameters.get("user"));
null,
privateKeyLocation,
null,
null,
parameters.get("account"),
parameters.get("user"));

CloseableHttpClient httpClient = HttpUtil.buildHttpClient(null, null, false);
TelemetryClient telemetry =
Expand Down