diff --git a/common/src/main/java/org/conscrypt/HpkeImpl.java b/common/src/main/java/org/conscrypt/HpkeImpl.java index 406324836..a09f3df43 100644 --- a/common/src/main/java/org/conscrypt/HpkeImpl.java +++ b/common/src/main/java/org/conscrypt/HpkeImpl.java @@ -49,10 +49,15 @@ public void engineInitSender(PublicKey recipientKey, byte[] info, PrivateKey sen byte[] psk, byte[] psk_id) throws InvalidKeyException { checkNotInitialised(); checkArgumentsForBaseModeOnly(senderKey, psk, psk_id); - final byte[] pk = hpkeSuite.getKem().validatePublicKeyTypeAndGetRawKey(recipientKey); + if (recipientKey == null) { + throw new InvalidKeyException("null recipient key"); + } else if (!(recipientKey instanceof OpenSSLX25519PublicKey)) { + throw new InvalidKeyException("Unsupported recipient key class: " + recipientKey.getClass()); + } + final byte[] recipientKeyBytes = ((OpenSSLX25519PublicKey) recipientKey).getU(); final Object[] result = NativeCrypto.EVP_HPKE_CTX_setup_base_mode_sender( - hpkeSuite, pk, info); + hpkeSuite, recipientKeyBytes, info); ctx = (NativeRef.EVP_HPKE_CTX) result[0]; encapsulated = (byte[]) result[1]; } @@ -63,10 +68,15 @@ public void engineInitSenderForTesting(PublicKey recipientKey, byte[] info, checkNotInitialised(); Objects.requireNonNull(sKe); checkArgumentsForBaseModeOnly(senderKey, psk, psk_id); - final byte[] pk = hpkeSuite.getKem().validatePublicKeyTypeAndGetRawKey(recipientKey); + if (recipientKey == null) { + throw new InvalidKeyException("null recipient key"); + } else if (!(recipientKey instanceof OpenSSLX25519PublicKey)) { + throw new InvalidKeyException("Unsupported recipient key class: " + recipientKey.getClass()); + } + final byte[] recipientKeyBytes = ((OpenSSLX25519PublicKey) recipientKey).getU(); final Object[] result = NativeCrypto.EVP_HPKE_CTX_setup_base_mode_sender_with_seed_for_testing( - hpkeSuite, pk, info, sKe); + hpkeSuite, recipientKeyBytes, info, sKe); ctx = (NativeRef.EVP_HPKE_CTX) result[0]; encapsulated = (byte[]) result[1]; } @@ -76,11 +86,20 @@ public void engineInitRecipient(byte[] encapsulated, PrivateKey recipientKey, byte[] info, PublicKey senderKey, byte[] psk, byte[] psk_id) throws InvalidKeyException { checkNotInitialised(); checkArgumentsForBaseModeOnly(senderKey, psk, psk_id); - hpkeSuite.getKem().validateEncapsulatedLength(encapsulated); - final byte[] sk = hpkeSuite.getKem().validatePrivateKeyTypeAndGetRawKey(recipientKey); + Preconditions.checkNotNull(encapsulated, "null encapsulated data"); + if (encapsulated.length != hpkeSuite.getKem().getEncapsulatedLength()) { + throw new InvalidKeyException("Invalid encapsulated length: " + encapsulated.length); + } + + if (recipientKey == null) { + throw new InvalidKeyException("null recipient key"); + } else if (!(recipientKey instanceof OpenSSLX25519PrivateKey)) { + throw new InvalidKeyException("Unsupported recipient key class: " + recipientKey.getClass()); + } + final byte[] recipientKeyBytes = ((OpenSSLX25519PrivateKey) recipientKey).getU(); ctx = (NativeRef.EVP_HPKE_CTX) NativeCrypto.EVP_HPKE_CTX_setup_base_mode_recipient( - hpkeSuite, sk, encapsulated, info); + hpkeSuite, recipientKeyBytes, encapsulated, info); } private void checkArgumentsForBaseModeOnly(Key senderKey, byte[] psk, byte[] psk_id) { @@ -105,7 +124,11 @@ public byte[] engineSeal(byte[] plaintext, byte[] aad) { @Override public byte[] engineExport(int length, byte[] exporterContext) { checkInitialised(); - hpkeSuite.getKdf().validateExportLength(length); + long maxLength = hpkeSuite.getKdf().maxExportLength(); + if (length < 0 || length > maxLength) { + throw new IllegalArgumentException("Export length must be between 0 and " + + maxLength + ", but was " + length); + } return NativeCrypto.EVP_HPKE_CTX_export(ctx, exporterContext, length); } diff --git a/common/src/main/java/org/conscrypt/HpkeSuite.java b/common/src/main/java/org/conscrypt/HpkeSuite.java index a2d8c6660..e0a5f3c8c 100644 --- a/common/src/main/java/org/conscrypt/HpkeSuite.java +++ b/common/src/main/java/org/conscrypt/HpkeSuite.java @@ -16,10 +16,6 @@ package org.conscrypt; -import java.security.InvalidKeyException; -import java.security.PrivateKey; -import java.security.PublicKey; - /** * Holds the KEM, KDF, and AEAD that are used and supported by {@link HpkeContextRecipient} and * {@link HpkeContextSender} defined on RFC 9180. @@ -64,18 +60,11 @@ public final class HpkeSuite { private final AEAD mAead; public HpkeSuite(int kem, int kdf, int aead) { - mKem = convertKem(kem); - mKdf = convertKdf(kdf); - mAead = convertAead(aead); - } - - public HpkeSuite(KEM kem, KDF kdf, AEAD aead) { - mKem = kem; - mKdf = kdf; - mAead = aead; + mKem = KEM.forId(kem); + mKdf = KDF.forId(kdf); + mAead = AEAD.forId(aead); } - public String name() { return String.format("%s/%s/%s", mKem.name(), mKdf.name(), mAead.name()); @@ -86,7 +75,7 @@ public String name() { * * @return kem */ - KEM getKem() { + public KEM getKem() { return mKem; } @@ -95,7 +84,7 @@ KEM getKem() { * * @return kdf */ - KDF getKdf() { + public KDF getKdf() { return mKdf; } @@ -104,64 +93,65 @@ KDF getKdf() { * * @return aead */ - AEAD getAead() { + public AEAD getAead() { return mAead; } /** - * Converts the kem value into its {@link KEM} representation. + * Converts the KEM value into its {@link KEM} representation. * * @param kem value * @return {@link KEM} representation. */ - private KEM convertKem(int kem) { - if (KEM_DHKEM_X25519_HKDF_SHA256 == kem) { - return KEM.DHKEM_X25519_HKDF_SHA256; - } - throw new IllegalArgumentException("KEM " + kem + " not supported."); + @Deprecated // Use KEM.forId() + public KEM convertKem(int kem) { + return KEM.forId(kem); } /** - * Converts the kdf value into its {@link KDF} representation. + * Converts the KDF value into its {@link KDF} representation. * * @param kdf value * @return {@link KDF} representation. */ - private KDF convertKdf(int kdf) { - if (KDF_HKDF_SHA256 == kdf) { - return KDF.HKDF_SHA256; - } - throw new IllegalArgumentException("KDF " + kdf + " not supported."); + @Deprecated // Use KDF.forId() + public KDF convertKdf(int kdf) { + return KDF.forId(kdf); } /** - * Converts the aead value into its {@link AEAD} representation. + * Converts the AEAD value into its {@link AEAD} representation. * * @param aead value * @return {@link AEAD} representation. */ - private AEAD convertAead(int aead) { - switch (aead) { - case AEAD_AES_128_GCM: - return AEAD.AES_128_GCM; - case AEAD_AES_256_GCM: - return AEAD.AES_256_GCM; - case AEAD_CHACHA20POLY1305: - return AEAD.CHACHA20POLY1305; - default: - throw new IllegalArgumentException("AEAD " + aead + " not supported."); - } + @Deprecated // Use AEAD.forId() + public AEAD convertAead(int aead) { + return AEAD.forId(aead); } - enum KEM { - DHKEM_X25519_HKDF_SHA256(/* id= */ 0x0020, /* encLength= */ 32); + /** + * Key Encapsulation Mechanisms (KEMs) + * + * @see + * rfc9180 + */ + public enum KEM { + DHKEM_X25519_HKDF_SHA256( + /* id= */ 0x20, /* nSecret= */ 32, /* nEnc= */ 32, /* nPk= */ 32, /* nSk= */ 32); private final int id; - private final int encLength; + private final int nSecret; + private final int nEnc; + private final int nPk; + private final int nSk; - KEM(int id, int encLength) { + KEM(int id, int nSecret, int nEnc, int nPk, int nSk) { this.id = id; - this.encLength = encLength; + this.nSecret = nSecret; + this.nEnc = nEnc; + this.nPk = nPk; + this.nSk = nSk; } /** @@ -172,89 +162,72 @@ enum KEM { * href="https://www.rfc-editor.org/rfc/rfc9180.html#name-key-encapsulation-mechanism">KEM * ids */ - int getId() { + public int getId() { return id; } /** - * The length in bytes of an encapsulated key produced by this KEM. - * - * @return encapsulated key size in bytes + * Returns the length in bytes of an encapsulated key produced by this KEM. */ - int getEncLength() { - return encLength; + @Deprecated // Use getEncapsulatedLength + public int getnEnc() { + return getEncapsulatedLength(); + } + public int getEncapsulatedLength() { + return nEnc; } /** - * Validates the encapsulated size in bytes matches the {@link KEM} spec. - * - * @param encapsulated encapsulated key produced by the kem - * @see - * expected enc size + * Returns the length in bytes of a KEM shared secret produced by this KEM. */ - void validateEncapsulatedLength(byte[] encapsulated) throws InvalidKeyException { - Preconditions.checkNotNull(encapsulated, "encapsulated"); - final int expectedLength = this.getEncLength(); - if (encapsulated.length != expectedLength) { - throw new InvalidKeyException( - "Expected encapsulated length of " + expectedLength + ", but was " - + encapsulated.length); - } + public int getSecretLength() { + return nSecret; } /** - * Validates the public key type and returns the raw bytes. - * - * @param publicKey alias pk - * @return key in its raw format - * @see expected - * pk size + * Returns the length in bytes of an encoded public key for this KEM. */ - byte[] validatePublicKeyTypeAndGetRawKey(PublicKey publicKey) throws InvalidKeyException { - String error; - if (publicKey == null) { - error = "null public key"; - } else if (!(publicKey instanceof OpenSSLX25519PublicKey)) { - error = "Public key algorithm " + publicKey.getAlgorithm() + " is not supported"; - } else { - return ((OpenSSLX25519PublicKey) publicKey).getU(); - } - throw new InvalidKeyException(error); + public int getPublicKeyLength() { + return nPk; } /** - * Validates the private key type and returns the raw bytes. - * - * @param privateKey alias sk - * @return key in its raw format - * @see expected - * sk size + * Returns The length in bytes of an encoded private key for this KEM. + */ + public int getPrivateKeyLength() { + return nPk; + } + + /** + * Returns the KEM value for a given id. */ - byte[] validatePrivateKeyTypeAndGetRawKey(PrivateKey privateKey) - throws InvalidKeyException { - String error; - if (privateKey == null) { - error = "null private key"; - } else if (!(privateKey instanceof OpenSSLX25519PrivateKey)) { - error = "Private key algorithm " + privateKey.getAlgorithm() + " is not supported"; - } else { - return ((OpenSSLX25519PrivateKey) privateKey).getU(); + public static KEM forId(int id) { + for (KEM kem : values()) { + if (kem.getId() == id) { + return kem; + } } - throw new InvalidKeyException(error); + throw new IllegalArgumentException("Unknown KEM " + id); } } - enum KDF { - HKDF_SHA256(/* id= */ 0x0001, /* hLength= */ 32); + /** + * Key Derivation Functions (KDFs) + * + * @see + * rfc9180 + */ + public enum KDF { + HKDF_SHA256(/* id= */ 0x0001, /* hLength= */ 32, /* hName= */ "HmacSHA256"); private final int id; private final int hLength; + private final String hName; - KDF(int id, int hLength) { + KDF(int id, int hLength, String hName) { this.id = id; this.hLength = hLength; + this.hName = hName; } /** @@ -274,35 +247,72 @@ int getId() { * * @return extract output size in bytes */ - int getHLength() { + int getMacLength() { return hLength; } + @Deprecated // Use getMacLength + public int getHLength() { + return getMacLength(); + } /** - * Validates the secret export size in bytes. The size has a maximum value of 255*Nh bytes. + * Returns the maximum export length that can be supported with this KDF. * - * @param l expected exporter output length * @see secret * export */ - void validateExportLength(int l) { - long upperLimitLength = this.getHLength() * 255L; - if (l < 0 || l > upperLimitLength) { - throw new IllegalArgumentException("Export length (L) must be between 0 and " - + upperLimitLength + ", but was " + l); + long maxExportLength() { + return this.getMacLength() * 255L; + } + + /** + * Name as defined in {@link javax.crypto.Mac}. + * + * @return name of mac algorithm used by the kdf. + */ + @Deprecated // Use getMacName + public String getMacAlgorithmName() { + return getMacName(); + } + public String getMacName() { + return hName; + } + + /** + * Returns the KDF value for a given id. + */ + public static KDF forId(int id) { + for (KDF kdf : values()) { + if (kdf.getId() == id) { + return kdf; + } } + throw new IllegalArgumentException("Unknown KDF " + id); } } - enum AEAD { - AES_128_GCM(/* id= */ 0x0001), - AES_256_GCM(/* id= */ 0x0002), - CHACHA20POLY1305(/* id= */ 0x0003); + /** + * AEAD ciphers. + * + * @see AEAD + * ids + */ + public enum AEAD { + AES_128_GCM(/* id= */ AEAD_AES_128_GCM, /* nk= */ 16, /* nn= */ 12, /* nt= */ 16), + AES_256_GCM(/* id= */ AEAD_AES_256_GCM, /* nk= */ 32, /* nn= */ 12, /* nt= */ 16), + CHACHA20POLY1305(/* id= */ AEAD_CHACHA20POLY1305, /* nk= */ 32, /* nn= */ 12, /* nt= */ 16); private final int id; + private final int nk; + private final int nn; + private final int nt; - AEAD(int id) { + AEAD(int id, int nk, int nn, int nt) { this.id = id; + this.nk = nk; + this.nn = nn; + this.nt = nt; } /** @@ -313,8 +323,64 @@ enum AEAD { * href="https://www.rfc-editor.org/rfc/rfc9180.html#name-authenticated-encryption-wi">AEAD * ids */ - int getId() { + public int getId() { return id; } + /** + * Returns the length in bytes of a key for this algorithm. + * + * @return AEAD Nk + * @see + * AEAD ids + */ + @Deprecated // Use getKeyLength() + public int getNk() { + return getKeyLength(); + } + public int getKeyLength() { + return nk; + } + + /** + * Returns the length in bytes of a nonce for this algorithm. + * + * @return AEAD Nn + * @see + * AEAD ids + */ + @Deprecated // Use getNonceLength() + public int getNn() { + return getNonceLength(); + } + public int getNonceLength() { + return nn; + } + + /** + * Returns the length in bytes of the AEAD authentication tag for this algorithm. + * + * @return AEAD Nt + * @see + * AEAD ids + */ + @Deprecated // Use getTagLength() + public int getNt() { + return nt; + } + public int getTagLength() { + return nt; + } + + /** + * Returns the AEAD value for a given id. + */ + public static AEAD forId(int id) { + for (AEAD aead : values()) { + if (aead.getId() == id) { + return aead; + } + } + throw new IllegalArgumentException("Unknown AEAD " + id); + } } } diff --git a/common/src/test/java/org/conscrypt/HpkeContextRecipientTest.java b/common/src/test/java/org/conscrypt/HpkeContextRecipientTest.java index b429bb416..1d7571721 100644 --- a/common/src/test/java/org/conscrypt/HpkeContextRecipientTest.java +++ b/common/src/test/java/org/conscrypt/HpkeContextRecipientTest.java @@ -187,8 +187,7 @@ public void testExport_lowerEdgeLength() throws Exception { final HpkeContextRecipient ctxRecipient = createDefaultHpkeContextRecipient(DEFAULT_ENC); final byte[] export = ctxRecipient.export(/* length= */ 0, DEFAULT_EXPORTER_CONTEXT); assertNotNull(export); - final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + assertThrows(IllegalArgumentException.class, () -> ctxRecipient.export(/* length= */ -1, DEFAULT_EXPORTER_CONTEXT)); - assertEquals("Export length (L) must be between 0 and 8160, but was -1", e.getMessage()); } } diff --git a/common/src/test/java/org/conscrypt/HpkeContextSenderTest.java b/common/src/test/java/org/conscrypt/HpkeContextSenderTest.java index 55ebf6f9c..26853047d 100644 --- a/common/src/test/java/org/conscrypt/HpkeContextSenderTest.java +++ b/common/src/test/java/org/conscrypt/HpkeContextSenderTest.java @@ -148,9 +148,8 @@ public void testExport_lowerEdgeLength() throws Exception { final byte[] export = ctxSender.export(/* length= */ 0, DEFAULT_EXPORTER_CONTEXT); assertNotNull(enc); assertNotNull(export); - final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + assertThrows(IllegalArgumentException.class, () -> ctxSender.export(/* length= */ -1, DEFAULT_EXPORTER_CONTEXT)); - assertEquals("Export length (L) must be between 0 and 8160, but was -1", e.getMessage()); } @Test diff --git a/common/src/test/java/org/conscrypt/HpkeSuiteTest.java b/common/src/test/java/org/conscrypt/HpkeSuiteTest.java index faf2ad1cc..c538c2f3a 100644 --- a/common/src/test/java/org/conscrypt/HpkeSuiteTest.java +++ b/common/src/test/java/org/conscrypt/HpkeSuiteTest.java @@ -23,6 +23,7 @@ import static org.conscrypt.HpkeSuite.KEM_DHKEM_X25519_HKDF_SHA256; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import org.junit.Test; import org.junit.runner.RunWith; @@ -41,20 +42,20 @@ public void testConstructor_validAlgorithms_noExceptionsThrown() { public void testConstructor_invalidKem_throwsArgumentException() { final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> new HpkeSuite(700, KDF_HKDF_SHA256, AEAD_AES_128_GCM)); - assertEquals("KEM 700 not supported.", e.getMessage()); + assertTrue(e.getMessage().contains("Unknown")); } @Test public void testConstructor_invalidKdf_throwsArgumentException() { final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> new HpkeSuite(KEM_DHKEM_X25519_HKDF_SHA256, 800, AEAD_AES_128_GCM)); - assertEquals("KDF 800 not supported.", e.getMessage()); + assertTrue(e.getMessage().contains("Unknown")); } @Test public void testConstructor_invalidAead_throwsArgumentException() { final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> new HpkeSuite(KEM_DHKEM_X25519_HKDF_SHA256, KDF_HKDF_SHA256, 900)); - assertEquals("AEAD 900 not supported.", e.getMessage()); + assertTrue(e.getMessage().contains("Unknown")); } } diff --git a/openjdk/src/test/java/org/conscrypt/ConscryptOpenJdkSuite.java b/openjdk/src/test/java/org/conscrypt/ConscryptOpenJdkSuite.java index b55ca1be1..a2e77b310 100644 --- a/openjdk/src/test/java/org/conscrypt/ConscryptOpenJdkSuite.java +++ b/openjdk/src/test/java/org/conscrypt/ConscryptOpenJdkSuite.java @@ -91,6 +91,11 @@ DuckTypedPSKKeyManagerTest.class, FileClientSessionCacheTest.class, HostnameVerifierTest.class, + HpkeContextTest.class, + HpkeContextRecipientTest.class, + HpkeContextSenderTest.class, + HpkeSuiteTest.class, + HpkeTestVectorsTest.class, NativeCryptoArgTest.class, NativeCryptoTest.class, NativeRefTest.class,